{-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DefaultSignatures #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE FunctionalDependencies #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE PartialTypeSignatures #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeInType #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} module Torch.Typed.Parameter ( module Torch.Typed.Parameter, Torch.NN.Randomizable (..), ) where import Control.Monad.State.Strict import Data.Kind (Type) import GHC.Generics import GHC.TypeLits import GHC.TypeLits.Extra import qualified Torch.Autograd (IndependentTensor (..), makeIndependent) import Torch.DType (DType) import Torch.Device (DeviceType) import Torch.HList import qualified Torch.NN (Parameter, Randomizable (..), sample) import qualified Torch.Tensor (toType, _toDevice) import Torch.Typed.Auxiliary import Torch.Typed.Factories import Torch.Typed.Functional import Torch.Typed.Tensor newtype Parameter (device :: (DeviceType, Nat)) (dtype :: DType) (shape :: [Nat]) = UnsafeMkParameter Torch.Autograd.IndependentTensor deriving (Int -> Parameter device dtype shape -> ShowS forall a. (Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a forall (device :: (DeviceType, Nat)) (dtype :: DType) (shape :: [Nat]). Int -> Parameter device dtype shape -> ShowS forall (device :: (DeviceType, Nat)) (dtype :: DType) (shape :: [Nat]). [Parameter device dtype shape] -> ShowS forall (device :: (DeviceType, Nat)) (dtype :: DType) (shape :: [Nat]). Parameter device dtype shape -> String showList :: [Parameter device dtype shape] -> ShowS $cshowList :: forall (device :: (DeviceType, Nat)) (dtype :: DType) (shape :: [Nat]). [Parameter device dtype shape] -> ShowS show :: Parameter device dtype shape -> String $cshow :: forall (device :: (DeviceType, Nat)) (dtype :: DType) (shape :: [Nat]). Parameter device dtype shape -> String showsPrec :: Int -> Parameter device dtype shape -> ShowS $cshowsPrec :: forall (device :: (DeviceType, Nat)) (dtype :: DType) (shape :: [Nat]). Int -> Parameter device dtype shape -> ShowS Show) untypeParam :: Parameter device dtype shape -> Torch.NN.Parameter untypeParam :: forall (device :: (DeviceType, Nat)) (dtype :: DType) (shape :: [Nat]). Parameter device dtype shape -> IndependentTensor untypeParam (UnsafeMkParameter IndependentTensor param) = IndependentTensor param toDependent :: forall shape dtype device. Parameter device dtype shape -> Tensor device dtype shape toDependent :: forall (shape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). Parameter device dtype shape -> Tensor device dtype shape toDependent (UnsafeMkParameter IndependentTensor t) = forall (device :: (DeviceType, Nat)) (dtype :: DType) (shape :: [Nat]). Tensor -> Tensor device dtype shape UnsafeMkTensor forall a b. (a -> b) -> a -> b $ IndependentTensor -> Tensor Torch.Autograd.toDependent IndependentTensor t data ToDependent = ToDependent instance Apply' ToDependent (Parameter device dtype shape) (Tensor device dtype shape) where apply' :: ToDependent -> Parameter device dtype shape -> Tensor device dtype shape apply' ToDependent _ = forall (shape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). Parameter device dtype shape -> Tensor device dtype shape toDependent makeIndependent :: forall shape dtype device. Tensor device dtype shape -> IO (Parameter device dtype shape) makeIndependent :: forall (shape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). Tensor device dtype shape -> IO (Parameter device dtype shape) makeIndependent Tensor device dtype shape t = forall (device :: (DeviceType, Nat)) (dtype :: DType) (shape :: [Nat]). IndependentTensor -> Parameter device dtype shape UnsafeMkParameter forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b <$> Tensor -> IO IndependentTensor Torch.Autograd.makeIndependent (forall t. Unnamed t => t -> Tensor toDynamic Tensor device dtype shape t) data MakeIndependent = MakeIndependent instance Apply' MakeIndependent (Tensor device dtype shape) (IO (Parameter device dtype shape)) where apply' :: MakeIndependent -> Tensor device dtype shape -> IO (Parameter device dtype shape) apply' MakeIndependent _ = forall (shape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). Tensor device dtype shape -> IO (Parameter device dtype shape) makeIndependent parameterToDevice :: forall device' device dtype shape. KnownDevice device' => Parameter device dtype shape -> Parameter device' dtype shape parameterToDevice :: forall (device' :: (DeviceType, Nat)) (device :: (DeviceType, Nat)) (dtype :: DType) (shape :: [Nat]). KnownDevice device' => Parameter device dtype shape -> Parameter device' dtype shape parameterToDevice (UnsafeMkParameter IndependentTensor t) = forall (device :: (DeviceType, Nat)) (dtype :: DType) (shape :: [Nat]). IndependentTensor -> Parameter device dtype shape UnsafeMkParameter forall b c a. (b -> c) -> (a -> b) -> a -> c . Tensor -> IndependentTensor Torch.Autograd.IndependentTensor forall b c a. (b -> c) -> (a -> b) -> a -> c . Device -> Tensor -> Tensor Torch.Tensor._toDevice (forall (device :: (DeviceType, Nat)). KnownDevice device => Device deviceVal @device') forall b c a. (b -> c) -> (a -> b) -> a -> c . IndependentTensor -> Tensor Torch.Autograd.toDependent forall a b. (a -> b) -> a -> b $ IndependentTensor t parameterToDType :: forall dtype' dtype device shape. KnownDType dtype' => Parameter device dtype shape -> Parameter device dtype' shape parameterToDType :: forall (dtype' :: DType) (dtype :: DType) (device :: (DeviceType, Nat)) (shape :: [Nat]). KnownDType dtype' => Parameter device dtype shape -> Parameter device dtype' shape parameterToDType (UnsafeMkParameter IndependentTensor t) = forall (device :: (DeviceType, Nat)) (dtype :: DType) (shape :: [Nat]). IndependentTensor -> Parameter device dtype shape UnsafeMkParameter forall b c a. (b -> c) -> (a -> b) -> a -> c . Tensor -> IndependentTensor Torch.Autograd.IndependentTensor forall b c a. (b -> c) -> (a -> b) -> a -> c . forall a. HasTypes a Tensor => DType -> a -> a Torch.Tensor.toType (forall (dtype :: DType). KnownDType dtype => DType dtypeVal @dtype') forall b c a. (b -> c) -> (a -> b) -> a -> c . IndependentTensor -> Tensor Torch.Autograd.toDependent forall a b. (a -> b) -> a -> b $ IndependentTensor t class Parameterized (f :: Type) where type Parameters f :: [Type] type Parameters f = GParameters (Rep f) flattenParameters :: f -> HList (Parameters f) default flattenParameters :: (Generic f, GParameterized (Rep f), Parameters f ~ GParameters (Rep f)) => f -> HList (Parameters f) flattenParameters f f = forall (f :: * -> *) a. GParameterized f => f a -> HList (GParameters f) gFlattenParameters (forall a x. Generic a => a -> Rep a x from f f) replaceParameters :: f -> HList (Parameters f) -> f default replaceParameters :: (Generic f, GParameterized (Rep f), Parameters f ~ GParameters (Rep f)) => f -> HList (Parameters f) -> f replaceParameters f f HList (Parameters f) as = forall a x. Generic a => Rep a x -> a to (forall (f :: * -> *) a. GParameterized f => f a -> HList (GParameters f) -> f a gReplaceParameters (forall a x. Generic a => a -> Rep a x from f f) HList (Parameters f) as) class GParameterized (f :: Type -> Type) where type GParameters f :: [Type] gFlattenParameters :: forall a. f a -> HList (GParameters f) gReplaceParameters :: forall a. f a -> HList (GParameters f) -> f a instance ( GParameterized l, GParameterized r, HAppendFD (GParameters l) (GParameters r) (GParameters l ++ GParameters r) ) => GParameterized (l :*: r) where type GParameters (l :*: r) = (GParameters l) ++ (GParameters r) gFlattenParameters :: forall a. (:*:) l r a -> HList (GParameters (l :*: r)) gFlattenParameters (l a l :*: r a r) = let as :: HList (GParameters l) as = forall (f :: * -> *) a. GParameterized f => f a -> HList (GParameters f) gFlattenParameters l a l bs :: HList (GParameters r) bs = forall (f :: * -> *) a. GParameterized f => f a -> HList (GParameters f) gFlattenParameters r a r in HList (GParameters l) as forall k (a :: [k]) (b :: [k]) (ab :: [k]). HAppendFD a b ab => HList a -> HList b -> HList ab `happendFD` HList (GParameters r) bs gReplaceParameters :: forall a. (:*:) l r a -> HList (GParameters (l :*: r)) -> (:*:) l r a gReplaceParameters (l a l :*: r a r) HList (GParameters (l :*: r)) cs = let (HList (GParameters l) as, HList (GParameters r) bs) = forall k (a :: [k]) (b :: [k]) (ab :: [k]). HAppendFD a b ab => HList ab -> (HList a, HList b) hunappendFD HList (GParameters (l :*: r)) cs l' :: l a l' = forall (f :: * -> *) a. GParameterized f => f a -> HList (GParameters f) -> f a gReplaceParameters l a l HList (GParameters l) as r' :: r a r' = forall (f :: * -> *) a. GParameterized f => f a -> HList (GParameters f) -> f a gReplaceParameters r a r HList (GParameters r) bs in l a l' forall k (f :: k -> *) (g :: k -> *) (p :: k). f p -> g p -> (:*:) f g p :*: r a r' instance Parameterized f => GParameterized (K1 i f) where type GParameters (K1 i f) = Parameters f gFlattenParameters :: forall a. K1 i f a -> HList (GParameters (K1 i f)) gFlattenParameters = forall f. Parameterized f => f -> HList (Parameters f) flattenParameters forall b c a. (b -> c) -> (a -> b) -> a -> c . forall k i c (p :: k). K1 i c p -> c unK1 gReplaceParameters :: forall a. K1 i f a -> HList (GParameters (K1 i f)) -> K1 i f a gReplaceParameters (K1 f f) = forall k i c (p :: k). c -> K1 i c p K1 forall b c a. (b -> c) -> (a -> b) -> a -> c . forall f. Parameterized f => f -> HList (Parameters f) -> f replaceParameters f f instance GParameterized f => GParameterized (M1 i t f) where type GParameters (M1 i t f) = GParameters f gFlattenParameters :: forall a. M1 i t f a -> HList (GParameters (M1 i t f)) gFlattenParameters = forall (f :: * -> *) a. GParameterized f => f a -> HList (GParameters f) gFlattenParameters forall b c a. (b -> c) -> (a -> b) -> a -> c . forall k i (c :: Meta) (f :: k -> *) (p :: k). M1 i c f p -> f p unM1 gReplaceParameters :: forall a. M1 i t f a -> HList (GParameters (M1 i t f)) -> M1 i t f a gReplaceParameters (M1 f a f) = forall k i (c :: Meta) (f :: k -> *) (p :: k). f p -> M1 i c f p M1 forall b c a. (b -> c) -> (a -> b) -> a -> c . forall (f :: * -> *) a. GParameterized f => f a -> HList (GParameters f) -> f a gReplaceParameters f a f instance GParameterized U1 where type GParameters U1 = '[] gFlattenParameters :: forall a. U1 a -> HList (GParameters U1) gFlattenParameters U1 a _ = forall k. HList '[] HNil gReplaceParameters :: forall a. U1 a -> HList (GParameters U1) -> U1 a gReplaceParameters = forall a b. a -> b -> a const instance Parameterized (Tensor device dtype shape) where type Parameters (Tensor device dtype shape) = '[] flattenParameters :: Tensor device dtype shape -> HList (Parameters (Tensor device dtype shape)) flattenParameters Tensor device dtype shape _ = forall k. HList '[] HNil replaceParameters :: Tensor device dtype shape -> HList (Parameters (Tensor device dtype shape)) -> Tensor device dtype shape replaceParameters = forall a b. a -> b -> a const instance Parameterized (Parameter device dtype shape) where type Parameters (Parameter device dtype shape) = '[Parameter device dtype shape] flattenParameters :: Parameter device dtype shape -> HList (Parameters (Parameter device dtype shape)) flattenParameters = (forall x (xs :: [*]). x -> HList xs -> HList (x : xs) :. forall k. HList '[] HNil) replaceParameters :: Parameter device dtype shape -> HList (Parameters (Parameter device dtype shape)) -> Parameter device dtype shape replaceParameters Parameter device dtype shape _ (Parameter device dtype shape parameter :. HList '[] R:HListk[] (*) HNil) = Parameter device dtype shape parameter instance Parameterized Int where type Parameters Int = '[] flattenParameters :: Int -> HList (Parameters Int) flattenParameters Int _ = forall k. HList '[] HNil replaceParameters :: Int -> HList (Parameters Int) -> Int replaceParameters = forall a b. a -> b -> a const instance Parameterized Float where type Parameters Float = '[] flattenParameters :: Float -> HList (Parameters Float) flattenParameters Float _ = forall k. HList '[] HNil replaceParameters :: Float -> HList (Parameters Float) -> Float replaceParameters = forall a b. a -> b -> a const instance Parameterized Double where type Parameters Double = '[] flattenParameters :: Double -> HList (Parameters Double) flattenParameters Double _ = forall k. HList '[] HNil replaceParameters :: Double -> HList (Parameters Double) -> Double replaceParameters = forall a b. a -> b -> a const instance Parameterized (HList '[]) where type Parameters (HList '[]) = '[] flattenParameters :: HList '[] -> HList (Parameters (HList '[])) flattenParameters HList '[] _ = forall k. HList '[] HNil replaceParameters :: HList '[] -> HList (Parameters (HList '[])) -> HList '[] replaceParameters = forall a b. a -> b -> a const instance ( Parameterized f, Parameterized (HList fs), HAppendFD (Parameters f) (Parameters (HList fs)) (Parameters f ++ Parameters (HList fs)) ) => Parameterized (HList (f ': fs)) where type Parameters (HList (f ': fs)) = Parameters f ++ Parameters (HList fs) flattenParameters :: HList (f : fs) -> HList (Parameters (HList (f : fs))) flattenParameters (f f :. HList fs fs) = forall f. Parameterized f => f -> HList (Parameters f) flattenParameters f f forall k (a :: [k]) (b :: [k]) (ab :: [k]). HAppendFD a b ab => HList a -> HList b -> HList ab `happendFD` forall f. Parameterized f => f -> HList (Parameters f) flattenParameters HList fs fs replaceParameters :: HList (f : fs) -> HList (Parameters (HList (f : fs))) -> HList (f : fs) replaceParameters (f f :. HList fs fs) HList (Parameters (HList (f : fs))) cs = let (HList (Parameters f) as, HList (Parameters (HList fs)) bs) = forall k (a :: [k]) (b :: [k]) (ab :: [k]). HAppendFD a b ab => HList ab -> (HList a, HList b) hunappendFD HList (Parameters (HList (f : fs))) cs f' :: f f' = forall f. Parameterized f => f -> HList (Parameters f) -> f replaceParameters f f HList (Parameters f) as fs' :: HList fs fs' = forall f. Parameterized f => f -> HList (Parameters f) -> f replaceParameters HList fs fs HList (Parameters (HList fs)) bs in f f' forall x (xs :: [*]). x -> HList xs -> HList (x : xs) :. HList fs fs' instance Torch.NN.Randomizable (HList ('[] :: [Type])) (HList ('[] :: [Type])) where sample :: HList '[] -> IO (HList '[]) sample = forall (m :: * -> *) a. Monad m => a -> m a return instance ( Torch.NN.Randomizable xSpec x, Torch.NN.Randomizable (HList xsSpec) (HList xs) ) => Torch.NN.Randomizable (HList (xSpec ': xsSpec)) (HList (x ': xs)) where sample :: HList (xSpec : xsSpec) -> IO (HList (x : xs)) sample (xSpec xSpec :. HList xsSpec xsSpec) = do x x <- forall spec f. Randomizable spec f => spec -> IO f Torch.NN.sample xSpec xSpec HList xs xs <- forall spec f. Randomizable spec f => spec -> IO f Torch.NN.sample HList xsSpec xsSpec forall (m :: * -> *) a. Monad m => a -> m a return forall a b. (a -> b) -> a -> b $ x x forall x (xs :: [*]). x -> HList xs -> HList (x : xs) :. HList xs xs