{-# LANGUAGE DataKinds #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
module Torch.Internal.Unmanaged.Optim where
import Control.Exception.Safe (bracket)
import Foreign
import Foreign.C.String
import Foreign.C.Types
import Foreign.Ptr
import qualified Language.C.Inline.Context as C
import qualified Language.C.Inline.Cpp as C
import qualified Language.C.Inline.Cpp.Unsafe as C
import qualified Language.C.Types as C
import Torch.Internal.Type
import Torch.Internal.Unmanaged.Helper
C.context $ C.cppCtx <> mempty {C.ctxTypesTable = typeTable}
C.include "<vector>"
C.include "<tuple>"
C.include "<torch/types.h>"
C.include "<torch/optim.h>"
C.include "<torch/serialize.h>"
adagrad
:: CDouble
-> CDouble
-> CDouble
-> CDouble
-> CDouble
-> Ptr TensorList
-> IO (Ptr Optimizer)
adagrad :: CDouble
-> CDouble
-> CDouble
-> CDouble
-> CDouble
-> Ptr TensorList
-> IO (Ptr Optimizer)
adagrad CDouble
lr CDouble
lr_decay CDouble
weight_decay CDouble
initial_accumulator_value CDouble
eps Ptr TensorList
initParams =
[C.throwBlock| torch::optim::Optimizer* {
std::vector<at::Tensor>* init_params = $(std::vector<at::Tensor>* initParams);
std::vector<at::Tensor> params;
for(int i=0;i<init_params->size();i++){
params.push_back((*init_params)[i].detach().set_requires_grad(true));
}
auto options = torch::optim::AdagradOptions()
.lr($(double lr))
.lr_decay($(double lr_decay))
.weight_decay($(double weight_decay))
.initial_accumulator_value($(double initial_accumulator_value))
.eps($(double eps));
torch::optim::Adagrad* optimizer = new torch::optim::Adagrad(params, options);
optimizer->zero_grad();
return dynamic_cast<torch::optim::Optimizer*>(optimizer);
}|]
rmsprop
:: CDouble
-> CDouble
-> CDouble
-> CDouble
-> CDouble
-> CBool
-> Ptr TensorList
-> IO (Ptr Optimizer)
rmsprop :: CDouble
-> CDouble
-> CDouble
-> CDouble
-> CDouble
-> CBool
-> Ptr TensorList
-> IO (Ptr Optimizer)
rmsprop CDouble
lr CDouble
alpha CDouble
eps CDouble
weight_decay CDouble
momentum CBool
centered Ptr TensorList
initParams =
[C.throwBlock| torch::optim::Optimizer* {
std::vector<at::Tensor>* init_params = $(std::vector<at::Tensor>* initParams);
std::vector<at::Tensor> params;
for(int i=0;i<init_params->size();i++){
params.push_back((*init_params)[i].detach().set_requires_grad(true));
}
auto options = torch::optim::RMSpropOptions()
.lr($(double lr))
.alpha($(double alpha))
.eps($(double eps))
.weight_decay($(double weight_decay))
.momentum($(double momentum))
.centered($(bool centered));
torch::optim::RMSprop* optimizer = new torch::optim::RMSprop(params, options);
optimizer->zero_grad();
return dynamic_cast<torch::optim::Optimizer*>(optimizer);
}|]
sgd
:: CDouble
-> CDouble
-> CDouble
-> CDouble
-> CBool
-> Ptr TensorList
-> IO (Ptr Optimizer)
sgd :: CDouble
-> CDouble
-> CDouble
-> CDouble
-> CBool
-> Ptr TensorList
-> IO (Ptr Optimizer)
sgd CDouble
lr CDouble
momentum CDouble
dampening CDouble
weight_decay CBool
nesterov Ptr TensorList
initParams =
[C.throwBlock| torch::optim::Optimizer* {
std::vector<at::Tensor>* init_params = $(std::vector<at::Tensor>* initParams);
std::vector<at::Tensor> params;
for(int i=0;i<init_params->size();i++){
params.push_back((*init_params)[i].detach().set_requires_grad(true));
}
auto options = torch::optim::SGDOptions($(double lr))
.momentum($(double momentum))
.dampening($(double dampening))
.weight_decay($(double weight_decay))
.nesterov($(bool nesterov));
torch::optim::SGD* optimizer = new torch::optim::SGD(params, options);
optimizer->zero_grad();
return dynamic_cast<torch::optim::Optimizer*>(optimizer);
}|]
adam
:: CDouble
-> CDouble
-> CDouble
-> CDouble
-> CDouble
-> CBool
-> Ptr TensorList
-> IO (Ptr Optimizer)
adam :: CDouble
-> CDouble
-> CDouble
-> CDouble
-> CDouble
-> CBool
-> Ptr TensorList
-> IO (Ptr Optimizer)
adam CDouble
adamLr CDouble
adamBetas0 CDouble
adamBetas1 CDouble
adamEps CDouble
adamWeightDecay CBool
adamAmsgrad Ptr TensorList
initParams =
[C.throwBlock| torch::optim::Optimizer* {
std::vector<at::Tensor>* init_params = $(std::vector<at::Tensor>* initParams);
std::vector<at::Tensor> params;
for(int i=0;i<init_params->size();i++){
params.push_back((*init_params)[i].detach().set_requires_grad(true));
}
auto options = torch::optim::AdamOptions()
.lr($(double adamLr))
.betas(std::make_tuple($(double adamBetas0),$(double adamBetas1)))
.eps($(double adamEps))
.weight_decay($(double adamWeightDecay))
.amsgrad($(bool adamAmsgrad));
torch::optim::Adam* optimizer = new torch::optim::Adam(params, options);
optimizer->zero_grad();
return dynamic_cast<torch::optim::Optimizer*>(optimizer);
}|]
adamw
:: CDouble
-> CDouble
-> CDouble
-> CDouble
-> CDouble
-> CBool
-> Ptr TensorList
-> IO (Ptr Optimizer)
adamw :: CDouble
-> CDouble
-> CDouble
-> CDouble
-> CDouble
-> CBool
-> Ptr TensorList
-> IO (Ptr Optimizer)
adamw CDouble
adamLr CDouble
adamBetas0 CDouble
adamBetas1 CDouble
adamEps CDouble
adamWeightDecay CBool
adamAmsgrad Ptr TensorList
initParams =
[C.throwBlock| torch::optim::Optimizer* {
std::vector<at::Tensor>* init_params = $(std::vector<at::Tensor>* initParams);
std::vector<at::Tensor> params;
for(int i=0;i<init_params->size();i++){
params.push_back((*init_params)[i].detach().set_requires_grad(true));
}
auto options = torch::optim::AdamWOptions()
.lr($(double adamLr))
.betas(std::make_tuple($(double adamBetas0),$(double adamBetas1)))
.eps($(double adamEps))
.weight_decay($(double adamWeightDecay))
.amsgrad($(bool adamAmsgrad));
torch::optim::AdamW* optimizer = new torch::optim::AdamW(params, options);
optimizer->zero_grad();
return dynamic_cast<torch::optim::Optimizer*>(optimizer);
}|]
lbfgs
:: CDouble
-> CInt
-> CInt
-> CDouble
-> CDouble
-> CInt
-> Maybe (Ptr StdString)
-> Ptr TensorList
-> IO (Ptr Optimizer)
lbfgs :: CDouble
-> CInt
-> CInt
-> CDouble
-> CDouble
-> CInt
-> Maybe (Ptr StdString)
-> Ptr TensorList
-> IO (Ptr Optimizer)
lbfgs CDouble
lr CInt
max_iter CInt
max_eval CDouble
tolerance_grad CDouble
tolerance_change CInt
history_size Maybe (Ptr StdString)
Nothing Ptr TensorList
initParams =
[C.throwBlock| torch::optim::Optimizer* {
std::vector<at::Tensor>* init_params = $(std::vector<at::Tensor>* initParams);
std::vector<at::Tensor> params;
for(int i=0;i<init_params->size();i++){
params.push_back((*init_params)[i].detach().set_requires_grad(true));
}
auto options = torch::optim::LBFGSOptions()
.lr($(double lr))
.max_iter($(int max_iter))
.max_eval($(int max_eval))
.tolerance_grad($(double tolerance_grad))
.tolerance_change($(double tolerance_change))
.history_size($(int history_size));
torch::optim::LBFGS* optimizer = new torch::optim::LBFGS(params, options);
optimizer->zero_grad();
return dynamic_cast<torch::optim::Optimizer*>(optimizer);
}|]
lbfgs CDouble
lr CInt
max_iter CInt
max_eval CDouble
tolerance_grad CDouble
tolerance_change CInt
history_size (Just Ptr StdString
line_search_fn) Ptr TensorList
initParams =
[C.throwBlock| torch::optim::Optimizer* {
std::vector<at::Tensor>* init_params = $(std::vector<at::Tensor>* initParams);
std::vector<at::Tensor> params;
for(int i=0;i<init_params->size();i++){
params.push_back((*init_params)[i].detach().set_requires_grad(true));
}
auto options = torch::optim::LBFGSOptions()
.lr($(double lr))
.max_iter($(int max_iter))
.max_eval($(int max_eval))
.tolerance_grad($(double tolerance_grad))
.tolerance_change($(double tolerance_change))
.history_size($(int history_size))
.line_search_fn(*$(std::string* line_search_fn));
torch::optim::LBFGS* optimizer = new torch::optim::LBFGS(params, options);
optimizer->zero_grad();
return dynamic_cast<torch::optim::Optimizer*>(optimizer);
}|]
getParams :: Ptr Optimizer -> IO (Ptr TensorList)
getParams :: Ptr Optimizer -> IO (Ptr TensorList)
getParams Ptr Optimizer
optimizer =
[C.throwBlock| std::vector<at::Tensor>* {
return new std::vector<at::Tensor>($(torch::optim::Optimizer* optimizer)->param_groups().at(0).params());
}|]
step :: Ptr Optimizer -> (Ptr TensorList -> IO (Ptr Tensor)) -> IO (Ptr Tensor)
step :: Ptr Optimizer
-> (Ptr TensorList -> IO (Ptr Tensor)) -> IO (Ptr Tensor)
step Ptr Optimizer
optimizer Ptr TensorList -> IO (Ptr Tensor)
lossFunc =
forall (m :: * -> *) a b c.
MonadMask m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket
((Ptr () -> IO (Ptr ())) -> IO (FunPtr (Ptr () -> IO (Ptr ())))
callbackHelper Ptr () -> IO (Ptr ())
lossFunc')
forall a. FunPtr a -> IO ()
freeHaskellFunPtr
forall a b. (a -> b) -> a -> b
$ \FunPtr (Ptr () -> IO (Ptr ()))
funcPtr ->
[C.throwBlock| at::Tensor* {
auto tfunc = $(void* (*funcPtr)(void*));
auto optimizer = $(torch::optim::Optimizer* optimizer);
typedef at::Tensor* (*Func)(std::vector<at::Tensor>*);
auto func = (Func)tfunc;
auto v = optimizer->step([&]{
optimizer->zero_grad();
auto loss = func(&(optimizer->param_groups().at(0).params()));
loss->backward();
return *loss;
});
return new at::Tensor(v);
}|]
where
lossFunc' :: Ptr () -> IO (Ptr ())
lossFunc' :: Ptr () -> IO (Ptr ())
lossFunc' Ptr ()
params = forall a b. Ptr a -> Ptr b
castPtr forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr TensorList -> IO (Ptr Tensor)
lossFunc (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
params)
stepWithGenerator :: Ptr Optimizer -> Ptr Generator -> (Ptr TensorList -> Ptr Generator -> IO (Ptr (StdTuple '(Tensor,Generator)))) -> IO (Ptr (StdTuple '(Tensor,Generator)))
stepWithGenerator :: Ptr Optimizer
-> Ptr Generator
-> (Ptr TensorList
-> Ptr Generator -> IO (Ptr (StdTuple '(Tensor, Generator))))
-> IO (Ptr (StdTuple '(Tensor, Generator)))
stepWithGenerator Ptr Optimizer
optimizer Ptr Generator
generator Ptr TensorList
-> Ptr Generator -> IO (Ptr (StdTuple '(Tensor, Generator)))
lossFunc =
forall (m :: * -> *) a b c.
MonadMask m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket
((Ptr () -> Ptr () -> IO (Ptr ()))
-> IO (FunPtr (Ptr () -> Ptr () -> IO (Ptr ())))
callbackHelper2 Ptr () -> Ptr () -> IO (Ptr ())
lossFunc')
forall a. FunPtr a -> IO ()
freeHaskellFunPtr
forall a b. (a -> b) -> a -> b
$ \FunPtr (Ptr () -> Ptr () -> IO (Ptr ()))
funcPtr ->
[C.throwBlock| std::tuple<at::Tensor,at::Generator>* {
auto tfunc = $(void* (*funcPtr)(void*,void*));
auto optimizer = $(torch::optim::Optimizer* optimizer);
typedef std::tuple<at::Tensor,at::Generator>* (*Func)(std::vector<at::Tensor>*,at::Generator*);
auto generator = $(at::Generator* generator)->clone();
auto func = (Func)tfunc;
auto v = optimizer->step([&]{
optimizer->zero_grad();
auto lossWithGenerator = func(&(optimizer->param_groups().at(0).params()),&generator);
auto loss = std::get<0>(*lossWithGenerator);
generator = std::get<1>(*lossWithGenerator);
loss.backward();
return loss;
});
return new std::tuple<at::Tensor,at::Generator>(std::make_tuple(v,generator));
}|]
where
lossFunc' :: Ptr () -> Ptr () -> IO (Ptr ())
lossFunc' :: Ptr () -> Ptr () -> IO (Ptr ())
lossFunc' Ptr ()
params Ptr ()
generator = forall a b. Ptr a -> Ptr b
castPtr forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr TensorList
-> Ptr Generator -> IO (Ptr (StdTuple '(Tensor, Generator)))
lossFunc (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
params) (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
generator)
unsafeStep :: Ptr Optimizer -> Ptr Tensor -> IO (Ptr TensorList)
unsafeStep :: Ptr Optimizer -> Ptr Tensor -> IO (Ptr TensorList)
unsafeStep Ptr Optimizer
optimizer Ptr Tensor
loss =
[C.throwBlock| std::vector<at::Tensor>* {
auto optimizer = $(torch::optim::Optimizer* optimizer);
auto loss = $(at::Tensor* loss);
optimizer->zero_grad();
loss->backward();
optimizer->step();
return new std::vector<at::Tensor>(optimizer->param_groups().at(0).params());
}|]
save :: Ptr Optimizer -> Ptr StdString -> IO ()
save :: Ptr Optimizer -> Ptr StdString -> IO ()
save Ptr Optimizer
optimizer Ptr StdString
filename =
[C.throwBlock| void {
std::ofstream output(*$(std::string* filename));
torch::save(*$(torch::optim::Optimizer* optimizer),output);
}|]
load :: Ptr Optimizer -> Ptr StdString -> IO ()
load :: Ptr Optimizer -> Ptr StdString -> IO ()
load Ptr Optimizer
optimizer Ptr StdString
filename =
[C.throwBlock| void {
std::ifstream input(*$(std::string* filename));
torch::load(*$(torch::optim::Optimizer* optimizer),input);
}|]