{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE RoleAnnotations #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}

module Torch.GraduallyTyped.Optim where

import Control.Concurrent.STM.TVar (newTVarIO)
import Control.Monad.State (evalStateT, execStateT)
import qualified Data.Map as Map
import Foreign.ForeignPtr (ForeignPtr)
import Torch.GraduallyTyped.NN.Class (HasStateDict (..), ModelSpec, StateDict, StateDictKey)
import Torch.GraduallyTyped.Prelude (Catch)
import Torch.GraduallyTyped.Random (Generator (..), SGetGeneratorDevice, getGenPtr)
import Torch.GraduallyTyped.RequiresGradient (Gradient (Gradient), RequiresGradient (WithGradient))
import Torch.GraduallyTyped.Shape.Type (Shape (Shape))
import Torch.GraduallyTyped.Tensor.Type (Tensor (..))
import Torch.GraduallyTyped.Unify (type (<+>))
import qualified Torch.Internal.Cast as ATen
import qualified Torch.Internal.Class as ATen
import qualified Torch.Internal.Managed.Optim as ATen
import qualified Torch.Internal.Type as ATen

-- | Options for the Adam optimizer.
data AdamOptions = AdamOptions
  { -- | learning rate
    AdamOptions -> Double
learningRate :: Double,
    -- | beta1
    AdamOptions -> Double
beta1 :: Double,
    -- | beta2
    AdamOptions -> Double
beta2 :: Double,
    -- | epsilon
    AdamOptions -> Double
epsilon :: Double,
    -- | weight decay
    AdamOptions -> Double
weightDecay :: Double,
    -- | use amsgrad
    AdamOptions -> Bool
amsgrad :: Bool
  }

-- | Default Adam options.
defaultAdamOptions :: AdamOptions
defaultAdamOptions :: AdamOptions
defaultAdamOptions =
  AdamOptions
    { learningRate :: Double
learningRate = Double
0.001,
      beta1 :: Double
beta1 = Double
0.9,
      beta2 :: Double
beta2 = Double
0.999,
      epsilon :: Double
epsilon = Double
1e-8,
      weightDecay :: Double
weightDecay = Double
0.0,
      amsgrad :: Bool
amsgrad = Bool
False
    }

-- | Optimizer data type.
data Optimizer model where
  UnsafeOptimizer ::
    forall model.
    { forall model. Optimizer model -> [StateDictKey]
optimizerStateDictKeys :: [StateDictKey],
      forall model. Optimizer model -> ForeignPtr Optimizer
optimizerPtr :: ForeignPtr ATen.Optimizer
    } ->
    Optimizer model

type role Optimizer nominal

-- | Get the model state dictionary from an optimizer.
getStateDict ::
  forall model. Optimizer model -> IO StateDict
getStateDict :: forall model. Optimizer model -> IO StateDict
getStateDict UnsafeOptimizer {[StateDictKey]
ForeignPtr Optimizer
optimizerPtr :: ForeignPtr Optimizer
optimizerStateDictKeys :: [StateDictKey]
optimizerPtr :: forall model. Optimizer model -> ForeignPtr Optimizer
optimizerStateDictKeys :: forall model. Optimizer model -> [StateDictKey]
..} =
  do
    [ForeignPtr Tensor]
tPtrs :: [ForeignPtr ATen.Tensor] <- forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Optimizer -> IO (ForeignPtr TensorList)
ATen.getParams ForeignPtr Optimizer
optimizerPtr
    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [StateDictKey]
optimizerStateDictKeys [ForeignPtr Tensor]
tPtrs

-- | Extract a model from an optimizer.
getModel ::
  forall model. HasStateDict model => ModelSpec model -> Optimizer model -> IO model
getModel :: forall model.
HasStateDict model =>
ModelSpec model -> Optimizer model -> IO model
getModel ModelSpec model
modelSpec Optimizer model
optimizer = do
  StateDict
stateDict <- forall model. Optimizer model -> IO StateDict
getStateDict Optimizer model
optimizer
  forall a b c. (a -> b -> c) -> b -> a -> c
flip forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT StateDict
stateDict forall a b. (a -> b) -> a -> b
$ forall model (m :: * -> *).
(HasStateDict model, MonadIO m, MonadThrow m,
 MonadState StateDict m) =>
ModelSpec model -> StateDictKey -> m model
fromStateDict ModelSpec model
modelSpec forall a. Monoid a => a
mempty

-- | Create a new Adam optimizer from a model.
mkAdam ::
  forall model.
  HasStateDict model =>
  -- | Adam options
  AdamOptions ->
  -- | initial model
  model ->
  -- | Adam optimizer
  IO (Optimizer model)
mkAdam :: forall model.
HasStateDict model =>
AdamOptions -> model -> IO (Optimizer model)
mkAdam AdamOptions {Bool
Double
amsgrad :: Bool
weightDecay :: Double
epsilon :: Double
beta2 :: Double
beta1 :: Double
learningRate :: Double
amsgrad :: AdamOptions -> Bool
weightDecay :: AdamOptions -> Double
epsilon :: AdamOptions -> Double
beta2 :: AdamOptions -> Double
beta1 :: AdamOptions -> Double
learningRate :: AdamOptions -> Double
..} model
model = do
  StateDict
stateDict <- forall a b c. (a -> b -> c) -> b -> a -> c
flip forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m s
execStateT forall k a. Map k a
Map.empty forall a b. (a -> b) -> a -> b
$ forall model (m :: * -> *).
(HasStateDict model, MonadThrow m, MonadState StateDict m) =>
StateDictKey -> model -> m ()
toStateDict forall a. Monoid a => a
mempty model
model
  let ([StateDictKey]
stateDictKeys, [ForeignPtr Tensor]
tPtrs) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ forall k a. Map k a -> [(k, a)]
Map.toList StateDict
stateDict
  forall model.
[StateDictKey] -> ForeignPtr Optimizer -> Optimizer model
UnsafeOptimizer [StateDictKey]
stateDictKeys
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable x6 cx6,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> IO y
ATen.cast7 CDouble
-> CDouble
-> CDouble
-> CDouble
-> CDouble
-> CBool
-> ForeignPtr TensorList
-> IO (ForeignPtr Optimizer)
ATen.adam Double
learningRate Double
beta1 Double
beta2 Double
epsilon Double
weightDecay Bool
amsgrad [ForeignPtr Tensor]
tPtrs

-- | Perform one step of optimization.
stepWithGenerator ::
  forall model generatorDevice lossGradient lossLayout lossDataType lossDevice lossShape generatorOutputDevice.
  ( HasStateDict model,
    SGetGeneratorDevice generatorDevice,
    SGetGeneratorDevice generatorOutputDevice,
    Catch (lossShape <+> 'Shape '[]),
    Catch (lossGradient <+> 'Gradient 'WithGradient)
  ) =>
  -- | optimizer for the model
  Optimizer model ->
  -- | model specification
  ModelSpec model ->
  -- | loss function to minimize
  ( model ->
    Generator generatorDevice ->
    IO
      ( Tensor lossGradient lossLayout lossDataType lossDevice lossShape,
        Generator generatorOutputDevice
      )
  ) ->
  -- | random generator
  Generator generatorDevice ->
  -- | loss and updated generator
  IO (Tensor lossGradient lossLayout lossDataType lossDevice lossShape, Generator generatorOutputDevice)
stepWithGenerator :: forall model (generatorDevice :: Device (DeviceType Nat))
       (lossGradient :: Gradient RequiresGradient)
       (lossLayout :: Layout LayoutType)
       (lossDataType :: Device (DeviceType Nat))
       (lossDevice :: DataType DType)
       (lossShape :: Shape [Dim (Name Symbol) (Size Nat)])
       (generatorOutputDevice :: Device (DeviceType Nat)).
(HasStateDict model, SGetGeneratorDevice generatorDevice,
 SGetGeneratorDevice generatorOutputDevice,
 Catch (lossShape <+> 'Shape '[]),
 Catch (lossGradient <+> 'Gradient 'WithGradient)) =>
Optimizer model
-> ModelSpec model
-> (model
    -> Generator generatorDevice
    -> IO
         (Tensor lossGradient lossLayout lossDataType lossDevice lossShape,
          Generator generatorOutputDevice))
-> Generator generatorDevice
-> IO
     (Tensor lossGradient lossLayout lossDataType lossDevice lossShape,
      Generator generatorOutputDevice)
stepWithGenerator UnsafeOptimizer {[StateDictKey]
ForeignPtr Optimizer
optimizerPtr :: ForeignPtr Optimizer
optimizerStateDictKeys :: [StateDictKey]
optimizerPtr :: forall model. Optimizer model -> ForeignPtr Optimizer
optimizerStateDictKeys :: forall model. Optimizer model -> [StateDictKey]
..} ModelSpec model
modelSpec model
-> Generator generatorDevice
-> IO
     (Tensor lossGradient lossLayout lossDataType lossDevice lossShape,
      Generator generatorOutputDevice)
lossFn (UnsafeGenerator TVar
  (Either (SDevice generatorDevice, Word64) (ForeignPtr Generator))
tvar) =
  do
    ForeignPtr Generator
genPtr <- forall (device :: Device (DeviceType Nat)).
SGetGeneratorDevice device =>
TVar (Either (SDevice device, Word64) (ForeignPtr Generator))
-> IO (ForeignPtr Generator)
getGenPtr TVar
  (Either (SDevice generatorDevice, Word64) (ForeignPtr Generator))
tvar
    let rawLossFn :: ForeignPtr ATen.TensorList -> ForeignPtr ATen.Generator -> IO (ForeignPtr (ATen.StdTuple '(ATen.Tensor, ATen.Generator)))
        rawLossFn :: ForeignPtr TensorList
-> ForeignPtr Generator
-> IO (ForeignPtr (StdTuple '(Tensor, Generator)))
rawLossFn ForeignPtr TensorList
tlPtr ForeignPtr Generator
genPtr'' = do
          Generator generatorDevice
g'' <- forall (device :: Device (DeviceType Nat)).
TVar (Either (SDevice device, Word64) (ForeignPtr Generator))
-> Generator device
UnsafeGenerator forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. a -> IO (TVar a)
newTVarIO (forall a b. b -> Either a b
Right ForeignPtr Generator
genPtr'')
          StateDict
stateDict' :: StateDict <- forall a b r. Castable a b => b -> (a -> IO r) -> IO r
ATen.uncast ForeignPtr TensorList
tlPtr (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. [a] -> [b] -> [(a, b)]
zip [StateDictKey]
optimizerStateDictKeys)
          model
model <- forall a b c. (a -> b -> c) -> b -> a -> c
flip forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT StateDict
stateDict' forall a b. (a -> b) -> a -> b
$ forall model (m :: * -> *).
(HasStateDict model, MonadIO m, MonadThrow m,
 MonadState StateDict m) =>
ModelSpec model -> StateDictKey -> m model
fromStateDict ModelSpec model
modelSpec forall a. Monoid a => a
mempty
          -- model <- getModel modelSpec optim
          (UnsafeTensor ForeignPtr Tensor
tPtr, UnsafeGenerator TVar
  (Either
     (SDevice generatorOutputDevice, Word64) (ForeignPtr Generator))
tvar''') <- model
-> Generator generatorDevice
-> IO
     (Tensor lossGradient lossLayout lossDataType lossDevice lossShape,
      Generator generatorOutputDevice)
lossFn model
model Generator generatorDevice
g''
          ForeignPtr Generator
genPtr''' <- forall (device :: Device (DeviceType Nat)).
SGetGeneratorDevice device =>
TVar (Either (SDevice device, Word64) (ForeignPtr Generator))
-> IO (ForeignPtr Generator)
getGenPtr TVar
  (Either
     (SDevice generatorOutputDevice, Word64) (ForeignPtr Generator))
tvar'''
          forall a b r. Castable a b => a -> (b -> IO r) -> IO r
ATen.cast (ForeignPtr Tensor
tPtr, ForeignPtr Generator
genPtr''') forall (f :: * -> *) a. Applicative f => a -> f a
pure
    (ForeignPtr Tensor
lossPtr, ForeignPtr Generator
genPtr' :: ForeignPtr ATen.Generator) <- forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Optimizer
-> ForeignPtr Generator
-> (ForeignPtr TensorList
    -> ForeignPtr Generator
    -> IO (ForeignPtr (StdTuple '(Tensor, Generator))))
-> IO (ForeignPtr (StdTuple '(Tensor, Generator)))
ATen.stepWithGenerator ForeignPtr Optimizer
optimizerPtr ForeignPtr Generator
genPtr ForeignPtr TensorList
-> ForeignPtr Generator
-> IO (ForeignPtr (StdTuple '(Tensor, Generator)))
rawLossFn
    TVar
  (Either
     (SDevice generatorOutputDevice, Word64) (ForeignPtr Generator))
g' <- forall a. a -> IO (TVar a)
newTVarIO (forall a b. b -> Either a b
Right ForeignPtr Generator
genPtr')
    forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
ForeignPtr Tensor -> Tensor gradient layout device dataType shape
UnsafeTensor ForeignPtr Tensor
lossPtr, forall (device :: Device (DeviceType Nat)).
TVar (Either (SDevice device, Word64) (ForeignPtr Generator))
-> Generator device
UnsafeGenerator TVar
  (Either
     (SDevice generatorOutputDevice, Word64) (ForeignPtr Generator))
g')