Safe Haskell | Safe-Inferred |
---|---|
Language | Haskell2010 |
Synopsis
- data AdamOptions = AdamOptions {}
- defaultAdamOptions :: AdamOptions
- data Optimizer model where
- UnsafeOptimizer :: forall model. {..} -> Optimizer model
- getStateDict :: forall model. Optimizer model -> IO StateDict
- getModel :: forall model. HasStateDict model => ModelSpec model -> Optimizer model -> IO model
- mkAdam :: forall model. HasStateDict model => AdamOptions -> model -> IO (Optimizer model)
- stepWithGenerator :: forall model generatorDevice lossGradient lossLayout lossDataType lossDevice lossShape generatorOutputDevice. (HasStateDict model, SGetGeneratorDevice generatorDevice, SGetGeneratorDevice generatorOutputDevice, Catch (lossShape <+> 'Shape '[]), Catch (lossGradient <+> 'Gradient 'WithGradient)) => Optimizer model -> ModelSpec model -> (model -> Generator generatorDevice -> IO (Tensor lossGradient lossLayout lossDataType lossDevice lossShape, Generator generatorOutputDevice)) -> Generator generatorDevice -> IO (Tensor lossGradient lossLayout lossDataType lossDevice lossShape, Generator generatorOutputDevice)
Documentation
defaultAdamOptions :: AdamOptions Source #
Default Adam options.
data Optimizer model where Source #
Optimizer data type.
UnsafeOptimizer | |
|
getStateDict :: forall model. Optimizer model -> IO StateDict Source #
Get the model state dictionary from an optimizer.
getModel :: forall model. HasStateDict model => ModelSpec model -> Optimizer model -> IO model Source #
Extract a model from an optimizer.
:: forall model. HasStateDict model | |
=> AdamOptions | Adam options |
-> model | initial model |
-> IO (Optimizer model) | Adam optimizer |
Create a new Adam optimizer from a model.
:: forall model generatorDevice lossGradient lossLayout lossDataType lossDevice lossShape generatorOutputDevice. (HasStateDict model, SGetGeneratorDevice generatorDevice, SGetGeneratorDevice generatorOutputDevice, Catch (lossShape <+> 'Shape '[]), Catch (lossGradient <+> 'Gradient 'WithGradient)) | |
=> Optimizer model | optimizer for the model |
-> ModelSpec model | model specification |
-> (model -> Generator generatorDevice -> IO (Tensor lossGradient lossLayout lossDataType lossDevice lossShape, Generator generatorOutputDevice)) | loss function to minimize |
-> Generator generatorDevice | random generator |
-> IO (Tensor lossGradient lossLayout lossDataType lossDevice lossShape, Generator generatorOutputDevice) | loss and updated generator |
Perform one step of optimization.