{-# LANGUAGE DataKinds #-}
module Torch.Internal.Managed.Optim where

import Foreign
import Foreign.C.String
import Foreign.C.Types
import Foreign.ForeignPtr.Unsafe
import Torch.Internal.Cast
import Torch.Internal.Class
import Torch.Internal.Objects
import Torch.Internal.Type
import qualified Torch.Internal.Unmanaged.Optim as Unmanaged
import Control.Concurrent.MVar (MVar(..), newEmptyMVar, putMVar, takeMVar)

-- optimizerWithAdam
--   :: CDouble
--   -> CDouble
--   -> CDouble
--   -> CDouble
--   -> CDouble
--   -> CBool
--   -> ForeignPtr TensorList
--   -> (ForeignPtr TensorList -> IO (ForeignPtr Tensor))
--   -> Int
--   -> IO (ForeignPtr TensorList)
-- optimizerWithAdam adamLr adamBetas0 adamBetas1 adamEps adamWeightDecay adamAmsgrad initParams loss numIter = _cast2 (\i n -> Unmanaged.optimizerWithAdam adamLr adamBetas0 adamBetas1 adamEps adamWeightDecay adamAmsgrad i (trans loss) n) initParams numIter
--   where
--     trans :: (ForeignPtr TensorList -> IO (ForeignPtr Tensor)) -> Ptr TensorList -> IO (Ptr Tensor)
--     trans func inputs = do
--       inputs' <- newForeignPtr_ inputs
--       ret <- func inputs'
--       return $ unsafeForeignPtrToPtr ret

adagrad
  :: CDouble
  -> CDouble
  -> CDouble
  -> CDouble
  -> CDouble
  -> ForeignPtr TensorList
  -> IO (ForeignPtr Optimizer)
adagrad :: CDouble
-> CDouble
-> CDouble
-> CDouble
-> CDouble
-> ForeignPtr TensorList
-> IO (ForeignPtr Optimizer)
adagrad = forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> IO y
_cast6 CDouble
-> CDouble
-> CDouble
-> CDouble
-> CDouble
-> Ptr TensorList
-> IO (Ptr Optimizer)
Unmanaged.adagrad

rmsprop
  :: CDouble
  -> CDouble
  -> CDouble
  -> CDouble
  -> CDouble
  -> CBool
  -> ForeignPtr TensorList
  -> IO (ForeignPtr Optimizer)
rmsprop :: CDouble
-> CDouble
-> CDouble
-> CDouble
-> CDouble
-> CBool
-> ForeignPtr TensorList
-> IO (ForeignPtr Optimizer)
rmsprop = forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable x6 cx6,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> IO y
_cast7 CDouble
-> CDouble
-> CDouble
-> CDouble
-> CDouble
-> CBool
-> Ptr TensorList
-> IO (Ptr Optimizer)
Unmanaged.rmsprop

sgd
  :: CDouble
  -> CDouble
  -> CDouble
  -> CDouble
  -> CBool
  -> ForeignPtr TensorList
  -> IO (ForeignPtr Optimizer)
sgd :: CDouble
-> CDouble
-> CDouble
-> CDouble
-> CBool
-> ForeignPtr TensorList
-> IO (ForeignPtr Optimizer)
sgd = forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> IO y
_cast6 CDouble
-> CDouble
-> CDouble
-> CDouble
-> CBool
-> Ptr TensorList
-> IO (Ptr Optimizer)
Unmanaged.sgd

adam
  :: CDouble
  -> CDouble
  -> CDouble
  -> CDouble
  -> CDouble
  -> CBool
  -> ForeignPtr TensorList
  -> IO (ForeignPtr Optimizer)
adam :: CDouble
-> CDouble
-> CDouble
-> CDouble
-> CDouble
-> CBool
-> ForeignPtr TensorList
-> IO (ForeignPtr Optimizer)
adam = forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable x6 cx6,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> IO y
_cast7 CDouble
-> CDouble
-> CDouble
-> CDouble
-> CDouble
-> CBool
-> Ptr TensorList
-> IO (Ptr Optimizer)
Unmanaged.adam

adamw
  :: CDouble
  -> CDouble
  -> CDouble
  -> CDouble
  -> CDouble
  -> CBool
  -> ForeignPtr TensorList
  -> IO (ForeignPtr Optimizer)
adamw :: CDouble
-> CDouble
-> CDouble
-> CDouble
-> CDouble
-> CBool
-> ForeignPtr TensorList
-> IO (ForeignPtr Optimizer)
adamw = forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable x6 cx6,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> IO y
_cast7 CDouble
-> CDouble
-> CDouble
-> CDouble
-> CDouble
-> CBool
-> Ptr TensorList
-> IO (Ptr Optimizer)
Unmanaged.adamw

lbfgs
  :: CDouble
  -> CInt
  -> CInt
  -> CDouble
  -> CDouble
  -> CInt
  -> Maybe (ForeignPtr StdString)
  -> ForeignPtr TensorList
  -> IO (ForeignPtr Optimizer)
lbfgs :: CDouble
-> CInt
-> CInt
-> CDouble
-> CDouble
-> CInt
-> Maybe (ForeignPtr StdString)
-> ForeignPtr TensorList
-> IO (ForeignPtr Optimizer)
lbfgs = forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 x7 cx7 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable x6 cx6, Castable x7 cx7,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> cx7 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> x7 -> IO y
_cast8 CDouble
-> CInt
-> CInt
-> CDouble
-> CDouble
-> CInt
-> Maybe (Ptr StdString)
-> Ptr TensorList
-> IO (Ptr Optimizer)
Unmanaged.lbfgs

getParams :: ForeignPtr Optimizer -> IO (ForeignPtr TensorList) 
getParams :: ForeignPtr Optimizer -> IO (ForeignPtr TensorList)
getParams = forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
_cast1 Ptr Optimizer -> IO (Ptr TensorList)
Unmanaged.getParams

step :: ForeignPtr Optimizer -> (ForeignPtr TensorList -> IO (ForeignPtr Tensor)) -> IO (ForeignPtr Tensor)
step :: ForeignPtr Optimizer
-> (ForeignPtr TensorList -> IO (ForeignPtr Tensor))
-> IO (ForeignPtr Tensor)
step ForeignPtr Optimizer
optimizer ForeignPtr TensorList -> IO (ForeignPtr Tensor)
loss = do
  MVar (ForeignPtr Tensor)
ref <- forall a. IO (MVar a)
newEmptyMVar
  ForeignPtr Tensor
ret <- forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 (\Ptr Optimizer
opt -> Ptr Optimizer
-> (Ptr TensorList -> IO (Ptr Tensor)) -> IO (Ptr Tensor)
Unmanaged.step Ptr Optimizer
opt (MVar (ForeignPtr Tensor)
-> (ForeignPtr TensorList -> IO (ForeignPtr Tensor))
-> Ptr TensorList
-> IO (Ptr Tensor)
trans MVar (ForeignPtr Tensor)
ref ForeignPtr TensorList -> IO (ForeignPtr Tensor)
loss)) ForeignPtr Optimizer
optimizer
  ForeignPtr Tensor
v <- forall a. MVar a -> IO a
takeMVar MVar (ForeignPtr Tensor)
ref
  forall a. ForeignPtr a -> IO ()
touchForeignPtr ForeignPtr Tensor
v
  forall (m :: * -> *) a. Monad m => a -> m a
return ForeignPtr Tensor
ret
  where
    trans :: MVar (ForeignPtr Tensor) -> (ForeignPtr TensorList -> IO (ForeignPtr Tensor)) -> Ptr TensorList -> IO (Ptr Tensor)
    trans :: MVar (ForeignPtr Tensor)
-> (ForeignPtr TensorList -> IO (ForeignPtr Tensor))
-> Ptr TensorList
-> IO (Ptr Tensor)
trans MVar (ForeignPtr Tensor)
ref ForeignPtr TensorList -> IO (ForeignPtr Tensor)
func Ptr TensorList
inputs = do
      ForeignPtr TensorList
inputs' <- forall a. Ptr a -> IO (ForeignPtr a)
newForeignPtr_ Ptr TensorList
inputs
      ForeignPtr Tensor
ret <- ForeignPtr TensorList -> IO (ForeignPtr Tensor)
func ForeignPtr TensorList
inputs'
      forall a. MVar a -> a -> IO ()
putMVar MVar (ForeignPtr Tensor)
ref ForeignPtr Tensor
ret
      forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a. ForeignPtr a -> Ptr a
unsafeForeignPtrToPtr ForeignPtr Tensor
ret

stepWithGenerator
  :: ForeignPtr Optimizer
  -> ForeignPtr Generator
  -> (ForeignPtr TensorList -> ForeignPtr Generator -> IO (ForeignPtr (StdTuple '(Tensor,Generator))))
  -> IO (ForeignPtr (StdTuple '(Tensor,Generator)))
stepWithGenerator :: ForeignPtr Optimizer
-> ForeignPtr Generator
-> (ForeignPtr TensorList
    -> ForeignPtr Generator
    -> IO (ForeignPtr (StdTuple '(Tensor, Generator))))
-> IO (ForeignPtr (StdTuple '(Tensor, Generator)))
stepWithGenerator ForeignPtr Optimizer
optimizer ForeignPtr Generator
generator ForeignPtr TensorList
-> ForeignPtr Generator
-> IO (ForeignPtr (StdTuple '(Tensor, Generator)))
loss = do
  MVar (ForeignPtr (StdTuple '(Tensor, Generator)))
ref <- forall a. IO (MVar a)
newEmptyMVar
  ForeignPtr (StdTuple '(Tensor, Generator))
ret <- forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 (\Ptr Optimizer
opt Ptr Generator
gen -> Ptr Optimizer
-> Ptr Generator
-> (Ptr TensorList
    -> Ptr Generator -> IO (Ptr (StdTuple '(Tensor, Generator))))
-> IO (Ptr (StdTuple '(Tensor, Generator)))
Unmanaged.stepWithGenerator Ptr Optimizer
opt Ptr Generator
gen (MVar (ForeignPtr (StdTuple '(Tensor, Generator)))
-> (ForeignPtr TensorList
    -> ForeignPtr Generator
    -> IO (ForeignPtr (StdTuple '(Tensor, Generator))))
-> Ptr TensorList
-> Ptr Generator
-> IO (Ptr (StdTuple '(Tensor, Generator)))
trans MVar (ForeignPtr (StdTuple '(Tensor, Generator)))
ref ForeignPtr TensorList
-> ForeignPtr Generator
-> IO (ForeignPtr (StdTuple '(Tensor, Generator)))
loss)) ForeignPtr Optimizer
optimizer ForeignPtr Generator
generator
  ForeignPtr (StdTuple '(Tensor, Generator))
v <- forall a. MVar a -> IO a
takeMVar MVar (ForeignPtr (StdTuple '(Tensor, Generator)))
ref
  forall a. ForeignPtr a -> IO ()
touchForeignPtr ForeignPtr (StdTuple '(Tensor, Generator))
v
  forall (m :: * -> *) a. Monad m => a -> m a
return ForeignPtr (StdTuple '(Tensor, Generator))
ret
  where
    trans
      :: MVar (ForeignPtr (StdTuple '(Tensor,Generator)))
      -> (ForeignPtr TensorList -> ForeignPtr Generator -> IO (ForeignPtr (StdTuple '(Tensor,Generator))))
      -> Ptr TensorList
      -> Ptr Generator
      -> IO (Ptr (StdTuple '(Tensor,Generator)))
    trans :: MVar (ForeignPtr (StdTuple '(Tensor, Generator)))
-> (ForeignPtr TensorList
    -> ForeignPtr Generator
    -> IO (ForeignPtr (StdTuple '(Tensor, Generator))))
-> Ptr TensorList
-> Ptr Generator
-> IO (Ptr (StdTuple '(Tensor, Generator)))
trans MVar (ForeignPtr (StdTuple '(Tensor, Generator)))
ref ForeignPtr TensorList
-> ForeignPtr Generator
-> IO (ForeignPtr (StdTuple '(Tensor, Generator)))
func Ptr TensorList
inputs Ptr Generator
generator = do
      ForeignPtr TensorList
inputs' <- forall a. Ptr a -> IO (ForeignPtr a)
newForeignPtr_ Ptr TensorList
inputs
      ForeignPtr Generator
generator' <- forall a. Ptr a -> IO (ForeignPtr a)
newForeignPtr_ Ptr Generator
generator
      ForeignPtr (StdTuple '(Tensor, Generator))
ret <- ForeignPtr TensorList
-> ForeignPtr Generator
-> IO (ForeignPtr (StdTuple '(Tensor, Generator)))
func ForeignPtr TensorList
inputs' ForeignPtr Generator
generator'
      forall a. MVar a -> a -> IO ()
putMVar MVar (ForeignPtr (StdTuple '(Tensor, Generator)))
ref ForeignPtr (StdTuple '(Tensor, Generator))
ret
      forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a. ForeignPtr a -> Ptr a
unsafeForeignPtrToPtr ForeignPtr (StdTuple '(Tensor, Generator))
ret


unsafeStep :: ForeignPtr Optimizer -> ForeignPtr Tensor -> IO (ForeignPtr TensorList)
unsafeStep :: ForeignPtr Optimizer
-> ForeignPtr Tensor -> IO (ForeignPtr TensorList)
unsafeStep = forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
_cast2 Ptr Optimizer -> Ptr Tensor -> IO (Ptr TensorList)
Unmanaged.unsafeStep

save :: ForeignPtr Optimizer -> ForeignPtr StdString -> IO ()
save :: ForeignPtr Optimizer -> ForeignPtr StdString -> IO ()
save = forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
_cast2 Ptr Optimizer -> Ptr StdString -> IO ()
Unmanaged.save

load :: ForeignPtr Optimizer -> ForeignPtr StdString -> IO ()
load :: ForeignPtr Optimizer -> ForeignPtr StdString -> IO ()
load = forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
_cast2 Ptr Optimizer -> Ptr StdString -> IO ()
Unmanaged.load