{-# LANGUAGE RecordWildCards #-}
module Torch.Optim where
import Control.Monad.State
import System.Mem (performGC)
import Torch.Autograd
import Torch.Functional
import Torch.Internal.GC (mallocTrim)
import Torch.NN
import Torch.Tensor
import Torch.TensorFactories
import Prelude hiding (sqrt)
type LearningRate = Tensor
type Loss = Tensor
newtype Gradients = Gradients [Tensor] deriving (Int -> Gradients -> ShowS
[Gradients] -> ShowS
Gradients -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Gradients] -> ShowS
$cshowList :: [Gradients] -> ShowS
show :: Gradients -> String
$cshow :: Gradients -> String
showsPrec :: Int -> Gradients -> ShowS
$cshowsPrec :: Int -> Gradients -> ShowS
Show)
newtype OptimizerState option = OptimizerState option
grad' :: Loss -> [Parameter] -> Gradients
grad' :: Tensor -> [Parameter] -> Gradients
grad' Tensor
t [Parameter]
p = [Tensor] -> Gradients
Gradients (Tensor -> [Parameter] -> [Tensor]
grad Tensor
t [Parameter]
p)
class Optimizer optimizer where
step :: LearningRate -> Gradients -> [Tensor] -> optimizer -> ([Tensor], optimizer)
runStep :: (Parameterized model) => model -> optimizer -> Loss -> LearningRate -> IO (model, optimizer)
runStep model
paramState optimizer
optState Tensor
lossValue = forall optimizer model.
(Optimizer optimizer, Parameterized model) =>
model -> optimizer -> Gradients -> Tensor -> IO (model, optimizer)
runStep' model
paramState optimizer
optState (Tensor -> [Parameter] -> Gradients
grad' Tensor
lossValue forall a b. (a -> b) -> a -> b
$ forall f. Parameterized f => f -> [Parameter]
flattenParameters model
paramState)
runStep' :: (Parameterized model) => model -> optimizer -> Gradients -> LearningRate -> IO (model, optimizer)
runStep' model
paramState optimizer
optState Gradients
gradients Tensor
lr = do
IO ()
performGC
CInt -> IO ()
mallocTrim CInt
0
let ([Tensor]
flatParameters', optimizer
optState') = forall optimizer.
Optimizer optimizer =>
Tensor
-> Gradients -> [Tensor] -> optimizer -> ([Tensor], optimizer)
step Tensor
lr Gradients
gradients [Tensor]
depParameters optimizer
optState
[Parameter]
newFlatParam <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Tensor -> IO Parameter
makeIndependent [Tensor]
flatParameters'
forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall f. Parameterized f => f -> [Parameter] -> f
replaceParameters model
paramState [Parameter]
newFlatParam, optimizer
optState')
where
flatParameters :: [Parameter]
flatParameters = forall f. Parameterized f => f -> [Parameter]
flattenParameters model
paramState
depParameters :: [Tensor]
depParameters = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Parameter -> Tensor
toDependent [Parameter]
flatParameters
data GD = GD deriving (Int -> GD -> ShowS
[GD] -> ShowS
GD -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [GD] -> ShowS
$cshowList :: [GD] -> ShowS
show :: GD -> String
$cshow :: GD -> String
showsPrec :: Int -> GD -> ShowS
$cshowsPrec :: Int -> GD -> ShowS
Show)
gd :: LearningRate -> Gradients -> [Tensor] -> [Tensor]
gd :: Tensor -> Gradients -> [Tensor] -> [Tensor]
gd Tensor
lr (Gradients [Tensor]
gradients) [Tensor]
parameters = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Tensor -> Tensor -> Tensor
step [Tensor]
parameters [Tensor]
gradients
where
step :: Tensor -> Tensor -> Tensor
step Tensor
p Tensor
dp = Tensor
p forall a. Num a => a -> a -> a
- (Tensor
lr forall a. Num a => a -> a -> a
* Tensor
dp)
gd' :: LearningRate -> Gradients -> [Tensor] -> GD -> ([Tensor], GD)
gd' :: Tensor -> Gradients -> [Tensor] -> GD -> ([Tensor], GD)
gd' Tensor
lr Gradients
gradients [Tensor]
depParameters GD
dummy = (Tensor -> Gradients -> [Tensor] -> [Tensor]
gd Tensor
lr Gradients
gradients [Tensor]
depParameters, GD
dummy)
instance Optimizer GD where
step :: Tensor -> Gradients -> [Tensor] -> GD -> ([Tensor], GD)
step = Tensor -> Gradients -> [Tensor] -> GD -> ([Tensor], GD)
gd'
sgd :: LearningRate -> [Parameter] -> [Tensor] -> [Tensor]
sgd :: Tensor -> [Parameter] -> [Tensor] -> [Tensor]
sgd Tensor
lr [Parameter]
parameters = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Tensor -> Tensor -> Tensor
step [Tensor]
depParameters
where
step :: Tensor -> Tensor -> Tensor
step Tensor
p Tensor
dp = Tensor
p forall a. Num a => a -> a -> a
- (Tensor
lr forall a. Num a => a -> a -> a
* Tensor
dp)
depParameters :: [Tensor]
depParameters = forall a b. (a -> b) -> [a] -> [b]
map Parameter -> Tensor
toDependent [Parameter]
parameters
data GDM = GDM {GDM -> Float
beta :: Float, GDM -> [Tensor]
momentum :: [Tensor]} deriving (Int -> GDM -> ShowS
[GDM] -> ShowS
GDM -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [GDM] -> ShowS
$cshowList :: [GDM] -> ShowS
show :: GDM -> String
$cshow :: GDM -> String
showsPrec :: Int -> GDM -> ShowS
$cshowsPrec :: Int -> GDM -> ShowS
Show)
gdm ::
LearningRate ->
Gradients ->
[Tensor] ->
GDM ->
([Tensor], GDM)
gdm :: Tensor -> Gradients -> [Tensor] -> GDM -> ([Tensor], GDM)
gdm Tensor
lr (Gradients [Tensor]
gradients) [Tensor]
parameters (GDM Float
beta [Tensor]
momentum) =
(forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a, b) -> a
fst [(Tensor, Tensor)]
runStep, Float -> [Tensor] -> GDM
GDM Float
beta (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a, b) -> b
snd [(Tensor, Tensor)]
runStep))
where
step :: Tensor -> Tensor -> Tensor -> (Tensor, Tensor)
step Tensor
p Tensor
dp Tensor
z = let z' :: Tensor
z' = forall a. Scalar a => a -> Tensor -> Tensor
mulScalar Float
beta Tensor
z forall a. Num a => a -> a -> a
+ Tensor
dp in (Tensor
p forall a. Num a => a -> a -> a
- Tensor
lr forall a. Num a => a -> a -> a
* Tensor
z', Tensor
z')
runStep :: [(Tensor, Tensor)]
runStep = forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 Tensor -> Tensor -> Tensor -> (Tensor, Tensor)
step [Tensor]
parameters [Tensor]
gradients [Tensor]
momentum
instance Optimizer GDM where
step :: Tensor -> Gradients -> [Tensor] -> GDM -> ([Tensor], GDM)
step = Tensor -> Gradients -> [Tensor] -> GDM -> ([Tensor], GDM)
gdm
data Adam = Adam
{ Adam -> Float
beta1 :: Float,
Adam -> Float
beta2 :: Float,
Adam -> [Tensor]
m1 :: [Tensor],
Adam -> [Tensor]
m2 :: [Tensor],
Adam -> Int
iter :: Int
}
deriving (Int -> Adam -> ShowS
[Adam] -> ShowS
Adam -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Adam] -> ShowS
$cshowList :: [Adam] -> ShowS
show :: Adam -> String
$cshow :: Adam -> String
showsPrec :: Int -> Adam -> ShowS
$cshowsPrec :: Int -> Adam -> ShowS
Show)
mkAdam ::
Int ->
Float ->
Float ->
[Parameter] ->
Adam
mkAdam :: Int -> Float -> Float -> [Parameter] -> Adam
mkAdam Int
iter Float
beta1 Float
beta2 [Parameter]
parameters =
Float -> Float -> [Tensor] -> [Tensor] -> Int -> Adam
Adam
Float
beta1
Float
beta2
(Parameter -> Tensor
initZeros forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Parameter]
parameters)
(Parameter -> Tensor
initZeros forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Parameter]
parameters)
Int
iter
where
initZeros :: Parameter -> Tensor
initZeros = Tensor -> Tensor
zerosLike forall b c a. (b -> c) -> (a -> b) -> a -> c
. Parameter -> Tensor
toDependent
adam ::
LearningRate ->
Gradients ->
[Tensor] ->
Adam ->
([Tensor], Adam)
adam :: Tensor -> Gradients -> [Tensor] -> Adam -> ([Tensor], Adam)
adam Tensor
lr (Gradients [Tensor]
gradients) [Tensor]
parameters Adam {Float
Int
[Tensor]
iter :: Int
m2 :: [Tensor]
m1 :: [Tensor]
beta2 :: Float
beta1 :: Float
iter :: Adam -> Int
m2 :: Adam -> [Tensor]
m1 :: Adam -> [Tensor]
beta2 :: Adam -> Float
beta1 :: Adam -> Float
..} = ([Tensor]
parameters', Float -> Float -> [Tensor] -> [Tensor] -> Int -> Adam
Adam Float
beta1 Float
beta2 [Tensor]
m1' [Tensor]
m2' (Int
iter forall a. Num a => a -> a -> a
+ Int
1))
where
f1 :: Tensor -> Tensor -> Tensor
f1 Tensor
m1 Tensor
dp = forall a. Scalar a => a -> Tensor -> Tensor
mulScalar Float
beta1 Tensor
m1 forall a. Num a => a -> a -> a
+ forall a. Scalar a => a -> Tensor -> Tensor
mulScalar (Float
1 forall a. Num a => a -> a -> a
- Float
beta1) Tensor
dp
f2 :: Tensor -> Tensor -> Tensor
f2 Tensor
m2 Tensor
dp = forall a. Scalar a => a -> Tensor -> Tensor
mulScalar Float
beta2 Tensor
m2 forall a. Num a => a -> a -> a
+ forall a. Scalar a => a -> Tensor -> Tensor
mulScalar (Float
1 forall a. Num a => a -> a -> a
- Float
beta2) (Tensor
dp forall a. Num a => a -> a -> a
* Tensor
dp)
m1' :: [Tensor]
m1' = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Tensor -> Tensor -> Tensor
f1 [Tensor]
m1 [Tensor]
gradients
m2' :: [Tensor]
m2' = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Tensor -> Tensor -> Tensor
f2 [Tensor]
m2 [Tensor]
gradients
a :: a -> Tensor -> Tensor
a a
beta = forall a. Scalar a => a -> Tensor -> Tensor
divScalar (a
1 forall a. Num a => a -> a -> a
- a
beta forall a b. (Num a, Integral b) => a -> b -> a
^ (Int
iter forall a. Num a => a -> a -> a
+ Int
1))
a1 :: [Tensor]
a1 = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall {a}. (Scalar a, Num a) => a -> Tensor -> Tensor
a Float
beta1) [Tensor]
m1'
a2 :: [Tensor]
a2 = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall {a}. (Scalar a, Num a) => a -> Tensor -> Tensor
a Float
beta2) [Tensor]
m2'
eps :: Tensor
eps = Tensor
1e-37
update :: Tensor -> Tensor -> Tensor -> Tensor
update Tensor
prevParam Tensor
a1' Tensor
a2' = Tensor
prevParam forall a. Num a => a -> a -> a
- Tensor
lr forall a. Num a => a -> a -> a
* Tensor
a1' forall a. Fractional a => a -> a -> a
/ (Tensor -> Tensor
sqrt Tensor
a2' forall a. Num a => a -> a -> a
+ Tensor
eps)
parameters' :: [Tensor]
parameters' = forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 Tensor -> Tensor -> Tensor -> Tensor
update [Tensor]
parameters [Tensor]
a1 [Tensor]
a2
instance Optimizer Adam where
step :: Tensor -> Gradients -> [Tensor] -> Adam -> ([Tensor], Adam)
step = Tensor -> Gradients -> [Tensor] -> Adam -> ([Tensor], Adam)
adam
data Adagrad = Adagrad {Adagrad -> [Tensor]
gsum :: [Tensor]}
deriving (Int -> Adagrad -> ShowS
[Adagrad] -> ShowS
Adagrad -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Adagrad] -> ShowS
$cshowList :: [Adagrad] -> ShowS
show :: Adagrad -> String
$cshow :: Adagrad -> String
showsPrec :: Int -> Adagrad -> ShowS
$cshowsPrec :: Int -> Adagrad -> ShowS
Show)
adagrad ::
LearningRate ->
Gradients ->
[Tensor] ->
Adagrad ->
([Tensor], Adagrad)
adagrad :: Tensor -> Gradients -> [Tensor] -> Adagrad -> ([Tensor], Adagrad)
adagrad Tensor
lr (Gradients [Tensor]
gradients) [Tensor]
parameters Adagrad {[Tensor]
gsum :: [Tensor]
gsum :: Adagrad -> [Tensor]
..} = ([Tensor]
parameters', [Tensor] -> Adagrad
Adagrad [Tensor]
gsum')
where
f :: a -> a -> a
f a
gsum a
dp = a
gsum forall a. Num a => a -> a -> a
+ a
dp forall a. Num a => a -> a -> a
* a
dp
gsum' :: [Tensor]
gsum' = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. Num a => a -> a -> a
f [Tensor]
gsum [Tensor]
gradients
eps :: Tensor
eps = Tensor
1e-37
update :: Tensor -> Tensor -> Tensor -> Tensor
update Tensor
prevParam Tensor
a1' Tensor
a2' = Tensor
prevParam forall a. Num a => a -> a -> a
- Tensor
lr forall a. Num a => a -> a -> a
* Tensor
a1' forall a. Fractional a => a -> a -> a
/ (Tensor -> Tensor
sqrt (Tensor
a2' forall a. Num a => a -> a -> a
+ Tensor
eps))
parameters' :: [Tensor]
parameters' = forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 Tensor -> Tensor -> Tensor -> Tensor
update [Tensor]
parameters [Tensor]
gradients [Tensor]
gsum'
instance Optimizer Adagrad where
step :: Tensor -> Gradients -> [Tensor] -> Adagrad -> ([Tensor], Adagrad)
step = Tensor -> Gradients -> [Tensor] -> Adagrad -> ([Tensor], Adagrad)
adagrad
foldLoop :: a -> Int -> (a -> Int -> IO a) -> IO a
foldLoop :: forall a. a -> Int -> (a -> Int -> IO a) -> IO a
foldLoop a
x Int
count a -> Int -> IO a
block = forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM a -> Int -> IO a
block a
x [Int
1 .. Int
count]