hasktorch-0.2.0.0: Functional differentiable programming in Haskell
Safe HaskellSafe-Inferred
LanguageHaskell2010

Torch.Typed.Optim

Synopsis

Documentation

type LearningRate device dtype = Tensor device dtype '[] Source #

type Loss device dtype = Tensor device dtype '[] Source #

data ZerosLike Source #

Constructors

ZerosLike 

Instances

Instances details
(parameter ~ Parameter device dtype shape, momentum ~ Tensor device dtype shape, TensorOptions shape dtype device) => Apply' ZerosLike parameter momentum Source # 
Instance details

Defined in Torch.Typed.Optim

Methods

apply' :: ZerosLike -> parameter -> momentum Source #

class Optimizer optim gradients tensors dtype device where Source #

Methods

step :: LearningRate device dtype -> HList gradients -> HList tensors -> optim -> (HList tensors, optim) Source #

Instances

Instances details
HZipWith (GDStep device dtype) tensors gradients tensors => Optimizer GD (gradients :: [k]) (tensors :: [k]) dtype device Source # 
Instance details

Defined in Torch.Typed.Optim

Methods

step :: LearningRate device dtype -> HList gradients -> HList tensors -> GD -> (HList tensors, GD) Source #

(HZipWith AdamMomentum1Update momenta gradients momenta, HZipWith AdamMomentum2Update momenta gradients momenta, HMap' AdamBiasAdjustment momenta momenta, HZipWith3 (AdamParameterUpdate device dtype) tensors momenta momenta tensors) => Optimizer (Adam momenta) (gradients :: [Type]) (tensors :: [Type]) dtype device Source # 
Instance details

Defined in Torch.Typed.Optim

Methods

step :: LearningRate device dtype -> HList gradients -> HList tensors -> Adam momenta -> (HList tensors, Adam momenta) Source #

(HZipWith3 (GDMStep device dtype) tensors gradients momenta gdmStep, HMap' AFst gdmStep tensors, HMap' ASnd gdmStep momenta) => Optimizer (GDM momenta) (gradients :: [Type]) (tensors :: [Type]) dtype device Source # 
Instance details

Defined in Torch.Typed.Optim

Methods

step :: LearningRate device dtype -> HList gradients -> HList tensors -> GDM momenta -> (HList tensors, GDM momenta) Source #

runStep :: forall model optim parameters gradients tensors dtype device. (Parameterized model, parameters ~ Parameters model, HasGrad (HList parameters) (HList gradients), tensors ~ gradients, HMap' ToDependent parameters tensors, Castable (HList gradients) [ATenTensor], Optimizer optim gradients tensors dtype device, HMapM' IO MakeIndependent tensors parameters) => model -> optim -> Loss device dtype -> LearningRate device dtype -> IO (model, optim) Source #

runStep' :: forall model optim parameters gradients tensors dtype device. (Parameterized model, parameters ~ Parameters model, tensors ~ gradients, HMap' ToDependent parameters tensors, Optimizer optim gradients tensors dtype device, HMapM' IO MakeIndependent tensors parameters) => model -> optim -> LearningRate device dtype -> HList gradients -> IO (model, optim) Source #

data GD Source #

Dummy state representation for GD Optimizer

Constructors

GD 

Instances

Instances details
Parameterized GD Source # 
Instance details

Defined in Torch.Typed.Optim

Associated Types

type Parameters GD :: [Type] Source #

HZipWith (GDStep device dtype) tensors gradients tensors => Optimizer GD (gradients :: [k]) (tensors :: [k]) dtype device Source # 
Instance details

Defined in Torch.Typed.Optim

Methods

step :: LearningRate device dtype -> HList gradients -> HList tensors -> GD -> (HList tensors, GD) Source #

type Parameters GD Source # 
Instance details

Defined in Torch.Typed.Optim

type Parameters GD = '[] :: [Type]

newtype GDStep device dtype Source #

Constructors

GDStep (LearningRate device dtype) 

Instances

Instances details
(parameter ~ Tensor device dtype shape, gradient ~ Tensor device dtype shape, shape ~ Broadcast ('[] :: [Nat]) shape, BasicArithmeticDTypeIsValid device dtype, KnownDevice device) => Apply' (GDStep device dtype) (parameter, gradient) parameter Source # 
Instance details

Defined in Torch.Typed.Optim

Methods

apply' :: GDStep device dtype -> (parameter, gradient) -> parameter Source #

gd :: forall gradients tensors dtype device. HZipWith (GDStep device dtype) tensors gradients tensors => LearningRate device dtype -> HList gradients -> HList tensors -> GD -> (HList tensors, GD) Source #

Gradient descent step with a dummy state variable

data GDM (momenta :: [Type]) Source #

State representation for GDM Optimizer

Constructors

GDM 

Fields

Instances

Instances details
(HZipWith3 (GDMStep device dtype) tensors gradients momenta gdmStep, HMap' AFst gdmStep tensors, HMap' ASnd gdmStep momenta) => Optimizer (GDM momenta) (gradients :: [Type]) (tensors :: [Type]) dtype device Source # 
Instance details

Defined in Torch.Typed.Optim

Methods

step :: LearningRate device dtype -> HList gradients -> HList tensors -> GDM momenta -> (HList tensors, GDM momenta) Source #

Parameterized (GDM momenta) Source # 
Instance details

Defined in Torch.Typed.Optim

Associated Types

type Parameters (GDM momenta) :: [Type] Source #

Methods

flattenParameters :: GDM momenta -> HList (Parameters (GDM momenta)) Source #

replaceParameters :: GDM momenta -> HList (Parameters (GDM momenta)) -> GDM momenta Source #

type Parameters (GDM momenta) Source # 
Instance details

Defined in Torch.Typed.Optim

type Parameters (GDM momenta) = momenta

mkGDM :: forall parameters momenta. HMap' ZerosLike parameters momenta => Float -> HList parameters -> GDM momenta Source #

data GDMStep device dtype Source #

Constructors

GDMStep Float (LearningRate device dtype) 

Instances

Instances details
(parameter ~ Tensor device dtype shape, gradient ~ Tensor device dtype shape, momentum ~ Tensor device dtype shape, shape ~ Broadcast ('[] :: [Nat]) shape, KnownDevice device, BasicArithmeticDTypeIsValid device dtype) => Apply' (GDMStep device dtype) (parameter, gradient, momentum) (parameter, momentum) Source # 
Instance details

Defined in Torch.Typed.Optim

Methods

apply' :: GDMStep device dtype -> (parameter, gradient, momentum) -> (parameter, momentum) Source #

gdm Source #

Arguments

:: forall gradients tensors momenta gdmStep dtype device. (HZipWith3 (GDMStep device dtype) tensors gradients momenta gdmStep, HMap' AFst gdmStep tensors, HMap' ASnd gdmStep momenta) 
=> LearningRate device dtype

learning rate

-> HList gradients

model parameter gradient tensors

-> HList tensors

model parameter tensors

-> GDM momenta

beta and model parameter momentum tensors

-> (HList tensors, GDM momenta)

returns updated parameters and momenta

gradient descent with momentum step

type AdamIter = Tensor '('CPU, 0) 'Int64 '[] Source #

data Adam (momenta :: [Type]) Source #

State representation for Adam Optimizer

Constructors

Adam 

Fields

Instances

Instances details
(HZipWith AdamMomentum1Update momenta gradients momenta, HZipWith AdamMomentum2Update momenta gradients momenta, HMap' AdamBiasAdjustment momenta momenta, HZipWith3 (AdamParameterUpdate device dtype) tensors momenta momenta tensors) => Optimizer (Adam momenta) (gradients :: [Type]) (tensors :: [Type]) dtype device Source # 
Instance details

Defined in Torch.Typed.Optim

Methods

step :: LearningRate device dtype -> HList gradients -> HList tensors -> Adam momenta -> (HList tensors, Adam momenta) Source #

HAppendFD momenta momenta (momenta ++ momenta) => Parameterized (Adam momenta) Source # 
Instance details

Defined in Torch.Typed.Optim

Associated Types

type Parameters (Adam momenta) :: [Type] Source #

Methods

flattenParameters :: Adam momenta -> HList (Parameters (Adam momenta)) Source #

replaceParameters :: Adam momenta -> HList (Parameters (Adam momenta)) -> Adam momenta Source #

type Parameters (Adam momenta) Source # 
Instance details

Defined in Torch.Typed.Optim

type Parameters (Adam momenta) = AdamIter ': (momenta ++ momenta)

mkAdam :: forall parameters momenta. HMap' ZerosLike parameters momenta => AdamIter -> Float -> Float -> HList parameters -> Adam momenta Source #

newtype AdamMomentum1Update Source #

Instances

Instances details
(gradient ~ Tensor device dtype shape, momentum1 ~ Tensor device dtype shape, KnownDevice device) => Apply' AdamMomentum1Update (momentum1, gradient) momentum1 Source #

decaying average of the first momenta

Instance details

Defined in Torch.Typed.Optim

Methods

apply' :: AdamMomentum1Update -> (momentum1, gradient) -> momentum1 Source #

newtype AdamMomentum2Update Source #

Instances

Instances details
(gradient ~ Tensor device dtype shape, momentum2 ~ Tensor device dtype shape, shape ~ Broadcast shape shape, KnownDevice device, BasicArithmeticDTypeIsValid device dtype) => Apply' AdamMomentum2Update (momentum2, gradient) momentum2 Source #

decaying average of the second momenta

Instance details

Defined in Torch.Typed.Optim

Methods

apply' :: AdamMomentum2Update -> (momentum2, gradient) -> momentum2 Source #

data AdamBiasAdjustment Source #

Instances

Instances details
(momentum ~ Tensor device dtype shape, KnownDevice device, KnownDType dtype, shape ~ Reverse (Reverse shape), BasicArithmeticDTypeIsValid device dtype) => Apply' AdamBiasAdjustment momentum momentum Source #

bias adjustment

Instance details

Defined in Torch.Typed.Optim

Methods

apply' :: AdamBiasAdjustment -> momentum -> momentum Source #

data AdamParameterUpdate device dtype Source #

Constructors

AdamParameterUpdate Float (LearningRate device dtype) 

Instances

Instances details
(parameter ~ Tensor device dtype shape, momentum ~ Tensor device dtype shape, shape ~ Broadcast ('[] :: [Nat]) shape, KnownDevice device, BasicArithmeticDTypeIsValid device dtype, StandardFloatingPointDTypeValidation device dtype) => Apply' (AdamParameterUpdate device dtype) (parameter, momentum, momentum) parameter Source #

parameter update

Instance details

Defined in Torch.Typed.Optim

Methods

apply' :: AdamParameterUpdate device dtype -> (parameter, momentum, momentum) -> parameter Source #

adam Source #

Arguments

:: forall gradients tensors momenta adamStep dtype device. (HZipWith AdamMomentum1Update momenta gradients momenta, HZipWith AdamMomentum2Update momenta gradients momenta, HMap' AdamBiasAdjustment momenta momenta, HZipWith3 (AdamParameterUpdate device dtype) tensors momenta momenta tensors) 
=> LearningRate device dtype

learning rate

-> HList gradients

model parameter gradient tensors

-> HList tensors

model parameter tensors

-> Adam momenta

adam parameters - beta1, beta2, momenta1, momenta2, iteration

-> (HList tensors, Adam momenta)

returns new parameters + updated adam parameters

Adam step