{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fplugin TypeLevel.Rewrite
                -fplugin-opt=TypeLevel.Rewrite:Torch.GraduallyTyped.Unify.UnifyIdempotenceL2 #-}

module Torch.GraduallyTyped.NN.Linear where

import Control.Monad.Indexed (IxPointed (ireturn), (>>>=))
import Control.Monad.Indexed.State (IxStateT (..))
import Control.Monad.Indexed.Trans (IxMonadTrans (ilift))
import Data.Functor.Indexed ((<<$>>), (<<*>>))
import Data.Kind (Type)
import GHC.Generics (Generic)
import GHC.TypeLits (Nat, Symbol)
import Torch.GraduallyTyped.DType (DType (..), DataType (..), SDataType (..))
import Torch.GraduallyTyped.Device (Device (..), DeviceType (..), SDevice (..))
import Torch.GraduallyTyped.Internal.TensorOptions (tensorDims)
import Torch.GraduallyTyped.Layout (Layout (..), LayoutType (..), SLayout (..), SLayoutType (..))
import Torch.GraduallyTyped.NN.Class (HasForward (..), HasInitialize (..), HasStateDict (..), ModelSpec, NamedModel (..))
import Torch.GraduallyTyped.NN.Functional.Linear (LinearWithBiasF, LinearWithoutBiasF, linearWithBias, linearWithoutBias)
import Torch.GraduallyTyped.NN.Initialization (FanMode (..), ForNonLinearity (..), calculateFan, getter, sKaimingUniform)
import Torch.GraduallyTyped.NN.Type (HasBias (..), SHasBias (..))
import Torch.GraduallyTyped.Prelude (pattern (:|:))
import Torch.GraduallyTyped.Prelude.List (SList (..))
import Torch.GraduallyTyped.Random (SGetGeneratorDevice)
import Torch.GraduallyTyped.RequiresGradient (Gradient, RequiresGradient (..), SGradient)
import Torch.GraduallyTyped.Shape.Type (Dim (..), Name (..), SDim (..), SShape (..), Shape (..), Size (..))
import Torch.GraduallyTyped.Tensor.Creation (sRandn)
import Torch.GraduallyTyped.Tensor.MathOperations.Pointwise (mulScalar, subScalar)
import Torch.GraduallyTyped.Tensor.Type (Tensor, TensorSpec (..))
import Torch.GraduallyTyped.Unify (type (<+>), type (<|>))

-- | Generic linear model with weight and optional bias.
data
  GLinear
    (weight :: Type)
    (bias :: Type)
  where
  GLinear ::
    forall weight bias.
    { -- | Linear weight
      forall weight bias. GLinear weight bias -> weight
linearWeight :: weight,
      -- | Linear bias
      forall weight bias. GLinear weight bias -> bias
linearBias :: bias
    } ->
    GLinear weight bias
  deriving stock (GLinear weight bias -> GLinear weight bias -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall weight bias.
(Eq weight, Eq bias) =>
GLinear weight bias -> GLinear weight bias -> Bool
/= :: GLinear weight bias -> GLinear weight bias -> Bool
$c/= :: forall weight bias.
(Eq weight, Eq bias) =>
GLinear weight bias -> GLinear weight bias -> Bool
== :: GLinear weight bias -> GLinear weight bias -> Bool
$c== :: forall weight bias.
(Eq weight, Eq bias) =>
GLinear weight bias -> GLinear weight bias -> Bool
Eq, GLinear weight bias -> GLinear weight bias -> Bool
GLinear weight bias -> GLinear weight bias -> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {weight} {bias}.
(Ord weight, Ord bias) =>
Eq (GLinear weight bias)
forall weight bias.
(Ord weight, Ord bias) =>
GLinear weight bias -> GLinear weight bias -> Bool
forall weight bias.
(Ord weight, Ord bias) =>
GLinear weight bias -> GLinear weight bias -> Ordering
forall weight bias.
(Ord weight, Ord bias) =>
GLinear weight bias -> GLinear weight bias -> GLinear weight bias
min :: GLinear weight bias -> GLinear weight bias -> GLinear weight bias
$cmin :: forall weight bias.
(Ord weight, Ord bias) =>
GLinear weight bias -> GLinear weight bias -> GLinear weight bias
max :: GLinear weight bias -> GLinear weight bias -> GLinear weight bias
$cmax :: forall weight bias.
(Ord weight, Ord bias) =>
GLinear weight bias -> GLinear weight bias -> GLinear weight bias
>= :: GLinear weight bias -> GLinear weight bias -> Bool
$c>= :: forall weight bias.
(Ord weight, Ord bias) =>
GLinear weight bias -> GLinear weight bias -> Bool
> :: GLinear weight bias -> GLinear weight bias -> Bool
$c> :: forall weight bias.
(Ord weight, Ord bias) =>
GLinear weight bias -> GLinear weight bias -> Bool
<= :: GLinear weight bias -> GLinear weight bias -> Bool
$c<= :: forall weight bias.
(Ord weight, Ord bias) =>
GLinear weight bias -> GLinear weight bias -> Bool
< :: GLinear weight bias -> GLinear weight bias -> Bool
$c< :: forall weight bias.
(Ord weight, Ord bias) =>
GLinear weight bias -> GLinear weight bias -> Bool
compare :: GLinear weight bias -> GLinear weight bias -> Ordering
$ccompare :: forall weight bias.
(Ord weight, Ord bias) =>
GLinear weight bias -> GLinear weight bias -> Ordering
Ord, Int -> GLinear weight bias -> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall weight bias.
(Show weight, Show bias) =>
Int -> GLinear weight bias -> ShowS
forall weight bias.
(Show weight, Show bias) =>
[GLinear weight bias] -> ShowS
forall weight bias.
(Show weight, Show bias) =>
GLinear weight bias -> String
showList :: [GLinear weight bias] -> ShowS
$cshowList :: forall weight bias.
(Show weight, Show bias) =>
[GLinear weight bias] -> ShowS
show :: GLinear weight bias -> String
$cshow :: forall weight bias.
(Show weight, Show bias) =>
GLinear weight bias -> String
showsPrec :: Int -> GLinear weight bias -> ShowS
$cshowsPrec :: forall weight bias.
(Show weight, Show bias) =>
Int -> GLinear weight bias -> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall weight bias x.
Rep (GLinear weight bias) x -> GLinear weight bias
forall weight bias x.
GLinear weight bias -> Rep (GLinear weight bias) x
$cto :: forall weight bias x.
Rep (GLinear weight bias) x -> GLinear weight bias
$cfrom :: forall weight bias x.
GLinear weight bias -> Rep (GLinear weight bias) x
Generic)

type instance ModelSpec (GLinear weight bias) = GLinear (ModelSpec weight) (ModelSpec bias)

type family
  GLinearF
    (hasBias :: HasBias)
    (gradient :: Gradient RequiresGradient)
    (device :: Device (DeviceType Nat))
    (dataType :: DataType DType)
    (inputDim :: Dim (Name Symbol) (Size Nat))
    (outputDim :: Dim (Name Symbol) (Size Nat)) ::
    Type
  where
  GLinearF hasBias gradient device dataType inputDim outputDim =
    GLinear
      (NamedModel (LinearWeightF gradient device dataType inputDim outputDim))
      (NamedModel (LinearBiasF hasBias gradient device dataType outputDim))

type family
  LinearWeightF
    (gradient :: Gradient RequiresGradient)
    (device :: Device (DeviceType Nat))
    (dataType :: DataType DType)
    (inputDim :: Dim (Name Symbol) (Size Nat))
    (outputDim :: Dim (Name Symbol) (Size Nat)) ::
    Type
  where
  LinearWeightF gradient device dataType inputDim outputDim = Tensor gradient ('Layout 'Dense) device dataType ('Shape '[outputDim, inputDim])

type family
  LinearBiasF
    (hasBias :: HasBias)
    (gradient :: Gradient RequiresGradient)
    (device :: Device (DeviceType Nat))
    (dataType :: DataType DType)
    (outputDim :: Dim (Name Symbol) (Size Nat)) ::
    Type
  where
  LinearBiasF 'WithoutBias _ _ _ _ = ()
  LinearBiasF 'WithBias gradient device dataType outputDim = Tensor gradient ('Layout 'Dense) device dataType ('Shape '[outputDim])

linearSpec ::
  forall hasBias gradient device dataType inputDim outputDim.
  SHasBias hasBias ->
  SGradient gradient ->
  SDevice device ->
  SDataType dataType ->
  SDim inputDim ->
  SDim outputDim ->
  ModelSpec (GLinearF hasBias gradient device dataType inputDim outputDim)
linearSpec :: forall (hasBias :: HasBias) (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (inputDim :: Dim (Name Symbol) (Size Nat))
       (outputDim :: Dim (Name Symbol) (Size Nat)).
SHasBias hasBias
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim inputDim
-> SDim outputDim
-> ModelSpec
     (GLinearF hasBias gradient device dataType inputDim outputDim)
linearSpec SHasBias hasBias
hasBias SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SDim inputDim
inputDim SDim outputDim
outputDim =
  let weightSpec :: TensorSpec
  gradient
  ('Layout 'Dense)
  device
  dataType
  ('Shape '[outputDim, inputDim])
weightSpec = forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> SShape shape
-> TensorSpec gradient layout device dataType shape
TensorSpec SGradient gradient
gradient (forall (layoutType :: LayoutType).
SLayoutType layoutType -> SLayout ('Layout layoutType)
SLayout SLayoutType 'Dense
SDense) SDevice device
device SDataType dataType
dataType (forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SList dims -> SShape ('Shape dims)
SShape forall a b. (a -> b) -> a -> b
$ SDim outputDim
outputDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: SDim inputDim
inputDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall a. SList '[]
SNil)
      biasSpec :: SHasBias hasBias
-> ModelSpec
     (LinearBiasF hasBias gradient device dataType outputDim)
biasSpec SHasBias hasBias
SWithBias = forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> SShape shape
-> TensorSpec gradient layout device dataType shape
TensorSpec SGradient gradient
gradient (forall (layoutType :: LayoutType).
SLayoutType layoutType -> SLayout ('Layout layoutType)
SLayout SLayoutType 'Dense
SDense) SDevice device
device SDataType dataType
dataType (forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SList dims -> SShape ('Shape dims)
SShape forall a b. (a -> b) -> a -> b
$ SDim outputDim
outputDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall a. SList '[]
SNil)
      biasSpec SHasBias hasBias
SWithoutBias = ()
   in forall weight bias. weight -> bias -> GLinear weight bias
GLinear (forall model. Text -> model -> NamedModel model
NamedModel Text
"weight" TensorSpec
  gradient
  ('Layout 'Dense)
  device
  dataType
  ('Shape '[outputDim, inputDim])
weightSpec) (forall model. Text -> model -> NamedModel model
NamedModel Text
"bias" forall a b. (a -> b) -> a -> b
$ SHasBias hasBias
-> ModelSpec
     (LinearBiasF hasBias gradient device dataType outputDim)
biasSpec SHasBias hasBias
hasBias)

-- | TODO: Add 'ForNonLinearity' as parameter.
instance
  ( output
      ~ GLinear
          (Tensor gradient ('Layout 'Dense) (device <+> generatorDevice) dataType ('Shape '[outputDim, inputDim]))
          (),
    generatorOutputDevice ~ (device <+> generatorDevice),
    SGetGeneratorDevice generatorDevice
  ) =>
  HasInitialize
    (GLinear (Tensor gradient ('Layout 'Dense) device dataType ('Shape '[outputDim, inputDim])) ())
    generatorDevice
    output
    generatorOutputDevice
  where
  initialize :: forall (m :: * -> *).
MonadThrow m =>
ModelSpec
  (GLinear
     (Tensor
        gradient
        ('Layout 'Dense)
        device
        dataType
        ('Shape '[outputDim, inputDim]))
     ())
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
initialize GLinear {()
TensorSpec
  gradient
  ('Layout 'Dense)
  device
  dataType
  ('Shape '[outputDim, inputDim])
linearBias :: ()
linearWeight :: TensorSpec
  gradient
  ('Layout 'Dense)
  device
  dataType
  ('Shape '[outputDim, inputDim])
linearBias :: forall weight bias. GLinear weight bias -> bias
linearWeight :: forall weight bias. GLinear weight bias -> weight
..} =
    let weight :: IxStateT
  m
  (Generator generatorDevice)
  (Generator generatorOutputDevice)
  (Tensor
     gradient
     ('Layout 'Dense)
     generatorOutputDevice
     dataType
     ('Shape '[outputDim, inputDim]))
weight =
          forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall a b. (a -> b) -> a -> b
$
            forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (generatorDevice :: Device (DeviceType Nat)) (m :: * -> *).
(MonadThrow m, SGetGeneratorDevice generatorDevice) =>
TensorSpec gradient layout device dataType shape
-> FanMode
-> ForNonLinearity
-> Generator generatorDevice
-> m (Tensor
        gradient layout (device <+> generatorDevice) dataType shape,
      Generator (device <+> generatorDevice))
sKaimingUniform
              TensorSpec
  gradient
  ('Layout 'Dense)
  device
  dataType
  ('Shape '[outputDim, inputDim])
linearWeight
              FanMode
FanIn
              (Float -> ForNonLinearity
ForLeakyRelu forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Floating a => a -> a
Prelude.sqrt forall a b. (a -> b) -> a -> b
$ Float
5)
        bias :: IxStateT
  m
  (Generator generatorOutputDevice)
  (Generator generatorOutputDevice)
  ()
bias = forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model (generatorDevice :: Device (DeviceType Nat)) output
       (generatorOutputDevice :: Device (DeviceType Nat)) (m :: * -> *).
(HasInitialize model generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
ModelSpec model
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
initialize forall a b. (a -> b) -> a -> b
$ ()
linearBias
     in forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT forall a b. (a -> b) -> a -> b
$ forall weight bias. weight -> bias -> GLinear weight bias
GLinear forall {k1} {k2} (f :: k1 -> k2 -> * -> *) a b (j :: k1)
       (k3 :: k2).
IxFunctor f =>
(a -> b) -> f j k3 a -> f j k3 b
<<$>> IxStateT
  m
  (Generator generatorDevice)
  (Generator generatorOutputDevice)
  (Tensor
     gradient
     ('Layout 'Dense)
     generatorOutputDevice
     dataType
     ('Shape '[outputDim, inputDim]))
weight forall {k1} (f :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a b
       (k2 :: k1).
IxApplicative f =>
f i j (a -> b) -> f j k2 a -> f i k2 b
<<*>> IxStateT
  m
  (Generator generatorOutputDevice)
  (Generator generatorOutputDevice)
  ()
bias

instance
  ( output
      ~ GLinear
          (Tensor gradient ('Layout 'Dense) (device <+> generatorDevice) dataType ('Shape '[outputDim, inputDim]))
          (Tensor gradient ('Layout 'Dense) (device <+> generatorDevice) dataType ('Shape '[outputDim])),
    generatorOutputDevice ~ (device <+> generatorDevice),
    SGetGeneratorDevice generatorDevice,
    SGetGeneratorDevice generatorOutputDevice
  ) =>
  HasInitialize
    (GLinear (Tensor gradient ('Layout 'Dense) device dataType ('Shape '[outputDim, inputDim])) (Tensor gradient ('Layout 'Dense) device dataType ('Shape '[outputDim])))
    generatorDevice
    output
    generatorOutputDevice
  where
  initialize :: forall (m :: * -> *).
MonadThrow m =>
ModelSpec
  (GLinear
     (Tensor
        gradient
        ('Layout 'Dense)
        device
        dataType
        ('Shape '[outputDim, inputDim]))
     (Tensor
        gradient ('Layout 'Dense) device dataType ('Shape '[outputDim])))
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
initialize GLinear {TensorSpec
  gradient
  ('Layout 'Dense)
  device
  dataType
  ('Shape '[outputDim, inputDim])
TensorSpec
  gradient ('Layout 'Dense) device dataType ('Shape '[outputDim])
linearBias :: TensorSpec
  gradient ('Layout 'Dense) device dataType ('Shape '[outputDim])
linearWeight :: TensorSpec
  gradient
  ('Layout 'Dense)
  device
  dataType
  ('Shape '[outputDim, inputDim])
linearBias :: forall weight bias. GLinear weight bias -> bias
linearWeight :: forall weight bias. GLinear weight bias -> weight
..} =
    let weight :: IxStateT
  m
  (Generator generatorDevice)
  (Generator generatorOutputDevice)
  (Tensor
     gradient
     ('Layout 'Dense)
     generatorOutputDevice
     dataType
     ('Shape '[outputDim, inputDim]))
weight =
          forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall a b. (a -> b) -> a -> b
$
            forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (generatorDevice :: Device (DeviceType Nat)) (m :: * -> *).
(MonadThrow m, SGetGeneratorDevice generatorDevice) =>
TensorSpec gradient layout device dataType shape
-> FanMode
-> ForNonLinearity
-> Generator generatorDevice
-> m (Tensor
        gradient layout (device <+> generatorDevice) dataType shape,
      Generator (device <+> generatorDevice))
sKaimingUniform
              TensorSpec
  gradient
  ('Layout 'Dense)
  device
  dataType
  ('Shape '[outputDim, inputDim])
linearWeight
              FanMode
FanIn
              (Float -> ForNonLinearity
ForLeakyRelu forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Floating a => a -> a
Prelude.sqrt forall a b. (a -> b) -> a -> b
$ Float
5)
        dims :: [Dim String Integer]
dims = forall (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SShape shape -> [Dim String Integer]
tensorDims forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
TensorSpec gradient layout device dataType shape -> SShape shape
tsShape forall a b. (a -> b) -> a -> b
$ TensorSpec
  gradient
  ('Layout 'Dense)
  device
  dataType
  ('Shape '[outputDim, inputDim])
linearWeight
        Float
bound :: Float =
          Float
1
            forall a. Fractional a => a -> a -> a
/ ( forall a. Floating a => a -> a
Prelude.sqrt forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Integral a, Num b) => a -> b
fromIntegral
                  forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. FanMode -> (a, a) -> a
getter FanMode
FanIn
                  forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Dim String Integer] -> (Integer, Integer)
calculateFan
                  forall a b. (a -> b) -> a -> b
$ [Dim String Integer]
dims
              )
        bias :: IxStateT
  m
  (Generator generatorOutputDevice)
  (Generator generatorOutputDevice)
  (Tensor
     gradient
     ('Layout 'Dense)
     (Unify (Device (DeviceType Nat)) device generatorOutputDevice)
     dataType
     ('Shape '[outputDim]))
bias =
          forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT (forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (generatorDevice :: Device (DeviceType Nat)) (m :: * -> *).
(SGetGeneratorDevice generatorDevice, MonadThrow m) =>
TensorSpec gradient layout device dataType shape
-> Generator generatorDevice
-> m (Tensor
        gradient layout (device <+> generatorDevice) dataType shape,
      Generator (device <+> generatorDevice))
sRandn TensorSpec
  gradient ('Layout 'Dense) device dataType ('Shape '[outputDim])
linearBias)
            forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (t :: (* -> *) -> k -> k -> * -> *) (m :: * -> *) a
       (i :: k).
(IxMonadTrans t, Monad m) =>
m a -> t m i i a
ilift
              forall b c a. (b -> c) -> (a -> b) -> a -> c
. ( \Tensor
  gradient
  ('Layout 'Dense)
  (Unify (Device (DeviceType Nat)) device generatorOutputDevice)
  dataType
  ('Shape '[outputDim])
bias' -> do
                    Tensor
  gradient
  ('Layout 'Dense)
  (Unify (Device (DeviceType Nat)) device generatorOutputDevice)
  dataType
  ('Shape '[outputDim])
x <- Tensor
  gradient
  ('Layout 'Dense)
  (Unify (Device (DeviceType Nat)) device generatorOutputDevice)
  dataType
  ('Shape '[outputDim])
bias' forall other (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(Scalar other, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> other -> m (Tensor gradient layout device dataType shape)
`mulScalar` (Float
bound forall a. Num a => a -> a -> a
* Float
2)
                    Tensor
  gradient
  ('Layout 'Dense)
  (Unify (Device (DeviceType Nat)) device generatorOutputDevice)
  dataType
  ('Shape '[outputDim])
x forall other (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(Scalar other, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> other -> m (Tensor gradient layout device dataType shape)
`subScalar` Float
bound
                )
     in forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT forall a b. (a -> b) -> a -> b
$ forall weight bias. weight -> bias -> GLinear weight bias
GLinear forall {k1} {k2} (f :: k1 -> k2 -> * -> *) a b (j :: k1)
       (k3 :: k2).
IxFunctor f =>
(a -> b) -> f j k3 a -> f j k3 b
<<$>> IxStateT
  m
  (Generator generatorDevice)
  (Generator generatorOutputDevice)
  (Tensor
     gradient
     ('Layout 'Dense)
     generatorOutputDevice
     dataType
     ('Shape '[outputDim, inputDim]))
weight forall {k1} (f :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a b
       (k2 :: k1).
IxApplicative f =>
f i j (a -> b) -> f j k2 a -> f i k2 b
<<*>> IxStateT
  m
  (Generator generatorOutputDevice)
  (Generator generatorOutputDevice)
  (Tensor
     gradient
     ('Layout 'Dense)
     (Unify (Device (DeviceType Nat)) device generatorOutputDevice)
     dataType
     ('Shape '[outputDim]))
bias

instance
  HasInitialize
    (GLinear weight bias)
    generatorDevice
    (GLinear weight bias)
    generatorDevice =>
  HasInitialize
    (GLinear (NamedModel weight) (NamedModel bias))
    generatorDevice
    (GLinear (NamedModel weight) (NamedModel bias))
    generatorDevice
  where
  initialize :: forall (m :: * -> *).
MonadThrow m =>
ModelSpec (GLinear (NamedModel weight) (NamedModel bias))
-> Generator generatorDevice
-> m (GLinear (NamedModel weight) (NamedModel bias),
      Generator generatorDevice)
initialize GLinear {NamedModel (ModelSpec weight)
NamedModel (ModelSpec bias)
linearBias :: NamedModel (ModelSpec bias)
linearWeight :: NamedModel (ModelSpec weight)
linearBias :: forall weight bias. GLinear weight bias -> bias
linearWeight :: forall weight bias. GLinear weight bias -> weight
..} =
    let NamedModel Text
weightName ModelSpec weight
weightSpec = NamedModel (ModelSpec weight)
linearWeight
        NamedModel Text
biasName ModelSpec bias
biasSpec = NamedModel (ModelSpec bias)
linearBias
        linear :: IxStateT
  m
  (Generator generatorDevice)
  (Generator generatorDevice)
  (GLinear weight bias)
linear = forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model (generatorDevice :: Device (DeviceType Nat)) output
       (generatorOutputDevice :: Device (DeviceType Nat)) (m :: * -> *).
(HasInitialize model generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
ModelSpec model
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
initialize forall a b. (a -> b) -> a -> b
$ forall weight bias. weight -> bias -> GLinear weight bias
GLinear ModelSpec weight
weightSpec ModelSpec bias
biasSpec
     in forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT forall a b. (a -> b) -> a -> b
$ IxStateT
  m
  (Generator generatorDevice)
  (Generator generatorDevice)
  (GLinear weight bias)
linear forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn forall b c a. (b -> c) -> (a -> b) -> a -> c
. (\(GLinear weight
weight bias
bias) -> forall weight bias. weight -> bias -> GLinear weight bias
GLinear (forall model. Text -> model -> NamedModel model
NamedModel Text
weightName weight
weight) (forall model. Text -> model -> NamedModel model
NamedModel Text
biasName bias
bias))

instance
  ( HasStateDict weight,
    HasStateDict bias
  ) =>
  HasStateDict (GLinear weight bias)
  where
  fromStateDict :: forall (m :: * -> *).
(MonadIO m, MonadThrow m, MonadState StateDict m) =>
ModelSpec (GLinear weight bias) -> Text -> m (GLinear weight bias)
fromStateDict GLinear {ModelSpec weight
ModelSpec bias
linearBias :: ModelSpec bias
linearWeight :: ModelSpec weight
linearBias :: forall weight bias. GLinear weight bias -> bias
linearWeight :: forall weight bias. GLinear weight bias -> weight
..} Text
k = forall weight bias. weight -> bias -> GLinear weight bias
GLinear forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall model (m :: * -> *).
(HasStateDict model, MonadIO m, MonadThrow m,
 MonadState StateDict m) =>
ModelSpec model -> Text -> m model
fromStateDict ModelSpec weight
linearWeight Text
k forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall model (m :: * -> *).
(HasStateDict model, MonadIO m, MonadThrow m,
 MonadState StateDict m) =>
ModelSpec model -> Text -> m model
fromStateDict ModelSpec bias
linearBias Text
k
  toStateDict :: forall (m :: * -> *).
(MonadThrow m, MonadState StateDict m) =>
Text -> GLinear weight bias -> m ()
toStateDict Text
k GLinear {weight
bias
linearBias :: bias
linearWeight :: weight
linearBias :: forall weight bias. GLinear weight bias -> bias
linearWeight :: forall weight bias. GLinear weight bias -> weight
..} = do
    forall model (m :: * -> *).
(HasStateDict model, MonadThrow m, MonadState StateDict m) =>
Text -> model -> m ()
toStateDict Text
k weight
linearWeight
    forall model (m :: * -> *).
(HasStateDict model, MonadThrow m, MonadState StateDict m) =>
Text -> model -> m ()
toStateDict Text
k bias
linearBias

instance
  ( output
      ~ Tensor
          (gradient <|> gradient')
          ('Layout 'Dense <+> layout')
          (device <+> device')
          (dataType <+> dataType')
          (LinearWithoutBiasF ('Shape '[outputDim, inputDim]) shape')
  ) =>
  HasForward
    ( GLinear
        (Tensor gradient ('Layout 'Dense) device dataType ('Shape '[outputDim, inputDim]))
        ()
    )
    (Tensor gradient' layout' device' dataType' shape')
    generatorDevice
    output
    generatorDevice
  where
  forward :: forall (m :: * -> *).
MonadThrow m =>
GLinear
  (Tensor
     gradient
     ('Layout 'Dense)
     device
     dataType
     ('Shape '[outputDim, inputDim]))
  ()
-> Tensor gradient' layout' device' dataType' shape'
-> Generator generatorDevice
-> m (output, Generator generatorDevice)
forward GLinear {()
Tensor
  gradient
  ('Layout 'Dense)
  device
  dataType
  ('Shape '[outputDim, inputDim])
linearBias :: ()
linearWeight :: Tensor
  gradient
  ('Layout 'Dense)
  device
  dataType
  ('Shape '[outputDim, inputDim])
linearBias :: forall weight bias. GLinear weight bias -> bias
linearWeight :: forall weight bias. GLinear weight bias -> weight
..} Tensor gradient' layout' device' dataType' shape'
input = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient' :: Gradient RequiresGradient)
       (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
       (dataType' :: DataType DType)
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> Tensor
     (gradient <|> gradient')
     (layout <+> layout')
     (device <+> device')
     (dataType <+> dataType')
     (LinearWithoutBiasF shape shape')
linearWithoutBias Tensor
  gradient
  ('Layout 'Dense)
  device
  dataType
  ('Shape '[outputDim, inputDim])
linearWeight Tensor gradient' layout' device' dataType' shape'
input,)

instance
  ( output
      ~ Tensor
          (gradient <|> gradient')
          ('Layout 'Dense <+> layout')
          (device <+> device')
          (dataType <+> dataType')
          (LinearWithBiasF ('Shape '[outputDim, inputDim]) ('Shape '[outputDim]) shape')
  ) =>
  HasForward
    ( GLinear
        (Tensor gradient ('Layout 'Dense) device dataType ('Shape '[outputDim, inputDim]))
        (Tensor gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))
    )
    (Tensor gradient' layout' device' dataType' shape')
    generatorDevice
    output
    generatorDevice
  where
  forward :: forall (m :: * -> *).
MonadThrow m =>
GLinear
  (Tensor
     gradient
     ('Layout 'Dense)
     device
     dataType
     ('Shape '[outputDim, inputDim]))
  (Tensor
     gradient ('Layout 'Dense) device dataType ('Shape '[outputDim]))
-> Tensor gradient' layout' device' dataType' shape'
-> Generator generatorDevice
-> m (output, Generator generatorDevice)
forward GLinear {Tensor
  gradient
  ('Layout 'Dense)
  device
  dataType
  ('Shape '[outputDim, inputDim])
Tensor
  gradient ('Layout 'Dense) device dataType ('Shape '[outputDim])
linearBias :: Tensor
  gradient ('Layout 'Dense) device dataType ('Shape '[outputDim])
linearWeight :: Tensor
  gradient
  ('Layout 'Dense)
  device
  dataType
  ('Shape '[outputDim, inputDim])
linearBias :: forall weight bias. GLinear weight bias -> bias
linearWeight :: forall weight bias. GLinear weight bias -> weight
..} Tensor gradient' layout' device' dataType' shape'
input = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient' :: Gradient RequiresGradient)
       (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
       (dataType' :: DataType DType)
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient'' :: Gradient RequiresGradient)
       (layout'' :: Layout LayoutType)
       (device'' :: Device (DeviceType Nat))
       (dataType'' :: DataType DType)
       (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> Tensor gradient'' layout'' device'' dataType'' shape''
-> Tensor
     (gradient' <|> (gradient'' <|> gradient''))
     (layout <+> (layout' <+> layout''))
     (device <+> (device' <+> device''))
     (dataType <+> (dataType' <+> dataType''))
     (LinearWithBiasF shape shape' shape'')
linearWithBias Tensor
  gradient
  ('Layout 'Dense)
  device
  dataType
  ('Shape '[outputDim, inputDim])
linearWeight Tensor
  gradient ('Layout 'Dense) device dataType ('Shape '[outputDim])
linearBias Tensor gradient' layout' device' dataType' shape'
input,)

instance
  HasForward
    (GLinear weight bias)
    input
    generatorDevice
    output
    generatorDevice =>
  HasForward
    (GLinear (NamedModel weight) (NamedModel bias))
    input
    generatorDevice
    output
    generatorDevice
  where
  forward :: forall (m :: * -> *).
MonadThrow m =>
GLinear (NamedModel weight) (NamedModel bias)
-> input
-> Generator generatorDevice
-> m (output, Generator generatorDevice)
forward GLinear {NamedModel weight
NamedModel bias
linearBias :: NamedModel bias
linearWeight :: NamedModel weight
linearBias :: forall weight bias. GLinear weight bias -> bias
linearWeight :: forall weight bias. GLinear weight bias -> weight
..} input
input =
    let NamedModel Text
_ weight
weight = NamedModel weight
linearWeight
        NamedModel Text
_ bias
bias = NamedModel bias
linearBias
     in forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward (forall weight bias. weight -> bias -> GLinear weight bias
GLinear weight
weight bias
bias) input
input