Safe Haskell | Safe-Inferred |
---|---|
Language | Haskell2010 |
Synopsis
- type LearningRate device dtype = Tensor device dtype '[]
- type Loss device dtype = Tensor device dtype '[]
- data ZerosLike = ZerosLike
- class Optimizer optim gradients tensors dtype device where
- step :: LearningRate device dtype -> HList gradients -> HList tensors -> optim -> (HList tensors, optim)
- runStep :: forall model optim parameters gradients tensors dtype device. (Parameterized model, parameters ~ Parameters model, HasGrad (HList parameters) (HList gradients), tensors ~ gradients, HMap' ToDependent parameters tensors, Castable (HList gradients) [ATenTensor], Optimizer optim gradients tensors dtype device, HMapM' IO MakeIndependent tensors parameters) => model -> optim -> Loss device dtype -> LearningRate device dtype -> IO (model, optim)
- runStep' :: forall model optim parameters gradients tensors dtype device. (Parameterized model, parameters ~ Parameters model, tensors ~ gradients, HMap' ToDependent parameters tensors, Optimizer optim gradients tensors dtype device, HMapM' IO MakeIndependent tensors parameters) => model -> optim -> LearningRate device dtype -> HList gradients -> IO (model, optim)
- data GD = GD
- mkGD :: GD
- newtype GDStep device dtype = GDStep (LearningRate device dtype)
- gd :: forall gradients tensors dtype device. HZipWith (GDStep device dtype) tensors gradients tensors => LearningRate device dtype -> HList gradients -> HList tensors -> GD -> (HList tensors, GD)
- data GDM (momenta :: [Type]) = GDM {}
- mkGDM :: forall parameters momenta. HMap' ZerosLike parameters momenta => Float -> HList parameters -> GDM momenta
- data GDMStep device dtype = GDMStep Float (LearningRate device dtype)
- gdm :: forall gradients tensors momenta gdmStep dtype device. (HZipWith3 (GDMStep device dtype) tensors gradients momenta gdmStep, HMap' AFst gdmStep tensors, HMap' ASnd gdmStep momenta) => LearningRate device dtype -> HList gradients -> HList tensors -> GDM momenta -> (HList tensors, GDM momenta)
- type AdamIter = Tensor '('CPU, 0) 'Int64 '[]
- data Adam (momenta :: [Type]) = Adam {}
- mkAdam :: forall parameters momenta. HMap' ZerosLike parameters momenta => AdamIter -> Float -> Float -> HList parameters -> Adam momenta
- newtype AdamMomentum1Update = AdamMomentum1Update Float
- newtype AdamMomentum2Update = AdamMomentum2Update Float
- data AdamBiasAdjustment = AdamBiasAdjustment AdamIter Float
- data AdamParameterUpdate device dtype = AdamParameterUpdate Float (LearningRate device dtype)
- adam :: forall gradients tensors momenta adamStep dtype device. (HZipWith AdamMomentum1Update momenta gradients momenta, HZipWith AdamMomentum2Update momenta gradients momenta, HMap' AdamBiasAdjustment momenta momenta, HZipWith3 (AdamParameterUpdate device dtype) tensors momenta momenta tensors) => LearningRate device dtype -> HList gradients -> HList tensors -> Adam momenta -> (HList tensors, Adam momenta)
Documentation
type LearningRate device dtype = Tensor device dtype '[] Source #
class Optimizer optim gradients tensors dtype device where Source #
step :: LearningRate device dtype -> HList gradients -> HList tensors -> optim -> (HList tensors, optim) Source #
Instances
HZipWith (GDStep device dtype) tensors gradients tensors => Optimizer GD (gradients :: [k]) (tensors :: [k]) dtype device Source # | |
Defined in Torch.Typed.Optim | |
(HZipWith AdamMomentum1Update momenta gradients momenta, HZipWith AdamMomentum2Update momenta gradients momenta, HMap' AdamBiasAdjustment momenta momenta, HZipWith3 (AdamParameterUpdate device dtype) tensors momenta momenta tensors) => Optimizer (Adam momenta) (gradients :: [Type]) (tensors :: [Type]) dtype device Source # | |
Defined in Torch.Typed.Optim | |
(HZipWith3 (GDMStep device dtype) tensors gradients momenta gdmStep, HMap' AFst gdmStep tensors, HMap' ASnd gdmStep momenta) => Optimizer (GDM momenta) (gradients :: [Type]) (tensors :: [Type]) dtype device Source # | |
Defined in Torch.Typed.Optim |
runStep :: forall model optim parameters gradients tensors dtype device. (Parameterized model, parameters ~ Parameters model, HasGrad (HList parameters) (HList gradients), tensors ~ gradients, HMap' ToDependent parameters tensors, Castable (HList gradients) [ATenTensor], Optimizer optim gradients tensors dtype device, HMapM' IO MakeIndependent tensors parameters) => model -> optim -> Loss device dtype -> LearningRate device dtype -> IO (model, optim) Source #
runStep' :: forall model optim parameters gradients tensors dtype device. (Parameterized model, parameters ~ Parameters model, tensors ~ gradients, HMap' ToDependent parameters tensors, Optimizer optim gradients tensors dtype device, HMapM' IO MakeIndependent tensors parameters) => model -> optim -> LearningRate device dtype -> HList gradients -> IO (model, optim) Source #
Dummy state representation for GD Optimizer
Instances
Parameterized GD Source # | |
Defined in Torch.Typed.Optim type Parameters GD :: [Type] Source # flattenParameters :: GD -> HList (Parameters GD) Source # replaceParameters :: GD -> HList (Parameters GD) -> GD Source # | |
HZipWith (GDStep device dtype) tensors gradients tensors => Optimizer GD (gradients :: [k]) (tensors :: [k]) dtype device Source # | |
Defined in Torch.Typed.Optim | |
type Parameters GD Source # | |
Defined in Torch.Typed.Optim |
newtype GDStep device dtype Source #
GDStep (LearningRate device dtype) |
Instances
(parameter ~ Tensor device dtype shape, gradient ~ Tensor device dtype shape, shape ~ Broadcast ('[] :: [Nat]) shape, BasicArithmeticDTypeIsValid device dtype, KnownDevice device) => Apply' (GDStep device dtype) (parameter, gradient) parameter Source # | |
Defined in Torch.Typed.Optim |
gd :: forall gradients tensors dtype device. HZipWith (GDStep device dtype) tensors gradients tensors => LearningRate device dtype -> HList gradients -> HList tensors -> GD -> (HList tensors, GD) Source #
Gradient descent step with a dummy state variable
data GDM (momenta :: [Type]) Source #
State representation for GDM Optimizer
Instances
(HZipWith3 (GDMStep device dtype) tensors gradients momenta gdmStep, HMap' AFst gdmStep tensors, HMap' ASnd gdmStep momenta) => Optimizer (GDM momenta) (gradients :: [Type]) (tensors :: [Type]) dtype device Source # | |
Defined in Torch.Typed.Optim | |
Parameterized (GDM momenta) Source # | |
Defined in Torch.Typed.Optim type Parameters (GDM momenta) :: [Type] Source # flattenParameters :: GDM momenta -> HList (Parameters (GDM momenta)) Source # replaceParameters :: GDM momenta -> HList (Parameters (GDM momenta)) -> GDM momenta Source # | |
type Parameters (GDM momenta) Source # | |
Defined in Torch.Typed.Optim |
mkGDM :: forall parameters momenta. HMap' ZerosLike parameters momenta => Float -> HList parameters -> GDM momenta Source #
data GDMStep device dtype Source #
GDMStep Float (LearningRate device dtype) |
Instances
(parameter ~ Tensor device dtype shape, gradient ~ Tensor device dtype shape, momentum ~ Tensor device dtype shape, shape ~ Broadcast ('[] :: [Nat]) shape, KnownDevice device, BasicArithmeticDTypeIsValid device dtype) => Apply' (GDMStep device dtype) (parameter, gradient, momentum) (parameter, momentum) Source # | |
Defined in Torch.Typed.Optim |
:: forall gradients tensors momenta gdmStep dtype device. (HZipWith3 (GDMStep device dtype) tensors gradients momenta gdmStep, HMap' AFst gdmStep tensors, HMap' ASnd gdmStep momenta) | |
=> LearningRate device dtype | learning rate |
-> HList gradients | model parameter gradient tensors |
-> HList tensors | model parameter tensors |
-> GDM momenta | beta and model parameter momentum tensors |
-> (HList tensors, GDM momenta) | returns updated parameters and momenta |
gradient descent with momentum step
data Adam (momenta :: [Type]) Source #
State representation for Adam Optimizer
Instances
(HZipWith AdamMomentum1Update momenta gradients momenta, HZipWith AdamMomentum2Update momenta gradients momenta, HMap' AdamBiasAdjustment momenta momenta, HZipWith3 (AdamParameterUpdate device dtype) tensors momenta momenta tensors) => Optimizer (Adam momenta) (gradients :: [Type]) (tensors :: [Type]) dtype device Source # | |
Defined in Torch.Typed.Optim | |
HAppendFD momenta momenta (momenta ++ momenta) => Parameterized (Adam momenta) Source # | |
Defined in Torch.Typed.Optim type Parameters (Adam momenta) :: [Type] Source # flattenParameters :: Adam momenta -> HList (Parameters (Adam momenta)) Source # replaceParameters :: Adam momenta -> HList (Parameters (Adam momenta)) -> Adam momenta Source # | |
type Parameters (Adam momenta) Source # | |
Defined in Torch.Typed.Optim |
mkAdam :: forall parameters momenta. HMap' ZerosLike parameters momenta => AdamIter -> Float -> Float -> HList parameters -> Adam momenta Source #
newtype AdamMomentum1Update Source #
Instances
(gradient ~ Tensor device dtype shape, momentum1 ~ Tensor device dtype shape, KnownDevice device) => Apply' AdamMomentum1Update (momentum1, gradient) momentum1 Source # | decaying average of the first momenta |
Defined in Torch.Typed.Optim apply' :: AdamMomentum1Update -> (momentum1, gradient) -> momentum1 Source # |
newtype AdamMomentum2Update Source #
Instances
(gradient ~ Tensor device dtype shape, momentum2 ~ Tensor device dtype shape, shape ~ Broadcast shape shape, KnownDevice device, BasicArithmeticDTypeIsValid device dtype) => Apply' AdamMomentum2Update (momentum2, gradient) momentum2 Source # | decaying average of the second momenta |
Defined in Torch.Typed.Optim apply' :: AdamMomentum2Update -> (momentum2, gradient) -> momentum2 Source # |
data AdamBiasAdjustment Source #
Instances
(momentum ~ Tensor device dtype shape, KnownDevice device, KnownDType dtype, shape ~ Reverse (Reverse shape), BasicArithmeticDTypeIsValid device dtype) => Apply' AdamBiasAdjustment momentum momentum Source # | bias adjustment |
Defined in Torch.Typed.Optim apply' :: AdamBiasAdjustment -> momentum -> momentum Source # |
data AdamParameterUpdate device dtype Source #
AdamParameterUpdate Float (LearningRate device dtype) |
Instances
(parameter ~ Tensor device dtype shape, momentum ~ Tensor device dtype shape, shape ~ Broadcast ('[] :: [Nat]) shape, KnownDevice device, BasicArithmeticDTypeIsValid device dtype, StandardFloatingPointDTypeValidation device dtype) => Apply' (AdamParameterUpdate device dtype) (parameter, momentum, momentum) parameter Source # | parameter update |
Defined in Torch.Typed.Optim apply' :: AdamParameterUpdate device dtype -> (parameter, momentum, momentum) -> parameter Source # |
:: forall gradients tensors momenta adamStep dtype device. (HZipWith AdamMomentum1Update momenta gradients momenta, HZipWith AdamMomentum2Update momenta gradients momenta, HMap' AdamBiasAdjustment momenta momenta, HZipWith3 (AdamParameterUpdate device dtype) tensors momenta momenta tensors) | |
=> LearningRate device dtype | learning rate |
-> HList gradients | model parameter gradient tensors |
-> HList tensors | model parameter tensors |
-> Adam momenta | adam parameters - beta1, beta2, momenta1, momenta2, iteration |
-> (HList tensors, Adam momenta) | returns new parameters + updated adam parameters |
Adam step