{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE NoStarIsType #-}
module Torch.Typed.Optim where
import Control.Monad.State
import Data.Kind
import System.Mem (performGC)
import qualified Torch.DType as D
import qualified Torch.Device as D
import Torch.HList
import qualified Torch.Internal.Cast as ATen
import qualified Torch.Internal.Class as ATen
import Torch.Internal.GC (mallocTrim)
import qualified Torch.Tensor as D
import Torch.Typed.Autograd
import Torch.Typed.Auxiliary
import Torch.Typed.Factories
import Torch.Typed.Functional
import Torch.Typed.Parameter
import Torch.Typed.Tensor
import Prelude hiding (div, sqrt)
type LearningRate device dtype = Tensor device dtype '[]
type Loss device dtype = Tensor device dtype '[]
data ZerosLike = ZerosLike
instance
( parameter ~ Parameter device dtype shape,
momentum ~ Tensor device dtype shape,
TensorOptions shape dtype device
) =>
Apply' ZerosLike parameter momentum
where
apply' :: ZerosLike -> parameter -> momentum
apply' ZerosLike
_ parameter
_ = forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
Tensor device dtype shape
zeros
class Optimizer optim gradients tensors dtype device where
step ::
LearningRate device dtype ->
HList gradients ->
HList tensors ->
optim ->
(HList tensors, optim)
runStep ::
forall model optim parameters gradients tensors dtype device.
( Parameterized model,
parameters ~ Parameters model,
HasGrad (HList parameters) (HList gradients),
tensors ~ gradients,
HMap' ToDependent parameters tensors,
ATen.Castable (HList gradients) [D.ATenTensor],
Optimizer optim gradients tensors dtype device,
HMapM' IO MakeIndependent tensors parameters
) =>
model ->
optim ->
Loss device dtype ->
LearningRate device dtype ->
IO (model, optim)
runStep :: forall model optim (parameters :: [Type]) (gradients :: [Type])
(tensors :: [Type]) (dtype :: DType) (device :: (DeviceType, Nat)).
(Parameterized model, parameters ~ Parameters model,
HasGrad (HList parameters) (HList gradients), tensors ~ gradients,
HMap' ToDependent parameters tensors,
Castable (HList gradients) [ATenTensor],
Optimizer optim gradients tensors dtype device,
HMapM' IO MakeIndependent tensors parameters) =>
model
-> optim
-> Loss device dtype
-> Loss device dtype
-> IO (model, optim)
runStep model
model optim
optim Loss device dtype
loss Loss device dtype
learningRate = do
IO ()
performGC
CInt -> IO ()
mallocTrim CInt
0
let parameters :: HList (Parameters model)
parameters = forall f. Parameterized f => f -> HList (Parameters f)
flattenParameters model
model
gradients :: HList gradients
gradients = forall a b (dtype :: DType) (device :: (DeviceType, Nat)).
HasGrad a b =>
Tensor device dtype '[] -> a -> b
grad Loss device dtype
loss HList (Parameters model)
parameters
tensors :: HList gradients
tensors = forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
hmap' ToDependent
ToDependent HList (Parameters model)
parameters
(HList gradients
tensors', optim
optim') = forall {k} {k} optim (gradients :: [k]) (tensors :: [k])
(dtype :: DType) (device :: (DeviceType, Nat)).
Optimizer optim gradients tensors dtype device =>
LearningRate device dtype
-> HList gradients
-> HList tensors
-> optim
-> (HList tensors, optim)
step Loss device dtype
learningRate HList gradients
gradients HList gradients
tensors optim
optim
HList parameters
parameters' <- forall k (m :: Type -> Type) f (xs :: [k]) (ys :: [k]).
HMapM' m f xs ys =>
f -> HList xs -> m (HList ys)
hmapM' MakeIndependent
MakeIndependent HList gradients
tensors'
let model' :: model
model' = forall f. Parameterized f => f -> HList (Parameters f) -> f
replaceParameters model
model HList parameters
parameters'
forall (m :: Type -> Type) a. Monad m => a -> m a
return (model
model', optim
optim')
runStep' ::
forall model optim parameters gradients tensors dtype device.
( Parameterized model,
parameters ~ Parameters model,
tensors ~ gradients,
HMap' ToDependent parameters tensors,
Optimizer optim gradients tensors dtype device,
HMapM' IO MakeIndependent tensors parameters
) =>
model ->
optim ->
LearningRate device dtype ->
HList gradients ->
IO (model, optim)
runStep' :: forall model optim (parameters :: [Type]) (gradients :: [Type])
(tensors :: [Type]) (dtype :: DType) (device :: (DeviceType, Nat)).
(Parameterized model, parameters ~ Parameters model,
tensors ~ gradients, HMap' ToDependent parameters tensors,
Optimizer optim gradients tensors dtype device,
HMapM' IO MakeIndependent tensors parameters) =>
model
-> optim
-> LearningRate device dtype
-> HList gradients
-> IO (model, optim)
runStep' model
model optim
optim LearningRate device dtype
learningRate HList gradients
gradients = do
IO ()
performGC
CInt -> IO ()
mallocTrim CInt
0
let parameters :: HList (Parameters model)
parameters = forall f. Parameterized f => f -> HList (Parameters f)
flattenParameters model
model
tensors :: HList gradients
tensors = forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
hmap' ToDependent
ToDependent HList (Parameters model)
parameters
(HList gradients
tensors', optim
optim') = forall {k} {k} optim (gradients :: [k]) (tensors :: [k])
(dtype :: DType) (device :: (DeviceType, Nat)).
Optimizer optim gradients tensors dtype device =>
LearningRate device dtype
-> HList gradients
-> HList tensors
-> optim
-> (HList tensors, optim)
step LearningRate device dtype
learningRate HList gradients
gradients HList gradients
tensors optim
optim
HList parameters
parameters' <- forall k (m :: Type -> Type) f (xs :: [k]) (ys :: [k]).
HMapM' m f xs ys =>
f -> HList xs -> m (HList ys)
hmapM' MakeIndependent
MakeIndependent HList gradients
tensors'
let model' :: model
model' = forall f. Parameterized f => f -> HList (Parameters f) -> f
replaceParameters model
model HList parameters
parameters'
forall (m :: Type -> Type) a. Monad m => a -> m a
return (model
model', optim
optim')
data GD = GD
mkGD :: GD
mkGD :: GD
mkGD = GD
GD
newtype GDStep device dtype = GDStep (LearningRate device dtype)
instance
( parameter ~ Tensor device dtype shape,
gradient ~ Tensor device dtype shape,
shape ~ Broadcast '[] shape,
BasicArithmeticDTypeIsValid device dtype,
KnownDevice device
) =>
Apply' (GDStep device dtype) (parameter, gradient) parameter
where
apply' :: GDStep device dtype -> (parameter, gradient) -> parameter
apply' (GDStep LearningRate device dtype
learningRate) (parameter
parameter, gradient
gradient) =
parameter
parameter forall a. Num a => a -> a -> a
- forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
(device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
shape'' ~ Broadcast shape shape',
BasicArithmeticDTypeIsValid device dtype,
BasicArithmeticDTypeIsValid device dtype',
BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
mul LearningRate device dtype
learningRate gradient
gradient
gd ::
forall gradients tensors dtype device.
HZipWith (GDStep device dtype) tensors gradients tensors =>
LearningRate device dtype ->
HList gradients ->
HList tensors ->
GD ->
(HList tensors, GD)
gd :: forall {k} (gradients :: [k]) (tensors :: [k]) (dtype :: DType)
(device :: (DeviceType, Nat)).
HZipWith (GDStep device dtype) tensors gradients tensors =>
LearningRate device dtype
-> HList gradients -> HList tensors -> GD -> (HList tensors, GD)
gd LearningRate device dtype
learningRate HList gradients
gradients HList tensors
parameters GD
gd =
let step :: HList tensors
step = forall k f (xs :: [k]) (ys :: [k]) (zs :: [k]).
HZipWith f xs ys zs =>
f -> HList xs -> HList ys -> HList zs
hzipWith (forall (device :: (DeviceType, Nat)) (dtype :: DType).
LearningRate device dtype -> GDStep device dtype
GDStep LearningRate device dtype
learningRate) HList tensors
parameters HList gradients
gradients in (HList tensors
step, GD
gd)
instance
( HZipWith (GDStep device dtype) tensors gradients tensors
) =>
Optimizer GD gradients tensors dtype device
where
step :: LearningRate device dtype
-> HList gradients -> HList tensors -> GD -> (HList tensors, GD)
step = forall {k} (gradients :: [k]) (tensors :: [k]) (dtype :: DType)
(device :: (DeviceType, Nat)).
HZipWith (GDStep device dtype) tensors gradients tensors =>
LearningRate device dtype
-> HList gradients -> HList tensors -> GD -> (HList tensors, GD)
gd
instance Parameterized GD where
type Parameters GD = '[]
flattenParameters :: GD -> HList (Parameters GD)
flattenParameters GD
_ = forall k. HList '[]
HNil
replaceParameters :: GD -> HList (Parameters GD) -> GD
replaceParameters = forall a b. a -> b -> a
const
data GDM (momenta :: [Type]) = GDM
{ forall (momenta :: [Type]). GDM momenta -> Float
beta :: Float,
forall (momenta :: [Type]). GDM momenta -> HList momenta
momenta :: HList momenta
}
mkGDM ::
forall parameters momenta.
(HMap' ZerosLike parameters momenta) =>
Float ->
HList parameters ->
GDM momenta
mkGDM :: forall (parameters :: [Type]) (momenta :: [Type]).
HMap' ZerosLike parameters momenta =>
Float -> HList parameters -> GDM momenta
mkGDM Float
beta HList parameters
parameters = forall (momenta :: [Type]). Float -> HList momenta -> GDM momenta
GDM Float
beta (forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
hmap' ZerosLike
ZerosLike HList parameters
parameters)
data GDMStep device dtype = GDMStep Float (LearningRate device dtype)
instance
( parameter ~ Tensor device dtype shape,
gradient ~ Tensor device dtype shape,
momentum ~ Tensor device dtype shape,
shape ~ Broadcast '[] shape,
KnownDevice device,
BasicArithmeticDTypeIsValid device dtype
) =>
Apply' (GDMStep device dtype) (parameter, gradient, momentum) (parameter, momentum)
where
apply' :: GDMStep device dtype
-> (parameter, gradient, momentum) -> (parameter, momentum)
apply' (GDMStep Float
beta LearningRate device dtype
learningRate) (parameter
parameter, gradient
gradient, momentum
momentum) =
let momentum' :: Tensor device dtype shape
momentum' = forall a (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
mulScalar Float
beta momentum
momentum forall a. Num a => a -> a -> a
+ gradient
gradient
parameter' :: parameter
parameter' = parameter
parameter forall a. Num a => a -> a -> a
- forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
(device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
shape'' ~ Broadcast shape shape',
BasicArithmeticDTypeIsValid device dtype,
BasicArithmeticDTypeIsValid device dtype',
BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
mul LearningRate device dtype
learningRate Tensor device dtype shape
momentum'
in (parameter
parameter', Tensor device dtype shape
momentum')
gdm ::
forall gradients tensors momenta gdmStep dtype device.
( HZipWith3 (GDMStep device dtype) tensors gradients momenta gdmStep,
HMap' AFst gdmStep tensors,
HMap' ASnd gdmStep momenta
) =>
LearningRate device dtype ->
HList gradients ->
HList tensors ->
GDM momenta ->
(HList tensors, GDM momenta)
gdm :: forall (gradients :: [Type]) (tensors :: [Type])
(momenta :: [Type]) (gdmStep :: [Type]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(HZipWith3
(GDMStep device dtype) tensors gradients momenta gdmStep,
HMap' AFst gdmStep tensors, HMap' ASnd gdmStep momenta) =>
LearningRate device dtype
-> HList gradients
-> HList tensors
-> GDM momenta
-> (HList tensors, GDM momenta)
gdm LearningRate device dtype
learningRate HList gradients
gradients HList tensors
parameters (GDM Float
beta HList momenta
momenta) =
let step :: HList gdmStep
step = forall k f (as :: [k]) (bs :: [k]) (cs :: [k]) (ds :: [k]).
HZipWith3 f as bs cs ds =>
f -> HList as -> HList bs -> HList cs -> HList ds
hzipWith3 (forall (device :: (DeviceType, Nat)) (dtype :: DType).
Float -> LearningRate device dtype -> GDMStep device dtype
GDMStep Float
beta LearningRate device dtype
learningRate) HList tensors
parameters HList gradients
gradients HList momenta
momenta
in (forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
hmap' AFst
AFst HList gdmStep
step, forall (momenta :: [Type]). Float -> HList momenta -> GDM momenta
GDM Float
beta (forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
hmap' ASnd
ASnd HList gdmStep
step))
instance
( HZipWith3 (GDMStep device dtype) tensors gradients momenta gdmStep,
HMap' AFst gdmStep tensors,
HMap' ASnd gdmStep momenta
) =>
Optimizer (GDM momenta) gradients tensors dtype device
where
step :: LearningRate device dtype
-> HList gradients
-> HList tensors
-> GDM momenta
-> (HList tensors, GDM momenta)
step = forall (gradients :: [Type]) (tensors :: [Type])
(momenta :: [Type]) (gdmStep :: [Type]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(HZipWith3
(GDMStep device dtype) tensors gradients momenta gdmStep,
HMap' AFst gdmStep tensors, HMap' ASnd gdmStep momenta) =>
LearningRate device dtype
-> HList gradients
-> HList tensors
-> GDM momenta
-> (HList tensors, GDM momenta)
gdm
instance Parameterized (GDM momenta) where
type Parameters (GDM momenta) = momenta
flattenParameters :: GDM momenta -> HList (Parameters (GDM momenta))
flattenParameters GDM {Float
HList momenta
momenta :: HList momenta
beta :: Float
momenta :: forall (momenta :: [Type]). GDM momenta -> HList momenta
beta :: forall (momenta :: [Type]). GDM momenta -> Float
..} = HList momenta
momenta
replaceParameters :: GDM momenta -> HList (Parameters (GDM momenta)) -> GDM momenta
replaceParameters GDM momenta
gdm HList (Parameters (GDM momenta))
momenta = GDM momenta
gdm {momenta :: HList momenta
momenta = HList (Parameters (GDM momenta))
momenta}
type AdamIter = Tensor '( 'D.CPU, 0) 'D.Int64 '[]
data Adam (momenta :: [Type]) = Adam
{ forall (momenta :: [Type]). Adam momenta -> AdamIter
iter :: AdamIter,
forall (momenta :: [Type]). Adam momenta -> Float
beta1 :: Float,
forall (momenta :: [Type]). Adam momenta -> Float
beta2 :: Float,
forall (momenta :: [Type]). Adam momenta -> HList momenta
momenta1 :: HList momenta,
forall (momenta :: [Type]). Adam momenta -> HList momenta
momenta2 :: HList momenta
}
mkAdam ::
forall parameters momenta.
(HMap' ZerosLike parameters momenta) =>
AdamIter ->
Float ->
Float ->
HList parameters ->
Adam momenta
mkAdam :: forall (parameters :: [Type]) (momenta :: [Type]).
HMap' ZerosLike parameters momenta =>
AdamIter -> Float -> Float -> HList parameters -> Adam momenta
mkAdam AdamIter
iter Float
beta1 Float
beta2 HList parameters
parameters =
forall (momenta :: [Type]).
AdamIter
-> Float -> Float -> HList momenta -> HList momenta -> Adam momenta
Adam
AdamIter
iter
Float
beta1
Float
beta2
(forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
hmap' ZerosLike
ZerosLike HList parameters
parameters)
(forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
hmap' ZerosLike
ZerosLike HList parameters
parameters)
newtype AdamMomentum1Update = AdamMomentum1Update Float
instance
( gradient ~ Tensor device dtype shape,
momentum1 ~ Tensor device dtype shape,
KnownDevice device
) =>
Apply' AdamMomentum1Update (momentum1, gradient) momentum1
where
apply' :: AdamMomentum1Update -> (momentum1, gradient) -> momentum1
apply' (AdamMomentum1Update Float
beta1) (momentum1
momentum1, gradient
gradient) =
forall a (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
mulScalar Float
beta1 momentum1
momentum1 forall a. Num a => a -> a -> a
+ forall a (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
mulScalar (Float
1 forall a. Num a => a -> a -> a
- Float
beta1) gradient
gradient
newtype AdamMomentum2Update = AdamMomentum2Update Float
instance
( gradient ~ Tensor device dtype shape,
momentum2 ~ Tensor device dtype shape,
shape ~ Broadcast shape shape,
KnownDevice device,
BasicArithmeticDTypeIsValid device dtype
) =>
Apply' AdamMomentum2Update (momentum2, gradient) momentum2
where
apply' :: AdamMomentum2Update -> (momentum2, gradient) -> momentum2
apply' (AdamMomentum2Update Float
beta2) (momentum2
momentum2, gradient
gradient) =
forall a (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
mulScalar Float
beta2 momentum2
momentum2 forall a. Num a => a -> a -> a
+ forall a (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
mulScalar (Float
1 forall a. Num a => a -> a -> a
- Float
beta2) (forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
(device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
shape'' ~ Broadcast shape shape',
BasicArithmeticDTypeIsValid device dtype,
BasicArithmeticDTypeIsValid device dtype',
BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
mul gradient
gradient gradient
gradient)
data AdamBiasAdjustment = AdamBiasAdjustment AdamIter Float
instance
( momentum ~ Tensor device dtype shape,
KnownDevice device,
KnownDType dtype,
shape ~ Reverse (Reverse shape),
BasicArithmeticDTypeIsValid device dtype
) =>
Apply' AdamBiasAdjustment momentum momentum
where
apply' :: AdamBiasAdjustment -> momentum -> momentum
apply' (AdamBiasAdjustment AdamIter
iter Float
beta) momentum
momentum =
let iter' :: Tensor device dtype '[]
iter' = forall (device' :: (DeviceType, Nat)) (device :: (DeviceType, Nat))
(dtype :: DType) (shape :: [Nat]) t t'.
(KnownDevice device', IsUnnamed t device dtype shape, Unnamed t',
t' ~ ReplaceDevice'' t device') =>
t -> t'
toDevice @device @'( 'D.CPU, 0) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (dtype' :: DType) (dtype :: DType)
(device :: (DeviceType, Nat)) (shape :: [Nat]) t t'.
(KnownDType dtype', IsUnnamed t device dtype shape, Unnamed t',
t' ~ ReplaceDType'' t dtype') =>
t -> t'
toDType @dtype @'D.Int64 forall a b. (a -> b) -> a -> b
$ AdamIter
iter forall a. Num a => a -> a -> a
+ AdamIter
1
beta' :: Tensor device dtype '[]
beta' = forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)) a.
(TensorOptions shape dtype device, Scalar a) =>
a -> Tensor device dtype shape
full @'[] @dtype @device Float
beta
in momentum
momentum forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
(device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
shape'' ~ Broadcast shape shape',
BasicArithmeticDTypeIsValid device dtype,
BasicArithmeticDTypeIsValid device dtype',
BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
`div` (Tensor device dtype '[]
1 forall a. Num a => a -> a -> a
- forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(BasicArithmeticDTypeIsValid device dtype,
shape'' ~ Broadcast shape shape') =>
Tensor device dtype shape
-> Tensor device dtype shape' -> Tensor device dtype shape''
pow Tensor device dtype '[]
iter' Tensor device dtype '[]
beta')
data AdamParameterUpdate device dtype = AdamParameterUpdate Float (LearningRate device dtype)
instance
( parameter ~ Tensor device dtype shape,
momentum ~ Tensor device dtype shape,
shape ~ Broadcast '[] shape,
KnownDevice device,
BasicArithmeticDTypeIsValid device dtype,
StandardFloatingPointDTypeValidation device dtype
) =>
Apply'
(AdamParameterUpdate device dtype)
(parameter, momentum, momentum)
parameter
where
apply' :: AdamParameterUpdate device dtype
-> (parameter, momentum, momentum) -> parameter
apply'
(AdamParameterUpdate Float
eps LearningRate device dtype
learningRate)
(parameter
parameter, momentum
biasAdjustedMomentum1, momentum
biasAdjustedMomentum2) =
parameter
parameter forall a. Num a => a -> a -> a
- forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
(device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
shape'' ~ Broadcast shape shape',
BasicArithmeticDTypeIsValid device dtype,
BasicArithmeticDTypeIsValid device dtype',
BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
mul LearningRate device dtype
learningRate momentum
biasAdjustedMomentum1
forall a. Fractional a => a -> a -> a
/ forall a (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
addScalar Float
eps (forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)) t.
(StandardFloatingPointDTypeValidation device dtype,
IsUnnamed t device dtype shape) =>
t -> t
sqrt momentum
biasAdjustedMomentum2)
adam ::
forall gradients tensors momenta adamStep dtype device.
( HZipWith AdamMomentum1Update momenta gradients momenta,
HZipWith AdamMomentum2Update momenta gradients momenta,
HMap' AdamBiasAdjustment momenta momenta,
HZipWith3 (AdamParameterUpdate device dtype) tensors momenta momenta tensors
) =>
LearningRate device dtype ->
HList gradients ->
HList tensors ->
Adam momenta ->
(HList tensors, Adam momenta)
adam :: forall {k} (gradients :: [Type]) (tensors :: [Type])
(momenta :: [Type]) (adamStep :: k) (dtype :: DType)
(device :: (DeviceType, Nat)).
(HZipWith AdamMomentum1Update momenta gradients momenta,
HZipWith AdamMomentum2Update momenta gradients momenta,
HMap' AdamBiasAdjustment momenta momenta,
HZipWith3
(AdamParameterUpdate device dtype)
tensors
momenta
momenta
tensors) =>
LearningRate device dtype
-> HList gradients
-> HList tensors
-> Adam momenta
-> (HList tensors, Adam momenta)
adam LearningRate device dtype
learningRate HList gradients
gradients HList tensors
parameters Adam {Float
HList momenta
AdamIter
momenta2 :: HList momenta
momenta1 :: HList momenta
beta2 :: Float
beta1 :: Float
iter :: AdamIter
momenta2 :: forall (momenta :: [Type]). Adam momenta -> HList momenta
momenta1 :: forall (momenta :: [Type]). Adam momenta -> HList momenta
beta2 :: forall (momenta :: [Type]). Adam momenta -> Float
beta1 :: forall (momenta :: [Type]). Adam momenta -> Float
iter :: forall (momenta :: [Type]). Adam momenta -> AdamIter
..} =
(HList tensors
parameters', forall (momenta :: [Type]).
AdamIter
-> Float -> Float -> HList momenta -> HList momenta -> Adam momenta
Adam (AdamIter
iter forall a. Num a => a -> a -> a
+ AdamIter
1) Float
beta1 Float
beta2 HList momenta
momenta1' HList momenta
momenta2')
where
momenta1' :: HList momenta
momenta1' = forall k f (xs :: [k]) (ys :: [k]) (zs :: [k]).
HZipWith f xs ys zs =>
f -> HList xs -> HList ys -> HList zs
hzipWith (Float -> AdamMomentum1Update
AdamMomentum1Update Float
beta1) HList momenta
momenta1 HList gradients
gradients
momenta2' :: HList momenta
momenta2' = forall k f (xs :: [k]) (ys :: [k]) (zs :: [k]).
HZipWith f xs ys zs =>
f -> HList xs -> HList ys -> HList zs
hzipWith (Float -> AdamMomentum2Update
AdamMomentum2Update Float
beta2) HList momenta
momenta2 HList gradients
gradients
biasAdjustedMomenta1 :: HList momenta
biasAdjustedMomenta1 = forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
hmap' (AdamIter -> Float -> AdamBiasAdjustment
AdamBiasAdjustment AdamIter
iter Float
beta1) HList momenta
momenta1'
biasAdjustedMomenta2 :: HList momenta
biasAdjustedMomenta2 = forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
hmap' (AdamIter -> Float -> AdamBiasAdjustment
AdamBiasAdjustment AdamIter
iter Float
beta2) HList momenta
momenta2'
parameters' :: HList tensors
parameters' =
forall k f (as :: [k]) (bs :: [k]) (cs :: [k]) (ds :: [k]).
HZipWith3 f as bs cs ds =>
f -> HList as -> HList bs -> HList cs -> HList ds
hzipWith3
(forall (device :: (DeviceType, Nat)) (dtype :: DType).
Float
-> LearningRate device dtype -> AdamParameterUpdate device dtype
AdamParameterUpdate Float
1e-37 LearningRate device dtype
learningRate)
HList tensors
parameters
HList momenta
biasAdjustedMomenta1
HList momenta
biasAdjustedMomenta2
instance
( HZipWith AdamMomentum1Update momenta gradients momenta,
HZipWith AdamMomentum2Update momenta gradients momenta,
HMap' AdamBiasAdjustment momenta momenta,
HZipWith3 (AdamParameterUpdate device dtype) tensors momenta momenta tensors
) =>
Optimizer (Adam momenta) gradients tensors dtype device
where
step :: LearningRate device dtype
-> HList gradients
-> HList tensors
-> Adam momenta
-> (HList tensors, Adam momenta)
step = forall {k} (gradients :: [Type]) (tensors :: [Type])
(momenta :: [Type]) (adamStep :: k) (dtype :: DType)
(device :: (DeviceType, Nat)).
(HZipWith AdamMomentum1Update momenta gradients momenta,
HZipWith AdamMomentum2Update momenta gradients momenta,
HMap' AdamBiasAdjustment momenta momenta,
HZipWith3
(AdamParameterUpdate device dtype)
tensors
momenta
momenta
tensors) =>
LearningRate device dtype
-> HList gradients
-> HList tensors
-> Adam momenta
-> (HList tensors, Adam momenta)
adam
instance
HAppendFD momenta momenta (momenta ++ momenta) =>
Parameterized (Adam momenta)
where
type Parameters (Adam momenta) = AdamIter ': (momenta ++ momenta)
flattenParameters :: Adam momenta -> HList (Parameters (Adam momenta))
flattenParameters Adam {Float
HList momenta
AdamIter
momenta2 :: HList momenta
momenta1 :: HList momenta
beta2 :: Float
beta1 :: Float
iter :: AdamIter
momenta2 :: forall (momenta :: [Type]). Adam momenta -> HList momenta
momenta1 :: forall (momenta :: [Type]). Adam momenta -> HList momenta
beta2 :: forall (momenta :: [Type]). Adam momenta -> Float
beta1 :: forall (momenta :: [Type]). Adam momenta -> Float
iter :: forall (momenta :: [Type]). Adam momenta -> AdamIter
..} = AdamIter
iter forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
:. (HList momenta
momenta1 forall k (a :: [k]) (b :: [k]) (ab :: [k]).
HAppendFD a b ab =>
HList a -> HList b -> HList ab
`happendFD` HList momenta
momenta2)
replaceParameters :: Adam momenta -> HList (Parameters (Adam momenta)) -> Adam momenta
replaceParameters Adam momenta
adam (AdamIter
iter :. HList (momenta ++ momenta)
momenta) =
let (HList momenta
momenta1, HList momenta
momenta2) = forall k (a :: [k]) (b :: [k]) (ab :: [k]).
HAppendFD a b ab =>
HList ab -> (HList a, HList b)
hunappendFD HList (momenta ++ momenta)
momenta
in Adam momenta
adam {iter :: AdamIter
iter = AdamIter
iter, momenta1 :: HList momenta
momenta1 = HList momenta
momenta1, momenta2 :: HList momenta
momenta2 = HList momenta
momenta2}