{-# LANGUAGE RecordWildCards #-}

module Torch.Optim where

import Control.Monad.State
import System.Mem (performGC)
import Torch.Autograd
import Torch.Functional
import Torch.Internal.GC (mallocTrim)
import Torch.NN
import Torch.Tensor
import Torch.TensorFactories
import Prelude hiding (sqrt)

type LearningRate = Tensor

type Loss = Tensor

newtype Gradients = Gradients [Tensor] deriving (Int -> Gradients -> ShowS
[Gradients] -> ShowS
Gradients -> String
(Int -> Gradients -> ShowS)
-> (Gradients -> String)
-> ([Gradients] -> ShowS)
-> Show Gradients
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Gradients] -> ShowS
$cshowList :: [Gradients] -> ShowS
show :: Gradients -> String
$cshow :: Gradients -> String
showsPrec :: Int -> Gradients -> ShowS
$cshowsPrec :: Int -> Gradients -> ShowS
Show)

newtype OptimizerState option = OptimizerState option

grad' :: Loss -> [Parameter] -> Gradients
grad' :: Loss -> [Parameter] -> Gradients
grad' Loss
t [Parameter]
p = [Loss] -> Gradients
Gradients (Loss -> [Parameter] -> [Loss]
grad Loss
t [Parameter]
p)

class Optimizer optimizer where
  step :: LearningRate -> Gradients -> [Tensor] -> optimizer -> ([Tensor], optimizer)

  -- | run a single iteration of an optimizer, returning new parameters and updated optimizer state
  runStep :: (Parameterized model) => model -> optimizer -> Loss -> LearningRate -> IO (model, optimizer)
  runStep model
paramState optimizer
optState Loss
lossValue = model -> optimizer -> Gradients -> Loss -> IO (model, optimizer)
forall optimizer model.
(Optimizer optimizer, Parameterized model) =>
model -> optimizer -> Gradients -> Loss -> IO (model, optimizer)
runStep' model
paramState optimizer
optState (Loss -> [Parameter] -> Gradients
grad' Loss
lossValue ([Parameter] -> Gradients) -> [Parameter] -> Gradients
forall a b. (a -> b) -> a -> b
$ model -> [Parameter]
forall f. Parameterized f => f -> [Parameter]
flattenParameters model
paramState)

  -- | run a single iteration of an optimizer, returning new parameters and updated optimizer state
  runStep' :: (Parameterized model) => model -> optimizer -> Gradients -> LearningRate -> IO (model, optimizer)
  runStep' model
paramState optimizer
optState Gradients
gradients Loss
lr = do
    IO ()
performGC
    CInt -> IO ()
mallocTrim CInt
0
    let ([Loss]
flatParameters', optimizer
optState') = Loss -> Gradients -> [Loss] -> optimizer -> ([Loss], optimizer)
forall optimizer.
Optimizer optimizer =>
Loss -> Gradients -> [Loss] -> optimizer -> ([Loss], optimizer)
step Loss
lr Gradients
gradients [Loss]
depParameters optimizer
optState
    [Parameter]
newFlatParam <- (Loss -> IO Parameter) -> [Loss] -> IO [Parameter]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Loss -> IO Parameter
makeIndependent [Loss]
flatParameters'
    (model, optimizer) -> IO (model, optimizer)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (model -> [Parameter] -> model
forall f. Parameterized f => f -> [Parameter] -> f
replaceParameters model
paramState [Parameter]
newFlatParam, optimizer
optState')
    where
      flatParameters :: [Parameter]
flatParameters = model -> [Parameter]
forall f. Parameterized f => f -> [Parameter]
flattenParameters model
paramState
      depParameters :: [Loss]
depParameters = (Parameter -> Loss) -> [Parameter] -> [Loss]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Parameter -> Loss
toDependent [Parameter]
flatParameters

--
-- Gradient Descent
--

data GD = GD deriving (Int -> GD -> ShowS
[GD] -> ShowS
GD -> String
(Int -> GD -> ShowS)
-> (GD -> String) -> ([GD] -> ShowS) -> Show GD
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [GD] -> ShowS
$cshowList :: [GD] -> ShowS
show :: GD -> String
$cshow :: GD -> String
showsPrec :: Int -> GD -> ShowS
$cshowsPrec :: Int -> GD -> ShowS
Show)

-- | Stateless gradient descent step
gd :: LearningRate -> Gradients -> [Tensor] -> [Tensor]
gd :: Loss -> Gradients -> [Loss] -> [Loss]
gd Loss
lr (Gradients [Loss]
gradients) [Loss]
parameters = (Loss -> Loss -> Loss) -> [Loss] -> [Loss] -> [Loss]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Loss -> Loss -> Loss
step [Loss]
parameters [Loss]
gradients
  where
    step :: Loss -> Loss -> Loss
step Loss
p Loss
dp = Loss
p Loss -> Loss -> Loss
forall a. Num a => a -> a -> a
- (Loss
lr Loss -> Loss -> Loss
forall a. Num a => a -> a -> a
* Loss
dp)

-- | Gradient descent step with a dummy state variable
gd' :: LearningRate -> Gradients -> [Tensor] -> GD -> ([Tensor], GD)
gd' :: Loss -> Gradients -> [Loss] -> GD -> ([Loss], GD)
gd' Loss
lr Gradients
gradients [Loss]
depParameters GD
dummy = (Loss -> Gradients -> [Loss] -> [Loss]
gd Loss
lr Gradients
gradients [Loss]
depParameters, GD
dummy)

instance Optimizer GD where
  step :: Loss -> Gradients -> [Loss] -> GD -> ([Loss], GD)
step = Loss -> Gradients -> [Loss] -> GD -> ([Loss], GD)
gd'

sgd :: LearningRate -> [Parameter] -> [Tensor] -> [Tensor]
sgd :: Loss -> [Parameter] -> [Loss] -> [Loss]
sgd Loss
lr [Parameter]
parameters = (Loss -> Loss -> Loss) -> [Loss] -> [Loss] -> [Loss]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Loss -> Loss -> Loss
step [Loss]
depParameters
  where
    step :: Loss -> Loss -> Loss
step Loss
p Loss
dp = Loss
p Loss -> Loss -> Loss
forall a. Num a => a -> a -> a
- (Loss
lr Loss -> Loss -> Loss
forall a. Num a => a -> a -> a
* Loss
dp)
    depParameters :: [Loss]
depParameters = (Parameter -> Loss) -> [Parameter] -> [Loss]
forall a b. (a -> b) -> [a] -> [b]
map Parameter -> Loss
toDependent [Parameter]
parameters

--
-- Gradient Descent with Momentum
--

data GDM = GDM {GDM -> Float
beta :: Float, GDM -> [Loss]
momentum :: [Tensor]} deriving (Int -> GDM -> ShowS
[GDM] -> ShowS
GDM -> String
(Int -> GDM -> ShowS)
-> (GDM -> String) -> ([GDM] -> ShowS) -> Show GDM
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [GDM] -> ShowS
$cshowList :: [GDM] -> ShowS
show :: GDM -> String
$cshow :: GDM -> String
showsPrec :: Int -> GDM -> ShowS
$cshowsPrec :: Int -> GDM -> ShowS
Show)

-- gradient descent with momentum step
gdm ::
  -- | learning rate
  LearningRate ->
  -- | model parameter gradients
  Gradients ->
  -- | model parameters
  [Tensor] ->
  -- | beta & momentum
  GDM ->
  -- | returns new parameters + updated momentum
  ([Tensor], GDM)
gdm :: Loss -> Gradients -> [Loss] -> GDM -> ([Loss], GDM)
gdm Loss
lr (Gradients [Loss]
gradients) [Loss]
parameters (GDM Float
beta [Loss]
momentum) =
  (((Loss, Loss) -> Loss) -> [(Loss, Loss)] -> [Loss]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Loss, Loss) -> Loss
forall a b. (a, b) -> a
fst [(Loss, Loss)]
runStep, Float -> [Loss] -> GDM
GDM Float
beta (((Loss, Loss) -> Loss) -> [(Loss, Loss)] -> [Loss]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Loss, Loss) -> Loss
forall a b. (a, b) -> b
snd [(Loss, Loss)]
runStep))
  where
    step :: Loss -> Loss -> Loss -> (Loss, Loss)
step Loss
p Loss
dp Loss
z = let z' :: Loss
z' = Float -> Loss -> Loss
forall a. Scalar a => a -> Loss -> Loss
mulScalar Float
beta Loss
z Loss -> Loss -> Loss
forall a. Num a => a -> a -> a
+ Loss
dp in (Loss
p Loss -> Loss -> Loss
forall a. Num a => a -> a -> a
- Loss
lr Loss -> Loss -> Loss
forall a. Num a => a -> a -> a
* Loss
z', Loss
z')
    runStep :: [(Loss, Loss)]
runStep = (Loss -> Loss -> Loss -> (Loss, Loss))
-> [Loss] -> [Loss] -> [Loss] -> [(Loss, Loss)]
forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 Loss -> Loss -> Loss -> (Loss, Loss)
step [Loss]
parameters [Loss]
gradients [Loss]
momentum

instance Optimizer GDM where
  step :: Loss -> Gradients -> [Loss] -> GDM -> ([Loss], GDM)
step = Loss -> Gradients -> [Loss] -> GDM -> ([Loss], GDM)
gdm

--
-- Adam
--

-- | State representation for Adam Optimizer
data Adam = Adam
  { Adam -> Float
beta1 :: Float, -- 1st moment forgetting factor
    Adam -> Float
beta2 :: Float, -- 2nd moment forgetting factor
    Adam -> [Loss]
m1 :: [Tensor], -- 1st moment
    Adam -> [Loss]
m2 :: [Tensor], -- 2nd moment
    Adam -> Int
iter :: Int -- iteration
  }
  deriving (Int -> Adam -> ShowS
[Adam] -> ShowS
Adam -> String
(Int -> Adam -> ShowS)
-> (Adam -> String) -> ([Adam] -> ShowS) -> Show Adam
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Adam] -> ShowS
$cshowList :: [Adam] -> ShowS
show :: Adam -> String
$cshow :: Adam -> String
showsPrec :: Int -> Adam -> ShowS
$cshowsPrec :: Int -> Adam -> ShowS
Show)

mkAdam ::
  Int ->
  Float ->
  Float ->
  [Parameter] ->
  Adam
mkAdam :: Int -> Float -> Float -> [Parameter] -> Adam
mkAdam Int
iter Float
beta1 Float
beta2 [Parameter]
parameters =
  Float -> Float -> [Loss] -> [Loss] -> Int -> Adam
Adam
    Float
beta1
    Float
beta2
    (Parameter -> Loss
initZeros (Parameter -> Loss) -> [Parameter] -> [Loss]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Parameter]
parameters)
    (Parameter -> Loss
initZeros (Parameter -> Loss) -> [Parameter] -> [Loss]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Parameter]
parameters)
    Int
iter
  where
    initZeros :: Parameter -> Loss
initZeros = Loss -> Loss
zerosLike (Loss -> Loss) -> (Parameter -> Loss) -> Parameter -> Loss
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Parameter -> Loss
toDependent

-- | Adam step
adam ::
  -- | learning rate
  LearningRate ->
  -- | model parameter gradients
  Gradients ->
  -- | model parameters
  [Tensor] ->
  -- | adam parameters - beta1, beta2, moments, iteration
  Adam ->
  -- | returns new parameters + updated adam parameters
  ([Tensor], Adam)
adam :: Loss -> Gradients -> [Loss] -> Adam -> ([Loss], Adam)
adam Loss
lr (Gradients [Loss]
gradients) [Loss]
parameters Adam {Float
Int
[Loss]
iter :: Int
m2 :: [Loss]
m1 :: [Loss]
beta2 :: Float
beta1 :: Float
iter :: Adam -> Int
m2 :: Adam -> [Loss]
m1 :: Adam -> [Loss]
beta2 :: Adam -> Float
beta1 :: Adam -> Float
..} = ([Loss]
parameters', Float -> Float -> [Loss] -> [Loss] -> Int -> Adam
Adam Float
beta1 Float
beta2 [Loss]
m1' [Loss]
m2' (Int
iter Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1))
  where
    -- decaying averages of 1st & 2nd moments
    f1 :: Loss -> Loss -> Loss
f1 Loss
m1 Loss
dp = Float -> Loss -> Loss
forall a. Scalar a => a -> Loss -> Loss
mulScalar Float
beta1 Loss
m1 Loss -> Loss -> Loss
forall a. Num a => a -> a -> a
+ Float -> Loss -> Loss
forall a. Scalar a => a -> Loss -> Loss
mulScalar (Float
1 Float -> Float -> Float
forall a. Num a => a -> a -> a
- Float
beta1) Loss
dp
    f2 :: Loss -> Loss -> Loss
f2 Loss
m2 Loss
dp = Float -> Loss -> Loss
forall a. Scalar a => a -> Loss -> Loss
mulScalar Float
beta2 Loss
m2 Loss -> Loss -> Loss
forall a. Num a => a -> a -> a
+ Float -> Loss -> Loss
forall a. Scalar a => a -> Loss -> Loss
mulScalar (Float
1 Float -> Float -> Float
forall a. Num a => a -> a -> a
- Float
beta2) (Loss
dp Loss -> Loss -> Loss
forall a. Num a => a -> a -> a
* Loss
dp)
    m1' :: [Loss]
m1' = (Loss -> Loss -> Loss) -> [Loss] -> [Loss] -> [Loss]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Loss -> Loss -> Loss
f1 [Loss]
m1 [Loss]
gradients
    m2' :: [Loss]
m2' = (Loss -> Loss -> Loss) -> [Loss] -> [Loss] -> [Loss]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Loss -> Loss -> Loss
f2 [Loss]
m2 [Loss]
gradients
    -- bias adjustment
    a :: a -> Loss -> Loss
a a
beta = a -> Loss -> Loss
forall a. Scalar a => a -> Loss -> Loss
divScalar (a
1 a -> a -> a
forall a. Num a => a -> a -> a
- a
beta a -> Int -> a
forall a b. (Num a, Integral b) => a -> b -> a
^ (Int
iter Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1))
    a1 :: [Loss]
a1 = (Loss -> Loss) -> [Loss] -> [Loss]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Float -> Loss -> Loss
forall a. (Scalar a, Num a) => a -> Loss -> Loss
a Float
beta1) [Loss]
m1'
    a2 :: [Loss]
a2 = (Loss -> Loss) -> [Loss] -> [Loss]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Float -> Loss -> Loss
forall a. (Scalar a, Num a) => a -> Loss -> Loss
a Float
beta2) [Loss]
m2'
    -- parameter update
    eps :: Loss
eps = Loss
1e-37
    update :: Loss -> Loss -> Loss -> Loss
update Loss
prevParam Loss
a1' Loss
a2' = Loss
prevParam Loss -> Loss -> Loss
forall a. Num a => a -> a -> a
- Loss
lr Loss -> Loss -> Loss
forall a. Num a => a -> a -> a
* Loss
a1' Loss -> Loss -> Loss
forall a. Fractional a => a -> a -> a
/ (Loss -> Loss
sqrt Loss
a2' Loss -> Loss -> Loss
forall a. Num a => a -> a -> a
+ Loss
eps)
    parameters' :: [Loss]
parameters' = (Loss -> Loss -> Loss -> Loss)
-> [Loss] -> [Loss] -> [Loss] -> [Loss]
forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 Loss -> Loss -> Loss -> Loss
update [Loss]
parameters [Loss]
a1 [Loss]
a2

instance Optimizer Adam where
  step :: Loss -> Gradients -> [Loss] -> Adam -> ([Loss], Adam)
step = Loss -> Gradients -> [Loss] -> Adam -> ([Loss], Adam)
adam

--
-- Adagrad
--

-- | State representation for Adagrad Optimizer
data Adagrad = Adagrad {Adagrad -> [Loss]
gsum :: [Tensor]} -- sum of squared gradients
  deriving (Int -> Adagrad -> ShowS
[Adagrad] -> ShowS
Adagrad -> String
(Int -> Adagrad -> ShowS)
-> (Adagrad -> String) -> ([Adagrad] -> ShowS) -> Show Adagrad
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Adagrad] -> ShowS
$cshowList :: [Adagrad] -> ShowS
show :: Adagrad -> String
$cshow :: Adagrad -> String
showsPrec :: Int -> Adagrad -> ShowS
$cshowsPrec :: Int -> Adagrad -> ShowS
Show)

-- | Adagrad step
adagrad ::
  -- | learning rate
  LearningRate ->
  -- | model parameter gradients
  Gradients ->
  -- | model parameters
  [Tensor] ->
  -- | adagrad parameters - gsum, iteration
  Adagrad ->
  -- | returns new parameters + updated adam parameters
  ([Tensor], Adagrad)
adagrad :: Loss -> Gradients -> [Loss] -> Adagrad -> ([Loss], Adagrad)
adagrad Loss
lr (Gradients [Loss]
gradients) [Loss]
parameters Adagrad {[Loss]
gsum :: [Loss]
gsum :: Adagrad -> [Loss]
..} = ([Loss]
parameters', [Loss] -> Adagrad
Adagrad [Loss]
gsum')
  where
    -- add gradient squared to running total
    f :: a -> a -> a
f a
gsum a
dp = a
gsum a -> a -> a
forall a. Num a => a -> a -> a
+ a
dp a -> a -> a
forall a. Num a => a -> a -> a
* a
dp
    gsum' :: [Loss]
gsum' = (Loss -> Loss -> Loss) -> [Loss] -> [Loss] -> [Loss]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Loss -> Loss -> Loss
forall a. Num a => a -> a -> a
f [Loss]
gsum [Loss]
gradients

    -- parameter update
    eps :: Loss
eps = Loss
1e-37
    update :: Loss -> Loss -> Loss -> Loss
update Loss
prevParam Loss
a1' Loss
a2' = Loss
prevParam Loss -> Loss -> Loss
forall a. Num a => a -> a -> a
- Loss
lr Loss -> Loss -> Loss
forall a. Num a => a -> a -> a
* Loss
a1' Loss -> Loss -> Loss
forall a. Fractional a => a -> a -> a
/ (Loss -> Loss
sqrt (Loss
a2' Loss -> Loss -> Loss
forall a. Num a => a -> a -> a
+ Loss
eps))
    parameters' :: [Loss]
parameters' = (Loss -> Loss -> Loss -> Loss)
-> [Loss] -> [Loss] -> [Loss] -> [Loss]
forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 Loss -> Loss -> Loss -> Loss
update [Loss]
parameters [Loss]
gradients [Loss]
gsum'

instance Optimizer Adagrad where
  step :: Loss -> Gradients -> [Loss] -> Adagrad -> ([Loss], Adagrad)
step = Loss -> Gradients -> [Loss] -> Adagrad -> ([Loss], Adagrad)
adagrad

-- | syntactic sugar for looping with foldM
foldLoop :: a -> Int -> (a -> Int -> IO a) -> IO a
foldLoop :: a -> Int -> (a -> Int -> IO a) -> IO a
foldLoop a
x Int
count a -> Int -> IO a
block = (a -> Int -> IO a) -> a -> [Int] -> IO a
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM a -> Int -> IO a
block a
x [Int
1 .. Int
count]