{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

module Torch.Optim.CppOptim where

import Data.Default.Class
import Foreign.ForeignPtr
import System.Mem (performGC)
import Torch.Autograd
import Torch.Internal.Cast
import Torch.Internal.Class (Castable (..), CppObject (..), CppTuple2 (..), CppTuple3 (..), CppTuple4 (..))
import Torch.Internal.GC (mallocTrim)
import qualified Torch.Internal.Managed.Optim as LibTorch
import qualified Torch.Internal.Type as ATen
import Torch.NN
import qualified Torch.Optim as Optim
import Torch.Tensor

type CppOptimizerRef = ForeignPtr ATen.Optimizer

data CppOptimizerState option = CppOptimizerState option CppOptimizerRef

-- class Optimizer option where
--   initOptimizer :: Parameterized model => option -> model -> IO (OptimizerState option model)
--   step :: Parameterized model => OptimizerState option model -> (model -> IO Tensor) -> IO Tensor
--   -- Returned d depends on the state of optimizer.
--   -- Do not call step function after this function is called.
--   getParams :: Parameterized model => OptimizerState option model -> IO model
--   step (OptimizerState _ optimizer initParams) loss = cast0 (LibTorch.step optimizer trans)
--     where
--       trans :: ForeignPtr ATen.TensorList -> IO (ForeignPtr ATen.Tensor)
--       trans inputs =
--         uncast inputs $ \inputs' -> do
--           (Unsafe ret) <- loss $ replaceParameters initParams $  map (IndependentTensor . Unsafe) inputs'
--           cast ret return
--   getParams (OptimizerState _ optimizer initParams) = fmap (replaceParameters initParams . map (IndependentTensor . Unsafe)) $ cast0 (LibTorch.getParams optimizer)

stepWithGenerator ::
  CppOptimizerState option ->
  ForeignPtr ATen.Generator ->
  ([Tensor] -> ForeignPtr ATen.Generator -> IO (Tensor, ForeignPtr ATen.Generator)) ->
  IO (Tensor, ForeignPtr ATen.Generator)
stepWithGenerator :: forall option.
CppOptimizerState option
-> ForeignPtr Generator
-> ([Tensor]
    -> ForeignPtr Generator -> IO (Tensor, ForeignPtr Generator))
-> IO (Tensor, ForeignPtr Generator)
stepWithGenerator o :: CppOptimizerState option
o@(CppOptimizerState option
_ CppOptimizerRef
ref) ForeignPtr Generator
generator [Tensor]
-> ForeignPtr Generator -> IO (Tensor, ForeignPtr Generator)
loss = do
  (Tensor
v, ForeignPtr Generator
nextGenerator) <- forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 CppOptimizerRef
-> ForeignPtr Generator
-> (ForeignPtr TensorList
    -> ForeignPtr Generator
    -> IO (ForeignPtr (StdTuple '(Tensor, Generator))))
-> IO (ForeignPtr (StdTuple '(Tensor, Generator)))
LibTorch.stepWithGenerator CppOptimizerRef
ref ForeignPtr Generator
generator ForeignPtr TensorList
-> ForeignPtr Generator
-> IO (ForeignPtr (StdTuple '(Tensor, Generator)))
loss'
  forall (m :: * -> *) a. Monad m => a -> m a
return (Tensor
v, ForeignPtr Generator
nextGenerator)
  where
    loss' :: ForeignPtr ATen.TensorList -> ForeignPtr ATen.Generator -> IO (ForeignPtr (ATen.StdTuple '(ATen.Tensor, ATen.Generator)))
    loss' :: ForeignPtr TensorList
-> ForeignPtr Generator
-> IO (ForeignPtr (StdTuple '(Tensor, Generator)))
loss' ForeignPtr TensorList
params ForeignPtr Generator
gen = do
      (Tensor
v :: Tensor, ForeignPtr Generator
gen') <- forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr TensorList
params forall a b. (a -> b) -> a -> b
$ \[Tensor]
params' -> [Tensor]
-> ForeignPtr Generator -> IO (Tensor, ForeignPtr Generator)
loss [Tensor]
params' ForeignPtr Generator
gen
      ForeignPtr Tensor
v' <- forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast Tensor
v forall (f :: * -> *) a. Applicative f => a -> f a
pure :: IO (ForeignPtr ATen.Tensor)
      forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast (ForeignPtr Tensor
v', ForeignPtr Generator
gen') forall (f :: * -> *) a. Applicative f => a -> f a
pure

class CppOptimizer option where
  initOptimizer :: Parameterized model => option -> model -> IO (CppOptimizerState option)
  unsafeStep :: Parameterized model => model -> CppOptimizerState option -> Tensor -> IO (model, CppOptimizerState option)
  unsafeStep model
model o :: CppOptimizerState option
o@(CppOptimizerState option
_ CppOptimizerRef
optimizer) Tensor
loss = do
    [ForeignPtr Tensor]
v <- forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 CppOptimizerRef -> ForeignPtr Tensor -> IO (ForeignPtr TensorList)
LibTorch.unsafeStep CppOptimizerRef
optimizer Tensor
loss
    let newModel :: model
newModel = forall f. Parameterized f => f -> [Parameter] -> f
replaceParameters model
model forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (Tensor -> Parameter
IndependentTensor forall b c a. (b -> c) -> (a -> b) -> a -> c
. ForeignPtr Tensor -> Tensor
Unsafe) [ForeignPtr Tensor]
v
    forall (m :: * -> *) a. Monad m => a -> m a
return (model
newModel, CppOptimizerState option
o)

instance {-# OVERLAPS #-} CppOptimizer option => Optim.Optimizer (CppOptimizerState option) where
  step :: Tensor
-> Gradients
-> [Tensor]
-> CppOptimizerState option
-> ([Tensor], CppOptimizerState option)
step = forall a. HasCallStack => [Char] -> a
error [Char]
"step is not implemented for CppOptimizer."
  runStep :: forall model.
Parameterized model =>
model
-> CppOptimizerState option
-> Tensor
-> Tensor
-> IO (model, CppOptimizerState option)
runStep model
paramState CppOptimizerState option
optState Tensor
lossValue Tensor
lr = do
    IO ()
performGC
    CInt -> IO ()
mallocTrim CInt
0
    forall option model.
(CppOptimizer option, Parameterized model) =>
model
-> CppOptimizerState option
-> Tensor
-> IO (model, CppOptimizerState option)
unsafeStep model
paramState CppOptimizerState option
optState Tensor
lossValue

  runStep' :: forall model.
Parameterized model =>
model
-> CppOptimizerState option
-> Gradients
-> Tensor
-> IO (model, CppOptimizerState option)
runStep' = forall a. HasCallStack => [Char] -> a
error [Char]
"runStep' is not implemented for CppOptimizer."

data AdagradOptions = AdagradOptions
  { AdagradOptions -> Double
adagradLr :: Double,
    AdagradOptions -> Double
adagradLrDecay :: Double,
    AdagradOptions -> Double
adagradWeightDecay :: Double,
    AdagradOptions -> Double
adagradInitialAccumulatorValue :: Double,
    AdagradOptions -> Double
adagradEps :: Double
  }
  deriving (Int -> AdagradOptions -> ShowS
[AdagradOptions] -> ShowS
AdagradOptions -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [AdagradOptions] -> ShowS
$cshowList :: [AdagradOptions] -> ShowS
show :: AdagradOptions -> [Char]
$cshow :: AdagradOptions -> [Char]
showsPrec :: Int -> AdagradOptions -> ShowS
$cshowsPrec :: Int -> AdagradOptions -> ShowS
Show, AdagradOptions -> AdagradOptions -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: AdagradOptions -> AdagradOptions -> Bool
$c/= :: AdagradOptions -> AdagradOptions -> Bool
== :: AdagradOptions -> AdagradOptions -> Bool
$c== :: AdagradOptions -> AdagradOptions -> Bool
Eq)

instance Default AdagradOptions where
  def :: AdagradOptions
def =
    AdagradOptions
      { adagradLr :: Double
adagradLr = Double
1e-2,
        adagradLrDecay :: Double
adagradLrDecay = Double
0,
        adagradWeightDecay :: Double
adagradWeightDecay = Double
0,
        adagradInitialAccumulatorValue :: Double
adagradInitialAccumulatorValue = Double
0,
        adagradEps :: Double
adagradEps = Double
1e-10
      }

instance CppOptimizer AdagradOptions where
  initOptimizer :: forall model.
Parameterized model =>
AdagradOptions -> model -> IO (CppOptimizerState AdagradOptions)
initOptimizer opt :: AdagradOptions
opt@AdagradOptions {Double
adagradEps :: Double
adagradInitialAccumulatorValue :: Double
adagradWeightDecay :: Double
adagradLrDecay :: Double
adagradLr :: Double
adagradEps :: AdagradOptions -> Double
adagradInitialAccumulatorValue :: AdagradOptions -> Double
adagradWeightDecay :: AdagradOptions -> Double
adagradLrDecay :: AdagradOptions -> Double
adagradLr :: AdagradOptions -> Double
..} model
initParams = do
    CppOptimizerRef
v <- forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> IO y
cast6 CDouble
-> CDouble
-> CDouble
-> CDouble
-> CDouble
-> ForeignPtr TensorList
-> IO CppOptimizerRef
LibTorch.adagrad Double
adagradLr Double
adagradLrDecay Double
adagradWeightDecay Double
adagradInitialAccumulatorValue Double
adagradEps [Tensor]
initParams'
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall option.
option -> CppOptimizerRef -> CppOptimizerState option
CppOptimizerState AdagradOptions
opt CppOptimizerRef
v
    where
      initParams' :: [Tensor]
initParams' = forall a b. (a -> b) -> [a] -> [b]
map Parameter -> Tensor
toDependent forall a b. (a -> b) -> a -> b
$ forall f. Parameterized f => f -> [Parameter]
flattenParameters model
initParams

data AdamOptions = AdamOptions
  { AdamOptions -> Double
adamLr :: Double,
    AdamOptions -> (Double, Double)
adamBetas :: (Double, Double),
    AdamOptions -> Double
adamEps :: Double,
    AdamOptions -> Double
adamWeightDecay :: Double,
    AdamOptions -> Bool
adamAmsgrad :: Bool
  }
  deriving (Int -> AdamOptions -> ShowS
[AdamOptions] -> ShowS
AdamOptions -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [AdamOptions] -> ShowS
$cshowList :: [AdamOptions] -> ShowS
show :: AdamOptions -> [Char]
$cshow :: AdamOptions -> [Char]
showsPrec :: Int -> AdamOptions -> ShowS
$cshowsPrec :: Int -> AdamOptions -> ShowS
Show, AdamOptions -> AdamOptions -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: AdamOptions -> AdamOptions -> Bool
$c/= :: AdamOptions -> AdamOptions -> Bool
== :: AdamOptions -> AdamOptions -> Bool
$c== :: AdamOptions -> AdamOptions -> Bool
Eq)

instance Default AdamOptions where
  def :: AdamOptions
def =
    AdamOptions
      { adamLr :: Double
adamLr = Double
1e-3,
        adamBetas :: (Double, Double)
adamBetas = (Double
0.9, Double
0.999),
        adamEps :: Double
adamEps = Double
1e-8,
        adamWeightDecay :: Double
adamWeightDecay = Double
0,
        adamAmsgrad :: Bool
adamAmsgrad = Bool
False
      }

instance CppOptimizer AdamOptions where
  initOptimizer :: forall model.
Parameterized model =>
AdamOptions -> model -> IO (CppOptimizerState AdamOptions)
initOptimizer opt :: AdamOptions
opt@AdamOptions {Bool
Double
(Double, Double)
adamAmsgrad :: Bool
adamWeightDecay :: Double
adamEps :: Double
adamBetas :: (Double, Double)
adamLr :: Double
adamAmsgrad :: AdamOptions -> Bool
adamWeightDecay :: AdamOptions -> Double
adamEps :: AdamOptions -> Double
adamBetas :: AdamOptions -> (Double, Double)
adamLr :: AdamOptions -> Double
..} model
initParams = do
    CppOptimizerRef
v <- forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable x6 cx6,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> IO y
cast7 CDouble
-> CDouble
-> CDouble
-> CDouble
-> CDouble
-> CBool
-> ForeignPtr TensorList
-> IO CppOptimizerRef
LibTorch.adam Double
adamLr (forall a b. (a, b) -> a
fst (Double, Double)
adamBetas) (forall a b. (a, b) -> b
snd (Double, Double)
adamBetas) Double
adamEps Double
adamWeightDecay Bool
adamAmsgrad [Tensor]
initParams'
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall option.
option -> CppOptimizerRef -> CppOptimizerState option
CppOptimizerState AdamOptions
opt CppOptimizerRef
v
    where
      initParams' :: [Tensor]
initParams' = forall a b. (a -> b) -> [a] -> [b]
map Parameter -> Tensor
toDependent forall a b. (a -> b) -> a -> b
$ forall f. Parameterized f => f -> [Parameter]
flattenParameters model
initParams

data AdamwOptions = AdamwOptions
  { AdamwOptions -> Double
adamwLr :: Double,
    AdamwOptions -> (Double, Double)
adamwBetas :: (Double, Double),
    AdamwOptions -> Double
adamwEps :: Double,
    AdamwOptions -> Double
adamwWeightDecay :: Double,
    AdamwOptions -> Bool
adamwAmsgrad :: Bool
  }
  deriving (Int -> AdamwOptions -> ShowS
[AdamwOptions] -> ShowS
AdamwOptions -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [AdamwOptions] -> ShowS
$cshowList :: [AdamwOptions] -> ShowS
show :: AdamwOptions -> [Char]
$cshow :: AdamwOptions -> [Char]
showsPrec :: Int -> AdamwOptions -> ShowS
$cshowsPrec :: Int -> AdamwOptions -> ShowS
Show, AdamwOptions -> AdamwOptions -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: AdamwOptions -> AdamwOptions -> Bool
$c/= :: AdamwOptions -> AdamwOptions -> Bool
== :: AdamwOptions -> AdamwOptions -> Bool
$c== :: AdamwOptions -> AdamwOptions -> Bool
Eq)

instance Default AdamwOptions where
  def :: AdamwOptions
def =
    AdamwOptions
      { adamwLr :: Double
adamwLr = Double
1e-3,
        adamwBetas :: (Double, Double)
adamwBetas = (Double
0.9, Double
0.999),
        adamwEps :: Double
adamwEps = Double
1e-8,
        adamwWeightDecay :: Double
adamwWeightDecay = Double
1e-2,
        adamwAmsgrad :: Bool
adamwAmsgrad = Bool
False
      }

instance CppOptimizer AdamwOptions where
  initOptimizer :: forall model.
Parameterized model =>
AdamwOptions -> model -> IO (CppOptimizerState AdamwOptions)
initOptimizer opt :: AdamwOptions
opt@AdamwOptions {Bool
Double
(Double, Double)
adamwAmsgrad :: Bool
adamwWeightDecay :: Double
adamwEps :: Double
adamwBetas :: (Double, Double)
adamwLr :: Double
adamwAmsgrad :: AdamwOptions -> Bool
adamwWeightDecay :: AdamwOptions -> Double
adamwEps :: AdamwOptions -> Double
adamwBetas :: AdamwOptions -> (Double, Double)
adamwLr :: AdamwOptions -> Double
..} model
initParams = do
    CppOptimizerRef
v <- forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable x6 cx6,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> IO y
cast7 CDouble
-> CDouble
-> CDouble
-> CDouble
-> CDouble
-> CBool
-> ForeignPtr TensorList
-> IO CppOptimizerRef
LibTorch.adamw Double
adamwLr (forall a b. (a, b) -> a
fst (Double, Double)
adamwBetas) (forall a b. (a, b) -> b
snd (Double, Double)
adamwBetas) Double
adamwEps Double
adamwWeightDecay Bool
adamwAmsgrad [Tensor]
initParams'
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall option.
option -> CppOptimizerRef -> CppOptimizerState option
CppOptimizerState AdamwOptions
opt CppOptimizerRef
v
    where
      initParams' :: [Tensor]
initParams' = forall a b. (a -> b) -> [a] -> [b]
map Parameter -> Tensor
toDependent forall a b. (a -> b) -> a -> b
$ forall f. Parameterized f => f -> [Parameter]
flattenParameters model
initParams

data LbfgsOptions = LbfgsOptions
  { LbfgsOptions -> Double
lbfgsLr :: Double,
    LbfgsOptions -> Int
lbfgsMaxIter :: Int,
    LbfgsOptions -> Int
lbfgsMaxEval :: Int,
    LbfgsOptions -> Double
lbfgsToleranceGrad :: Double,
    LbfgsOptions -> Double
lbfgsToleranceChange :: Double,
    LbfgsOptions -> Int
lbfgsHistorySize :: Int,
    LbfgsOptions -> Maybe [Char]
lbfgsLineSearchFn :: Maybe String
  }
  deriving (Int -> LbfgsOptions -> ShowS
[LbfgsOptions] -> ShowS
LbfgsOptions -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [LbfgsOptions] -> ShowS
$cshowList :: [LbfgsOptions] -> ShowS
show :: LbfgsOptions -> [Char]
$cshow :: LbfgsOptions -> [Char]
showsPrec :: Int -> LbfgsOptions -> ShowS
$cshowsPrec :: Int -> LbfgsOptions -> ShowS
Show, LbfgsOptions -> LbfgsOptions -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: LbfgsOptions -> LbfgsOptions -> Bool
$c/= :: LbfgsOptions -> LbfgsOptions -> Bool
== :: LbfgsOptions -> LbfgsOptions -> Bool
$c== :: LbfgsOptions -> LbfgsOptions -> Bool
Eq)

instance Default LbfgsOptions where
  def :: LbfgsOptions
def =
    LbfgsOptions
      { lbfgsLr :: Double
lbfgsLr = Double
1,
        lbfgsMaxIter :: Int
lbfgsMaxIter = Int
20,
        lbfgsMaxEval :: Int
lbfgsMaxEval = (Int
20 forall a. Num a => a -> a -> a
* Int
5) forall a. Integral a => a -> a -> a
`div` Int
4,
        lbfgsToleranceGrad :: Double
lbfgsToleranceGrad = Double
1e-7,
        lbfgsToleranceChange :: Double
lbfgsToleranceChange = Double
1e-9,
        lbfgsHistorySize :: Int
lbfgsHistorySize = Int
100,
        lbfgsLineSearchFn :: Maybe [Char]
lbfgsLineSearchFn = forall a. Maybe a
Nothing
      }

instance CppOptimizer LbfgsOptions where
  initOptimizer :: forall model.
Parameterized model =>
LbfgsOptions -> model -> IO (CppOptimizerState LbfgsOptions)
initOptimizer opt :: LbfgsOptions
opt@LbfgsOptions {Double
Int
Maybe [Char]
lbfgsLineSearchFn :: Maybe [Char]
lbfgsHistorySize :: Int
lbfgsToleranceChange :: Double
lbfgsToleranceGrad :: Double
lbfgsMaxEval :: Int
lbfgsMaxIter :: Int
lbfgsLr :: Double
lbfgsLineSearchFn :: LbfgsOptions -> Maybe [Char]
lbfgsHistorySize :: LbfgsOptions -> Int
lbfgsToleranceChange :: LbfgsOptions -> Double
lbfgsToleranceGrad :: LbfgsOptions -> Double
lbfgsMaxEval :: LbfgsOptions -> Int
lbfgsMaxIter :: LbfgsOptions -> Int
lbfgsLr :: LbfgsOptions -> Double
..} model
initParams = do
    CppOptimizerRef
v <- forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 x7 cx7 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable x6 cx6, Castable x7 cx7,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> cx7 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> x7 -> IO y
cast8 CDouble
-> CInt
-> CInt
-> CDouble
-> CDouble
-> CInt
-> Maybe (ForeignPtr StdString)
-> ForeignPtr TensorList
-> IO CppOptimizerRef
LibTorch.lbfgs Double
lbfgsLr Int
lbfgsMaxIter Int
lbfgsMaxEval Double
lbfgsToleranceGrad Double
lbfgsToleranceChange Int
lbfgsHistorySize Maybe [Char]
lbfgsLineSearchFn [Tensor]
initParams'
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall option.
option -> CppOptimizerRef -> CppOptimizerState option
CppOptimizerState LbfgsOptions
opt CppOptimizerRef
v
    where
      initParams' :: [Tensor]
initParams' = forall a b. (a -> b) -> [a] -> [b]
map Parameter -> Tensor
toDependent forall a b. (a -> b) -> a -> b
$ forall f. Parameterized f => f -> [Parameter]
flattenParameters model
initParams

data RmspropOptions = RmspropOptions
  { RmspropOptions -> Double
rmspropLr :: Double,
    RmspropOptions -> Double
rmspropAlpha :: Double,
    RmspropOptions -> Double
rmspropEps :: Double,
    RmspropOptions -> Double
rmspropWeightDecay :: Double,
    RmspropOptions -> Double
rmspropMomentum :: Double,
    RmspropOptions -> Bool
rmspropCentered :: Bool
  }
  deriving (Int -> RmspropOptions -> ShowS
[RmspropOptions] -> ShowS
RmspropOptions -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [RmspropOptions] -> ShowS
$cshowList :: [RmspropOptions] -> ShowS
show :: RmspropOptions -> [Char]
$cshow :: RmspropOptions -> [Char]
showsPrec :: Int -> RmspropOptions -> ShowS
$cshowsPrec :: Int -> RmspropOptions -> ShowS
Show, RmspropOptions -> RmspropOptions -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: RmspropOptions -> RmspropOptions -> Bool
$c/= :: RmspropOptions -> RmspropOptions -> Bool
== :: RmspropOptions -> RmspropOptions -> Bool
$c== :: RmspropOptions -> RmspropOptions -> Bool
Eq)

instance Default RmspropOptions where
  def :: RmspropOptions
def =
    RmspropOptions
      { rmspropLr :: Double
rmspropLr = Double
1e-2,
        rmspropAlpha :: Double
rmspropAlpha = Double
0.99,
        rmspropEps :: Double
rmspropEps = Double
1e-8,
        rmspropWeightDecay :: Double
rmspropWeightDecay = Double
0,
        rmspropMomentum :: Double
rmspropMomentum = Double
0,
        rmspropCentered :: Bool
rmspropCentered = Bool
False
      }

instance CppOptimizer RmspropOptions where
  initOptimizer :: forall model.
Parameterized model =>
RmspropOptions -> model -> IO (CppOptimizerState RmspropOptions)
initOptimizer opt :: RmspropOptions
opt@RmspropOptions {Bool
Double
rmspropCentered :: Bool
rmspropMomentum :: Double
rmspropWeightDecay :: Double
rmspropEps :: Double
rmspropAlpha :: Double
rmspropLr :: Double
rmspropCentered :: RmspropOptions -> Bool
rmspropMomentum :: RmspropOptions -> Double
rmspropWeightDecay :: RmspropOptions -> Double
rmspropEps :: RmspropOptions -> Double
rmspropAlpha :: RmspropOptions -> Double
rmspropLr :: RmspropOptions -> Double
..} model
initParams = do
    CppOptimizerRef
v <- forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable x6 cx6,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> IO y
cast7 CDouble
-> CDouble
-> CDouble
-> CDouble
-> CDouble
-> CBool
-> ForeignPtr TensorList
-> IO CppOptimizerRef
LibTorch.rmsprop Double
rmspropLr Double
rmspropAlpha Double
rmspropEps Double
rmspropWeightDecay Double
rmspropMomentum Bool
rmspropCentered [Tensor]
initParams'
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall option.
option -> CppOptimizerRef -> CppOptimizerState option
CppOptimizerState RmspropOptions
opt CppOptimizerRef
v
    where
      initParams' :: [Tensor]
initParams' = forall a b. (a -> b) -> [a] -> [b]
map Parameter -> Tensor
toDependent forall a b. (a -> b) -> a -> b
$ forall f. Parameterized f => f -> [Parameter]
flattenParameters model
initParams

data SGDOptions = SGDOptions
  { SGDOptions -> Double
sgdLr :: Double,
    SGDOptions -> Double
sgdMomentum :: Double,
    SGDOptions -> Double
sgdDampening :: Double,
    SGDOptions -> Double
sgdWeightDecay :: Double,
    SGDOptions -> Bool
sgdNesterov :: Bool
  }
  deriving (Int -> SGDOptions -> ShowS
[SGDOptions] -> ShowS
SGDOptions -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [SGDOptions] -> ShowS
$cshowList :: [SGDOptions] -> ShowS
show :: SGDOptions -> [Char]
$cshow :: SGDOptions -> [Char]
showsPrec :: Int -> SGDOptions -> ShowS
$cshowsPrec :: Int -> SGDOptions -> ShowS
Show, SGDOptions -> SGDOptions -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SGDOptions -> SGDOptions -> Bool
$c/= :: SGDOptions -> SGDOptions -> Bool
== :: SGDOptions -> SGDOptions -> Bool
$c== :: SGDOptions -> SGDOptions -> Bool
Eq)

instance Default SGDOptions where
  def :: SGDOptions
def =
    SGDOptions
      { sgdLr :: Double
sgdLr = Double
1e-3,
        sgdMomentum :: Double
sgdMomentum = Double
0,
        sgdDampening :: Double
sgdDampening = Double
0,
        sgdWeightDecay :: Double
sgdWeightDecay = Double
0,
        sgdNesterov :: Bool
sgdNesterov = Bool
False
      }

instance CppOptimizer SGDOptions where
  initOptimizer :: forall model.
Parameterized model =>
SGDOptions -> model -> IO (CppOptimizerState SGDOptions)
initOptimizer opt :: SGDOptions
opt@SGDOptions {Bool
Double
sgdNesterov :: Bool
sgdWeightDecay :: Double
sgdDampening :: Double
sgdMomentum :: Double
sgdLr :: Double
sgdNesterov :: SGDOptions -> Bool
sgdWeightDecay :: SGDOptions -> Double
sgdDampening :: SGDOptions -> Double
sgdMomentum :: SGDOptions -> Double
sgdLr :: SGDOptions -> Double
..} model
initParams = do
    CppOptimizerRef
v <- forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> IO y
cast6 CDouble
-> CDouble
-> CDouble
-> CDouble
-> CBool
-> ForeignPtr TensorList
-> IO CppOptimizerRef
LibTorch.sgd Double
sgdLr Double
sgdMomentum Double
sgdDampening Double
sgdWeightDecay Bool
sgdNesterov [Tensor]
initParams'
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall option.
option -> CppOptimizerRef -> CppOptimizerState option
CppOptimizerState SGDOptions
opt CppOptimizerRef
v
    where
      initParams' :: [Tensor]
initParams' = forall a b. (a -> b) -> [a] -> [b]
map Parameter -> Tensor
toDependent forall a b. (a -> b) -> a -> b
$ forall f. Parameterized f => f -> [Parameter]
flattenParameters model
initParams

saveState :: CppOptimizerState option -> FilePath -> IO ()
saveState :: forall option. CppOptimizerState option -> [Char] -> IO ()
saveState (CppOptimizerState option
_ CppOptimizerRef
optimizer) [Char]
file = forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 CppOptimizerRef -> ForeignPtr StdString -> IO ()
LibTorch.save CppOptimizerRef
optimizer [Char]
file

loadState :: CppOptimizerState option -> FilePath -> IO ()
loadState :: forall option. CppOptimizerState option -> [Char] -> IO ()
loadState (CppOptimizerState option
_ CppOptimizerRef
optimizer) [Char]
file = forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 CppOptimizerRef -> ForeignPtr StdString -> IO ()
LibTorch.load CppOptimizerRef
optimizer [Char]
file