{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
module Torch.Typed.Optim.CppOptim
( module Torch.Typed.Optim.CppOptim,
AdagradOptions (..),
AdamOptions (..),
AdamwOptions (..),
LbfgsOptions (..),
RmspropOptions (..),
SGDOptions (..),
)
where
import Data.Default.Class
import Data.Foldable (for_)
import qualified Debug.Trace as Debug
import Foreign.ForeignPtr
import System.Mem (performGC)
import qualified Torch as TD
import Torch.HList
import Torch.Internal.Cast
import Torch.Internal.Class (Castable (..), CppObject (..), CppTuple2 (..), CppTuple3 (..), CppTuple4 (..))
import Torch.Internal.GC (mallocTrim)
import qualified Torch.Internal.Managed.Optim as LibTorch
import qualified Torch.Internal.Type as ATen
import Torch.Optim.CppOptim
( AdagradOptions (..),
AdamOptions (..),
AdamwOptions (..),
LbfgsOptions (..),
RmspropOptions (..),
SGDOptions (..),
)
import Torch.Typed.Autograd
import Torch.Typed.NN
import qualified Torch.Typed.Optim as Optim
import Torch.Typed.Parameter
import Torch.Typed.Tensor
type CppOptimizerRef = ForeignPtr ATen.Optimizer
data CppOptimizerState option (params :: [*])
= CppOptimizerState option CppOptimizerRef
data ToParameter = ToParameter
instance Apply' ToParameter (Tensor dev dtype shape) (Parameter dev dtype shape) where
apply' :: ToParameter -> Tensor dev dtype shape -> Parameter dev dtype shape
apply' ToParameter
_ (UnsafeMkTensor Tensor
tensor) = forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
IndependentTensor -> Parameter device dtype shape
UnsafeMkParameter forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor -> IndependentTensor
TD.IndependentTensor forall a b. (a -> b) -> a -> b
$ Tensor
tensor
class CppOptimizer option where
initOptimizer ::
forall model tensors.
( Parameterized model,
HMap' ToDependent (Parameters model) tensors,
Castable (HList tensors) [TD.ATenTensor]
) =>
option ->
model ->
IO (CppOptimizerState option (Parameters model))
unsafeStep ::
forall model dev dtype lossShape tensors res.
( Parameterized model,
HMap' ToDependent (Parameters model) tensors,
HMap' ToParameter tensors (Parameters model),
Castable (HList tensors) [TD.ATenTensor]
) =>
model ->
CppOptimizerState option (Parameters model) ->
Tensor dev dtype lossShape ->
IO (model, CppOptimizerState option (Parameters model))
unsafeStep model
model o :: CppOptimizerState option (Parameters model)
o@(CppOptimizerState option
_ CppOptimizerRef
optimizer) Tensor dev dtype lossShape
loss = do
[ATenTensor]
v :: [TD.ATenTensor] <- forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 CppOptimizerRef -> ATenTensor -> IO (ForeignPtr TensorList)
LibTorch.unsafeStep CppOptimizerRef
optimizer Tensor dev dtype lossShape
loss
HList tensors
newParamTensors :: HList tensors <- forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast [ATenTensor]
v forall (f :: * -> *) a. Applicative f => a -> f a
pure
let newParams :: HList (Parameters model)
newParams = forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
hmap' ToParameter
ToParameter HList tensors
newParamTensors
let newModel :: model
newModel = forall f. Parameterized f => f -> HList (Parameters f) -> f
replaceParameters model
model HList (Parameters model)
newParams
forall (m :: * -> *) a. Monad m => a -> m a
return (model
newModel, CppOptimizerState option (Parameters model)
o)
instance CppOptimizer AdamOptions where
initOptimizer :: forall model (tensors :: [*]).
(Parameterized model, HMap' ToDependent (Parameters model) tensors,
Castable (HList tensors) [ATenTensor]) =>
AdamOptions
-> model -> IO (CppOptimizerState AdamOptions (Parameters model))
initOptimizer opt :: AdamOptions
opt@AdamOptions {Bool
Double
(Double, Double)
adamAmsgrad :: AdamOptions -> Bool
adamWeightDecay :: AdamOptions -> Double
adamEps :: AdamOptions -> Double
adamBetas :: AdamOptions -> (Double, Double)
adamLr :: AdamOptions -> Double
adamAmsgrad :: Bool
adamWeightDecay :: Double
adamEps :: Double
adamBetas :: (Double, Double)
adamLr :: Double
..} model
model = do
CppOptimizerRef
v <-
forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable x4 cx4, Castable x5 cx5, Castable x6 cx6,
Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> IO y
cast7
CDouble
-> CDouble
-> CDouble
-> CDouble
-> CDouble
-> CBool
-> ForeignPtr TensorList
-> IO CppOptimizerRef
LibTorch.adam
Double
adamLr
(forall a b. (a, b) -> a
fst (Double, Double)
adamBetas)
(forall a b. (a, b) -> b
snd (Double, Double)
adamBetas)
Double
adamEps
Double
adamWeightDecay
Bool
adamAmsgrad
HList tensors
initParams'
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall option (params :: [*]).
option -> CppOptimizerRef -> CppOptimizerState option params
CppOptimizerState AdamOptions
opt CppOptimizerRef
v
where
initParams' :: HList tensors
initParams' = forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
hmap' ToDependent
ToDependent forall a b. (a -> b) -> a -> b
$ forall f. Parameterized f => f -> HList (Parameters f)
flattenParameters model
model
instance CppOptimizer AdamwOptions where
initOptimizer :: forall model (tensors :: [*]).
(Parameterized model, HMap' ToDependent (Parameters model) tensors,
Castable (HList tensors) [ATenTensor]) =>
AdamwOptions
-> model -> IO (CppOptimizerState AdamwOptions (Parameters model))
initOptimizer opt :: AdamwOptions
opt@AdamwOptions {Bool
Double
(Double, Double)
adamwAmsgrad :: AdamwOptions -> Bool
adamwWeightDecay :: AdamwOptions -> Double
adamwEps :: AdamwOptions -> Double
adamwBetas :: AdamwOptions -> (Double, Double)
adamwLr :: AdamwOptions -> Double
adamwAmsgrad :: Bool
adamwWeightDecay :: Double
adamwEps :: Double
adamwBetas :: (Double, Double)
adamwLr :: Double
..} model
model = do
CppOptimizerRef
v <- forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable x4 cx4, Castable x5 cx5, Castable x6 cx6,
Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> IO y
cast7 CDouble
-> CDouble
-> CDouble
-> CDouble
-> CDouble
-> CBool
-> ForeignPtr TensorList
-> IO CppOptimizerRef
LibTorch.adamw Double
adamwLr (forall a b. (a, b) -> a
fst (Double, Double)
adamwBetas) (forall a b. (a, b) -> b
snd (Double, Double)
adamwBetas) Double
adamwEps Double
adamwWeightDecay Bool
adamwAmsgrad HList tensors
initParams'
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall option (params :: [*]).
option -> CppOptimizerRef -> CppOptimizerState option params
CppOptimizerState AdamwOptions
opt CppOptimizerRef
v
where
initParams' :: HList tensors
initParams' = forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
hmap' ToDependent
ToDependent forall a b. (a -> b) -> a -> b
$ forall f. Parameterized f => f -> HList (Parameters f)
flattenParameters model
model
instance CppOptimizer LbfgsOptions where
initOptimizer :: forall model (tensors :: [*]).
(Parameterized model, HMap' ToDependent (Parameters model) tensors,
Castable (HList tensors) [ATenTensor]) =>
LbfgsOptions
-> model -> IO (CppOptimizerState LbfgsOptions (Parameters model))
initOptimizer opt :: LbfgsOptions
opt@LbfgsOptions {Double
Int
Maybe String
lbfgsLineSearchFn :: LbfgsOptions -> Maybe String
lbfgsHistorySize :: LbfgsOptions -> Int
lbfgsToleranceChange :: LbfgsOptions -> Double
lbfgsToleranceGrad :: LbfgsOptions -> Double
lbfgsMaxEval :: LbfgsOptions -> Int
lbfgsMaxIter :: LbfgsOptions -> Int
lbfgsLr :: LbfgsOptions -> Double
lbfgsLineSearchFn :: Maybe String
lbfgsHistorySize :: Int
lbfgsToleranceChange :: Double
lbfgsToleranceGrad :: Double
lbfgsMaxEval :: Int
lbfgsMaxIter :: Int
lbfgsLr :: Double
..} model
model = do
CppOptimizerRef
v <- forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 x7 cx7 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable x4 cx4, Castable x5 cx5, Castable x6 cx6, Castable x7 cx7,
Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> cx7 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> x7 -> IO y
cast8 CDouble
-> CInt
-> CInt
-> CDouble
-> CDouble
-> CInt
-> Maybe (ForeignPtr StdString)
-> ForeignPtr TensorList
-> IO CppOptimizerRef
LibTorch.lbfgs Double
lbfgsLr Int
lbfgsMaxIter Int
lbfgsMaxEval Double
lbfgsToleranceGrad Double
lbfgsToleranceChange Int
lbfgsHistorySize Maybe String
lbfgsLineSearchFn HList tensors
initParams'
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall option (params :: [*]).
option -> CppOptimizerRef -> CppOptimizerState option params
CppOptimizerState LbfgsOptions
opt CppOptimizerRef
v
where
initParams' :: HList tensors
initParams' = forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
hmap' ToDependent
ToDependent forall a b. (a -> b) -> a -> b
$ forall f. Parameterized f => f -> HList (Parameters f)
flattenParameters model
model
instance CppOptimizer RmspropOptions where
initOptimizer :: forall model (tensors :: [*]).
(Parameterized model, HMap' ToDependent (Parameters model) tensors,
Castable (HList tensors) [ATenTensor]) =>
RmspropOptions
-> model
-> IO (CppOptimizerState RmspropOptions (Parameters model))
initOptimizer opt :: RmspropOptions
opt@RmspropOptions {Bool
Double
rmspropCentered :: RmspropOptions -> Bool
rmspropMomentum :: RmspropOptions -> Double
rmspropWeightDecay :: RmspropOptions -> Double
rmspropEps :: RmspropOptions -> Double
rmspropAlpha :: RmspropOptions -> Double
rmspropLr :: RmspropOptions -> Double
rmspropCentered :: Bool
rmspropMomentum :: Double
rmspropWeightDecay :: Double
rmspropEps :: Double
rmspropAlpha :: Double
rmspropLr :: Double
..} model
model = do
CppOptimizerRef
v <- forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable x4 cx4, Castable x5 cx5, Castable x6 cx6,
Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> IO y
cast7 CDouble
-> CDouble
-> CDouble
-> CDouble
-> CDouble
-> CBool
-> ForeignPtr TensorList
-> IO CppOptimizerRef
LibTorch.rmsprop Double
rmspropLr Double
rmspropAlpha Double
rmspropEps Double
rmspropWeightDecay Double
rmspropMomentum Bool
rmspropCentered HList tensors
initParams'
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall option (params :: [*]).
option -> CppOptimizerRef -> CppOptimizerState option params
CppOptimizerState RmspropOptions
opt CppOptimizerRef
v
where
initParams' :: HList tensors
initParams' = forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
hmap' ToDependent
ToDependent forall a b. (a -> b) -> a -> b
$ forall f. Parameterized f => f -> HList (Parameters f)
flattenParameters model
model
instance CppOptimizer SGDOptions where
initOptimizer :: forall model (tensors :: [*]).
(Parameterized model, HMap' ToDependent (Parameters model) tensors,
Castable (HList tensors) [ATenTensor]) =>
SGDOptions
-> model -> IO (CppOptimizerState SGDOptions (Parameters model))
initOptimizer opt :: SGDOptions
opt@SGDOptions {Bool
Double
sgdNesterov :: SGDOptions -> Bool
sgdWeightDecay :: SGDOptions -> Double
sgdDampening :: SGDOptions -> Double
sgdMomentum :: SGDOptions -> Double
sgdLr :: SGDOptions -> Double
sgdNesterov :: Bool
sgdWeightDecay :: Double
sgdDampening :: Double
sgdMomentum :: Double
sgdLr :: Double
..} model
model = do
CppOptimizerRef
v <- forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable x4 cx4, Castable x5 cx5, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> IO y
cast6 CDouble
-> CDouble
-> CDouble
-> CDouble
-> CBool
-> ForeignPtr TensorList
-> IO CppOptimizerRef
LibTorch.sgd Double
sgdLr Double
sgdMomentum Double
sgdDampening Double
sgdWeightDecay Bool
sgdNesterov HList tensors
initParams'
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall option (params :: [*]).
option -> CppOptimizerRef -> CppOptimizerState option params
CppOptimizerState SGDOptions
opt CppOptimizerRef
v
where
initParams' :: HList tensors
initParams' = forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
hmap' ToDependent
ToDependent forall a b. (a -> b) -> a -> b
$ forall f. Parameterized f => f -> HList (Parameters f)
flattenParameters model
model
runStep ::
( CppOptimizer option,
Parameterized model,
HMap' ToDependent (Parameters model) tensors,
HMap' ToParameter tensors (Parameters model),
Castable (HList tensors) [TD.ATenTensor]
) =>
model ->
CppOptimizerState option (Parameters model) ->
Optim.Loss dev dtype ->
IO (model, CppOptimizerState option (Parameters model))
runStep :: forall option model (tensors :: [*]) (dev :: (DeviceType, Nat))
(dtype :: DType).
(CppOptimizer option, Parameterized model,
HMap' ToDependent (Parameters model) tensors,
HMap' ToParameter tensors (Parameters model),
Castable (HList tensors) [ATenTensor]) =>
model
-> CppOptimizerState option (Parameters model)
-> Loss dev dtype
-> IO (model, CppOptimizerState option (Parameters model))
runStep model
model CppOptimizerState option (Parameters model)
optim Loss dev dtype
loss = do
IO ()
performGC
CInt -> IO ()
mallocTrim CInt
0
forall option model (dev :: (DeviceType, Nat)) (dtype :: DType)
(lossShape :: [Nat]) (tensors :: [*]) res.
(CppOptimizer option, Parameterized model,
HMap' ToDependent (Parameters model) tensors,
HMap' ToParameter tensors (Parameters model),
Castable (HList tensors) [ATenTensor]) =>
model
-> CppOptimizerState option (Parameters model)
-> Tensor dev dtype lossShape
-> IO (model, CppOptimizerState option (Parameters model))
unsafeStep model
model CppOptimizerState option (Parameters model)
optim Loss dev dtype
loss