Safe Haskell | Safe-Inferred |
---|---|
Language | Haskell2010 |
Synopsis
- train :: forall m model input generatorDevice lossGradient lossLayout lossDataType lossDevice lossShape generatorOutputDevice. (MonadIO m, HasStateDict model, HasForward model input generatorDevice (Tensor lossGradient lossLayout lossDataType lossDevice lossShape) generatorOutputDevice, HasForward model input generatorOutputDevice (Tensor lossGradient lossLayout lossDataType lossDevice lossShape) generatorOutputDevice, SGetGeneratorDevice generatorDevice, SGetGeneratorDevice generatorOutputDevice, SGetGradient lossGradient, SGetShape lossShape, Catch (lossShape <+> 'Shape '[]), Catch (lossGradient <+> 'Gradient 'WithGradient)) => Optimizer model -> ModelSpec model -> ListT m input -> Generator generatorDevice -> m (Either (Generator generatorDevice) (Tensor ('Gradient 'WithoutGradient) lossLayout lossDataType lossDevice ('Shape '[]), Generator generatorOutputDevice))
- eval :: (MonadIO m, HasStateDict model, HasForward model input generatorDevice (Tensor lossGradient lossLayout lossDataType lossDevice lossShape) generatorOutputDevice, HasForward model input generatorOutputDevice (Tensor lossGradient lossLayout lossDataType lossDevice lossShape) generatorOutputDevice, SGetGradient lossGradient, SGetShape lossShape, Catch (lossShape <+> 'Shape '[]), Catch (lossGradient <+> 'Gradient 'WithoutGradient)) => model -> ListT m input -> Generator generatorDevice -> m (Either (Generator generatorDevice) (Tensor ('Gradient 'WithoutGradient) lossLayout lossDataType lossDevice ('Shape '[]), Generator generatorOutputDevice))
Documentation
:: forall m model input generatorDevice lossGradient lossLayout lossDataType lossDevice lossShape generatorOutputDevice. (MonadIO m, HasStateDict model, HasForward model input generatorDevice (Tensor lossGradient lossLayout lossDataType lossDevice lossShape) generatorOutputDevice, HasForward model input generatorOutputDevice (Tensor lossGradient lossLayout lossDataType lossDevice lossShape) generatorOutputDevice, SGetGeneratorDevice generatorDevice, SGetGeneratorDevice generatorOutputDevice, SGetGradient lossGradient, SGetShape lossShape, Catch (lossShape <+> 'Shape '[]), Catch (lossGradient <+> 'Gradient 'WithGradient)) | |
=> Optimizer model | optimizer for the model |
-> ModelSpec model | model specification |
-> ListT m input | stream of training examples |
-> Generator generatorDevice | random generator |
-> m (Either (Generator generatorDevice) (Tensor ('Gradient 'WithoutGradient) lossLayout lossDataType lossDevice ('Shape '[]), Generator generatorOutputDevice)) | returned is either the original generator or the average training loss and a new generator |
Train the model for one epoch.
:: (MonadIO m, HasStateDict model, HasForward model input generatorDevice (Tensor lossGradient lossLayout lossDataType lossDevice lossShape) generatorOutputDevice, HasForward model input generatorOutputDevice (Tensor lossGradient lossLayout lossDataType lossDevice lossShape) generatorOutputDevice, SGetGradient lossGradient, SGetShape lossShape, Catch (lossShape <+> 'Shape '[]), Catch (lossGradient <+> 'Gradient 'WithoutGradient)) | |
=> model | model |
-> ListT m input | stream of examples |
-> Generator generatorDevice | random generator |
-> m (Either (Generator generatorDevice) (Tensor ('Gradient 'WithoutGradient) lossLayout lossDataType lossDevice ('Shape '[]), Generator generatorOutputDevice)) | returned is either the original generator or the average evaluation loss and a new generator |
Evaluate the model on the given examples.