{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE RoleAnnotations #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Torch.GraduallyTyped.Optim where
import Control.Concurrent.STM.TVar (newTVarIO)
import Control.Monad.State (evalStateT, execStateT)
import qualified Data.Map as Map
import Foreign.ForeignPtr (ForeignPtr)
import Torch.GraduallyTyped.NN.Class (HasStateDict (..), ModelSpec, StateDict, StateDictKey)
import Torch.GraduallyTyped.Prelude (Catch)
import Torch.GraduallyTyped.Random (Generator (..), SGetGeneratorDevice, getGenPtr)
import Torch.GraduallyTyped.RequiresGradient (Gradient (Gradient), RequiresGradient (WithGradient))
import Torch.GraduallyTyped.Shape.Type (Shape (Shape))
import Torch.GraduallyTyped.Tensor.Type (Tensor (..))
import Torch.GraduallyTyped.Unify (type (<+>))
import qualified Torch.Internal.Cast as ATen
import qualified Torch.Internal.Class as ATen
import qualified Torch.Internal.Managed.Optim as ATen
import qualified Torch.Internal.Type as ATen
data AdamOptions = AdamOptions
{
AdamOptions -> Double
learningRate :: Double,
AdamOptions -> Double
beta1 :: Double,
AdamOptions -> Double
beta2 :: Double,
AdamOptions -> Double
epsilon :: Double,
AdamOptions -> Double
weightDecay :: Double,
AdamOptions -> Bool
amsgrad :: Bool
}
defaultAdamOptions :: AdamOptions
defaultAdamOptions :: AdamOptions
defaultAdamOptions =
AdamOptions
{ learningRate :: Double
learningRate = Double
0.001,
beta1 :: Double
beta1 = Double
0.9,
beta2 :: Double
beta2 = Double
0.999,
epsilon :: Double
epsilon = Double
1e-8,
weightDecay :: Double
weightDecay = Double
0.0,
amsgrad :: Bool
amsgrad = Bool
False
}
data Optimizer model where
UnsafeOptimizer ::
forall model.
{ forall model. Optimizer model -> [StateDictKey]
optimizerStateDictKeys :: [StateDictKey],
forall model. Optimizer model -> ForeignPtr Optimizer
optimizerPtr :: ForeignPtr ATen.Optimizer
} ->
Optimizer model
type role Optimizer nominal
getStateDict ::
forall model. Optimizer model -> IO StateDict
getStateDict :: forall model. Optimizer model -> IO StateDict
getStateDict UnsafeOptimizer {[StateDictKey]
ForeignPtr Optimizer
optimizerPtr :: ForeignPtr Optimizer
optimizerStateDictKeys :: [StateDictKey]
optimizerPtr :: forall model. Optimizer model -> ForeignPtr Optimizer
optimizerStateDictKeys :: forall model. Optimizer model -> [StateDictKey]
..} =
do
[ForeignPtr Tensor]
tPtrs :: [ForeignPtr ATen.Tensor] <- forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Optimizer -> IO (ForeignPtr TensorList)
ATen.getParams ForeignPtr Optimizer
optimizerPtr
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [StateDictKey]
optimizerStateDictKeys [ForeignPtr Tensor]
tPtrs
getModel ::
forall model. HasStateDict model => ModelSpec model -> Optimizer model -> IO model
getModel :: forall model.
HasStateDict model =>
ModelSpec model -> Optimizer model -> IO model
getModel ModelSpec model
modelSpec Optimizer model
optimizer = do
StateDict
stateDict <- forall model. Optimizer model -> IO StateDict
getStateDict Optimizer model
optimizer
forall a b c. (a -> b -> c) -> b -> a -> c
flip forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT StateDict
stateDict forall a b. (a -> b) -> a -> b
$ forall model (m :: * -> *).
(HasStateDict model, MonadIO m, MonadThrow m,
MonadState StateDict m) =>
ModelSpec model -> StateDictKey -> m model
fromStateDict ModelSpec model
modelSpec forall a. Monoid a => a
mempty
mkAdam ::
forall model.
HasStateDict model =>
AdamOptions ->
model ->
IO (Optimizer model)
mkAdam :: forall model.
HasStateDict model =>
AdamOptions -> model -> IO (Optimizer model)
mkAdam AdamOptions {Bool
Double
amsgrad :: Bool
weightDecay :: Double
epsilon :: Double
beta2 :: Double
beta1 :: Double
learningRate :: Double
amsgrad :: AdamOptions -> Bool
weightDecay :: AdamOptions -> Double
epsilon :: AdamOptions -> Double
beta2 :: AdamOptions -> Double
beta1 :: AdamOptions -> Double
learningRate :: AdamOptions -> Double
..} model
model = do
StateDict
stateDict <- forall a b c. (a -> b -> c) -> b -> a -> c
flip forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m s
execStateT forall k a. Map k a
Map.empty forall a b. (a -> b) -> a -> b
$ forall model (m :: * -> *).
(HasStateDict model, MonadThrow m, MonadState StateDict m) =>
StateDictKey -> model -> m ()
toStateDict forall a. Monoid a => a
mempty model
model
let ([StateDictKey]
stateDictKeys, [ForeignPtr Tensor]
tPtrs) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ forall k a. Map k a -> [(k, a)]
Map.toList StateDict
stateDict
forall model.
[StateDictKey] -> ForeignPtr Optimizer -> Optimizer model
UnsafeOptimizer [StateDictKey]
stateDictKeys
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable x4 cx4, Castable x5 cx5, Castable x6 cx6,
Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> IO y
ATen.cast7 CDouble
-> CDouble
-> CDouble
-> CDouble
-> CDouble
-> CBool
-> ForeignPtr TensorList
-> IO (ForeignPtr Optimizer)
ATen.adam Double
learningRate Double
beta1 Double
beta2 Double
epsilon Double
weightDecay Bool
amsgrad [ForeignPtr Tensor]
tPtrs
stepWithGenerator ::
forall model generatorDevice lossGradient lossLayout lossDataType lossDevice lossShape generatorOutputDevice.
( HasStateDict model,
SGetGeneratorDevice generatorDevice,
SGetGeneratorDevice generatorOutputDevice,
Catch (lossShape <+> 'Shape '[]),
Catch (lossGradient <+> 'Gradient 'WithGradient)
) =>
Optimizer model ->
ModelSpec model ->
( model ->
Generator generatorDevice ->
IO
( Tensor lossGradient lossLayout lossDataType lossDevice lossShape,
Generator generatorOutputDevice
)
) ->
Generator generatorDevice ->
IO (Tensor lossGradient lossLayout lossDataType lossDevice lossShape, Generator generatorOutputDevice)
stepWithGenerator :: forall model (generatorDevice :: Device (DeviceType Nat))
(lossGradient :: Gradient RequiresGradient)
(lossLayout :: Layout LayoutType)
(lossDataType :: Device (DeviceType Nat))
(lossDevice :: DataType DType)
(lossShape :: Shape [Dim (Name Symbol) (Size Nat)])
(generatorOutputDevice :: Device (DeviceType Nat)).
(HasStateDict model, SGetGeneratorDevice generatorDevice,
SGetGeneratorDevice generatorOutputDevice,
Catch (lossShape <+> 'Shape '[]),
Catch (lossGradient <+> 'Gradient 'WithGradient)) =>
Optimizer model
-> ModelSpec model
-> (model
-> Generator generatorDevice
-> IO
(Tensor lossGradient lossLayout lossDataType lossDevice lossShape,
Generator generatorOutputDevice))
-> Generator generatorDevice
-> IO
(Tensor lossGradient lossLayout lossDataType lossDevice lossShape,
Generator generatorOutputDevice)
stepWithGenerator UnsafeOptimizer {[StateDictKey]
ForeignPtr Optimizer
optimizerPtr :: ForeignPtr Optimizer
optimizerStateDictKeys :: [StateDictKey]
optimizerPtr :: forall model. Optimizer model -> ForeignPtr Optimizer
optimizerStateDictKeys :: forall model. Optimizer model -> [StateDictKey]
..} ModelSpec model
modelSpec model
-> Generator generatorDevice
-> IO
(Tensor lossGradient lossLayout lossDataType lossDevice lossShape,
Generator generatorOutputDevice)
lossFn (UnsafeGenerator TVar
(Either (SDevice generatorDevice, Word64) (ForeignPtr Generator))
tvar) =
do
ForeignPtr Generator
genPtr <- forall (device :: Device (DeviceType Nat)).
SGetGeneratorDevice device =>
TVar (Either (SDevice device, Word64) (ForeignPtr Generator))
-> IO (ForeignPtr Generator)
getGenPtr TVar
(Either (SDevice generatorDevice, Word64) (ForeignPtr Generator))
tvar
let rawLossFn :: ForeignPtr ATen.TensorList -> ForeignPtr ATen.Generator -> IO (ForeignPtr (ATen.StdTuple '(ATen.Tensor, ATen.Generator)))
rawLossFn :: ForeignPtr TensorList
-> ForeignPtr Generator
-> IO (ForeignPtr (StdTuple '(Tensor, Generator)))
rawLossFn ForeignPtr TensorList
tlPtr ForeignPtr Generator
genPtr'' = do
Generator generatorDevice
g'' <- forall (device :: Device (DeviceType Nat)).
TVar (Either (SDevice device, Word64) (ForeignPtr Generator))
-> Generator device
UnsafeGenerator forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. a -> IO (TVar a)
newTVarIO (forall a b. b -> Either a b
Right ForeignPtr Generator
genPtr'')
StateDict
stateDict' :: StateDict <- forall a b r. Castable a b => b -> (a -> IO r) -> IO r
ATen.uncast ForeignPtr TensorList
tlPtr (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. [a] -> [b] -> [(a, b)]
zip [StateDictKey]
optimizerStateDictKeys)
model
model <- forall a b c. (a -> b -> c) -> b -> a -> c
flip forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT StateDict
stateDict' forall a b. (a -> b) -> a -> b
$ forall model (m :: * -> *).
(HasStateDict model, MonadIO m, MonadThrow m,
MonadState StateDict m) =>
ModelSpec model -> StateDictKey -> m model
fromStateDict ModelSpec model
modelSpec forall a. Monoid a => a
mempty
(UnsafeTensor ForeignPtr Tensor
tPtr, UnsafeGenerator TVar
(Either
(SDevice generatorOutputDevice, Word64) (ForeignPtr Generator))
tvar''') <- model
-> Generator generatorDevice
-> IO
(Tensor lossGradient lossLayout lossDataType lossDevice lossShape,
Generator generatorOutputDevice)
lossFn model
model Generator generatorDevice
g''
ForeignPtr Generator
genPtr''' <- forall (device :: Device (DeviceType Nat)).
SGetGeneratorDevice device =>
TVar (Either (SDevice device, Word64) (ForeignPtr Generator))
-> IO (ForeignPtr Generator)
getGenPtr TVar
(Either
(SDevice generatorOutputDevice, Word64) (ForeignPtr Generator))
tvar'''
forall a b r. Castable a b => a -> (b -> IO r) -> IO r
ATen.cast (ForeignPtr Tensor
tPtr, ForeignPtr Generator
genPtr''') forall (f :: * -> *) a. Applicative f => a -> f a
pure
(ForeignPtr Tensor
lossPtr, ForeignPtr Generator
genPtr' :: ForeignPtr ATen.Generator) <- forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Optimizer
-> ForeignPtr Generator
-> (ForeignPtr TensorList
-> ForeignPtr Generator
-> IO (ForeignPtr (StdTuple '(Tensor, Generator))))
-> IO (ForeignPtr (StdTuple '(Tensor, Generator)))
ATen.stepWithGenerator ForeignPtr Optimizer
optimizerPtr ForeignPtr Generator
genPtr ForeignPtr TensorList
-> ForeignPtr Generator
-> IO (ForeignPtr (StdTuple '(Tensor, Generator)))
rawLossFn
TVar
(Either
(SDevice generatorOutputDevice, Word64) (ForeignPtr Generator))
g' <- forall a. a -> IO (TVar a)
newTVarIO (forall a b. b -> Either a b
Right ForeignPtr Generator
genPtr')
forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
ForeignPtr Tensor -> Tensor gradient layout device dataType shape
UnsafeTensor ForeignPtr Tensor
lossPtr, forall (device :: Device (DeviceType Nat)).
TVar (Either (SDevice device, Word64) (ForeignPtr Generator))
-> Generator device
UnsafeGenerator TVar
(Either
(SDevice generatorOutputDevice, Word64) (ForeignPtr Generator))
g')