{-# LANGUAGE RecordWildCards #-}

module Torch.Optim where

import Control.Monad.State
import System.Mem (performGC)
import Torch.Autograd
import Torch.Functional
import Torch.Internal.GC (mallocTrim)
import Torch.NN
import Torch.Tensor
import Torch.TensorFactories
import Prelude hiding (sqrt)

type LearningRate = Tensor

type Loss = Tensor

newtype Gradients = Gradients [Tensor] deriving (Int -> Gradients -> ShowS
[Gradients] -> ShowS
Gradients -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Gradients] -> ShowS
$cshowList :: [Gradients] -> ShowS
show :: Gradients -> String
$cshow :: Gradients -> String
showsPrec :: Int -> Gradients -> ShowS
$cshowsPrec :: Int -> Gradients -> ShowS
Show)

newtype OptimizerState option = OptimizerState option

grad' :: Loss -> [Parameter] -> Gradients
grad' :: Tensor -> [Parameter] -> Gradients
grad' Tensor
t [Parameter]
p = [Tensor] -> Gradients
Gradients (Tensor -> [Parameter] -> [Tensor]
grad Tensor
t [Parameter]
p)

class Optimizer optimizer where
  step :: LearningRate -> Gradients -> [Tensor] -> optimizer -> ([Tensor], optimizer)

  -- | run a single iteration of an optimizer, returning new parameters and updated optimizer state
  runStep :: (Parameterized model) => model -> optimizer -> Loss -> LearningRate -> IO (model, optimizer)
  runStep model
paramState optimizer
optState Tensor
lossValue = forall optimizer model.
(Optimizer optimizer, Parameterized model) =>
model -> optimizer -> Gradients -> Tensor -> IO (model, optimizer)
runStep' model
paramState optimizer
optState (Tensor -> [Parameter] -> Gradients
grad' Tensor
lossValue forall a b. (a -> b) -> a -> b
$ forall f. Parameterized f => f -> [Parameter]
flattenParameters model
paramState)

  -- | run a single iteration of an optimizer, returning new parameters and updated optimizer state
  runStep' :: (Parameterized model) => model -> optimizer -> Gradients -> LearningRate -> IO (model, optimizer)
  runStep' model
paramState optimizer
optState Gradients
gradients Tensor
lr = do
    IO ()
performGC
    CInt -> IO ()
mallocTrim CInt
0
    let ([Tensor]
flatParameters', optimizer
optState') = forall optimizer.
Optimizer optimizer =>
Tensor
-> Gradients -> [Tensor] -> optimizer -> ([Tensor], optimizer)
step Tensor
lr Gradients
gradients [Tensor]
depParameters optimizer
optState
    [Parameter]
newFlatParam <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Tensor -> IO Parameter
makeIndependent [Tensor]
flatParameters'
    forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall f. Parameterized f => f -> [Parameter] -> f
replaceParameters model
paramState [Parameter]
newFlatParam, optimizer
optState')
    where
      flatParameters :: [Parameter]
flatParameters = forall f. Parameterized f => f -> [Parameter]
flattenParameters model
paramState
      depParameters :: [Tensor]
depParameters = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Parameter -> Tensor
toDependent [Parameter]
flatParameters

--
-- Gradient Descent
--

data GD = GD deriving (Int -> GD -> ShowS
[GD] -> ShowS
GD -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [GD] -> ShowS
$cshowList :: [GD] -> ShowS
show :: GD -> String
$cshow :: GD -> String
showsPrec :: Int -> GD -> ShowS
$cshowsPrec :: Int -> GD -> ShowS
Show)

-- | Stateless gradient descent step
gd :: LearningRate -> Gradients -> [Tensor] -> [Tensor]
gd :: Tensor -> Gradients -> [Tensor] -> [Tensor]
gd Tensor
lr (Gradients [Tensor]
gradients) [Tensor]
parameters = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Tensor -> Tensor -> Tensor
step [Tensor]
parameters [Tensor]
gradients
  where
    step :: Tensor -> Tensor -> Tensor
step Tensor
p Tensor
dp = Tensor
p forall a. Num a => a -> a -> a
- (Tensor
lr forall a. Num a => a -> a -> a
* Tensor
dp)

-- | Gradient descent step with a dummy state variable
gd' :: LearningRate -> Gradients -> [Tensor] -> GD -> ([Tensor], GD)
gd' :: Tensor -> Gradients -> [Tensor] -> GD -> ([Tensor], GD)
gd' Tensor
lr Gradients
gradients [Tensor]
depParameters GD
dummy = (Tensor -> Gradients -> [Tensor] -> [Tensor]
gd Tensor
lr Gradients
gradients [Tensor]
depParameters, GD
dummy)

instance Optimizer GD where
  step :: Tensor -> Gradients -> [Tensor] -> GD -> ([Tensor], GD)
step = Tensor -> Gradients -> [Tensor] -> GD -> ([Tensor], GD)
gd'

sgd :: LearningRate -> [Parameter] -> [Tensor] -> [Tensor]
sgd :: Tensor -> [Parameter] -> [Tensor] -> [Tensor]
sgd Tensor
lr [Parameter]
parameters = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Tensor -> Tensor -> Tensor
step [Tensor]
depParameters
  where
    step :: Tensor -> Tensor -> Tensor
step Tensor
p Tensor
dp = Tensor
p forall a. Num a => a -> a -> a
- (Tensor
lr forall a. Num a => a -> a -> a
* Tensor
dp)
    depParameters :: [Tensor]
depParameters = forall a b. (a -> b) -> [a] -> [b]
map Parameter -> Tensor
toDependent [Parameter]
parameters

--
-- Gradient Descent with Momentum
--

data GDM = GDM {GDM -> Float
beta :: Float, GDM -> [Tensor]
momentum :: [Tensor]} deriving (Int -> GDM -> ShowS
[GDM] -> ShowS
GDM -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [GDM] -> ShowS
$cshowList :: [GDM] -> ShowS
show :: GDM -> String
$cshow :: GDM -> String
showsPrec :: Int -> GDM -> ShowS
$cshowsPrec :: Int -> GDM -> ShowS
Show)

-- gradient descent with momentum step
gdm ::
  -- | learning rate
  LearningRate ->
  -- | model parameter gradients
  Gradients ->
  -- | model parameters
  [Tensor] ->
  -- | beta & momentum
  GDM ->
  -- | returns new parameters + updated momentum
  ([Tensor], GDM)
gdm :: Tensor -> Gradients -> [Tensor] -> GDM -> ([Tensor], GDM)
gdm Tensor
lr (Gradients [Tensor]
gradients) [Tensor]
parameters (GDM Float
beta [Tensor]
momentum) =
  (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a, b) -> a
fst [(Tensor, Tensor)]
runStep, Float -> [Tensor] -> GDM
GDM Float
beta (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a, b) -> b
snd [(Tensor, Tensor)]
runStep))
  where
    step :: Tensor -> Tensor -> Tensor -> (Tensor, Tensor)
step Tensor
p Tensor
dp Tensor
z = let z' :: Tensor
z' = forall a. Scalar a => a -> Tensor -> Tensor
mulScalar Float
beta Tensor
z forall a. Num a => a -> a -> a
+ Tensor
dp in (Tensor
p forall a. Num a => a -> a -> a
- Tensor
lr forall a. Num a => a -> a -> a
* Tensor
z', Tensor
z')
    runStep :: [(Tensor, Tensor)]
runStep = forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 Tensor -> Tensor -> Tensor -> (Tensor, Tensor)
step [Tensor]
parameters [Tensor]
gradients [Tensor]
momentum

instance Optimizer GDM where
  step :: Tensor -> Gradients -> [Tensor] -> GDM -> ([Tensor], GDM)
step = Tensor -> Gradients -> [Tensor] -> GDM -> ([Tensor], GDM)
gdm

--
-- Adam
--

-- | State representation for Adam Optimizer
data Adam = Adam
  { Adam -> Float
beta1 :: Float, -- 1st moment forgetting factor
    Adam -> Float
beta2 :: Float, -- 2nd moment forgetting factor
    Adam -> [Tensor]
m1 :: [Tensor], -- 1st moment
    Adam -> [Tensor]
m2 :: [Tensor], -- 2nd moment
    Adam -> Int
iter :: Int -- iteration
  }
  deriving (Int -> Adam -> ShowS
[Adam] -> ShowS
Adam -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Adam] -> ShowS
$cshowList :: [Adam] -> ShowS
show :: Adam -> String
$cshow :: Adam -> String
showsPrec :: Int -> Adam -> ShowS
$cshowsPrec :: Int -> Adam -> ShowS
Show)

mkAdam ::
  Int ->
  Float ->
  Float ->
  [Parameter] ->
  Adam
mkAdam :: Int -> Float -> Float -> [Parameter] -> Adam
mkAdam Int
iter Float
beta1 Float
beta2 [Parameter]
parameters =
  Float -> Float -> [Tensor] -> [Tensor] -> Int -> Adam
Adam
    Float
beta1
    Float
beta2
    (Parameter -> Tensor
initZeros forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Parameter]
parameters)
    (Parameter -> Tensor
initZeros forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Parameter]
parameters)
    Int
iter
  where
    initZeros :: Parameter -> Tensor
initZeros = Tensor -> Tensor
zerosLike forall b c a. (b -> c) -> (a -> b) -> a -> c
. Parameter -> Tensor
toDependent

-- | Adam step
adam ::
  -- | learning rate
  LearningRate ->
  -- | model parameter gradients
  Gradients ->
  -- | model parameters
  [Tensor] ->
  -- | adam parameters - beta1, beta2, moments, iteration
  Adam ->
  -- | returns new parameters + updated adam parameters
  ([Tensor], Adam)
adam :: Tensor -> Gradients -> [Tensor] -> Adam -> ([Tensor], Adam)
adam Tensor
lr (Gradients [Tensor]
gradients) [Tensor]
parameters Adam {Float
Int
[Tensor]
iter :: Int
m2 :: [Tensor]
m1 :: [Tensor]
beta2 :: Float
beta1 :: Float
iter :: Adam -> Int
m2 :: Adam -> [Tensor]
m1 :: Adam -> [Tensor]
beta2 :: Adam -> Float
beta1 :: Adam -> Float
..} = ([Tensor]
parameters', Float -> Float -> [Tensor] -> [Tensor] -> Int -> Adam
Adam Float
beta1 Float
beta2 [Tensor]
m1' [Tensor]
m2' (Int
iter forall a. Num a => a -> a -> a
+ Int
1))
  where
    -- decaying averages of 1st & 2nd moments
    f1 :: Tensor -> Tensor -> Tensor
f1 Tensor
m1 Tensor
dp = forall a. Scalar a => a -> Tensor -> Tensor
mulScalar Float
beta1 Tensor
m1 forall a. Num a => a -> a -> a
+ forall a. Scalar a => a -> Tensor -> Tensor
mulScalar (Float
1 forall a. Num a => a -> a -> a
- Float
beta1) Tensor
dp
    f2 :: Tensor -> Tensor -> Tensor
f2 Tensor
m2 Tensor
dp = forall a. Scalar a => a -> Tensor -> Tensor
mulScalar Float
beta2 Tensor
m2 forall a. Num a => a -> a -> a
+ forall a. Scalar a => a -> Tensor -> Tensor
mulScalar (Float
1 forall a. Num a => a -> a -> a
- Float
beta2) (Tensor
dp forall a. Num a => a -> a -> a
* Tensor
dp)
    m1' :: [Tensor]
m1' = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Tensor -> Tensor -> Tensor
f1 [Tensor]
m1 [Tensor]
gradients
    m2' :: [Tensor]
m2' = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Tensor -> Tensor -> Tensor
f2 [Tensor]
m2 [Tensor]
gradients
    -- bias adjustment
    a :: a -> Tensor -> Tensor
a a
beta = forall a. Scalar a => a -> Tensor -> Tensor
divScalar (a
1 forall a. Num a => a -> a -> a
- a
beta forall a b. (Num a, Integral b) => a -> b -> a
^ (Int
iter forall a. Num a => a -> a -> a
+ Int
1))
    a1 :: [Tensor]
a1 = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall {a}. (Scalar a, Num a) => a -> Tensor -> Tensor
a Float
beta1) [Tensor]
m1'
    a2 :: [Tensor]
a2 = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall {a}. (Scalar a, Num a) => a -> Tensor -> Tensor
a Float
beta2) [Tensor]
m2'
    -- parameter update
    eps :: Tensor
eps = Tensor
1e-37
    update :: Tensor -> Tensor -> Tensor -> Tensor
update Tensor
prevParam Tensor
a1' Tensor
a2' = Tensor
prevParam forall a. Num a => a -> a -> a
- Tensor
lr forall a. Num a => a -> a -> a
* Tensor
a1' forall a. Fractional a => a -> a -> a
/ (Tensor -> Tensor
sqrt Tensor
a2' forall a. Num a => a -> a -> a
+ Tensor
eps)
    parameters' :: [Tensor]
parameters' = forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 Tensor -> Tensor -> Tensor -> Tensor
update [Tensor]
parameters [Tensor]
a1 [Tensor]
a2

instance Optimizer Adam where
  step :: Tensor -> Gradients -> [Tensor] -> Adam -> ([Tensor], Adam)
step = Tensor -> Gradients -> [Tensor] -> Adam -> ([Tensor], Adam)
adam

--
-- Adagrad
--

-- | State representation for Adagrad Optimizer
data Adagrad = Adagrad {Adagrad -> [Tensor]
gsum :: [Tensor]} -- sum of squared gradients
  deriving (Int -> Adagrad -> ShowS
[Adagrad] -> ShowS
Adagrad -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Adagrad] -> ShowS
$cshowList :: [Adagrad] -> ShowS
show :: Adagrad -> String
$cshow :: Adagrad -> String
showsPrec :: Int -> Adagrad -> ShowS
$cshowsPrec :: Int -> Adagrad -> ShowS
Show)

-- | Adagrad step
adagrad ::
  -- | learning rate
  LearningRate ->
  -- | model parameter gradients
  Gradients ->
  -- | model parameters
  [Tensor] ->
  -- | adagrad parameters - gsum, iteration
  Adagrad ->
  -- | returns new parameters + updated adam parameters
  ([Tensor], Adagrad)
adagrad :: Tensor -> Gradients -> [Tensor] -> Adagrad -> ([Tensor], Adagrad)
adagrad Tensor
lr (Gradients [Tensor]
gradients) [Tensor]
parameters Adagrad {[Tensor]
gsum :: [Tensor]
gsum :: Adagrad -> [Tensor]
..} = ([Tensor]
parameters', [Tensor] -> Adagrad
Adagrad [Tensor]
gsum')
  where
    -- add gradient squared to running total
    f :: a -> a -> a
f a
gsum a
dp = a
gsum forall a. Num a => a -> a -> a
+ a
dp forall a. Num a => a -> a -> a
* a
dp
    gsum' :: [Tensor]
gsum' = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. Num a => a -> a -> a
f [Tensor]
gsum [Tensor]
gradients

    -- parameter update
    eps :: Tensor
eps = Tensor
1e-37
    update :: Tensor -> Tensor -> Tensor -> Tensor
update Tensor
prevParam Tensor
a1' Tensor
a2' = Tensor
prevParam forall a. Num a => a -> a -> a
- Tensor
lr forall a. Num a => a -> a -> a
* Tensor
a1' forall a. Fractional a => a -> a -> a
/ (Tensor -> Tensor
sqrt (Tensor
a2' forall a. Num a => a -> a -> a
+ Tensor
eps))
    parameters' :: [Tensor]
parameters' = forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 Tensor -> Tensor -> Tensor -> Tensor
update [Tensor]
parameters [Tensor]
gradients [Tensor]
gsum'

instance Optimizer Adagrad where
  step :: Tensor -> Gradients -> [Tensor] -> Adagrad -> ([Tensor], Adagrad)
step = Tensor -> Gradients -> [Tensor] -> Adagrad -> ([Tensor], Adagrad)
adagrad

-- | syntactic sugar for looping with foldM
foldLoop :: a -> Int -> (a -> Int -> IO a) -> IO a
foldLoop :: forall a. a -> Int -> (a -> Int -> IO a) -> IO a
foldLoop a
x Int
count a -> Int -> IO a
block = forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM a -> Int -> IO a
block a
x [Int
1 .. Int
count]