{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Torch.GraduallyTyped.NN.Training where
import Control.Monad.IO.Class (MonadIO (..))
import qualified Pipes as P
import qualified Pipes.Prelude as P
import Torch.GraduallyTyped.NN.Class (HasForward (..), HasStateDict (..), ModelSpec)
import Torch.GraduallyTyped.Optim (Optimizer, stepWithGenerator)
import Torch.GraduallyTyped.Prelude (Catch)
import Torch.GraduallyTyped.Prelude.List (SList (..))
import Torch.GraduallyTyped.Random (Generator, SGetGeneratorDevice)
import Torch.GraduallyTyped.RequiresGradient (Gradient (..), RequiresGradient (..), SGradient (..), SRequiresGradient (..))
import Torch.GraduallyTyped.Shape.Type (SShape (..), Shape (..))
import Torch.GraduallyTyped.Tensor.MathOperations.Pointwise (divScalar)
import Torch.GraduallyTyped.Tensor.Type (SGetGradient, SGetShape, Tensor, sCheckedGradient, sCheckedShape, withoutGradient)
import Torch.GraduallyTyped.Unify (type (<+>))
train ::
forall m model input generatorDevice lossGradient lossLayout lossDataType lossDevice lossShape generatorOutputDevice.
( MonadIO m,
HasStateDict model,
HasForward model input generatorDevice (Tensor lossGradient lossLayout lossDataType lossDevice lossShape) generatorOutputDevice,
HasForward model input generatorOutputDevice (Tensor lossGradient lossLayout lossDataType lossDevice lossShape) generatorOutputDevice,
SGetGeneratorDevice generatorDevice,
SGetGeneratorDevice generatorOutputDevice,
SGetGradient lossGradient,
SGetShape lossShape,
Catch (lossShape <+> 'Shape '[]),
Catch (lossGradient <+> 'Gradient 'WithGradient)
) =>
Optimizer model ->
ModelSpec model ->
P.ListT m input ->
Generator generatorDevice ->
m
( Either
(Generator generatorDevice)
(Tensor ('Gradient 'WithoutGradient) lossLayout lossDataType lossDevice ('Shape '[]), Generator generatorOutputDevice)
)
train :: forall (m :: * -> *) model input
(generatorDevice :: Device (DeviceType Nat))
(lossGradient :: Gradient RequiresGradient)
(lossLayout :: Layout LayoutType)
(lossDataType :: Device (DeviceType Nat))
(lossDevice :: DataType DType)
(lossShape :: Shape [Dim (Name Symbol) (Size Nat)])
(generatorOutputDevice :: Device (DeviceType Nat)).
(MonadIO m, HasStateDict model,
HasForward
model
input
generatorDevice
(Tensor lossGradient lossLayout lossDataType lossDevice lossShape)
generatorOutputDevice,
HasForward
model
input
generatorOutputDevice
(Tensor lossGradient lossLayout lossDataType lossDevice lossShape)
generatorOutputDevice,
SGetGeneratorDevice generatorDevice,
SGetGeneratorDevice generatorOutputDevice,
SGetGradient lossGradient, SGetShape lossShape,
Catch (lossShape <+> 'Shape '[]),
Catch (lossGradient <+> 'Gradient 'WithGradient)) =>
Optimizer model
-> ModelSpec model
-> ListT m input
-> Generator generatorDevice
-> m (Either
(Generator generatorDevice)
(Tensor
('Gradient 'WithoutGradient)
lossLayout
lossDataType
lossDevice
('Shape '[]),
Generator generatorOutputDevice))
train Optimizer model
optim ModelSpec model
modelSpec ListT m input
examples Generator generatorDevice
g = do
let producer :: Proxy X () () (input, Int) m ()
producer = forall (m :: * -> *) a r b x' x.
Monad m =>
Producer a m r -> Producer b m r -> Proxy x' x () (a, b) m r
P.zip (forall (m :: * -> *) a. ListT m a -> Producer a m ()
P.enumerate ListT m input
examples) (forall (m :: * -> *) (f :: * -> *) a x' x.
(Functor m, Foldable f) =>
f a -> Proxy x' x () a m ()
P.each [Int
0 :: Int ..])
Either () ((input, Int), Proxy X () () (input, Int) m ())
x <- forall (m :: * -> *) a r.
Monad m =>
Producer a m r -> m (Either r (a, Producer a m r))
P.next Proxy X () () (input, Int) m ()
producer
case Either () ((input, Int), Proxy X () () (input, Int) m ())
x of
Left ()
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. a -> Either a b
Left forall a b. (a -> b) -> a -> b
$ Generator generatorDevice
g
Right ((input
input, Int
iter), Proxy X () () (input, Int) m ()
producer') -> do
let step :: ((Tensor
('Gradient 'WithoutGradient)
lossLayout
lossDataType
lossDevice
('Shape '[]),
Int),
Generator generatorOutputDevice)
-> (input, Int)
-> m ((Tensor
('Gradient 'WithoutGradient)
lossLayout
lossDataType
lossDevice
('Shape '[]),
Int),
Generator generatorOutputDevice)
step ((Tensor
('Gradient 'WithoutGradient)
lossLayout
lossDataType
lossDevice
('Shape '[])
loss, Int
_), Generator generatorOutputDevice
g') (input
input', Int
iter') = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ do
let forward' :: model
-> Generator generatorOutputDevice
-> IO
(Tensor
('Gradient 'WithGradient)
lossLayout
lossDataType
lossDevice
('Shape '[]),
Generator generatorOutputDevice)
forward' model
model Generator generatorOutputDevice
g'' = do
(Tensor lossGradient lossLayout lossDataType lossDevice lossShape
loss', Generator generatorOutputDevice
g''') <- forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward model
model input
input' Generator generatorOutputDevice
g''
Tensor
('Gradient 'WithGradient)
lossLayout
lossDataType
lossDevice
('Shape '[])
loss'' <- forall (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
(m :: * -> *) (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(SGetShape shape, MonadThrow m, Catch (shape <+> shape')) =>
SShape shape'
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
sCheckedShape (forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SList dims -> SShape ('Shape dims)
SShape forall a. SList '[]
SNil) forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (gradient' :: Gradient RequiresGradient) (m :: * -> *)
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(SGetGradient gradient, MonadThrow m,
Catch (gradient <+> gradient')) =>
SGradient gradient'
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient' layout device dataType shape)
sCheckedGradient (forall (requiresGradient :: RequiresGradient).
SRequiresGradient requiresGradient
-> SGradient ('Gradient requiresGradient)
SGradient SRequiresGradient 'WithGradient
SWithGradient) Tensor lossGradient lossLayout lossDataType lossDevice lossShape
loss'
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Tensor
('Gradient 'WithGradient)
lossLayout
lossDataType
lossDevice
('Shape '[])
loss'', Generator generatorOutputDevice
g''')
(Tensor
('Gradient 'WithGradient)
lossLayout
lossDataType
lossDevice
('Shape '[])
loss', Generator generatorOutputDevice
g'') <- forall model (generatorDevice :: Device (DeviceType Nat))
(lossGradient :: Gradient RequiresGradient)
(lossLayout :: Layout LayoutType)
(lossDataType :: Device (DeviceType Nat))
(lossDevice :: DataType DType)
(lossShape :: Shape [Dim (Name Symbol) (Size Nat)])
(generatorOutputDevice :: Device (DeviceType Nat)).
(HasStateDict model, SGetGeneratorDevice generatorDevice,
SGetGeneratorDevice generatorOutputDevice,
Catch (lossShape <+> 'Shape '[]),
Catch (lossGradient <+> 'Gradient 'WithGradient)) =>
Optimizer model
-> ModelSpec model
-> (model
-> Generator generatorDevice
-> IO
(Tensor lossGradient lossLayout lossDataType lossDevice lossShape,
Generator generatorOutputDevice))
-> Generator generatorDevice
-> IO
(Tensor lossGradient lossLayout lossDataType lossDevice lossShape,
Generator generatorOutputDevice)
stepWithGenerator Optimizer model
optim ModelSpec model
modelSpec model
-> Generator generatorOutputDevice
-> IO
(Tensor
('Gradient 'WithGradient)
lossLayout
lossDataType
lossDevice
('Shape '[]),
Generator generatorOutputDevice)
forward' Generator generatorOutputDevice
g'
Tensor
('Gradient 'WithoutGradient)
lossLayout
lossDataType
lossDevice
('Shape '[])
loss'' <- forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient layout device dataType shape
-> IO
(Tensor ('Gradient 'WithoutGradient) layout device dataType shape)
withoutGradient Tensor
('Gradient 'WithGradient)
lossLayout
lossDataType
lossDevice
('Shape '[])
loss'
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Tensor
('Gradient 'WithoutGradient)
lossLayout
lossDataType
lossDevice
('Shape '[])
loss forall a. Num a => a -> a -> a
+ Tensor
('Gradient 'WithoutGradient)
lossLayout
lossDataType
lossDevice
('Shape '[])
loss'', Int
iter'), Generator generatorOutputDevice
g'')
init' :: m ((Tensor
('Gradient 'WithoutGradient)
lossLayout
lossDataType
lossDevice
('Shape '[]),
Int),
Generator generatorOutputDevice)
init' = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ do
let forward' :: model
-> Generator generatorDevice
-> IO
(Tensor
('Gradient 'WithGradient)
lossLayout
lossDataType
lossDevice
('Shape '[]),
Generator generatorOutputDevice)
forward' model
model Generator generatorDevice
g' = do
(Tensor lossGradient lossLayout lossDataType lossDevice lossShape
loss, Generator generatorOutputDevice
g'') <- forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward model
model input
input Generator generatorDevice
g'
Tensor
('Gradient 'WithGradient)
lossLayout
lossDataType
lossDevice
('Shape '[])
loss' <- forall (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
(m :: * -> *) (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(SGetShape shape, MonadThrow m, Catch (shape <+> shape')) =>
SShape shape'
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
sCheckedShape (forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SList dims -> SShape ('Shape dims)
SShape forall a. SList '[]
SNil) forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (gradient' :: Gradient RequiresGradient) (m :: * -> *)
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(SGetGradient gradient, MonadThrow m,
Catch (gradient <+> gradient')) =>
SGradient gradient'
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient' layout device dataType shape)
sCheckedGradient (forall (requiresGradient :: RequiresGradient).
SRequiresGradient requiresGradient
-> SGradient ('Gradient requiresGradient)
SGradient SRequiresGradient 'WithGradient
SWithGradient) Tensor lossGradient lossLayout lossDataType lossDevice lossShape
loss
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Tensor
('Gradient 'WithGradient)
lossLayout
lossDataType
lossDevice
('Shape '[])
loss', Generator generatorOutputDevice
g'')
(Tensor
('Gradient 'WithGradient)
lossLayout
lossDataType
lossDevice
('Shape '[])
loss, Generator generatorOutputDevice
g') <- forall model (generatorDevice :: Device (DeviceType Nat))
(lossGradient :: Gradient RequiresGradient)
(lossLayout :: Layout LayoutType)
(lossDataType :: Device (DeviceType Nat))
(lossDevice :: DataType DType)
(lossShape :: Shape [Dim (Name Symbol) (Size Nat)])
(generatorOutputDevice :: Device (DeviceType Nat)).
(HasStateDict model, SGetGeneratorDevice generatorDevice,
SGetGeneratorDevice generatorOutputDevice,
Catch (lossShape <+> 'Shape '[]),
Catch (lossGradient <+> 'Gradient 'WithGradient)) =>
Optimizer model
-> ModelSpec model
-> (model
-> Generator generatorDevice
-> IO
(Tensor lossGradient lossLayout lossDataType lossDevice lossShape,
Generator generatorOutputDevice))
-> Generator generatorDevice
-> IO
(Tensor lossGradient lossLayout lossDataType lossDevice lossShape,
Generator generatorOutputDevice)
stepWithGenerator Optimizer model
optim ModelSpec model
modelSpec model
-> Generator generatorDevice
-> IO
(Tensor
('Gradient 'WithGradient)
lossLayout
lossDataType
lossDevice
('Shape '[]),
Generator generatorOutputDevice)
forward' Generator generatorDevice
g
Tensor
('Gradient 'WithoutGradient)
lossLayout
lossDataType
lossDevice
('Shape '[])
loss' <- forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient layout device dataType shape
-> IO
(Tensor ('Gradient 'WithoutGradient) layout device dataType shape)
withoutGradient Tensor
('Gradient 'WithGradient)
lossLayout
lossDataType
lossDevice
('Shape '[])
loss
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Tensor
('Gradient 'WithoutGradient)
lossLayout
lossDataType
lossDevice
('Shape '[])
loss', Int
iter), Generator generatorOutputDevice
g')
done :: ((Tensor gradient layout device dataType shape, divisor), b)
-> f (Either a (Tensor gradient layout device dataType shape, b))
done ((Tensor gradient layout device dataType shape
loss, divisor
iter'), b
g'') = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. b -> Either a b
Right forall a b. (a -> b) -> a -> b
$ (Tensor gradient layout device dataType shape
loss forall divisor (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Scalar divisor =>
Tensor gradient layout device dataType shape
-> divisor -> Tensor gradient layout device dataType shape
`divScalar` divisor
iter', b
g'')
forall (m :: * -> *) x a b.
Monad m =>
(x -> a -> m x) -> m x -> (x -> m b) -> Producer a m () -> m b
P.foldM ((Tensor
('Gradient 'WithoutGradient)
lossLayout
lossDataType
lossDevice
('Shape '[]),
Int),
Generator generatorOutputDevice)
-> (input, Int)
-> m ((Tensor
('Gradient 'WithoutGradient)
lossLayout
lossDataType
lossDevice
('Shape '[]),
Int),
Generator generatorOutputDevice)
step m ((Tensor
('Gradient 'WithoutGradient)
lossLayout
lossDataType
lossDevice
('Shape '[]),
Int),
Generator generatorOutputDevice)
init' forall {f :: * -> *} {divisor}
{gradient :: Gradient RequiresGradient}
{layout :: Layout LayoutType} {device :: Device (DeviceType Nat)}
{dataType :: DataType DType}
{shape :: Shape [Dim (Name Symbol) (Size Nat)]} {b} {a}.
(Applicative f, Scalar divisor) =>
((Tensor gradient layout device dataType shape, divisor), b)
-> f (Either a (Tensor gradient layout device dataType shape, b))
done Proxy X () () (input, Int) m ()
producer'
eval ::
( MonadIO m,
HasStateDict model,
HasForward model input generatorDevice (Tensor lossGradient lossLayout lossDataType lossDevice lossShape) generatorOutputDevice,
HasForward model input generatorOutputDevice (Tensor lossGradient lossLayout lossDataType lossDevice lossShape) generatorOutputDevice,
SGetGradient lossGradient,
SGetShape lossShape,
Catch (lossShape <+> 'Shape '[]),
Catch (lossGradient <+> 'Gradient 'WithoutGradient)
) =>
model ->
P.ListT m input ->
Generator generatorDevice ->
m
( Either
(Generator generatorDevice)
(Tensor ('Gradient 'WithoutGradient) lossLayout lossDataType lossDevice ('Shape '[]), Generator generatorOutputDevice)
)
eval :: forall (m :: * -> *) model input
(generatorDevice :: Device (DeviceType Nat))
(lossGradient :: Gradient RequiresGradient)
(lossLayout :: Layout LayoutType)
(lossDataType :: Device (DeviceType Nat))
(lossDevice :: DataType DType)
(lossShape :: Shape [Dim (Name Symbol) (Size Nat)])
(generatorOutputDevice :: Device (DeviceType Nat)).
(MonadIO m, HasStateDict model,
HasForward
model
input
generatorDevice
(Tensor lossGradient lossLayout lossDataType lossDevice lossShape)
generatorOutputDevice,
HasForward
model
input
generatorOutputDevice
(Tensor lossGradient lossLayout lossDataType lossDevice lossShape)
generatorOutputDevice,
SGetGradient lossGradient, SGetShape lossShape,
Catch (lossShape <+> 'Shape '[]),
Catch (lossGradient <+> 'Gradient 'WithoutGradient)) =>
model
-> ListT m input
-> Generator generatorDevice
-> m (Either
(Generator generatorDevice)
(Tensor
('Gradient 'WithoutGradient)
lossLayout
lossDataType
lossDevice
('Shape '[]),
Generator generatorOutputDevice))
eval model
model ListT m input
examples Generator generatorDevice
g = do
let producer :: Proxy X () () (input, Int) m ()
producer = forall (m :: * -> *) a r b x' x.
Monad m =>
Producer a m r -> Producer b m r -> Proxy x' x () (a, b) m r
P.zip (forall (m :: * -> *) a. ListT m a -> Producer a m ()
P.enumerate ListT m input
examples) (forall (m :: * -> *) (f :: * -> *) a x' x.
(Functor m, Foldable f) =>
f a -> Proxy x' x () a m ()
P.each [Int
0 :: Int ..])
Either () ((input, Int), Proxy X () () (input, Int) m ())
x <- forall (m :: * -> *) a r.
Monad m =>
Producer a m r -> m (Either r (a, Producer a m r))
P.next Proxy X () () (input, Int) m ()
producer
case Either () ((input, Int), Proxy X () () (input, Int) m ())
x of
Left ()
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. a -> Either a b
Left forall a b. (a -> b) -> a -> b
$ Generator generatorDevice
g
Right ((input
input, Int
iter), Proxy X () () (input, Int) m ()
producer') -> do
let step :: ((Tensor
('Gradient 'WithoutGradient)
lossLayout
lossDataType
lossDevice
('Shape '[]),
Int),
Generator generatorOutputDevice)
-> (input, Int)
-> m ((Tensor
('Gradient 'WithoutGradient)
lossLayout
lossDataType
lossDevice
('Shape '[]),
Int),
Generator generatorOutputDevice)
step ((Tensor
('Gradient 'WithoutGradient)
lossLayout
lossDataType
lossDevice
('Shape '[])
loss, Int
_), Generator generatorOutputDevice
g') (input
input', Int
iter') = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ do
(Tensor lossGradient lossLayout lossDataType lossDevice lossShape
loss', Generator generatorOutputDevice
g'') <- forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward model
model input
input' Generator generatorOutputDevice
g'
Tensor
('Gradient 'WithoutGradient)
lossLayout
lossDataType
lossDevice
('Shape '[])
loss'' <- forall (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
(m :: * -> *) (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(SGetShape shape, MonadThrow m, Catch (shape <+> shape')) =>
SShape shape'
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
sCheckedShape (forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SList dims -> SShape ('Shape dims)
SShape forall a. SList '[]
SNil) forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (gradient' :: Gradient RequiresGradient) (m :: * -> *)
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(SGetGradient gradient, MonadThrow m,
Catch (gradient <+> gradient')) =>
SGradient gradient'
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient' layout device dataType shape)
sCheckedGradient (forall (requiresGradient :: RequiresGradient).
SRequiresGradient requiresGradient
-> SGradient ('Gradient requiresGradient)
SGradient SRequiresGradient 'WithoutGradient
SWithoutGradient) Tensor lossGradient lossLayout lossDataType lossDevice lossShape
loss'
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Tensor
('Gradient 'WithoutGradient)
lossLayout
lossDataType
lossDevice
('Shape '[])
loss forall a. Num a => a -> a -> a
+ Tensor
('Gradient 'WithoutGradient)
lossLayout
lossDataType
lossDevice
('Shape '[])
loss'', Int
iter'), Generator generatorOutputDevice
g'')
init' :: m ((Tensor
('Gradient 'WithoutGradient)
lossLayout
lossDataType
lossDevice
('Shape '[]),
Int),
Generator generatorOutputDevice)
init' = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ do
(Tensor lossGradient lossLayout lossDataType lossDevice lossShape
loss, Generator generatorOutputDevice
g') <- forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward model
model input
input Generator generatorDevice
g
Tensor
('Gradient 'WithoutGradient)
lossLayout
lossDataType
lossDevice
('Shape '[])
loss' <- forall (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
(m :: * -> *) (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(SGetShape shape, MonadThrow m, Catch (shape <+> shape')) =>
SShape shape'
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
sCheckedShape (forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SList dims -> SShape ('Shape dims)
SShape forall a. SList '[]
SNil) forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (gradient' :: Gradient RequiresGradient) (m :: * -> *)
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(SGetGradient gradient, MonadThrow m,
Catch (gradient <+> gradient')) =>
SGradient gradient'
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient' layout device dataType shape)
sCheckedGradient (forall (requiresGradient :: RequiresGradient).
SRequiresGradient requiresGradient
-> SGradient ('Gradient requiresGradient)
SGradient SRequiresGradient 'WithoutGradient
SWithoutGradient) Tensor lossGradient lossLayout lossDataType lossDevice lossShape
loss
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Tensor
('Gradient 'WithoutGradient)
lossLayout
lossDataType
lossDevice
('Shape '[])
loss', Int
iter), Generator generatorOutputDevice
g')
done :: ((Tensor gradient layout device dataType shape, divisor), b)
-> f (Either a (Tensor gradient layout device dataType shape, b))
done ((Tensor gradient layout device dataType shape
loss, divisor
iter'), b
g'') = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. b -> Either a b
Right forall a b. (a -> b) -> a -> b
$ (Tensor gradient layout device dataType shape
loss forall divisor (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Scalar divisor =>
Tensor gradient layout device dataType shape
-> divisor -> Tensor gradient layout device dataType shape
`divScalar` divisor
iter', b
g'')
forall (m :: * -> *) x a b.
Monad m =>
(x -> a -> m x) -> m x -> (x -> m b) -> Producer a m () -> m b
P.foldM ((Tensor
('Gradient 'WithoutGradient)
lossLayout
lossDataType
lossDevice
('Shape '[]),
Int),
Generator generatorOutputDevice)
-> (input, Int)
-> m ((Tensor
('Gradient 'WithoutGradient)
lossLayout
lossDataType
lossDevice
('Shape '[]),
Int),
Generator generatorOutputDevice)
step m ((Tensor
('Gradient 'WithoutGradient)
lossLayout
lossDataType
lossDevice
('Shape '[]),
Int),
Generator generatorOutputDevice)
init' forall {f :: * -> *} {divisor}
{gradient :: Gradient RequiresGradient}
{layout :: Layout LayoutType} {device :: Device (DeviceType Nat)}
{dataType :: DataType DType}
{shape :: Shape [Dim (Name Symbol) (Size Nat)]} {b} {a}.
(Applicative f, Scalar divisor) =>
((Tensor gradient layout device dataType shape, divisor), b)
-> f (Either a (Tensor gradient layout device dataType shape, b))
done Proxy X () () (input, Int) m ()
producer'