{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE NoStarIsType #-}

module Torch.Typed.Optim where

import Control.Monad.State
import Data.Kind
import System.Mem (performGC)
import qualified Torch.DType as D
import qualified Torch.Device as D
import Torch.HList
import qualified Torch.Internal.Cast as ATen
import qualified Torch.Internal.Class as ATen
import Torch.Internal.GC (mallocTrim)
import qualified Torch.Tensor as D
import Torch.Typed.Autograd
import Torch.Typed.Auxiliary
import Torch.Typed.Factories
import Torch.Typed.Functional
import Torch.Typed.Parameter
import Torch.Typed.Tensor
import Prelude hiding (div, sqrt)

type LearningRate device dtype = Tensor device dtype '[]

type Loss device dtype = Tensor device dtype '[]

data ZerosLike = ZerosLike

instance
  ( parameter ~ Parameter device dtype shape,
    momentum ~ Tensor device dtype shape,
    TensorOptions shape dtype device
  ) =>
  Apply' ZerosLike parameter momentum
  where
  apply' :: ZerosLike -> parameter -> momentum
apply' ZerosLike
_ parameter
_ = forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
Tensor device dtype shape
zeros

class Optimizer optim gradients tensors dtype device where
  step ::
    LearningRate device dtype ->
    HList gradients ->
    HList tensors ->
    optim ->
    (HList tensors, optim)

runStep ::
  forall model optim parameters gradients tensors dtype device.
  ( Parameterized model,
    parameters ~ Parameters model,
    HasGrad (HList parameters) (HList gradients),
    tensors ~ gradients,
    HMap' ToDependent parameters tensors,
    ATen.Castable (HList gradients) [D.ATenTensor],
    Optimizer optim gradients tensors dtype device,
    HMapM' IO MakeIndependent tensors parameters
  ) =>
  model ->
  optim ->
  Loss device dtype ->
  LearningRate device dtype ->
  IO (model, optim)
runStep :: forall model optim (parameters :: [Type]) (gradients :: [Type])
       (tensors :: [Type]) (dtype :: DType) (device :: (DeviceType, Nat)).
(Parameterized model, parameters ~ Parameters model,
 HasGrad (HList parameters) (HList gradients), tensors ~ gradients,
 HMap' ToDependent parameters tensors,
 Castable (HList gradients) [ATenTensor],
 Optimizer optim gradients tensors dtype device,
 HMapM' IO MakeIndependent tensors parameters) =>
model
-> optim
-> Loss device dtype
-> Loss device dtype
-> IO (model, optim)
runStep model
model optim
optim Loss device dtype
loss Loss device dtype
learningRate = do
  IO ()
performGC
  CInt -> IO ()
mallocTrim CInt
0
  let parameters :: HList (Parameters model)
parameters = forall f. Parameterized f => f -> HList (Parameters f)
flattenParameters model
model
      gradients :: HList gradients
gradients = forall a b (dtype :: DType) (device :: (DeviceType, Nat)).
HasGrad a b =>
Tensor device dtype '[] -> a -> b
grad Loss device dtype
loss HList (Parameters model)
parameters
      tensors :: HList gradients
tensors = forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
hmap' ToDependent
ToDependent HList (Parameters model)
parameters
      (HList gradients
tensors', optim
optim') = forall {k} {k} optim (gradients :: [k]) (tensors :: [k])
       (dtype :: DType) (device :: (DeviceType, Nat)).
Optimizer optim gradients tensors dtype device =>
LearningRate device dtype
-> HList gradients
-> HList tensors
-> optim
-> (HList tensors, optim)
step Loss device dtype
learningRate HList gradients
gradients HList gradients
tensors optim
optim
  HList parameters
parameters' <- forall k (m :: Type -> Type) f (xs :: [k]) (ys :: [k]).
HMapM' m f xs ys =>
f -> HList xs -> m (HList ys)
hmapM' MakeIndependent
MakeIndependent HList gradients
tensors'
  let model' :: model
model' = forall f. Parameterized f => f -> HList (Parameters f) -> f
replaceParameters model
model HList parameters
parameters'
  forall (m :: Type -> Type) a. Monad m => a -> m a
return (model
model', optim
optim')

runStep' ::
  forall model optim parameters gradients tensors dtype device.
  ( Parameterized model,
    parameters ~ Parameters model,
    tensors ~ gradients,
    HMap' ToDependent parameters tensors,
    Optimizer optim gradients tensors dtype device,
    HMapM' IO MakeIndependent tensors parameters
  ) =>
  model ->
  optim ->
  LearningRate device dtype ->
  HList gradients ->
  IO (model, optim)
runStep' :: forall model optim (parameters :: [Type]) (gradients :: [Type])
       (tensors :: [Type]) (dtype :: DType) (device :: (DeviceType, Nat)).
(Parameterized model, parameters ~ Parameters model,
 tensors ~ gradients, HMap' ToDependent parameters tensors,
 Optimizer optim gradients tensors dtype device,
 HMapM' IO MakeIndependent tensors parameters) =>
model
-> optim
-> LearningRate device dtype
-> HList gradients
-> IO (model, optim)
runStep' model
model optim
optim LearningRate device dtype
learningRate HList gradients
gradients = do
  IO ()
performGC
  CInt -> IO ()
mallocTrim CInt
0
  let parameters :: HList (Parameters model)
parameters = forall f. Parameterized f => f -> HList (Parameters f)
flattenParameters model
model
      tensors :: HList gradients
tensors = forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
hmap' ToDependent
ToDependent HList (Parameters model)
parameters
      (HList gradients
tensors', optim
optim') = forall {k} {k} optim (gradients :: [k]) (tensors :: [k])
       (dtype :: DType) (device :: (DeviceType, Nat)).
Optimizer optim gradients tensors dtype device =>
LearningRate device dtype
-> HList gradients
-> HList tensors
-> optim
-> (HList tensors, optim)
step LearningRate device dtype
learningRate HList gradients
gradients HList gradients
tensors optim
optim
  HList parameters
parameters' <- forall k (m :: Type -> Type) f (xs :: [k]) (ys :: [k]).
HMapM' m f xs ys =>
f -> HList xs -> m (HList ys)
hmapM' MakeIndependent
MakeIndependent HList gradients
tensors'
  let model' :: model
model' = forall f. Parameterized f => f -> HList (Parameters f) -> f
replaceParameters model
model HList parameters
parameters'
  forall (m :: Type -> Type) a. Monad m => a -> m a
return (model
model', optim
optim')

--
-- Gradient Descent (GD)
--

-- | Dummy state representation for GD Optimizer
data GD = GD

mkGD :: GD
mkGD :: GD
mkGD = GD
GD

newtype GDStep device dtype = GDStep (LearningRate device dtype)

instance
  ( parameter ~ Tensor device dtype shape,
    gradient ~ Tensor device dtype shape,
    shape ~ Broadcast '[] shape,
    BasicArithmeticDTypeIsValid device dtype,
    KnownDevice device
  ) =>
  Apply' (GDStep device dtype) (parameter, gradient) parameter
  where
  apply' :: GDStep device dtype -> (parameter, gradient) -> parameter
apply' (GDStep LearningRate device dtype
learningRate) (parameter
parameter, gradient
gradient) =
    parameter
parameter forall a. Num a => a -> a -> a
- forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
       (device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
 shape'' ~ Broadcast shape shape',
 BasicArithmeticDTypeIsValid device dtype,
 BasicArithmeticDTypeIsValid device dtype',
 BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
mul LearningRate device dtype
learningRate gradient
gradient

-- | Gradient descent step with a dummy state variable
gd ::
  forall gradients tensors dtype device.
  HZipWith (GDStep device dtype) tensors gradients tensors =>
  LearningRate device dtype ->
  HList gradients ->
  HList tensors ->
  GD ->
  (HList tensors, GD)
gd :: forall {k} (gradients :: [k]) (tensors :: [k]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
HZipWith (GDStep device dtype) tensors gradients tensors =>
LearningRate device dtype
-> HList gradients -> HList tensors -> GD -> (HList tensors, GD)
gd LearningRate device dtype
learningRate HList gradients
gradients HList tensors
parameters GD
gd =
  let step :: HList tensors
step = forall k f (xs :: [k]) (ys :: [k]) (zs :: [k]).
HZipWith f xs ys zs =>
f -> HList xs -> HList ys -> HList zs
hzipWith (forall (device :: (DeviceType, Nat)) (dtype :: DType).
LearningRate device dtype -> GDStep device dtype
GDStep LearningRate device dtype
learningRate) HList tensors
parameters HList gradients
gradients in (HList tensors
step, GD
gd)

instance
  ( HZipWith (GDStep device dtype) tensors gradients tensors
  ) =>
  Optimizer GD gradients tensors dtype device
  where
  step :: LearningRate device dtype
-> HList gradients -> HList tensors -> GD -> (HList tensors, GD)
step = forall {k} (gradients :: [k]) (tensors :: [k]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
HZipWith (GDStep device dtype) tensors gradients tensors =>
LearningRate device dtype
-> HList gradients -> HList tensors -> GD -> (HList tensors, GD)
gd

instance Parameterized GD where
  type Parameters GD = '[]
  flattenParameters :: GD -> HList (Parameters GD)
flattenParameters GD
_ = forall k. HList '[]
HNil
  replaceParameters :: GD -> HList (Parameters GD) -> GD
replaceParameters = forall a b. a -> b -> a
const

--
-- Gradient Descent with Momentum (GDM)
--

-- | State representation for GDM Optimizer
data GDM (momenta :: [Type]) = GDM
  { forall (momenta :: [Type]). GDM momenta -> Float
beta :: Float, -- moment forgetting factor
    forall (momenta :: [Type]). GDM momenta -> HList momenta
momenta :: HList momenta -- momenta
  }

mkGDM ::
  forall parameters momenta.
  (HMap' ZerosLike parameters momenta) =>
  Float ->
  HList parameters ->
  GDM momenta
mkGDM :: forall (parameters :: [Type]) (momenta :: [Type]).
HMap' ZerosLike parameters momenta =>
Float -> HList parameters -> GDM momenta
mkGDM Float
beta HList parameters
parameters = forall (momenta :: [Type]). Float -> HList momenta -> GDM momenta
GDM Float
beta (forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
hmap' ZerosLike
ZerosLike HList parameters
parameters)

data GDMStep device dtype = GDMStep Float (LearningRate device dtype)

instance
  ( parameter ~ Tensor device dtype shape,
    gradient ~ Tensor device dtype shape,
    momentum ~ Tensor device dtype shape,
    shape ~ Broadcast '[] shape,
    KnownDevice device,
    BasicArithmeticDTypeIsValid device dtype
  ) =>
  Apply' (GDMStep device dtype) (parameter, gradient, momentum) (parameter, momentum)
  where
  apply' :: GDMStep device dtype
-> (parameter, gradient, momentum) -> (parameter, momentum)
apply' (GDMStep Float
beta LearningRate device dtype
learningRate) (parameter
parameter, gradient
gradient, momentum
momentum) =
    let momentum' :: Tensor device dtype shape
momentum' = forall a (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
mulScalar Float
beta momentum
momentum forall a. Num a => a -> a -> a
+ gradient
gradient
        parameter' :: parameter
parameter' = parameter
parameter forall a. Num a => a -> a -> a
- forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
       (device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
 shape'' ~ Broadcast shape shape',
 BasicArithmeticDTypeIsValid device dtype,
 BasicArithmeticDTypeIsValid device dtype',
 BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
mul LearningRate device dtype
learningRate Tensor device dtype shape
momentum'
     in (parameter
parameter', Tensor device dtype shape
momentum')

-- | gradient descent with momentum step
gdm ::
  forall gradients tensors momenta gdmStep dtype device.
  ( HZipWith3 (GDMStep device dtype) tensors gradients momenta gdmStep,
    HMap' AFst gdmStep tensors,
    HMap' ASnd gdmStep momenta
  ) =>
  -- | learning rate
  LearningRate device dtype ->
  -- | model parameter gradient tensors
  HList gradients ->
  -- | model parameter tensors
  HList tensors ->
  -- | beta and model parameter momentum tensors
  GDM momenta ->
  -- | returns updated parameters and momenta
  (HList tensors, GDM momenta)
gdm :: forall (gradients :: [Type]) (tensors :: [Type])
       (momenta :: [Type]) (gdmStep :: [Type]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(HZipWith3
   (GDMStep device dtype) tensors gradients momenta gdmStep,
 HMap' AFst gdmStep tensors, HMap' ASnd gdmStep momenta) =>
LearningRate device dtype
-> HList gradients
-> HList tensors
-> GDM momenta
-> (HList tensors, GDM momenta)
gdm LearningRate device dtype
learningRate HList gradients
gradients HList tensors
parameters (GDM Float
beta HList momenta
momenta) =
  let step :: HList gdmStep
step = forall k f (as :: [k]) (bs :: [k]) (cs :: [k]) (ds :: [k]).
HZipWith3 f as bs cs ds =>
f -> HList as -> HList bs -> HList cs -> HList ds
hzipWith3 (forall (device :: (DeviceType, Nat)) (dtype :: DType).
Float -> LearningRate device dtype -> GDMStep device dtype
GDMStep Float
beta LearningRate device dtype
learningRate) HList tensors
parameters HList gradients
gradients HList momenta
momenta
   in (forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
hmap' AFst
AFst HList gdmStep
step, forall (momenta :: [Type]). Float -> HList momenta -> GDM momenta
GDM Float
beta (forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
hmap' ASnd
ASnd HList gdmStep
step))

instance
  ( HZipWith3 (GDMStep device dtype) tensors gradients momenta gdmStep,
    HMap' AFst gdmStep tensors,
    HMap' ASnd gdmStep momenta
  ) =>
  Optimizer (GDM momenta) gradients tensors dtype device
  where
  step :: LearningRate device dtype
-> HList gradients
-> HList tensors
-> GDM momenta
-> (HList tensors, GDM momenta)
step = forall (gradients :: [Type]) (tensors :: [Type])
       (momenta :: [Type]) (gdmStep :: [Type]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(HZipWith3
   (GDMStep device dtype) tensors gradients momenta gdmStep,
 HMap' AFst gdmStep tensors, HMap' ASnd gdmStep momenta) =>
LearningRate device dtype
-> HList gradients
-> HList tensors
-> GDM momenta
-> (HList tensors, GDM momenta)
gdm

instance Parameterized (GDM momenta) where
  type Parameters (GDM momenta) = momenta
  flattenParameters :: GDM momenta -> HList (Parameters (GDM momenta))
flattenParameters GDM {Float
HList momenta
momenta :: HList momenta
beta :: Float
momenta :: forall (momenta :: [Type]). GDM momenta -> HList momenta
beta :: forall (momenta :: [Type]). GDM momenta -> Float
..} = HList momenta
momenta
  replaceParameters :: GDM momenta -> HList (Parameters (GDM momenta)) -> GDM momenta
replaceParameters GDM momenta
gdm HList (Parameters (GDM momenta))
momenta = GDM momenta
gdm {momenta :: HList momenta
momenta = HList (Parameters (GDM momenta))
momenta}

--
-- Adam
-- https://arxiv.org/pdf/1412.6980.pdf
--

type AdamIter = Tensor '( 'D.CPU, 0) 'D.Int64 '[]

-- | State representation for Adam Optimizer
data Adam (momenta :: [Type]) = Adam
  { forall (momenta :: [Type]). Adam momenta -> AdamIter
iter :: AdamIter, -- iteration
    forall (momenta :: [Type]). Adam momenta -> Float
beta1 :: Float, -- 1st moment forgetting factor
    forall (momenta :: [Type]). Adam momenta -> Float
beta2 :: Float, -- 2nd moment forgetting factor
    forall (momenta :: [Type]). Adam momenta -> HList momenta
momenta1 :: HList momenta, -- 1st momenta
    forall (momenta :: [Type]). Adam momenta -> HList momenta
momenta2 :: HList momenta -- 2nd momenta
  }

mkAdam ::
  forall parameters momenta.
  (HMap' ZerosLike parameters momenta) =>
  AdamIter ->
  Float ->
  Float ->
  HList parameters ->
  Adam momenta
mkAdam :: forall (parameters :: [Type]) (momenta :: [Type]).
HMap' ZerosLike parameters momenta =>
AdamIter -> Float -> Float -> HList parameters -> Adam momenta
mkAdam AdamIter
iter Float
beta1 Float
beta2 HList parameters
parameters =
  forall (momenta :: [Type]).
AdamIter
-> Float -> Float -> HList momenta -> HList momenta -> Adam momenta
Adam
    AdamIter
iter
    Float
beta1
    Float
beta2
    (forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
hmap' ZerosLike
ZerosLike HList parameters
parameters)
    (forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
hmap' ZerosLike
ZerosLike HList parameters
parameters)

newtype AdamMomentum1Update = AdamMomentum1Update Float

-- | decaying average of the first momenta
instance
  ( gradient ~ Tensor device dtype shape,
    momentum1 ~ Tensor device dtype shape,
    KnownDevice device
  ) =>
  Apply' AdamMomentum1Update (momentum1, gradient) momentum1
  where
  apply' :: AdamMomentum1Update -> (momentum1, gradient) -> momentum1
apply' (AdamMomentum1Update Float
beta1) (momentum1
momentum1, gradient
gradient) =
    forall a (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
mulScalar Float
beta1 momentum1
momentum1 forall a. Num a => a -> a -> a
+ forall a (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
mulScalar (Float
1 forall a. Num a => a -> a -> a
- Float
beta1) gradient
gradient

newtype AdamMomentum2Update = AdamMomentum2Update Float

-- | decaying average of the second momenta
instance
  ( gradient ~ Tensor device dtype shape,
    momentum2 ~ Tensor device dtype shape,
    shape ~ Broadcast shape shape,
    KnownDevice device,
    BasicArithmeticDTypeIsValid device dtype
  ) =>
  Apply' AdamMomentum2Update (momentum2, gradient) momentum2
  where
  apply' :: AdamMomentum2Update -> (momentum2, gradient) -> momentum2
apply' (AdamMomentum2Update Float
beta2) (momentum2
momentum2, gradient
gradient) =
    forall a (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
mulScalar Float
beta2 momentum2
momentum2 forall a. Num a => a -> a -> a
+ forall a (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
mulScalar (Float
1 forall a. Num a => a -> a -> a
- Float
beta2) (forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
       (device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
 shape'' ~ Broadcast shape shape',
 BasicArithmeticDTypeIsValid device dtype,
 BasicArithmeticDTypeIsValid device dtype',
 BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
mul gradient
gradient gradient
gradient)

data AdamBiasAdjustment = AdamBiasAdjustment AdamIter Float

-- | bias adjustment
instance
  ( momentum ~ Tensor device dtype shape,
    KnownDevice device,
    KnownDType dtype,
    shape ~ Reverse (Reverse shape),
    BasicArithmeticDTypeIsValid device dtype
  ) =>
  Apply' AdamBiasAdjustment momentum momentum
  where
  apply' :: AdamBiasAdjustment -> momentum -> momentum
apply' (AdamBiasAdjustment AdamIter
iter Float
beta) momentum
momentum =
    let iter' :: Tensor device dtype '[]
iter' = forall (device' :: (DeviceType, Nat)) (device :: (DeviceType, Nat))
       (dtype :: DType) (shape :: [Nat]) t t'.
(KnownDevice device', IsUnnamed t device dtype shape, Unnamed t',
 t' ~ ReplaceDevice'' t device') =>
t -> t'
toDevice @device @'( 'D.CPU, 0) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (dtype' :: DType) (dtype :: DType)
       (device :: (DeviceType, Nat)) (shape :: [Nat]) t t'.
(KnownDType dtype', IsUnnamed t device dtype shape, Unnamed t',
 t' ~ ReplaceDType'' t dtype') =>
t -> t'
toDType @dtype @'D.Int64 forall a b. (a -> b) -> a -> b
$ AdamIter
iter forall a. Num a => a -> a -> a
+ AdamIter
1
        beta' :: Tensor device dtype '[]
beta' = forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)) a.
(TensorOptions shape dtype device, Scalar a) =>
a -> Tensor device dtype shape
full @'[] @dtype @device Float
beta
     in momentum
momentum forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
       (device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
 shape'' ~ Broadcast shape shape',
 BasicArithmeticDTypeIsValid device dtype,
 BasicArithmeticDTypeIsValid device dtype',
 BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
`div` (Tensor device dtype '[]
1 forall a. Num a => a -> a -> a
- forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(BasicArithmeticDTypeIsValid device dtype,
 shape'' ~ Broadcast shape shape') =>
Tensor device dtype shape
-> Tensor device dtype shape' -> Tensor device dtype shape''
pow Tensor device dtype '[]
iter' Tensor device dtype '[]
beta')

data AdamParameterUpdate device dtype = AdamParameterUpdate Float (LearningRate device dtype)

-- | parameter update
instance
  ( parameter ~ Tensor device dtype shape,
    momentum ~ Tensor device dtype shape,
    shape ~ Broadcast '[] shape,
    KnownDevice device,
    BasicArithmeticDTypeIsValid device dtype,
    StandardFloatingPointDTypeValidation device dtype
  ) =>
  Apply'
    (AdamParameterUpdate device dtype)
    (parameter, momentum, momentum)
    parameter
  where
  apply' :: AdamParameterUpdate device dtype
-> (parameter, momentum, momentum) -> parameter
apply'
    (AdamParameterUpdate Float
eps LearningRate device dtype
learningRate)
    (parameter
parameter, momentum
biasAdjustedMomentum1, momentum
biasAdjustedMomentum2) =
      parameter
parameter forall a. Num a => a -> a -> a
- forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
       (device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
 shape'' ~ Broadcast shape shape',
 BasicArithmeticDTypeIsValid device dtype,
 BasicArithmeticDTypeIsValid device dtype',
 BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
mul LearningRate device dtype
learningRate momentum
biasAdjustedMomentum1
        forall a. Fractional a => a -> a -> a
/ forall a (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
addScalar Float
eps (forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)) t.
(StandardFloatingPointDTypeValidation device dtype,
 IsUnnamed t device dtype shape) =>
t -> t
sqrt momentum
biasAdjustedMomentum2)

-- | Adam step
adam ::
  forall gradients tensors momenta adamStep dtype device.
  ( HZipWith AdamMomentum1Update momenta gradients momenta,
    HZipWith AdamMomentum2Update momenta gradients momenta,
    HMap' AdamBiasAdjustment momenta momenta,
    HZipWith3 (AdamParameterUpdate device dtype) tensors momenta momenta tensors
  ) =>
  -- | learning rate
  LearningRate device dtype ->
  -- | model parameter gradient tensors
  HList gradients ->
  -- | model parameter tensors
  HList tensors ->
  -- | adam parameters - beta1, beta2, momenta1, momenta2, iteration
  Adam momenta ->
  -- | returns new parameters + updated adam parameters
  (HList tensors, Adam momenta)
adam :: forall {k} (gradients :: [Type]) (tensors :: [Type])
       (momenta :: [Type]) (adamStep :: k) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(HZipWith AdamMomentum1Update momenta gradients momenta,
 HZipWith AdamMomentum2Update momenta gradients momenta,
 HMap' AdamBiasAdjustment momenta momenta,
 HZipWith3
   (AdamParameterUpdate device dtype)
   tensors
   momenta
   momenta
   tensors) =>
LearningRate device dtype
-> HList gradients
-> HList tensors
-> Adam momenta
-> (HList tensors, Adam momenta)
adam LearningRate device dtype
learningRate HList gradients
gradients HList tensors
parameters Adam {Float
HList momenta
AdamIter
momenta2 :: HList momenta
momenta1 :: HList momenta
beta2 :: Float
beta1 :: Float
iter :: AdamIter
momenta2 :: forall (momenta :: [Type]). Adam momenta -> HList momenta
momenta1 :: forall (momenta :: [Type]). Adam momenta -> HList momenta
beta2 :: forall (momenta :: [Type]). Adam momenta -> Float
beta1 :: forall (momenta :: [Type]). Adam momenta -> Float
iter :: forall (momenta :: [Type]). Adam momenta -> AdamIter
..} =
  (HList tensors
parameters', forall (momenta :: [Type]).
AdamIter
-> Float -> Float -> HList momenta -> HList momenta -> Adam momenta
Adam (AdamIter
iter forall a. Num a => a -> a -> a
+ AdamIter
1) Float
beta1 Float
beta2 HList momenta
momenta1' HList momenta
momenta2')
  where
    momenta1' :: HList momenta
momenta1' = forall k f (xs :: [k]) (ys :: [k]) (zs :: [k]).
HZipWith f xs ys zs =>
f -> HList xs -> HList ys -> HList zs
hzipWith (Float -> AdamMomentum1Update
AdamMomentum1Update Float
beta1) HList momenta
momenta1 HList gradients
gradients
    momenta2' :: HList momenta
momenta2' = forall k f (xs :: [k]) (ys :: [k]) (zs :: [k]).
HZipWith f xs ys zs =>
f -> HList xs -> HList ys -> HList zs
hzipWith (Float -> AdamMomentum2Update
AdamMomentum2Update Float
beta2) HList momenta
momenta2 HList gradients
gradients
    biasAdjustedMomenta1 :: HList momenta
biasAdjustedMomenta1 = forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
hmap' (AdamIter -> Float -> AdamBiasAdjustment
AdamBiasAdjustment AdamIter
iter Float
beta1) HList momenta
momenta1'
    biasAdjustedMomenta2 :: HList momenta
biasAdjustedMomenta2 = forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
hmap' (AdamIter -> Float -> AdamBiasAdjustment
AdamBiasAdjustment AdamIter
iter Float
beta2) HList momenta
momenta2'
    parameters' :: HList tensors
parameters' =
      forall k f (as :: [k]) (bs :: [k]) (cs :: [k]) (ds :: [k]).
HZipWith3 f as bs cs ds =>
f -> HList as -> HList bs -> HList cs -> HList ds
hzipWith3
        (forall (device :: (DeviceType, Nat)) (dtype :: DType).
Float
-> LearningRate device dtype -> AdamParameterUpdate device dtype
AdamParameterUpdate Float
1e-37 LearningRate device dtype
learningRate)
        HList tensors
parameters
        HList momenta
biasAdjustedMomenta1
        HList momenta
biasAdjustedMomenta2

instance
  ( HZipWith AdamMomentum1Update momenta gradients momenta,
    HZipWith AdamMomentum2Update momenta gradients momenta,
    HMap' AdamBiasAdjustment momenta momenta,
    HZipWith3 (AdamParameterUpdate device dtype) tensors momenta momenta tensors
  ) =>
  Optimizer (Adam momenta) gradients tensors dtype device
  where
  step :: LearningRate device dtype
-> HList gradients
-> HList tensors
-> Adam momenta
-> (HList tensors, Adam momenta)
step = forall {k} (gradients :: [Type]) (tensors :: [Type])
       (momenta :: [Type]) (adamStep :: k) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(HZipWith AdamMomentum1Update momenta gradients momenta,
 HZipWith AdamMomentum2Update momenta gradients momenta,
 HMap' AdamBiasAdjustment momenta momenta,
 HZipWith3
   (AdamParameterUpdate device dtype)
   tensors
   momenta
   momenta
   tensors) =>
LearningRate device dtype
-> HList gradients
-> HList tensors
-> Adam momenta
-> (HList tensors, Adam momenta)
adam

instance
  HAppendFD momenta momenta (momenta ++ momenta) =>
  Parameterized (Adam momenta)
  where
  type Parameters (Adam momenta) = AdamIter ': (momenta ++ momenta)
  flattenParameters :: Adam momenta -> HList (Parameters (Adam momenta))
flattenParameters Adam {Float
HList momenta
AdamIter
momenta2 :: HList momenta
momenta1 :: HList momenta
beta2 :: Float
beta1 :: Float
iter :: AdamIter
momenta2 :: forall (momenta :: [Type]). Adam momenta -> HList momenta
momenta1 :: forall (momenta :: [Type]). Adam momenta -> HList momenta
beta2 :: forall (momenta :: [Type]). Adam momenta -> Float
beta1 :: forall (momenta :: [Type]). Adam momenta -> Float
iter :: forall (momenta :: [Type]). Adam momenta -> AdamIter
..} = AdamIter
iter forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
:. (HList momenta
momenta1 forall k (a :: [k]) (b :: [k]) (ab :: [k]).
HAppendFD a b ab =>
HList a -> HList b -> HList ab
`happendFD` HList momenta
momenta2)
  replaceParameters :: Adam momenta -> HList (Parameters (Adam momenta)) -> Adam momenta
replaceParameters Adam momenta
adam (AdamIter
iter :. HList (momenta ++ momenta)
momenta) =
    let (HList momenta
momenta1, HList momenta
momenta2) = forall k (a :: [k]) (b :: [k]) (ab :: [k]).
HAppendFD a b ab =>
HList ab -> (HList a, HList b)
hunappendFD HList (momenta ++ momenta)
momenta
     in Adam momenta
adam {iter :: AdamIter
iter = AdamIter
iter, momenta1 :: HList momenta
momenta1 = HList momenta
momenta1, momenta2 :: HList momenta
momenta2 = HList momenta
momenta2}