{-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} module Torch.GraduallyTyped.NN.Loss where import GHC.Generics (Generic) import Torch.GraduallyTyped.NN.Class (HasForward (..), HasInitialize (..), HasStateDict (..), ModelSpec) import Torch.GraduallyTyped.NN.Functional.Loss (mseLoss) import Torch.GraduallyTyped.Prelude (Catch) import Torch.GraduallyTyped.Shape.Type (Shape (..)) import Torch.GraduallyTyped.Tensor.Type (Tensor) import Torch.GraduallyTyped.Unify (type (<+>), type (<|>)) data MSELoss = MSELoss deriving stock (MSELoss -> MSELoss -> Bool forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a /= :: MSELoss -> MSELoss -> Bool $c/= :: MSELoss -> MSELoss -> Bool == :: MSELoss -> MSELoss -> Bool $c== :: MSELoss -> MSELoss -> Bool Eq, Eq MSELoss MSELoss -> MSELoss -> Bool MSELoss -> MSELoss -> Ordering MSELoss -> MSELoss -> MSELoss forall a. Eq a -> (a -> a -> Ordering) -> (a -> a -> Bool) -> (a -> a -> Bool) -> (a -> a -> Bool) -> (a -> a -> Bool) -> (a -> a -> a) -> (a -> a -> a) -> Ord a min :: MSELoss -> MSELoss -> MSELoss $cmin :: MSELoss -> MSELoss -> MSELoss max :: MSELoss -> MSELoss -> MSELoss $cmax :: MSELoss -> MSELoss -> MSELoss >= :: MSELoss -> MSELoss -> Bool $c>= :: MSELoss -> MSELoss -> Bool > :: MSELoss -> MSELoss -> Bool $c> :: MSELoss -> MSELoss -> Bool <= :: MSELoss -> MSELoss -> Bool $c<= :: MSELoss -> MSELoss -> Bool < :: MSELoss -> MSELoss -> Bool $c< :: MSELoss -> MSELoss -> Bool compare :: MSELoss -> MSELoss -> Ordering $ccompare :: MSELoss -> MSELoss -> Ordering Ord, Int -> MSELoss -> ShowS [MSELoss] -> ShowS MSELoss -> String forall a. (Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a showList :: [MSELoss] -> ShowS $cshowList :: [MSELoss] -> ShowS show :: MSELoss -> String $cshow :: MSELoss -> String showsPrec :: Int -> MSELoss -> ShowS $cshowsPrec :: Int -> MSELoss -> ShowS Show, forall x. Rep MSELoss x -> MSELoss forall x. MSELoss -> Rep MSELoss x forall a. (forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a $cto :: forall x. Rep MSELoss x -> MSELoss $cfrom :: forall x. MSELoss -> Rep MSELoss x Generic) type instance ModelSpec MSELoss = MSELoss instance HasInitialize MSELoss generatorDevice MSELoss generatorDevice instance HasStateDict MSELoss instance ( Catch (predShape <+> targetShape), output ~ Tensor (predGradient <|> targetGradient) (predLayout <+> targetLayout) (predDevice <+> targetDevice) (predDataType <+> targetDataType) ('Shape '[]) ) => HasForward MSELoss ( Tensor predGradient predLayout predDevice predDataType predShape, Tensor targetGradient targetLayout targetDevice targetDataType targetShape ) generatorDevice output generatorDevice where forward :: forall (m :: * -> *). MonadThrow m => MSELoss -> (Tensor predGradient predLayout predDevice predDataType predShape, Tensor targetGradient targetLayout targetDevice targetDataType targetShape) -> Generator generatorDevice -> m (output, Generator generatorDevice) forward MSELoss MSELoss (Tensor predGradient predLayout predDevice predDataType predShape prediction, Tensor targetGradient targetLayout targetDevice targetDataType targetShape target) Generator generatorDevice g = do output loss <- Tensor predGradient predLayout predDevice predDataType predShape prediction forall (m :: * -> *) (gradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (gradient' :: Gradient RequiresGradient) (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat)) (dataType' :: DataType DType) (shape' :: Shape [Dim (Name Symbol) (Size Nat)]). (MonadThrow m, Catch (shape <+> shape')) => Tensor gradient layout device dataType shape -> Tensor gradient' layout' device' dataType' shape' -> m (Tensor (gradient <|> gradient') (layout <+> layout') (device <+> device') (dataType <+> dataType') ('Shape '[])) `mseLoss` Tensor targetGradient targetLayout targetDevice targetDataType targetShape target forall (f :: * -> *) a. Applicative f => a -> f a pure (output loss, Generator generatorDevice g)