{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fplugin TypeLevel.Rewrite
-fplugin-opt=TypeLevel.Rewrite:Torch.GraduallyTyped.Unify.UnifyIdempotenceL2 #-}
module Torch.GraduallyTyped.NN.Linear where
import Control.Monad.Indexed (IxPointed (ireturn), (>>>=))
import Control.Monad.Indexed.State (IxStateT (..))
import Control.Monad.Indexed.Trans (IxMonadTrans (ilift))
import Data.Functor.Indexed ((<<$>>), (<<*>>))
import Data.Kind (Type)
import GHC.Generics (Generic)
import GHC.TypeLits (Nat, Symbol)
import Torch.GraduallyTyped.DType (DType (..), DataType (..), SDataType (..))
import Torch.GraduallyTyped.Device (Device (..), DeviceType (..), SDevice (..))
import Torch.GraduallyTyped.Internal.TensorOptions (tensorDims)
import Torch.GraduallyTyped.Layout (Layout (..), LayoutType (..), SLayout (..), SLayoutType (..))
import Torch.GraduallyTyped.NN.Class (HasForward (..), HasInitialize (..), HasStateDict (..), ModelSpec, NamedModel (..))
import Torch.GraduallyTyped.NN.Functional.Linear (LinearWithBiasF, LinearWithoutBiasF, linearWithBias, linearWithoutBias)
import Torch.GraduallyTyped.NN.Initialization (FanMode (..), ForNonLinearity (..), calculateFan, getter, sKaimingUniform)
import Torch.GraduallyTyped.NN.Type (HasBias (..), SHasBias (..))
import Torch.GraduallyTyped.Prelude (pattern (:|:))
import Torch.GraduallyTyped.Prelude.List (SList (..))
import Torch.GraduallyTyped.Random (SGetGeneratorDevice)
import Torch.GraduallyTyped.RequiresGradient (Gradient, RequiresGradient (..), SGradient)
import Torch.GraduallyTyped.Shape.Type (Dim (..), Name (..), SDim (..), SShape (..), Shape (..), Size (..))
import Torch.GraduallyTyped.Tensor.Creation (sRandn)
import Torch.GraduallyTyped.Tensor.MathOperations.Pointwise (mulScalar, subScalar)
import Torch.GraduallyTyped.Tensor.Type (Tensor, TensorSpec (..))
import Torch.GraduallyTyped.Unify (type (<+>), type (<|>))
data
GLinear
(weight :: Type)
(bias :: Type)
where
GLinear ::
forall weight bias.
{
forall weight bias. GLinear weight bias -> weight
linearWeight :: weight,
forall weight bias. GLinear weight bias -> bias
linearBias :: bias
} ->
GLinear weight bias
deriving stock (GLinear weight bias -> GLinear weight bias -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall weight bias.
(Eq weight, Eq bias) =>
GLinear weight bias -> GLinear weight bias -> Bool
/= :: GLinear weight bias -> GLinear weight bias -> Bool
$c/= :: forall weight bias.
(Eq weight, Eq bias) =>
GLinear weight bias -> GLinear weight bias -> Bool
== :: GLinear weight bias -> GLinear weight bias -> Bool
$c== :: forall weight bias.
(Eq weight, Eq bias) =>
GLinear weight bias -> GLinear weight bias -> Bool
Eq, GLinear weight bias -> GLinear weight bias -> Bool
GLinear weight bias -> GLinear weight bias -> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {weight} {bias}.
(Ord weight, Ord bias) =>
Eq (GLinear weight bias)
forall weight bias.
(Ord weight, Ord bias) =>
GLinear weight bias -> GLinear weight bias -> Bool
forall weight bias.
(Ord weight, Ord bias) =>
GLinear weight bias -> GLinear weight bias -> Ordering
forall weight bias.
(Ord weight, Ord bias) =>
GLinear weight bias -> GLinear weight bias -> GLinear weight bias
min :: GLinear weight bias -> GLinear weight bias -> GLinear weight bias
$cmin :: forall weight bias.
(Ord weight, Ord bias) =>
GLinear weight bias -> GLinear weight bias -> GLinear weight bias
max :: GLinear weight bias -> GLinear weight bias -> GLinear weight bias
$cmax :: forall weight bias.
(Ord weight, Ord bias) =>
GLinear weight bias -> GLinear weight bias -> GLinear weight bias
>= :: GLinear weight bias -> GLinear weight bias -> Bool
$c>= :: forall weight bias.
(Ord weight, Ord bias) =>
GLinear weight bias -> GLinear weight bias -> Bool
> :: GLinear weight bias -> GLinear weight bias -> Bool
$c> :: forall weight bias.
(Ord weight, Ord bias) =>
GLinear weight bias -> GLinear weight bias -> Bool
<= :: GLinear weight bias -> GLinear weight bias -> Bool
$c<= :: forall weight bias.
(Ord weight, Ord bias) =>
GLinear weight bias -> GLinear weight bias -> Bool
< :: GLinear weight bias -> GLinear weight bias -> Bool
$c< :: forall weight bias.
(Ord weight, Ord bias) =>
GLinear weight bias -> GLinear weight bias -> Bool
compare :: GLinear weight bias -> GLinear weight bias -> Ordering
$ccompare :: forall weight bias.
(Ord weight, Ord bias) =>
GLinear weight bias -> GLinear weight bias -> Ordering
Ord, Int -> GLinear weight bias -> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall weight bias.
(Show weight, Show bias) =>
Int -> GLinear weight bias -> ShowS
forall weight bias.
(Show weight, Show bias) =>
[GLinear weight bias] -> ShowS
forall weight bias.
(Show weight, Show bias) =>
GLinear weight bias -> String
showList :: [GLinear weight bias] -> ShowS
$cshowList :: forall weight bias.
(Show weight, Show bias) =>
[GLinear weight bias] -> ShowS
show :: GLinear weight bias -> String
$cshow :: forall weight bias.
(Show weight, Show bias) =>
GLinear weight bias -> String
showsPrec :: Int -> GLinear weight bias -> ShowS
$cshowsPrec :: forall weight bias.
(Show weight, Show bias) =>
Int -> GLinear weight bias -> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall weight bias x.
Rep (GLinear weight bias) x -> GLinear weight bias
forall weight bias x.
GLinear weight bias -> Rep (GLinear weight bias) x
$cto :: forall weight bias x.
Rep (GLinear weight bias) x -> GLinear weight bias
$cfrom :: forall weight bias x.
GLinear weight bias -> Rep (GLinear weight bias) x
Generic)
type instance ModelSpec (GLinear weight bias) = GLinear (ModelSpec weight) (ModelSpec bias)
type family
GLinearF
(hasBias :: HasBias)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(inputDim :: Dim (Name Symbol) (Size Nat))
(outputDim :: Dim (Name Symbol) (Size Nat)) ::
Type
where
GLinearF hasBias gradient device dataType inputDim outputDim =
GLinear
(NamedModel (LinearWeightF gradient device dataType inputDim outputDim))
(NamedModel (LinearBiasF hasBias gradient device dataType outputDim))
type family
LinearWeightF
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(inputDim :: Dim (Name Symbol) (Size Nat))
(outputDim :: Dim (Name Symbol) (Size Nat)) ::
Type
where
LinearWeightF gradient device dataType inputDim outputDim = Tensor gradient ('Layout 'Dense) device dataType ('Shape '[outputDim, inputDim])
type family
LinearBiasF
(hasBias :: HasBias)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(outputDim :: Dim (Name Symbol) (Size Nat)) ::
Type
where
LinearBiasF 'WithoutBias _ _ _ _ = ()
LinearBiasF 'WithBias gradient device dataType outputDim = Tensor gradient ('Layout 'Dense) device dataType ('Shape '[outputDim])
linearSpec ::
forall hasBias gradient device dataType inputDim outputDim.
SHasBias hasBias ->
SGradient gradient ->
SDevice device ->
SDataType dataType ->
SDim inputDim ->
SDim outputDim ->
ModelSpec (GLinearF hasBias gradient device dataType inputDim outputDim)
linearSpec :: forall (hasBias :: HasBias) (gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat)) (dataType :: DataType DType)
(inputDim :: Dim (Name Symbol) (Size Nat))
(outputDim :: Dim (Name Symbol) (Size Nat)).
SHasBias hasBias
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim inputDim
-> SDim outputDim
-> ModelSpec
(GLinearF hasBias gradient device dataType inputDim outputDim)
linearSpec SHasBias hasBias
hasBias SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SDim inputDim
inputDim SDim outputDim
outputDim =
let weightSpec :: TensorSpec
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim])
weightSpec = forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> SShape shape
-> TensorSpec gradient layout device dataType shape
TensorSpec SGradient gradient
gradient (forall (layoutType :: LayoutType).
SLayoutType layoutType -> SLayout ('Layout layoutType)
SLayout SLayoutType 'Dense
SDense) SDevice device
device SDataType dataType
dataType (forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SList dims -> SShape ('Shape dims)
SShape forall a b. (a -> b) -> a -> b
$ SDim outputDim
outputDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: SDim inputDim
inputDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall a. SList '[]
SNil)
biasSpec :: SHasBias hasBias
-> ModelSpec
(LinearBiasF hasBias gradient device dataType outputDim)
biasSpec SHasBias hasBias
SWithBias = forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> SShape shape
-> TensorSpec gradient layout device dataType shape
TensorSpec SGradient gradient
gradient (forall (layoutType :: LayoutType).
SLayoutType layoutType -> SLayout ('Layout layoutType)
SLayout SLayoutType 'Dense
SDense) SDevice device
device SDataType dataType
dataType (forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SList dims -> SShape ('Shape dims)
SShape forall a b. (a -> b) -> a -> b
$ SDim outputDim
outputDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall a. SList '[]
SNil)
biasSpec SHasBias hasBias
SWithoutBias = ()
in forall weight bias. weight -> bias -> GLinear weight bias
GLinear (forall model. Text -> model -> NamedModel model
NamedModel Text
"weight" TensorSpec
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim])
weightSpec) (forall model. Text -> model -> NamedModel model
NamedModel Text
"bias" forall a b. (a -> b) -> a -> b
$ SHasBias hasBias
-> ModelSpec
(LinearBiasF hasBias gradient device dataType outputDim)
biasSpec SHasBias hasBias
hasBias)
instance
( output
~ GLinear
(Tensor gradient ('Layout 'Dense) (device <+> generatorDevice) dataType ('Shape '[outputDim, inputDim]))
(),
generatorOutputDevice ~ (device <+> generatorDevice),
SGetGeneratorDevice generatorDevice
) =>
HasInitialize
(GLinear (Tensor gradient ('Layout 'Dense) device dataType ('Shape '[outputDim, inputDim])) ())
generatorDevice
output
generatorOutputDevice
where
initialize :: forall (m :: * -> *).
MonadThrow m =>
ModelSpec
(GLinear
(Tensor
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim]))
())
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
initialize GLinear {()
TensorSpec
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim])
linearBias :: ()
linearWeight :: TensorSpec
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim])
linearBias :: forall weight bias. GLinear weight bias -> bias
linearWeight :: forall weight bias. GLinear weight bias -> weight
..} =
let weight :: IxStateT
m
(Generator generatorDevice)
(Generator generatorOutputDevice)
(Tensor
gradient
('Layout 'Dense)
generatorOutputDevice
dataType
('Shape '[outputDim, inputDim]))
weight =
forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall a b. (a -> b) -> a -> b
$
forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(generatorDevice :: Device (DeviceType Nat)) (m :: * -> *).
(MonadThrow m, SGetGeneratorDevice generatorDevice) =>
TensorSpec gradient layout device dataType shape
-> FanMode
-> ForNonLinearity
-> Generator generatorDevice
-> m (Tensor
gradient layout (device <+> generatorDevice) dataType shape,
Generator (device <+> generatorDevice))
sKaimingUniform
TensorSpec
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim])
linearWeight
FanMode
FanIn
(Float -> ForNonLinearity
ForLeakyRelu forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Floating a => a -> a
Prelude.sqrt forall a b. (a -> b) -> a -> b
$ Float
5)
bias :: IxStateT
m
(Generator generatorOutputDevice)
(Generator generatorOutputDevice)
()
bias = forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model (generatorDevice :: Device (DeviceType Nat)) output
(generatorOutputDevice :: Device (DeviceType Nat)) (m :: * -> *).
(HasInitialize model generatorDevice output generatorOutputDevice,
MonadThrow m) =>
ModelSpec model
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
initialize forall a b. (a -> b) -> a -> b
$ ()
linearBias
in forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT forall a b. (a -> b) -> a -> b
$ forall weight bias. weight -> bias -> GLinear weight bias
GLinear forall {k1} {k2} (f :: k1 -> k2 -> * -> *) a b (j :: k1)
(k3 :: k2).
IxFunctor f =>
(a -> b) -> f j k3 a -> f j k3 b
<<$>> IxStateT
m
(Generator generatorDevice)
(Generator generatorOutputDevice)
(Tensor
gradient
('Layout 'Dense)
generatorOutputDevice
dataType
('Shape '[outputDim, inputDim]))
weight forall {k1} (f :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a b
(k2 :: k1).
IxApplicative f =>
f i j (a -> b) -> f j k2 a -> f i k2 b
<<*>> IxStateT
m
(Generator generatorOutputDevice)
(Generator generatorOutputDevice)
()
bias
instance
( output
~ GLinear
(Tensor gradient ('Layout 'Dense) (device <+> generatorDevice) dataType ('Shape '[outputDim, inputDim]))
(Tensor gradient ('Layout 'Dense) (device <+> generatorDevice) dataType ('Shape '[outputDim])),
generatorOutputDevice ~ (device <+> generatorDevice),
SGetGeneratorDevice generatorDevice,
SGetGeneratorDevice generatorOutputDevice
) =>
HasInitialize
(GLinear (Tensor gradient ('Layout 'Dense) device dataType ('Shape '[outputDim, inputDim])) (Tensor gradient ('Layout 'Dense) device dataType ('Shape '[outputDim])))
generatorDevice
output
generatorOutputDevice
where
initialize :: forall (m :: * -> *).
MonadThrow m =>
ModelSpec
(GLinear
(Tensor
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim]))
(Tensor
gradient ('Layout 'Dense) device dataType ('Shape '[outputDim])))
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
initialize GLinear {TensorSpec
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim])
TensorSpec
gradient ('Layout 'Dense) device dataType ('Shape '[outputDim])
linearBias :: TensorSpec
gradient ('Layout 'Dense) device dataType ('Shape '[outputDim])
linearWeight :: TensorSpec
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim])
linearBias :: forall weight bias. GLinear weight bias -> bias
linearWeight :: forall weight bias. GLinear weight bias -> weight
..} =
let weight :: IxStateT
m
(Generator generatorDevice)
(Generator generatorOutputDevice)
(Tensor
gradient
('Layout 'Dense)
generatorOutputDevice
dataType
('Shape '[outputDim, inputDim]))
weight =
forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall a b. (a -> b) -> a -> b
$
forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(generatorDevice :: Device (DeviceType Nat)) (m :: * -> *).
(MonadThrow m, SGetGeneratorDevice generatorDevice) =>
TensorSpec gradient layout device dataType shape
-> FanMode
-> ForNonLinearity
-> Generator generatorDevice
-> m (Tensor
gradient layout (device <+> generatorDevice) dataType shape,
Generator (device <+> generatorDevice))
sKaimingUniform
TensorSpec
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim])
linearWeight
FanMode
FanIn
(Float -> ForNonLinearity
ForLeakyRelu forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Floating a => a -> a
Prelude.sqrt forall a b. (a -> b) -> a -> b
$ Float
5)
dims :: [Dim String Integer]
dims = forall (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SShape shape -> [Dim String Integer]
tensorDims forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape -> SShape shape
tsShape forall a b. (a -> b) -> a -> b
$ TensorSpec
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim])
linearWeight
Float
bound :: Float =
Float
1
forall a. Fractional a => a -> a -> a
/ ( forall a. Floating a => a -> a
Prelude.sqrt forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Integral a, Num b) => a -> b
fromIntegral
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. FanMode -> (a, a) -> a
getter FanMode
FanIn
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Dim String Integer] -> (Integer, Integer)
calculateFan
forall a b. (a -> b) -> a -> b
$ [Dim String Integer]
dims
)
bias :: IxStateT
m
(Generator generatorOutputDevice)
(Generator generatorOutputDevice)
(Tensor
gradient
('Layout 'Dense)
(Unify (Device (DeviceType Nat)) device generatorOutputDevice)
dataType
('Shape '[outputDim]))
bias =
forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT (forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(generatorDevice :: Device (DeviceType Nat)) (m :: * -> *).
(SGetGeneratorDevice generatorDevice, MonadThrow m) =>
TensorSpec gradient layout device dataType shape
-> Generator generatorDevice
-> m (Tensor
gradient layout (device <+> generatorDevice) dataType shape,
Generator (device <+> generatorDevice))
sRandn TensorSpec
gradient ('Layout 'Dense) device dataType ('Shape '[outputDim])
linearBias)
forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
(i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ( \Tensor
gradient
('Layout 'Dense)
(Unify (Device (DeviceType Nat)) device generatorOutputDevice)
dataType
('Shape '[outputDim])
bias' -> do
Tensor
gradient
('Layout 'Dense)
(Unify (Device (DeviceType Nat)) device generatorOutputDevice)
dataType
('Shape '[outputDim])
x <- Tensor
gradient
('Layout 'Dense)
(Unify (Device (DeviceType Nat)) device generatorOutputDevice)
dataType
('Shape '[outputDim])
bias' forall other (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(Scalar other, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> other -> m (Tensor gradient layout device dataType shape)
`mulScalar` (Float
bound forall a. Num a => a -> a -> a
* Float
2)
Tensor
gradient
('Layout 'Dense)
(Unify (Device (DeviceType Nat)) device generatorOutputDevice)
dataType
('Shape '[outputDim])
x forall other (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(Scalar other, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> other -> m (Tensor gradient layout device dataType shape)
`subScalar` Float
bound
)
in forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT forall a b. (a -> b) -> a -> b
$ forall weight bias. weight -> bias -> GLinear weight bias
GLinear forall {k1} {k2} (f :: k1 -> k2 -> * -> *) a b (j :: k1)
(k3 :: k2).
IxFunctor f =>
(a -> b) -> f j k3 a -> f j k3 b
<<$>> IxStateT
m
(Generator generatorDevice)
(Generator generatorOutputDevice)
(Tensor
gradient
('Layout 'Dense)
generatorOutputDevice
dataType
('Shape '[outputDim, inputDim]))
weight forall {k1} (f :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a b
(k2 :: k1).
IxApplicative f =>
f i j (a -> b) -> f j k2 a -> f i k2 b
<<*>> IxStateT
m
(Generator generatorOutputDevice)
(Generator generatorOutputDevice)
(Tensor
gradient
('Layout 'Dense)
(Unify (Device (DeviceType Nat)) device generatorOutputDevice)
dataType
('Shape '[outputDim]))
bias
instance
HasInitialize
(GLinear weight bias)
generatorDevice
(GLinear weight bias)
generatorDevice =>
HasInitialize
(GLinear (NamedModel weight) (NamedModel bias))
generatorDevice
(GLinear (NamedModel weight) (NamedModel bias))
generatorDevice
where
initialize :: forall (m :: * -> *).
MonadThrow m =>
ModelSpec (GLinear (NamedModel weight) (NamedModel bias))
-> Generator generatorDevice
-> m (GLinear (NamedModel weight) (NamedModel bias),
Generator generatorDevice)
initialize GLinear {NamedModel (ModelSpec weight)
NamedModel (ModelSpec bias)
linearBias :: NamedModel (ModelSpec bias)
linearWeight :: NamedModel (ModelSpec weight)
linearBias :: forall weight bias. GLinear weight bias -> bias
linearWeight :: forall weight bias. GLinear weight bias -> weight
..} =
let NamedModel Text
weightName ModelSpec weight
weightSpec = NamedModel (ModelSpec weight)
linearWeight
NamedModel Text
biasName ModelSpec bias
biasSpec = NamedModel (ModelSpec bias)
linearBias
linear :: IxStateT
m
(Generator generatorDevice)
(Generator generatorDevice)
(GLinear weight bias)
linear = forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model (generatorDevice :: Device (DeviceType Nat)) output
(generatorOutputDevice :: Device (DeviceType Nat)) (m :: * -> *).
(HasInitialize model generatorDevice output generatorOutputDevice,
MonadThrow m) =>
ModelSpec model
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
initialize forall a b. (a -> b) -> a -> b
$ forall weight bias. weight -> bias -> GLinear weight bias
GLinear ModelSpec weight
weightSpec ModelSpec bias
biasSpec
in forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT forall a b. (a -> b) -> a -> b
$ IxStateT
m
(Generator generatorDevice)
(Generator generatorDevice)
(GLinear weight bias)
linear forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
(k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn forall b c a. (b -> c) -> (a -> b) -> a -> c
. (\(GLinear weight
weight bias
bias) -> forall weight bias. weight -> bias -> GLinear weight bias
GLinear (forall model. Text -> model -> NamedModel model
NamedModel Text
weightName weight
weight) (forall model. Text -> model -> NamedModel model
NamedModel Text
biasName bias
bias))
instance
( HasStateDict weight,
HasStateDict bias
) =>
HasStateDict (GLinear weight bias)
where
fromStateDict :: forall (m :: * -> *).
(MonadIO m, MonadThrow m, MonadState StateDict m) =>
ModelSpec (GLinear weight bias) -> Text -> m (GLinear weight bias)
fromStateDict GLinear {ModelSpec weight
ModelSpec bias
linearBias :: ModelSpec bias
linearWeight :: ModelSpec weight
linearBias :: forall weight bias. GLinear weight bias -> bias
linearWeight :: forall weight bias. GLinear weight bias -> weight
..} Text
k = forall weight bias. weight -> bias -> GLinear weight bias
GLinear forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall model (m :: * -> *).
(HasStateDict model, MonadIO m, MonadThrow m,
MonadState StateDict m) =>
ModelSpec model -> Text -> m model
fromStateDict ModelSpec weight
linearWeight Text
k forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall model (m :: * -> *).
(HasStateDict model, MonadIO m, MonadThrow m,
MonadState StateDict m) =>
ModelSpec model -> Text -> m model
fromStateDict ModelSpec bias
linearBias Text
k
toStateDict :: forall (m :: * -> *).
(MonadThrow m, MonadState StateDict m) =>
Text -> GLinear weight bias -> m ()
toStateDict Text
k GLinear {weight
bias
linearBias :: bias
linearWeight :: weight
linearBias :: forall weight bias. GLinear weight bias -> bias
linearWeight :: forall weight bias. GLinear weight bias -> weight
..} = do
forall model (m :: * -> *).
(HasStateDict model, MonadThrow m, MonadState StateDict m) =>
Text -> model -> m ()
toStateDict Text
k weight
linearWeight
forall model (m :: * -> *).
(HasStateDict model, MonadThrow m, MonadState StateDict m) =>
Text -> model -> m ()
toStateDict Text
k bias
linearBias
instance
( output
~ Tensor
(gradient <|> gradient')
('Layout 'Dense <+> layout')
(device <+> device')
(dataType <+> dataType')
(LinearWithoutBiasF ('Shape '[outputDim, inputDim]) shape')
) =>
HasForward
( GLinear
(Tensor gradient ('Layout 'Dense) device dataType ('Shape '[outputDim, inputDim]))
()
)
(Tensor gradient' layout' device' dataType' shape')
generatorDevice
output
generatorDevice
where
forward :: forall (m :: * -> *).
MonadThrow m =>
GLinear
(Tensor
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim]))
()
-> Tensor gradient' layout' device' dataType' shape'
-> Generator generatorDevice
-> m (output, Generator generatorDevice)
forward GLinear {()
Tensor
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim])
linearBias :: ()
linearWeight :: Tensor
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim])
linearBias :: forall weight bias. GLinear weight bias -> bias
linearWeight :: forall weight bias. GLinear weight bias -> weight
..} Tensor gradient' layout' device' dataType' shape'
input = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(gradient' :: Gradient RequiresGradient)
(layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
(dataType' :: DataType DType)
(shape' :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> Tensor
(gradient <|> gradient')
(layout <+> layout')
(device <+> device')
(dataType <+> dataType')
(LinearWithoutBiasF shape shape')
linearWithoutBias Tensor
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim])
linearWeight Tensor gradient' layout' device' dataType' shape'
input,)
instance
( output
~ Tensor
(gradient <|> gradient')
('Layout 'Dense <+> layout')
(device <+> device')
(dataType <+> dataType')
(LinearWithBiasF ('Shape '[outputDim, inputDim]) ('Shape '[outputDim]) shape')
) =>
HasForward
( GLinear
(Tensor gradient ('Layout 'Dense) device dataType ('Shape '[outputDim, inputDim]))
(Tensor gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))
)
(Tensor gradient' layout' device' dataType' shape')
generatorDevice
output
generatorDevice
where
forward :: forall (m :: * -> *).
MonadThrow m =>
GLinear
(Tensor
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim]))
(Tensor
gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))
-> Tensor gradient' layout' device' dataType' shape'
-> Generator generatorDevice
-> m (output, Generator generatorDevice)
forward GLinear {Tensor
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim])
Tensor
gradient ('Layout 'Dense) device dataType ('Shape '[outputDim])
linearBias :: Tensor
gradient ('Layout 'Dense) device dataType ('Shape '[outputDim])
linearWeight :: Tensor
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim])
linearBias :: forall weight bias. GLinear weight bias -> bias
linearWeight :: forall weight bias. GLinear weight bias -> weight
..} Tensor gradient' layout' device' dataType' shape'
input = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(gradient' :: Gradient RequiresGradient)
(layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
(dataType' :: DataType DType)
(shape' :: Shape [Dim (Name Symbol) (Size Nat)])
(gradient'' :: Gradient RequiresGradient)
(layout'' :: Layout LayoutType)
(device'' :: Device (DeviceType Nat))
(dataType'' :: DataType DType)
(shape'' :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> Tensor gradient'' layout'' device'' dataType'' shape''
-> Tensor
(gradient' <|> (gradient'' <|> gradient''))
(layout <+> (layout' <+> layout''))
(device <+> (device' <+> device''))
(dataType <+> (dataType' <+> dataType''))
(LinearWithBiasF shape shape' shape'')
linearWithBias Tensor
gradient
('Layout 'Dense)
device
dataType
('Shape '[outputDim, inputDim])
linearWeight Tensor
gradient ('Layout 'Dense) device dataType ('Shape '[outputDim])
linearBias Tensor gradient' layout' device' dataType' shape'
input,)
instance
HasForward
(GLinear weight bias)
input
generatorDevice
output
generatorDevice =>
HasForward
(GLinear (NamedModel weight) (NamedModel bias))
input
generatorDevice
output
generatorDevice
where
forward :: forall (m :: * -> *).
MonadThrow m =>
GLinear (NamedModel weight) (NamedModel bias)
-> input
-> Generator generatorDevice
-> m (output, Generator generatorDevice)
forward GLinear {NamedModel weight
NamedModel bias
linearBias :: NamedModel bias
linearWeight :: NamedModel weight
linearBias :: forall weight bias. GLinear weight bias -> bias
linearWeight :: forall weight bias. GLinear weight bias -> weight
..} input
input =
let NamedModel Text
_ weight
weight = NamedModel weight
linearWeight
NamedModel Text
_ bias
bias = NamedModel bias
linearBias
in forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward (forall weight bias. weight -> bias -> GLinear weight bias
GLinear weight
weight bias
bias) input
input