{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}

module Torch.GraduallyTyped.NN.Training where

import Control.Monad.IO.Class (MonadIO (..))
import qualified Pipes as P
import qualified Pipes.Prelude as P
import Torch.GraduallyTyped.NN.Class (HasForward (..), HasStateDict (..), ModelSpec)
import Torch.GraduallyTyped.Optim (Optimizer, stepWithGenerator)
import Torch.GraduallyTyped.Prelude (Catch)
import Torch.GraduallyTyped.Prelude.List (SList (..))
import Torch.GraduallyTyped.Random (Generator, SGetGeneratorDevice)
import Torch.GraduallyTyped.RequiresGradient (Gradient (..), RequiresGradient (..), SGradient (..), SRequiresGradient (..))
import Torch.GraduallyTyped.Shape.Type (SShape (..), Shape (..))
import Torch.GraduallyTyped.Tensor.MathOperations.Pointwise (divScalar)
import Torch.GraduallyTyped.Tensor.Type (SGetGradient, SGetShape, Tensor, sCheckedGradient, sCheckedShape, withoutGradient)
import Torch.GraduallyTyped.Unify (type (<+>))

-- | Train the model for one epoch.
train ::
  forall m model input generatorDevice lossGradient lossLayout lossDataType lossDevice lossShape generatorOutputDevice.
  ( MonadIO m,
    HasStateDict model,
    HasForward model input generatorDevice (Tensor lossGradient lossLayout lossDataType lossDevice lossShape) generatorOutputDevice,
    HasForward model input generatorOutputDevice (Tensor lossGradient lossLayout lossDataType lossDevice lossShape) generatorOutputDevice,
    SGetGeneratorDevice generatorDevice,
    SGetGeneratorDevice generatorOutputDevice,
    SGetGradient lossGradient,
    SGetShape lossShape,
    Catch (lossShape <+> 'Shape '[]),
    Catch (lossGradient <+> 'Gradient 'WithGradient)
  ) =>
  -- | optimizer for the model
  Optimizer model ->
  -- | model specification
  ModelSpec model ->
  -- | stream of training examples
  P.ListT m input ->
  -- | random generator
  Generator generatorDevice ->
  -- | returned is either the original generator or the average training loss and a new generator
  m
    ( Either
        (Generator generatorDevice)
        (Tensor ('Gradient 'WithoutGradient) lossLayout lossDataType lossDevice ('Shape '[]), Generator generatorOutputDevice)
    )
train :: forall (m :: * -> *) model input
       (generatorDevice :: Device (DeviceType Nat))
       (lossGradient :: Gradient RequiresGradient)
       (lossLayout :: Layout LayoutType)
       (lossDataType :: Device (DeviceType Nat))
       (lossDevice :: DataType DType)
       (lossShape :: Shape [Dim (Name Symbol) (Size Nat)])
       (generatorOutputDevice :: Device (DeviceType Nat)).
(MonadIO m, HasStateDict model,
 HasForward
   model
   input
   generatorDevice
   (Tensor lossGradient lossLayout lossDataType lossDevice lossShape)
   generatorOutputDevice,
 HasForward
   model
   input
   generatorOutputDevice
   (Tensor lossGradient lossLayout lossDataType lossDevice lossShape)
   generatorOutputDevice,
 SGetGeneratorDevice generatorDevice,
 SGetGeneratorDevice generatorOutputDevice,
 SGetGradient lossGradient, SGetShape lossShape,
 Catch (lossShape <+> 'Shape '[]),
 Catch (lossGradient <+> 'Gradient 'WithGradient)) =>
Optimizer model
-> ModelSpec model
-> ListT m input
-> Generator generatorDevice
-> m (Either
        (Generator generatorDevice)
        (Tensor
           ('Gradient 'WithoutGradient)
           lossLayout
           lossDataType
           lossDevice
           ('Shape '[]),
         Generator generatorOutputDevice))
train Optimizer model
optim ModelSpec model
modelSpec ListT m input
examples Generator generatorDevice
g = do
  let producer :: Proxy X () () (input, Int) m ()
producer = forall (m :: * -> *) a r b x' x.
Monad m =>
Producer a m r -> Producer b m r -> Proxy x' x () (a, b) m r
P.zip (forall (m :: * -> *) a. ListT m a -> Producer a m ()
P.enumerate ListT m input
examples) (forall (m :: * -> *) (f :: * -> *) a x' x.
(Functor m, Foldable f) =>
f a -> Proxy x' x () a m ()
P.each [Int
0 :: Int ..])
  Either () ((input, Int), Proxy X () () (input, Int) m ())
x <- forall (m :: * -> *) a r.
Monad m =>
Producer a m r -> m (Either r (a, Producer a m r))
P.next Proxy X () () (input, Int) m ()
producer
  case Either () ((input, Int), Proxy X () () (input, Int) m ())
x of
    Left ()
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. a -> Either a b
Left forall a b. (a -> b) -> a -> b
$ Generator generatorDevice
g
    Right ((input
input, Int
iter), Proxy X () () (input, Int) m ()
producer') -> do
      let step :: ((Tensor
    ('Gradient 'WithoutGradient)
    lossLayout
    lossDataType
    lossDevice
    ('Shape '[]),
  Int),
 Generator generatorOutputDevice)
-> (input, Int)
-> m ((Tensor
         ('Gradient 'WithoutGradient)
         lossLayout
         lossDataType
         lossDevice
         ('Shape '[]),
       Int),
      Generator generatorOutputDevice)
step ((Tensor
  ('Gradient 'WithoutGradient)
  lossLayout
  lossDataType
  lossDevice
  ('Shape '[])
loss, Int
_), Generator generatorOutputDevice
g') (input
input', Int
iter') = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ do
            let forward' :: model
-> Generator generatorOutputDevice
-> IO
     (Tensor
        ('Gradient 'WithGradient)
        lossLayout
        lossDataType
        lossDevice
        ('Shape '[]),
      Generator generatorOutputDevice)
forward' model
model Generator generatorOutputDevice
g'' = do
                  (Tensor lossGradient lossLayout lossDataType lossDevice lossShape
loss', Generator generatorOutputDevice
g''') <- forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward model
model input
input' Generator generatorOutputDevice
g''
                  Tensor
  ('Gradient 'WithGradient)
  lossLayout
  lossDataType
  lossDevice
  ('Shape '[])
loss'' <- forall (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (m :: * -> *) (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(SGetShape shape, MonadThrow m, Catch (shape <+> shape')) =>
SShape shape'
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
sCheckedShape (forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SList dims -> SShape ('Shape dims)
SShape forall a. SList '[]
SNil) forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (gradient' :: Gradient RequiresGradient) (m :: * -> *)
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(SGetGradient gradient, MonadThrow m,
 Catch (gradient <+> gradient')) =>
SGradient gradient'
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient' layout device dataType shape)
sCheckedGradient (forall (requiresGradient :: RequiresGradient).
SRequiresGradient requiresGradient
-> SGradient ('Gradient requiresGradient)
SGradient SRequiresGradient 'WithGradient
SWithGradient) Tensor lossGradient lossLayout lossDataType lossDevice lossShape
loss'
                  forall (f :: * -> *) a. Applicative f => a -> f a
pure (Tensor
  ('Gradient 'WithGradient)
  lossLayout
  lossDataType
  lossDevice
  ('Shape '[])
loss'', Generator generatorOutputDevice
g''')
            (Tensor
  ('Gradient 'WithGradient)
  lossLayout
  lossDataType
  lossDevice
  ('Shape '[])
loss', Generator generatorOutputDevice
g'') <- forall model (generatorDevice :: Device (DeviceType Nat))
       (lossGradient :: Gradient RequiresGradient)
       (lossLayout :: Layout LayoutType)
       (lossDataType :: Device (DeviceType Nat))
       (lossDevice :: DataType DType)
       (lossShape :: Shape [Dim (Name Symbol) (Size Nat)])
       (generatorOutputDevice :: Device (DeviceType Nat)).
(HasStateDict model, SGetGeneratorDevice generatorDevice,
 SGetGeneratorDevice generatorOutputDevice,
 Catch (lossShape <+> 'Shape '[]),
 Catch (lossGradient <+> 'Gradient 'WithGradient)) =>
Optimizer model
-> ModelSpec model
-> (model
    -> Generator generatorDevice
    -> IO
         (Tensor lossGradient lossLayout lossDataType lossDevice lossShape,
          Generator generatorOutputDevice))
-> Generator generatorDevice
-> IO
     (Tensor lossGradient lossLayout lossDataType lossDevice lossShape,
      Generator generatorOutputDevice)
stepWithGenerator Optimizer model
optim ModelSpec model
modelSpec model
-> Generator generatorOutputDevice
-> IO
     (Tensor
        ('Gradient 'WithGradient)
        lossLayout
        lossDataType
        lossDevice
        ('Shape '[]),
      Generator generatorOutputDevice)
forward' Generator generatorOutputDevice
g'
            Tensor
  ('Gradient 'WithoutGradient)
  lossLayout
  lossDataType
  lossDevice
  ('Shape '[])
loss'' <- forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient layout device dataType shape
-> IO
     (Tensor ('Gradient 'WithoutGradient) layout device dataType shape)
withoutGradient Tensor
  ('Gradient 'WithGradient)
  lossLayout
  lossDataType
  lossDevice
  ('Shape '[])
loss'
            forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Tensor
  ('Gradient 'WithoutGradient)
  lossLayout
  lossDataType
  lossDevice
  ('Shape '[])
loss forall a. Num a => a -> a -> a
+ Tensor
  ('Gradient 'WithoutGradient)
  lossLayout
  lossDataType
  lossDevice
  ('Shape '[])
loss'', Int
iter'), Generator generatorOutputDevice
g'')
          init' :: m ((Tensor
      ('Gradient 'WithoutGradient)
      lossLayout
      lossDataType
      lossDevice
      ('Shape '[]),
    Int),
   Generator generatorOutputDevice)
init' = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ do
            let forward' :: model
-> Generator generatorDevice
-> IO
     (Tensor
        ('Gradient 'WithGradient)
        lossLayout
        lossDataType
        lossDevice
        ('Shape '[]),
      Generator generatorOutputDevice)
forward' model
model Generator generatorDevice
g' = do
                  (Tensor lossGradient lossLayout lossDataType lossDevice lossShape
loss, Generator generatorOutputDevice
g'') <- forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward model
model input
input Generator generatorDevice
g'
                  Tensor
  ('Gradient 'WithGradient)
  lossLayout
  lossDataType
  lossDevice
  ('Shape '[])
loss' <- forall (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (m :: * -> *) (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(SGetShape shape, MonadThrow m, Catch (shape <+> shape')) =>
SShape shape'
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
sCheckedShape (forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SList dims -> SShape ('Shape dims)
SShape forall a. SList '[]
SNil) forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (gradient' :: Gradient RequiresGradient) (m :: * -> *)
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(SGetGradient gradient, MonadThrow m,
 Catch (gradient <+> gradient')) =>
SGradient gradient'
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient' layout device dataType shape)
sCheckedGradient (forall (requiresGradient :: RequiresGradient).
SRequiresGradient requiresGradient
-> SGradient ('Gradient requiresGradient)
SGradient SRequiresGradient 'WithGradient
SWithGradient) Tensor lossGradient lossLayout lossDataType lossDevice lossShape
loss
                  forall (f :: * -> *) a. Applicative f => a -> f a
pure (Tensor
  ('Gradient 'WithGradient)
  lossLayout
  lossDataType
  lossDevice
  ('Shape '[])
loss', Generator generatorOutputDevice
g'')
            (Tensor
  ('Gradient 'WithGradient)
  lossLayout
  lossDataType
  lossDevice
  ('Shape '[])
loss, Generator generatorOutputDevice
g') <- forall model (generatorDevice :: Device (DeviceType Nat))
       (lossGradient :: Gradient RequiresGradient)
       (lossLayout :: Layout LayoutType)
       (lossDataType :: Device (DeviceType Nat))
       (lossDevice :: DataType DType)
       (lossShape :: Shape [Dim (Name Symbol) (Size Nat)])
       (generatorOutputDevice :: Device (DeviceType Nat)).
(HasStateDict model, SGetGeneratorDevice generatorDevice,
 SGetGeneratorDevice generatorOutputDevice,
 Catch (lossShape <+> 'Shape '[]),
 Catch (lossGradient <+> 'Gradient 'WithGradient)) =>
Optimizer model
-> ModelSpec model
-> (model
    -> Generator generatorDevice
    -> IO
         (Tensor lossGradient lossLayout lossDataType lossDevice lossShape,
          Generator generatorOutputDevice))
-> Generator generatorDevice
-> IO
     (Tensor lossGradient lossLayout lossDataType lossDevice lossShape,
      Generator generatorOutputDevice)
stepWithGenerator Optimizer model
optim ModelSpec model
modelSpec model
-> Generator generatorDevice
-> IO
     (Tensor
        ('Gradient 'WithGradient)
        lossLayout
        lossDataType
        lossDevice
        ('Shape '[]),
      Generator generatorOutputDevice)
forward' Generator generatorDevice
g
            Tensor
  ('Gradient 'WithoutGradient)
  lossLayout
  lossDataType
  lossDevice
  ('Shape '[])
loss' <- forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient layout device dataType shape
-> IO
     (Tensor ('Gradient 'WithoutGradient) layout device dataType shape)
withoutGradient Tensor
  ('Gradient 'WithGradient)
  lossLayout
  lossDataType
  lossDevice
  ('Shape '[])
loss
            forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Tensor
  ('Gradient 'WithoutGradient)
  lossLayout
  lossDataType
  lossDevice
  ('Shape '[])
loss', Int
iter), Generator generatorOutputDevice
g')
          done :: ((Tensor gradient layout device dataType shape, divisor), b)
-> f (Either a (Tensor gradient layout device dataType shape, b))
done ((Tensor gradient layout device dataType shape
loss, divisor
iter'), b
g'') = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. b -> Either a b
Right forall a b. (a -> b) -> a -> b
$ (Tensor gradient layout device dataType shape
loss forall divisor (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Scalar divisor =>
Tensor gradient layout device dataType shape
-> divisor -> Tensor gradient layout device dataType shape
`divScalar` divisor
iter', b
g'')
      forall (m :: * -> *) x a b.
Monad m =>
(x -> a -> m x) -> m x -> (x -> m b) -> Producer a m () -> m b
P.foldM ((Tensor
    ('Gradient 'WithoutGradient)
    lossLayout
    lossDataType
    lossDevice
    ('Shape '[]),
  Int),
 Generator generatorOutputDevice)
-> (input, Int)
-> m ((Tensor
         ('Gradient 'WithoutGradient)
         lossLayout
         lossDataType
         lossDevice
         ('Shape '[]),
       Int),
      Generator generatorOutputDevice)
step m ((Tensor
      ('Gradient 'WithoutGradient)
      lossLayout
      lossDataType
      lossDevice
      ('Shape '[]),
    Int),
   Generator generatorOutputDevice)
init' forall {f :: * -> *} {divisor}
       {gradient :: Gradient RequiresGradient}
       {layout :: Layout LayoutType} {device :: Device (DeviceType Nat)}
       {dataType :: DataType DType}
       {shape :: Shape [Dim (Name Symbol) (Size Nat)]} {b} {a}.
(Applicative f, Scalar divisor) =>
((Tensor gradient layout device dataType shape, divisor), b)
-> f (Either a (Tensor gradient layout device dataType shape, b))
done Proxy X () () (input, Int) m ()
producer'

-- | Evaluate the model on the given examples.
eval ::
  ( MonadIO m,
    HasStateDict model,
    HasForward model input generatorDevice (Tensor lossGradient lossLayout lossDataType lossDevice lossShape) generatorOutputDevice,
    HasForward model input generatorOutputDevice (Tensor lossGradient lossLayout lossDataType lossDevice lossShape) generatorOutputDevice,
    SGetGradient lossGradient,
    SGetShape lossShape,
    Catch (lossShape <+> 'Shape '[]),
    Catch (lossGradient <+> 'Gradient 'WithoutGradient)
  ) =>
  -- | model
  model ->
  -- | stream of examples
  P.ListT m input ->
  -- | random generator
  Generator generatorDevice ->
  -- | returned is either the original generator or the average evaluation loss and a new generator
  m
    ( Either
        (Generator generatorDevice)
        (Tensor ('Gradient 'WithoutGradient) lossLayout lossDataType lossDevice ('Shape '[]), Generator generatorOutputDevice)
    )
eval :: forall (m :: * -> *) model input
       (generatorDevice :: Device (DeviceType Nat))
       (lossGradient :: Gradient RequiresGradient)
       (lossLayout :: Layout LayoutType)
       (lossDataType :: Device (DeviceType Nat))
       (lossDevice :: DataType DType)
       (lossShape :: Shape [Dim (Name Symbol) (Size Nat)])
       (generatorOutputDevice :: Device (DeviceType Nat)).
(MonadIO m, HasStateDict model,
 HasForward
   model
   input
   generatorDevice
   (Tensor lossGradient lossLayout lossDataType lossDevice lossShape)
   generatorOutputDevice,
 HasForward
   model
   input
   generatorOutputDevice
   (Tensor lossGradient lossLayout lossDataType lossDevice lossShape)
   generatorOutputDevice,
 SGetGradient lossGradient, SGetShape lossShape,
 Catch (lossShape <+> 'Shape '[]),
 Catch (lossGradient <+> 'Gradient 'WithoutGradient)) =>
model
-> ListT m input
-> Generator generatorDevice
-> m (Either
        (Generator generatorDevice)
        (Tensor
           ('Gradient 'WithoutGradient)
           lossLayout
           lossDataType
           lossDevice
           ('Shape '[]),
         Generator generatorOutputDevice))
eval model
model ListT m input
examples Generator generatorDevice
g = do
  let producer :: Proxy X () () (input, Int) m ()
producer = forall (m :: * -> *) a r b x' x.
Monad m =>
Producer a m r -> Producer b m r -> Proxy x' x () (a, b) m r
P.zip (forall (m :: * -> *) a. ListT m a -> Producer a m ()
P.enumerate ListT m input
examples) (forall (m :: * -> *) (f :: * -> *) a x' x.
(Functor m, Foldable f) =>
f a -> Proxy x' x () a m ()
P.each [Int
0 :: Int ..])
  Either () ((input, Int), Proxy X () () (input, Int) m ())
x <- forall (m :: * -> *) a r.
Monad m =>
Producer a m r -> m (Either r (a, Producer a m r))
P.next Proxy X () () (input, Int) m ()
producer
  case Either () ((input, Int), Proxy X () () (input, Int) m ())
x of
    Left ()
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. a -> Either a b
Left forall a b. (a -> b) -> a -> b
$ Generator generatorDevice
g
    Right ((input
input, Int
iter), Proxy X () () (input, Int) m ()
producer') -> do
      let step :: ((Tensor
    ('Gradient 'WithoutGradient)
    lossLayout
    lossDataType
    lossDevice
    ('Shape '[]),
  Int),
 Generator generatorOutputDevice)
-> (input, Int)
-> m ((Tensor
         ('Gradient 'WithoutGradient)
         lossLayout
         lossDataType
         lossDevice
         ('Shape '[]),
       Int),
      Generator generatorOutputDevice)
step ((Tensor
  ('Gradient 'WithoutGradient)
  lossLayout
  lossDataType
  lossDevice
  ('Shape '[])
loss, Int
_), Generator generatorOutputDevice
g') (input
input', Int
iter') = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ do
            (Tensor lossGradient lossLayout lossDataType lossDevice lossShape
loss', Generator generatorOutputDevice
g'') <- forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward model
model input
input' Generator generatorOutputDevice
g'
            Tensor
  ('Gradient 'WithoutGradient)
  lossLayout
  lossDataType
  lossDevice
  ('Shape '[])
loss'' <- forall (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (m :: * -> *) (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(SGetShape shape, MonadThrow m, Catch (shape <+> shape')) =>
SShape shape'
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
sCheckedShape (forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SList dims -> SShape ('Shape dims)
SShape forall a. SList '[]
SNil) forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (gradient' :: Gradient RequiresGradient) (m :: * -> *)
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(SGetGradient gradient, MonadThrow m,
 Catch (gradient <+> gradient')) =>
SGradient gradient'
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient' layout device dataType shape)
sCheckedGradient (forall (requiresGradient :: RequiresGradient).
SRequiresGradient requiresGradient
-> SGradient ('Gradient requiresGradient)
SGradient SRequiresGradient 'WithoutGradient
SWithoutGradient) Tensor lossGradient lossLayout lossDataType lossDevice lossShape
loss'
            forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Tensor
  ('Gradient 'WithoutGradient)
  lossLayout
  lossDataType
  lossDevice
  ('Shape '[])
loss forall a. Num a => a -> a -> a
+ Tensor
  ('Gradient 'WithoutGradient)
  lossLayout
  lossDataType
  lossDevice
  ('Shape '[])
loss'', Int
iter'), Generator generatorOutputDevice
g'')
          init' :: m ((Tensor
      ('Gradient 'WithoutGradient)
      lossLayout
      lossDataType
      lossDevice
      ('Shape '[]),
    Int),
   Generator generatorOutputDevice)
init' = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ do
            (Tensor lossGradient lossLayout lossDataType lossDevice lossShape
loss, Generator generatorOutputDevice
g') <- forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward model
model input
input Generator generatorDevice
g
            Tensor
  ('Gradient 'WithoutGradient)
  lossLayout
  lossDataType
  lossDevice
  ('Shape '[])
loss' <- forall (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (m :: * -> *) (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(SGetShape shape, MonadThrow m, Catch (shape <+> shape')) =>
SShape shape'
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
sCheckedShape (forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SList dims -> SShape ('Shape dims)
SShape forall a. SList '[]
SNil) forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (gradient' :: Gradient RequiresGradient) (m :: * -> *)
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(SGetGradient gradient, MonadThrow m,
 Catch (gradient <+> gradient')) =>
SGradient gradient'
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient' layout device dataType shape)
sCheckedGradient (forall (requiresGradient :: RequiresGradient).
SRequiresGradient requiresGradient
-> SGradient ('Gradient requiresGradient)
SGradient SRequiresGradient 'WithoutGradient
SWithoutGradient) Tensor lossGradient lossLayout lossDataType lossDevice lossShape
loss
            forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Tensor
  ('Gradient 'WithoutGradient)
  lossLayout
  lossDataType
  lossDevice
  ('Shape '[])
loss', Int
iter), Generator generatorOutputDevice
g')
          done :: ((Tensor gradient layout device dataType shape, divisor), b)
-> f (Either a (Tensor gradient layout device dataType shape, b))
done ((Tensor gradient layout device dataType shape
loss, divisor
iter'), b
g'') = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. b -> Either a b
Right forall a b. (a -> b) -> a -> b
$ (Tensor gradient layout device dataType shape
loss forall divisor (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Scalar divisor =>
Tensor gradient layout device dataType shape
-> divisor -> Tensor gradient layout device dataType shape
`divScalar` divisor
iter', b
g'')
      forall (m :: * -> *) x a b.
Monad m =>
(x -> a -> m x) -> m x -> (x -> m b) -> Producer a m () -> m b
P.foldM ((Tensor
    ('Gradient 'WithoutGradient)
    lossLayout
    lossDataType
    lossDevice
    ('Shape '[]),
  Int),
 Generator generatorOutputDevice)
-> (input, Int)
-> m ((Tensor
         ('Gradient 'WithoutGradient)
         lossLayout
         lossDataType
         lossDevice
         ('Shape '[]),
       Int),
      Generator generatorOutputDevice)
step m ((Tensor
      ('Gradient 'WithoutGradient)
      lossLayout
      lossDataType
      lossDevice
      ('Shape '[]),
    Int),
   Generator generatorOutputDevice)
init' forall {f :: * -> *} {divisor}
       {gradient :: Gradient RequiresGradient}
       {layout :: Layout LayoutType} {device :: Device (DeviceType Nat)}
       {dataType :: DataType DType}
       {shape :: Shape [Dim (Name Symbol) (Size Nat)]} {b} {a}.
(Applicative f, Scalar divisor) =>
((Tensor gradient layout device dataType shape, divisor), b)
-> f (Either a (Tensor gradient layout device dataType shape, b))
done Proxy X () () (input, Int) m ()
producer'