{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeInType #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

module Torch.Typed.Parameter
  ( module Torch.Typed.Parameter,
    Torch.NN.Randomizable (..),
  )
where

import Control.Monad.State.Strict
import Data.Kind (Type)
import GHC.Generics
import GHC.TypeLits
import GHC.TypeLits.Extra
import qualified Torch.Autograd (IndependentTensor (..), makeIndependent)
import Torch.DType (DType)
import Torch.Device (DeviceType)
import Torch.HList
import qualified Torch.NN (Parameter, Randomizable (..), sample)
import qualified Torch.Tensor (toType, _toDevice)
import Torch.Typed.Auxiliary
import Torch.Typed.Factories
import Torch.Typed.Functional
import Torch.Typed.Tensor

newtype
  Parameter
    (device :: (DeviceType, Nat))
    (dtype :: DType)
    (shape :: [Nat])
  = UnsafeMkParameter Torch.Autograd.IndependentTensor
  deriving (Int -> Parameter device dtype shape -> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Int -> Parameter device dtype shape -> ShowS
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
[Parameter device dtype shape] -> ShowS
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Parameter device dtype shape -> String
showList :: [Parameter device dtype shape] -> ShowS
$cshowList :: forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
[Parameter device dtype shape] -> ShowS
show :: Parameter device dtype shape -> String
$cshow :: forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Parameter device dtype shape -> String
showsPrec :: Int -> Parameter device dtype shape -> ShowS
$cshowsPrec :: forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Int -> Parameter device dtype shape -> ShowS
Show)

untypeParam :: Parameter device dtype shape -> Torch.NN.Parameter
untypeParam :: forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Parameter device dtype shape -> IndependentTensor
untypeParam (UnsafeMkParameter IndependentTensor
param) = IndependentTensor
param

toDependent ::
  forall shape dtype device.
  Parameter device dtype shape ->
  Tensor device dtype shape
toDependent :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Parameter device dtype shape -> Tensor device dtype shape
toDependent (UnsafeMkParameter IndependentTensor
t) = forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$ IndependentTensor -> Tensor
Torch.Autograd.toDependent IndependentTensor
t

data ToDependent = ToDependent

instance Apply' ToDependent (Parameter device dtype shape) (Tensor device dtype shape) where
  apply' :: ToDependent
-> Parameter device dtype shape -> Tensor device dtype shape
apply' ToDependent
_ = forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Parameter device dtype shape -> Tensor device dtype shape
toDependent

makeIndependent ::
  forall shape dtype device.
  Tensor device dtype shape ->
  IO (Parameter device dtype shape)
makeIndependent :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent Tensor device dtype shape
t = forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
IndependentTensor -> Parameter device dtype shape
UnsafeMkParameter forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Tensor -> IO IndependentTensor
Torch.Autograd.makeIndependent (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
t)

data MakeIndependent = MakeIndependent

instance
  Apply'
    MakeIndependent
    (Tensor device dtype shape)
    (IO (Parameter device dtype shape))
  where
  apply' :: MakeIndependent
-> Tensor device dtype shape -> IO (Parameter device dtype shape)
apply' MakeIndependent
_ = forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent

parameterToDevice ::
  forall device' device dtype shape.
  KnownDevice device' =>
  Parameter device dtype shape ->
  Parameter device' dtype shape
parameterToDevice :: forall (device' :: (DeviceType, Nat)) (device :: (DeviceType, Nat))
       (dtype :: DType) (shape :: [Nat]).
KnownDevice device' =>
Parameter device dtype shape -> Parameter device' dtype shape
parameterToDevice (UnsafeMkParameter IndependentTensor
t) =
  forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
IndependentTensor -> Parameter device dtype shape
UnsafeMkParameter
    forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor -> IndependentTensor
Torch.Autograd.IndependentTensor
    forall b c a. (b -> c) -> (a -> b) -> a -> c
. Device -> Tensor -> Tensor
Torch.Tensor._toDevice (forall (device :: (DeviceType, Nat)). KnownDevice device => Device
deviceVal @device')
    forall b c a. (b -> c) -> (a -> b) -> a -> c
. IndependentTensor -> Tensor
Torch.Autograd.toDependent
    forall a b. (a -> b) -> a -> b
$ IndependentTensor
t

parameterToDType ::
  forall dtype' dtype device shape.
  KnownDType dtype' =>
  Parameter device dtype shape ->
  Parameter device dtype' shape
parameterToDType :: forall (dtype' :: DType) (dtype :: DType)
       (device :: (DeviceType, Nat)) (shape :: [Nat]).
KnownDType dtype' =>
Parameter device dtype shape -> Parameter device dtype' shape
parameterToDType (UnsafeMkParameter IndependentTensor
t) =
  forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
IndependentTensor -> Parameter device dtype shape
UnsafeMkParameter
    forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor -> IndependentTensor
Torch.Autograd.IndependentTensor
    forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. HasTypes a Tensor => DType -> a -> a
Torch.Tensor.toType (forall (dtype :: DType). KnownDType dtype => DType
dtypeVal @dtype')
    forall b c a. (b -> c) -> (a -> b) -> a -> c
. IndependentTensor -> Tensor
Torch.Autograd.toDependent
    forall a b. (a -> b) -> a -> b
$ IndependentTensor
t

class Parameterized (f :: Type) where
  type Parameters f :: [Type]
  type Parameters f = GParameters (Rep f)
  flattenParameters :: f -> HList (Parameters f)
  default flattenParameters ::
    (Generic f, GParameterized (Rep f), Parameters f ~ GParameters (Rep f)) =>
    f ->
    HList (Parameters f)
  flattenParameters f
f = forall (f :: * -> *) a.
GParameterized f =>
f a -> HList (GParameters f)
gFlattenParameters (forall a x. Generic a => a -> Rep a x
from f
f)
  replaceParameters :: f -> HList (Parameters f) -> f
  default replaceParameters ::
    (Generic f, GParameterized (Rep f), Parameters f ~ GParameters (Rep f)) =>
    f ->
    HList (Parameters f) ->
    f
  replaceParameters f
f HList (Parameters f)
as = forall a x. Generic a => Rep a x -> a
to (forall (f :: * -> *) a.
GParameterized f =>
f a -> HList (GParameters f) -> f a
gReplaceParameters (forall a x. Generic a => a -> Rep a x
from f
f) HList (Parameters f)
as)

class GParameterized (f :: Type -> Type) where
  type GParameters f :: [Type]
  gFlattenParameters :: forall a. f a -> HList (GParameters f)
  gReplaceParameters :: forall a. f a -> HList (GParameters f) -> f a

instance
  ( GParameterized l,
    GParameterized r,
    HAppendFD (GParameters l) (GParameters r) (GParameters l ++ GParameters r)
  ) =>
  GParameterized (l :*: r)
  where
  type GParameters (l :*: r) = (GParameters l) ++ (GParameters r)
  gFlattenParameters :: forall a. (:*:) l r a -> HList (GParameters (l :*: r))
gFlattenParameters (l a
l :*: r a
r) =
    let as :: HList (GParameters l)
as = forall (f :: * -> *) a.
GParameterized f =>
f a -> HList (GParameters f)
gFlattenParameters l a
l
        bs :: HList (GParameters r)
bs = forall (f :: * -> *) a.
GParameterized f =>
f a -> HList (GParameters f)
gFlattenParameters r a
r
     in HList (GParameters l)
as forall k (a :: [k]) (b :: [k]) (ab :: [k]).
HAppendFD a b ab =>
HList a -> HList b -> HList ab
`happendFD` HList (GParameters r)
bs
  gReplaceParameters :: forall a.
(:*:) l r a -> HList (GParameters (l :*: r)) -> (:*:) l r a
gReplaceParameters (l a
l :*: r a
r) HList (GParameters (l :*: r))
cs =
    let (HList (GParameters l)
as, HList (GParameters r)
bs) = forall k (a :: [k]) (b :: [k]) (ab :: [k]).
HAppendFD a b ab =>
HList ab -> (HList a, HList b)
hunappendFD HList (GParameters (l :*: r))
cs
        l' :: l a
l' = forall (f :: * -> *) a.
GParameterized f =>
f a -> HList (GParameters f) -> f a
gReplaceParameters l a
l HList (GParameters l)
as
        r' :: r a
r' = forall (f :: * -> *) a.
GParameterized f =>
f a -> HList (GParameters f) -> f a
gReplaceParameters r a
r HList (GParameters r)
bs
     in l a
l' forall k (f :: k -> *) (g :: k -> *) (p :: k).
f p -> g p -> (:*:) f g p
:*: r a
r'

instance
  Parameterized f =>
  GParameterized (K1 i f)
  where
  type GParameters (K1 i f) = Parameters f
  gFlattenParameters :: forall a. K1 i f a -> HList (GParameters (K1 i f))
gFlattenParameters = forall f. Parameterized f => f -> HList (Parameters f)
flattenParameters forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k i c (p :: k). K1 i c p -> c
unK1
  gReplaceParameters :: forall a. K1 i f a -> HList (GParameters (K1 i f)) -> K1 i f a
gReplaceParameters (K1 f
f) = forall k i c (p :: k). c -> K1 i c p
K1 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall f. Parameterized f => f -> HList (Parameters f) -> f
replaceParameters f
f

instance GParameterized f => GParameterized (M1 i t f) where
  type GParameters (M1 i t f) = GParameters f
  gFlattenParameters :: forall a. M1 i t f a -> HList (GParameters (M1 i t f))
gFlattenParameters = forall (f :: * -> *) a.
GParameterized f =>
f a -> HList (GParameters f)
gFlattenParameters forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k i (c :: Meta) (f :: k -> *) (p :: k). M1 i c f p -> f p
unM1
  gReplaceParameters :: forall a.
M1 i t f a -> HList (GParameters (M1 i t f)) -> M1 i t f a
gReplaceParameters (M1 f a
f) = forall k i (c :: Meta) (f :: k -> *) (p :: k). f p -> M1 i c f p
M1 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a.
GParameterized f =>
f a -> HList (GParameters f) -> f a
gReplaceParameters f a
f

instance GParameterized U1 where
  type GParameters U1 = '[]
  gFlattenParameters :: forall a. U1 a -> HList (GParameters U1)
gFlattenParameters U1 a
_ = forall k. HList '[]
HNil
  gReplaceParameters :: forall a. U1 a -> HList (GParameters U1) -> U1 a
gReplaceParameters = forall a b. a -> b -> a
const

instance Parameterized (Tensor device dtype shape) where
  type Parameters (Tensor device dtype shape) = '[]
  flattenParameters :: Tensor device dtype shape
-> HList (Parameters (Tensor device dtype shape))
flattenParameters Tensor device dtype shape
_ = forall k. HList '[]
HNil
  replaceParameters :: Tensor device dtype shape
-> HList (Parameters (Tensor device dtype shape))
-> Tensor device dtype shape
replaceParameters = forall a b. a -> b -> a
const

instance Parameterized (Parameter device dtype shape) where
  type Parameters (Parameter device dtype shape) = '[Parameter device dtype shape]
  flattenParameters :: Parameter device dtype shape
-> HList (Parameters (Parameter device dtype shape))
flattenParameters = (forall x (xs :: [*]). x -> HList xs -> HList (x : xs)
:. forall k. HList '[]
HNil)
  replaceParameters :: Parameter device dtype shape
-> HList (Parameters (Parameter device dtype shape))
-> Parameter device dtype shape
replaceParameters Parameter device dtype shape
_ (Parameter device dtype shape
parameter :. HList '[]
R:HListk[] (*)
HNil) = Parameter device dtype shape
parameter

instance Parameterized Int where
  type Parameters Int = '[]
  flattenParameters :: Int -> HList (Parameters Int)
flattenParameters Int
_ = forall k. HList '[]
HNil
  replaceParameters :: Int -> HList (Parameters Int) -> Int
replaceParameters = forall a b. a -> b -> a
const

instance Parameterized Float where
  type Parameters Float = '[]
  flattenParameters :: Float -> HList (Parameters Float)
flattenParameters Float
_ = forall k. HList '[]
HNil
  replaceParameters :: Float -> HList (Parameters Float) -> Float
replaceParameters = forall a b. a -> b -> a
const

instance Parameterized Double where
  type Parameters Double = '[]
  flattenParameters :: Double -> HList (Parameters Double)
flattenParameters Double
_ = forall k. HList '[]
HNil
  replaceParameters :: Double -> HList (Parameters Double) -> Double
replaceParameters = forall a b. a -> b -> a
const

instance Parameterized (HList '[]) where
  type Parameters (HList '[]) = '[]
  flattenParameters :: HList '[] -> HList (Parameters (HList '[]))
flattenParameters HList '[]
_ = forall k. HList '[]
HNil
  replaceParameters :: HList '[] -> HList (Parameters (HList '[])) -> HList '[]
replaceParameters = forall a b. a -> b -> a
const

instance
  ( Parameterized f,
    Parameterized (HList fs),
    HAppendFD (Parameters f) (Parameters (HList fs)) (Parameters f ++ Parameters (HList fs))
  ) =>
  Parameterized (HList (f ': fs))
  where
  type Parameters (HList (f ': fs)) = Parameters f ++ Parameters (HList fs)
  flattenParameters :: HList (f : fs) -> HList (Parameters (HList (f : fs)))
flattenParameters (f
f :. HList fs
fs) = forall f. Parameterized f => f -> HList (Parameters f)
flattenParameters f
f forall k (a :: [k]) (b :: [k]) (ab :: [k]).
HAppendFD a b ab =>
HList a -> HList b -> HList ab
`happendFD` forall f. Parameterized f => f -> HList (Parameters f)
flattenParameters HList fs
fs
  replaceParameters :: HList (f : fs)
-> HList (Parameters (HList (f : fs))) -> HList (f : fs)
replaceParameters (f
f :. HList fs
fs) HList (Parameters (HList (f : fs)))
cs =
    let (HList (Parameters f)
as, HList (Parameters (HList fs))
bs) = forall k (a :: [k]) (b :: [k]) (ab :: [k]).
HAppendFD a b ab =>
HList ab -> (HList a, HList b)
hunappendFD HList (Parameters (HList (f : fs)))
cs
        f' :: f
f' = forall f. Parameterized f => f -> HList (Parameters f) -> f
replaceParameters f
f HList (Parameters f)
as
        fs' :: HList fs
fs' = forall f. Parameterized f => f -> HList (Parameters f) -> f
replaceParameters HList fs
fs HList (Parameters (HList fs))
bs
     in f
f' forall x (xs :: [*]). x -> HList xs -> HList (x : xs)
:. HList fs
fs'

instance Torch.NN.Randomizable (HList ('[] :: [Type])) (HList ('[] :: [Type])) where
  sample :: HList '[] -> IO (HList '[])
sample = forall (m :: * -> *) a. Monad m => a -> m a
return

instance
  ( Torch.NN.Randomizable xSpec x,
    Torch.NN.Randomizable (HList xsSpec) (HList xs)
  ) =>
  Torch.NN.Randomizable (HList (xSpec ': xsSpec)) (HList (x ': xs))
  where
  sample :: HList (xSpec : xsSpec) -> IO (HList (x : xs))
sample (xSpec
xSpec :. HList xsSpec
xsSpec) = do
    x
x <- forall spec f. Randomizable spec f => spec -> IO f
Torch.NN.sample xSpec
xSpec
    HList xs
xs <- forall spec f. Randomizable spec f => spec -> IO f
Torch.NN.sample HList xsSpec
xsSpec
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ x
x forall x (xs :: [*]). x -> HList xs -> HList (x : xs)
:. HList xs
xs