{-# LANGUAGE DataKinds #-}
module Torch.Internal.Managed.Optim where
import Foreign
import Foreign.C.String
import Foreign.C.Types
import Foreign.ForeignPtr.Unsafe
import Torch.Internal.Cast
import Torch.Internal.Class
import Torch.Internal.Objects
import Torch.Internal.Type
import qualified Torch.Internal.Unmanaged.Optim as Unmanaged
import Control.Concurrent.MVar (MVar(..), newEmptyMVar, putMVar, takeMVar)
adagrad
:: CDouble
-> CDouble
-> CDouble
-> CDouble
-> CDouble
-> ForeignPtr TensorList
-> IO (ForeignPtr Optimizer)
adagrad :: CDouble
-> CDouble
-> CDouble
-> CDouble
-> CDouble
-> ForeignPtr TensorList
-> IO (ForeignPtr Optimizer)
adagrad = forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable x4 cx4, Castable x5 cx5, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> IO y
_cast6 CDouble
-> CDouble
-> CDouble
-> CDouble
-> CDouble
-> Ptr TensorList
-> IO (Ptr Optimizer)
Unmanaged.adagrad
rmsprop
:: CDouble
-> CDouble
-> CDouble
-> CDouble
-> CDouble
-> CBool
-> ForeignPtr TensorList
-> IO (ForeignPtr Optimizer)
rmsprop :: CDouble
-> CDouble
-> CDouble
-> CDouble
-> CDouble
-> CBool
-> ForeignPtr TensorList
-> IO (ForeignPtr Optimizer)
rmsprop = forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable x4 cx4, Castable x5 cx5, Castable x6 cx6,
Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> IO y
_cast7 CDouble
-> CDouble
-> CDouble
-> CDouble
-> CDouble
-> CBool
-> Ptr TensorList
-> IO (Ptr Optimizer)
Unmanaged.rmsprop
sgd
:: CDouble
-> CDouble
-> CDouble
-> CDouble
-> CBool
-> ForeignPtr TensorList
-> IO (ForeignPtr Optimizer)
sgd :: CDouble
-> CDouble
-> CDouble
-> CDouble
-> CBool
-> ForeignPtr TensorList
-> IO (ForeignPtr Optimizer)
sgd = forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable x4 cx4, Castable x5 cx5, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> IO y
_cast6 CDouble
-> CDouble
-> CDouble
-> CDouble
-> CBool
-> Ptr TensorList
-> IO (Ptr Optimizer)
Unmanaged.sgd
adam
:: CDouble
-> CDouble
-> CDouble
-> CDouble
-> CDouble
-> CBool
-> ForeignPtr TensorList
-> IO (ForeignPtr Optimizer)
adam :: CDouble
-> CDouble
-> CDouble
-> CDouble
-> CDouble
-> CBool
-> ForeignPtr TensorList
-> IO (ForeignPtr Optimizer)
adam = forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable x4 cx4, Castable x5 cx5, Castable x6 cx6,
Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> IO y
_cast7 CDouble
-> CDouble
-> CDouble
-> CDouble
-> CDouble
-> CBool
-> Ptr TensorList
-> IO (Ptr Optimizer)
Unmanaged.adam
adamw
:: CDouble
-> CDouble
-> CDouble
-> CDouble
-> CDouble
-> CBool
-> ForeignPtr TensorList
-> IO (ForeignPtr Optimizer)
adamw :: CDouble
-> CDouble
-> CDouble
-> CDouble
-> CDouble
-> CBool
-> ForeignPtr TensorList
-> IO (ForeignPtr Optimizer)
adamw = forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable x4 cx4, Castable x5 cx5, Castable x6 cx6,
Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> IO y
_cast7 CDouble
-> CDouble
-> CDouble
-> CDouble
-> CDouble
-> CBool
-> Ptr TensorList
-> IO (Ptr Optimizer)
Unmanaged.adamw
lbfgs
:: CDouble
-> CInt
-> CInt
-> CDouble
-> CDouble
-> CInt
-> Maybe (ForeignPtr StdString)
-> ForeignPtr TensorList
-> IO (ForeignPtr Optimizer)
lbfgs :: CDouble
-> CInt
-> CInt
-> CDouble
-> CDouble
-> CInt
-> Maybe (ForeignPtr StdString)
-> ForeignPtr TensorList
-> IO (ForeignPtr Optimizer)
lbfgs = forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 x7 cx7 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable x4 cx4, Castable x5 cx5, Castable x6 cx6, Castable x7 cx7,
Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> cx7 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> x7 -> IO y
_cast8 CDouble
-> CInt
-> CInt
-> CDouble
-> CDouble
-> CInt
-> Maybe (Ptr StdString)
-> Ptr TensorList
-> IO (Ptr Optimizer)
Unmanaged.lbfgs
getParams :: ForeignPtr Optimizer -> IO (ForeignPtr TensorList)
getParams :: ForeignPtr Optimizer -> IO (ForeignPtr TensorList)
getParams = forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
_cast1 Ptr Optimizer -> IO (Ptr TensorList)
Unmanaged.getParams
step :: ForeignPtr Optimizer -> (ForeignPtr TensorList -> IO (ForeignPtr Tensor)) -> IO (ForeignPtr Tensor)
step :: ForeignPtr Optimizer
-> (ForeignPtr TensorList -> IO (ForeignPtr Tensor))
-> IO (ForeignPtr Tensor)
step ForeignPtr Optimizer
optimizer ForeignPtr TensorList -> IO (ForeignPtr Tensor)
loss = do
MVar (ForeignPtr Tensor)
ref <- forall a. IO (MVar a)
newEmptyMVar
ForeignPtr Tensor
ret <- forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 (\Ptr Optimizer
opt -> Ptr Optimizer
-> (Ptr TensorList -> IO (Ptr Tensor)) -> IO (Ptr Tensor)
Unmanaged.step Ptr Optimizer
opt (MVar (ForeignPtr Tensor)
-> (ForeignPtr TensorList -> IO (ForeignPtr Tensor))
-> Ptr TensorList
-> IO (Ptr Tensor)
trans MVar (ForeignPtr Tensor)
ref ForeignPtr TensorList -> IO (ForeignPtr Tensor)
loss)) ForeignPtr Optimizer
optimizer
ForeignPtr Tensor
v <- forall a. MVar a -> IO a
takeMVar MVar (ForeignPtr Tensor)
ref
forall a. ForeignPtr a -> IO ()
touchForeignPtr ForeignPtr Tensor
v
forall (m :: * -> *) a. Monad m => a -> m a
return ForeignPtr Tensor
ret
where
trans :: MVar (ForeignPtr Tensor) -> (ForeignPtr TensorList -> IO (ForeignPtr Tensor)) -> Ptr TensorList -> IO (Ptr Tensor)
trans :: MVar (ForeignPtr Tensor)
-> (ForeignPtr TensorList -> IO (ForeignPtr Tensor))
-> Ptr TensorList
-> IO (Ptr Tensor)
trans MVar (ForeignPtr Tensor)
ref ForeignPtr TensorList -> IO (ForeignPtr Tensor)
func Ptr TensorList
inputs = do
ForeignPtr TensorList
inputs' <- forall a. Ptr a -> IO (ForeignPtr a)
newForeignPtr_ Ptr TensorList
inputs
ForeignPtr Tensor
ret <- ForeignPtr TensorList -> IO (ForeignPtr Tensor)
func ForeignPtr TensorList
inputs'
forall a. MVar a -> a -> IO ()
putMVar MVar (ForeignPtr Tensor)
ref ForeignPtr Tensor
ret
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a. ForeignPtr a -> Ptr a
unsafeForeignPtrToPtr ForeignPtr Tensor
ret
stepWithGenerator
:: ForeignPtr Optimizer
-> ForeignPtr Generator
-> (ForeignPtr TensorList -> ForeignPtr Generator -> IO (ForeignPtr (StdTuple '(Tensor,Generator))))
-> IO (ForeignPtr (StdTuple '(Tensor,Generator)))
stepWithGenerator :: ForeignPtr Optimizer
-> ForeignPtr Generator
-> (ForeignPtr TensorList
-> ForeignPtr Generator
-> IO (ForeignPtr (StdTuple '(Tensor, Generator))))
-> IO (ForeignPtr (StdTuple '(Tensor, Generator)))
stepWithGenerator ForeignPtr Optimizer
optimizer ForeignPtr Generator
generator ForeignPtr TensorList
-> ForeignPtr Generator
-> IO (ForeignPtr (StdTuple '(Tensor, Generator)))
loss = do
MVar (ForeignPtr (StdTuple '(Tensor, Generator)))
ref <- forall a. IO (MVar a)
newEmptyMVar
ForeignPtr (StdTuple '(Tensor, Generator))
ret <- forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 (\Ptr Optimizer
opt Ptr Generator
gen -> Ptr Optimizer
-> Ptr Generator
-> (Ptr TensorList
-> Ptr Generator -> IO (Ptr (StdTuple '(Tensor, Generator))))
-> IO (Ptr (StdTuple '(Tensor, Generator)))
Unmanaged.stepWithGenerator Ptr Optimizer
opt Ptr Generator
gen (MVar (ForeignPtr (StdTuple '(Tensor, Generator)))
-> (ForeignPtr TensorList
-> ForeignPtr Generator
-> IO (ForeignPtr (StdTuple '(Tensor, Generator))))
-> Ptr TensorList
-> Ptr Generator
-> IO (Ptr (StdTuple '(Tensor, Generator)))
trans MVar (ForeignPtr (StdTuple '(Tensor, Generator)))
ref ForeignPtr TensorList
-> ForeignPtr Generator
-> IO (ForeignPtr (StdTuple '(Tensor, Generator)))
loss)) ForeignPtr Optimizer
optimizer ForeignPtr Generator
generator
ForeignPtr (StdTuple '(Tensor, Generator))
v <- forall a. MVar a -> IO a
takeMVar MVar (ForeignPtr (StdTuple '(Tensor, Generator)))
ref
forall a. ForeignPtr a -> IO ()
touchForeignPtr ForeignPtr (StdTuple '(Tensor, Generator))
v
forall (m :: * -> *) a. Monad m => a -> m a
return ForeignPtr (StdTuple '(Tensor, Generator))
ret
where
trans
:: MVar (ForeignPtr (StdTuple '(Tensor,Generator)))
-> (ForeignPtr TensorList -> ForeignPtr Generator -> IO (ForeignPtr (StdTuple '(Tensor,Generator))))
-> Ptr TensorList
-> Ptr Generator
-> IO (Ptr (StdTuple '(Tensor,Generator)))
trans :: MVar (ForeignPtr (StdTuple '(Tensor, Generator)))
-> (ForeignPtr TensorList
-> ForeignPtr Generator
-> IO (ForeignPtr (StdTuple '(Tensor, Generator))))
-> Ptr TensorList
-> Ptr Generator
-> IO (Ptr (StdTuple '(Tensor, Generator)))
trans MVar (ForeignPtr (StdTuple '(Tensor, Generator)))
ref ForeignPtr TensorList
-> ForeignPtr Generator
-> IO (ForeignPtr (StdTuple '(Tensor, Generator)))
func Ptr TensorList
inputs Ptr Generator
generator = do
ForeignPtr TensorList
inputs' <- forall a. Ptr a -> IO (ForeignPtr a)
newForeignPtr_ Ptr TensorList
inputs
ForeignPtr Generator
generator' <- forall a. Ptr a -> IO (ForeignPtr a)
newForeignPtr_ Ptr Generator
generator
ForeignPtr (StdTuple '(Tensor, Generator))
ret <- ForeignPtr TensorList
-> ForeignPtr Generator
-> IO (ForeignPtr (StdTuple '(Tensor, Generator)))
func ForeignPtr TensorList
inputs' ForeignPtr Generator
generator'
forall a. MVar a -> a -> IO ()
putMVar MVar (ForeignPtr (StdTuple '(Tensor, Generator)))
ref ForeignPtr (StdTuple '(Tensor, Generator))
ret
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a. ForeignPtr a -> Ptr a
unsafeForeignPtrToPtr ForeignPtr (StdTuple '(Tensor, Generator))
ret
unsafeStep :: ForeignPtr Optimizer -> ForeignPtr Tensor -> IO (ForeignPtr TensorList)
unsafeStep :: ForeignPtr Optimizer
-> ForeignPtr Tensor -> IO (ForeignPtr TensorList)
unsafeStep = forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
_cast2 Ptr Optimizer -> Ptr Tensor -> IO (Ptr TensorList)
Unmanaged.unsafeStep
save :: ForeignPtr Optimizer -> ForeignPtr StdString -> IO ()
save :: ForeignPtr Optimizer -> ForeignPtr StdString -> IO ()
save = forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
_cast2 Ptr Optimizer -> Ptr StdString -> IO ()
Unmanaged.save
load :: ForeignPtr Optimizer -> ForeignPtr StdString -> IO ()
load :: ForeignPtr Optimizer -> ForeignPtr StdString -> IO ()
load = forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
_cast2 Ptr Optimizer -> Ptr StdString -> IO ()
Unmanaged.load