{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

module Torch.Typed.Optim.CppOptim
  ( module Torch.Typed.Optim.CppOptim,
    AdagradOptions (..),
    AdamOptions (..),
    AdamwOptions (..),
    LbfgsOptions (..),
    RmspropOptions (..),
    SGDOptions (..),
  )
where

import Data.Default.Class
import Data.Foldable (for_)
import qualified Debug.Trace as Debug
import Foreign.ForeignPtr
import System.Mem (performGC)
import qualified Torch as TD
import Torch.HList
import Torch.Internal.Cast
import Torch.Internal.Class (Castable (..), CppObject (..), CppTuple2 (..), CppTuple3 (..), CppTuple4 (..))
import Torch.Internal.GC (mallocTrim)
import qualified Torch.Internal.Managed.Optim as LibTorch
import qualified Torch.Internal.Type as ATen
import Torch.Optim.CppOptim
  ( AdagradOptions (..),
    AdamOptions (..),
    AdamwOptions (..),
    LbfgsOptions (..),
    RmspropOptions (..),
    SGDOptions (..),
  )
import Torch.Typed.Autograd
import Torch.Typed.NN
import qualified Torch.Typed.Optim as Optim
import Torch.Typed.Parameter
import Torch.Typed.Tensor

type CppOptimizerRef = ForeignPtr ATen.Optimizer

data CppOptimizerState option (params :: [*])
  = CppOptimizerState option CppOptimizerRef

data ToParameter = ToParameter

instance Apply' ToParameter (Tensor dev dtype shape) (Parameter dev dtype shape) where
  apply' :: ToParameter -> Tensor dev dtype shape -> Parameter dev dtype shape
apply' ToParameter
_ (UnsafeMkTensor Tensor
tensor) = forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
IndependentTensor -> Parameter device dtype shape
UnsafeMkParameter forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor -> IndependentTensor
TD.IndependentTensor forall a b. (a -> b) -> a -> b
$ Tensor
tensor

class CppOptimizer option where
  initOptimizer ::
    forall model tensors.
    ( Parameterized model,
      HMap' ToDependent (Parameters model) tensors,
      Castable (HList tensors) [TD.ATenTensor]
    ) =>
    option ->
    model ->
    IO (CppOptimizerState option (Parameters model))

  unsafeStep ::
    forall model dev dtype lossShape tensors res.
    ( Parameterized model,
      HMap' ToDependent (Parameters model) tensors,
      HMap' ToParameter tensors (Parameters model),
      Castable (HList tensors) [TD.ATenTensor]
    ) =>
    model ->
    CppOptimizerState option (Parameters model) ->
    Tensor dev dtype lossShape ->
    IO (model, CppOptimizerState option (Parameters model))
  unsafeStep model
model o :: CppOptimizerState option (Parameters model)
o@(CppOptimizerState option
_ CppOptimizerRef
optimizer) Tensor dev dtype lossShape
loss = do
    -- let deps :: HList tensors
    --    deps = hmap' ToDependent $ flattenParameters model

    -- Debug.traceIO $ "Tensors in: "
    -- cast deps (Debug.traceIO . show . map (TD.shape . TD.Unsafe))
    [ATenTensor]
v :: [TD.ATenTensor] <- forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 CppOptimizerRef -> ATenTensor -> IO (ForeignPtr TensorList)
LibTorch.unsafeStep CppOptimizerRef
optimizer Tensor dev dtype lossShape
loss
    -- Debug.traceIO $ "Params returned by unsafeStep: "<>show (length v)

    HList tensors
newParamTensors :: HList tensors <- forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast [ATenTensor]
v forall (f :: * -> *) a. Applicative f => a -> f a
pure
    -- Debug.traceIO $ "Tensors out: "
    -- cast newParamTensors (Debug.traceIO . show . map (TD.shape . TD.Unsafe))
    let newParams :: HList (Parameters model)
newParams = forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
hmap' ToParameter
ToParameter HList tensors
newParamTensors
    let newModel :: model
newModel = forall f. Parameterized f => f -> HList (Parameters f) -> f
replaceParameters model
model HList (Parameters model)
newParams
    forall (m :: * -> *) a. Monad m => a -> m a
return (model
newModel, CppOptimizerState option (Parameters model)
o)

instance CppOptimizer AdamOptions where
  initOptimizer :: forall model (tensors :: [*]).
(Parameterized model, HMap' ToDependent (Parameters model) tensors,
 Castable (HList tensors) [ATenTensor]) =>
AdamOptions
-> model -> IO (CppOptimizerState AdamOptions (Parameters model))
initOptimizer opt :: AdamOptions
opt@AdamOptions {Bool
Double
(Double, Double)
adamAmsgrad :: AdamOptions -> Bool
adamWeightDecay :: AdamOptions -> Double
adamEps :: AdamOptions -> Double
adamBetas :: AdamOptions -> (Double, Double)
adamLr :: AdamOptions -> Double
adamAmsgrad :: Bool
adamWeightDecay :: Double
adamEps :: Double
adamBetas :: (Double, Double)
adamLr :: Double
..} model
model = do
    CppOptimizerRef
v <-
      forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable x6 cx6,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> IO y
cast7
        CDouble
-> CDouble
-> CDouble
-> CDouble
-> CDouble
-> CBool
-> ForeignPtr TensorList
-> IO CppOptimizerRef
LibTorch.adam
        Double
adamLr
        (forall a b. (a, b) -> a
fst (Double, Double)
adamBetas)
        (forall a b. (a, b) -> b
snd (Double, Double)
adamBetas)
        Double
adamEps
        Double
adamWeightDecay
        Bool
adamAmsgrad
        HList tensors
initParams'
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall option (params :: [*]).
option -> CppOptimizerRef -> CppOptimizerState option params
CppOptimizerState AdamOptions
opt CppOptimizerRef
v
    where
      initParams' :: HList tensors
initParams' = forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
hmap' ToDependent
ToDependent forall a b. (a -> b) -> a -> b
$ forall f. Parameterized f => f -> HList (Parameters f)
flattenParameters model
model

instance CppOptimizer AdamwOptions where
  initOptimizer :: forall model (tensors :: [*]).
(Parameterized model, HMap' ToDependent (Parameters model) tensors,
 Castable (HList tensors) [ATenTensor]) =>
AdamwOptions
-> model -> IO (CppOptimizerState AdamwOptions (Parameters model))
initOptimizer opt :: AdamwOptions
opt@AdamwOptions {Bool
Double
(Double, Double)
adamwAmsgrad :: AdamwOptions -> Bool
adamwWeightDecay :: AdamwOptions -> Double
adamwEps :: AdamwOptions -> Double
adamwBetas :: AdamwOptions -> (Double, Double)
adamwLr :: AdamwOptions -> Double
adamwAmsgrad :: Bool
adamwWeightDecay :: Double
adamwEps :: Double
adamwBetas :: (Double, Double)
adamwLr :: Double
..} model
model = do
    CppOptimizerRef
v <- forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable x6 cx6,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> IO y
cast7 CDouble
-> CDouble
-> CDouble
-> CDouble
-> CDouble
-> CBool
-> ForeignPtr TensorList
-> IO CppOptimizerRef
LibTorch.adamw Double
adamwLr (forall a b. (a, b) -> a
fst (Double, Double)
adamwBetas) (forall a b. (a, b) -> b
snd (Double, Double)
adamwBetas) Double
adamwEps Double
adamwWeightDecay Bool
adamwAmsgrad HList tensors
initParams'
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall option (params :: [*]).
option -> CppOptimizerRef -> CppOptimizerState option params
CppOptimizerState AdamwOptions
opt CppOptimizerRef
v
    where
      initParams' :: HList tensors
initParams' = forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
hmap' ToDependent
ToDependent forall a b. (a -> b) -> a -> b
$ forall f. Parameterized f => f -> HList (Parameters f)
flattenParameters model
model

instance CppOptimizer LbfgsOptions where
  initOptimizer :: forall model (tensors :: [*]).
(Parameterized model, HMap' ToDependent (Parameters model) tensors,
 Castable (HList tensors) [ATenTensor]) =>
LbfgsOptions
-> model -> IO (CppOptimizerState LbfgsOptions (Parameters model))
initOptimizer opt :: LbfgsOptions
opt@LbfgsOptions {Double
Int
Maybe String
lbfgsLineSearchFn :: LbfgsOptions -> Maybe String
lbfgsHistorySize :: LbfgsOptions -> Int
lbfgsToleranceChange :: LbfgsOptions -> Double
lbfgsToleranceGrad :: LbfgsOptions -> Double
lbfgsMaxEval :: LbfgsOptions -> Int
lbfgsMaxIter :: LbfgsOptions -> Int
lbfgsLr :: LbfgsOptions -> Double
lbfgsLineSearchFn :: Maybe String
lbfgsHistorySize :: Int
lbfgsToleranceChange :: Double
lbfgsToleranceGrad :: Double
lbfgsMaxEval :: Int
lbfgsMaxIter :: Int
lbfgsLr :: Double
..} model
model = do
    CppOptimizerRef
v <- forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 x7 cx7 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable x6 cx6, Castable x7 cx7,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> cx7 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> x7 -> IO y
cast8 CDouble
-> CInt
-> CInt
-> CDouble
-> CDouble
-> CInt
-> Maybe (ForeignPtr StdString)
-> ForeignPtr TensorList
-> IO CppOptimizerRef
LibTorch.lbfgs Double
lbfgsLr Int
lbfgsMaxIter Int
lbfgsMaxEval Double
lbfgsToleranceGrad Double
lbfgsToleranceChange Int
lbfgsHistorySize Maybe String
lbfgsLineSearchFn HList tensors
initParams'
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall option (params :: [*]).
option -> CppOptimizerRef -> CppOptimizerState option params
CppOptimizerState LbfgsOptions
opt CppOptimizerRef
v
    where
      initParams' :: HList tensors
initParams' = forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
hmap' ToDependent
ToDependent forall a b. (a -> b) -> a -> b
$ forall f. Parameterized f => f -> HList (Parameters f)
flattenParameters model
model

instance CppOptimizer RmspropOptions where
  initOptimizer :: forall model (tensors :: [*]).
(Parameterized model, HMap' ToDependent (Parameters model) tensors,
 Castable (HList tensors) [ATenTensor]) =>
RmspropOptions
-> model
-> IO (CppOptimizerState RmspropOptions (Parameters model))
initOptimizer opt :: RmspropOptions
opt@RmspropOptions {Bool
Double
rmspropCentered :: RmspropOptions -> Bool
rmspropMomentum :: RmspropOptions -> Double
rmspropWeightDecay :: RmspropOptions -> Double
rmspropEps :: RmspropOptions -> Double
rmspropAlpha :: RmspropOptions -> Double
rmspropLr :: RmspropOptions -> Double
rmspropCentered :: Bool
rmspropMomentum :: Double
rmspropWeightDecay :: Double
rmspropEps :: Double
rmspropAlpha :: Double
rmspropLr :: Double
..} model
model = do
    CppOptimizerRef
v <- forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable x6 cx6,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> IO y
cast7 CDouble
-> CDouble
-> CDouble
-> CDouble
-> CDouble
-> CBool
-> ForeignPtr TensorList
-> IO CppOptimizerRef
LibTorch.rmsprop Double
rmspropLr Double
rmspropAlpha Double
rmspropEps Double
rmspropWeightDecay Double
rmspropMomentum Bool
rmspropCentered HList tensors
initParams'
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall option (params :: [*]).
option -> CppOptimizerRef -> CppOptimizerState option params
CppOptimizerState RmspropOptions
opt CppOptimizerRef
v
    where
      initParams' :: HList tensors
initParams' = forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
hmap' ToDependent
ToDependent forall a b. (a -> b) -> a -> b
$ forall f. Parameterized f => f -> HList (Parameters f)
flattenParameters model
model

instance CppOptimizer SGDOptions where
  initOptimizer :: forall model (tensors :: [*]).
(Parameterized model, HMap' ToDependent (Parameters model) tensors,
 Castable (HList tensors) [ATenTensor]) =>
SGDOptions
-> model -> IO (CppOptimizerState SGDOptions (Parameters model))
initOptimizer opt :: SGDOptions
opt@SGDOptions {Bool
Double
sgdNesterov :: SGDOptions -> Bool
sgdWeightDecay :: SGDOptions -> Double
sgdDampening :: SGDOptions -> Double
sgdMomentum :: SGDOptions -> Double
sgdLr :: SGDOptions -> Double
sgdNesterov :: Bool
sgdWeightDecay :: Double
sgdDampening :: Double
sgdMomentum :: Double
sgdLr :: Double
..} model
model = do
    CppOptimizerRef
v <- forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> IO y
cast6 CDouble
-> CDouble
-> CDouble
-> CDouble
-> CBool
-> ForeignPtr TensorList
-> IO CppOptimizerRef
LibTorch.sgd Double
sgdLr Double
sgdMomentum Double
sgdDampening Double
sgdWeightDecay Bool
sgdNesterov HList tensors
initParams'
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall option (params :: [*]).
option -> CppOptimizerRef -> CppOptimizerState option params
CppOptimizerState SGDOptions
opt CppOptimizerRef
v
    where
      initParams' :: HList tensors
initParams' = forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
hmap' ToDependent
ToDependent forall a b. (a -> b) -> a -> b
$ forall f. Parameterized f => f -> HList (Parameters f)
flattenParameters model
model

runStep ::
  ( CppOptimizer option,
    Parameterized model,
    HMap' ToDependent (Parameters model) tensors,
    HMap' ToParameter tensors (Parameters model),
    Castable (HList tensors) [TD.ATenTensor]
  ) =>
  model ->
  CppOptimizerState option (Parameters model) ->
  Optim.Loss dev dtype ->
  IO (model, CppOptimizerState option (Parameters model))
runStep :: forall option model (tensors :: [*]) (dev :: (DeviceType, Nat))
       (dtype :: DType).
(CppOptimizer option, Parameterized model,
 HMap' ToDependent (Parameters model) tensors,
 HMap' ToParameter tensors (Parameters model),
 Castable (HList tensors) [ATenTensor]) =>
model
-> CppOptimizerState option (Parameters model)
-> Loss dev dtype
-> IO (model, CppOptimizerState option (Parameters model))
runStep model
model CppOptimizerState option (Parameters model)
optim Loss dev dtype
loss = do
  IO ()
performGC
  CInt -> IO ()
mallocTrim CInt
0
  forall option model (dev :: (DeviceType, Nat)) (dtype :: DType)
       (lossShape :: [Nat]) (tensors :: [*]) res.
(CppOptimizer option, Parameterized model,
 HMap' ToDependent (Parameters model) tensors,
 HMap' ToParameter tensors (Parameters model),
 Castable (HList tensors) [ATenTensor]) =>
model
-> CppOptimizerState option (Parameters model)
-> Tensor dev dtype lossShape
-> IO (model, CppOptimizerState option (Parameters model))
unsafeStep model
model CppOptimizerState option (Parameters model)
optim Loss dev dtype
loss