{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

module Torch.GraduallyTyped.NN.Loss where

import GHC.Generics (Generic)
import Torch.GraduallyTyped.NN.Class (HasForward (..), HasInitialize (..), HasStateDict (..), ModelSpec)
import Torch.GraduallyTyped.NN.Functional.Loss (mseLoss)
import Torch.GraduallyTyped.Prelude (Catch)
import Torch.GraduallyTyped.Shape.Type (Shape (..))
import Torch.GraduallyTyped.Tensor.Type (Tensor)
import Torch.GraduallyTyped.Unify (type (<+>), type (<|>))

data MSELoss = MSELoss
  deriving stock (MSELoss -> MSELoss -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: MSELoss -> MSELoss -> Bool
$c/= :: MSELoss -> MSELoss -> Bool
== :: MSELoss -> MSELoss -> Bool
$c== :: MSELoss -> MSELoss -> Bool
Eq, Eq MSELoss
MSELoss -> MSELoss -> Bool
MSELoss -> MSELoss -> Ordering
MSELoss -> MSELoss -> MSELoss
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: MSELoss -> MSELoss -> MSELoss
$cmin :: MSELoss -> MSELoss -> MSELoss
max :: MSELoss -> MSELoss -> MSELoss
$cmax :: MSELoss -> MSELoss -> MSELoss
>= :: MSELoss -> MSELoss -> Bool
$c>= :: MSELoss -> MSELoss -> Bool
> :: MSELoss -> MSELoss -> Bool
$c> :: MSELoss -> MSELoss -> Bool
<= :: MSELoss -> MSELoss -> Bool
$c<= :: MSELoss -> MSELoss -> Bool
< :: MSELoss -> MSELoss -> Bool
$c< :: MSELoss -> MSELoss -> Bool
compare :: MSELoss -> MSELoss -> Ordering
$ccompare :: MSELoss -> MSELoss -> Ordering
Ord, Int -> MSELoss -> ShowS
[MSELoss] -> ShowS
MSELoss -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [MSELoss] -> ShowS
$cshowList :: [MSELoss] -> ShowS
show :: MSELoss -> String
$cshow :: MSELoss -> String
showsPrec :: Int -> MSELoss -> ShowS
$cshowsPrec :: Int -> MSELoss -> ShowS
Show, forall x. Rep MSELoss x -> MSELoss
forall x. MSELoss -> Rep MSELoss x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep MSELoss x -> MSELoss
$cfrom :: forall x. MSELoss -> Rep MSELoss x
Generic)

type instance ModelSpec MSELoss = MSELoss

instance HasInitialize MSELoss generatorDevice MSELoss generatorDevice

instance HasStateDict MSELoss

instance
  ( Catch (predShape <+> targetShape),
    output
      ~ Tensor
          (predGradient <|> targetGradient)
          (predLayout <+> targetLayout)
          (predDevice <+> targetDevice)
          (predDataType <+> targetDataType)
          ('Shape '[])
  ) =>
  HasForward
    MSELoss
    ( Tensor predGradient predLayout predDevice predDataType predShape,
      Tensor targetGradient targetLayout targetDevice targetDataType targetShape
    )
    generatorDevice
    output
    generatorDevice
  where
  forward :: forall (m :: * -> *).
MonadThrow m =>
MSELoss
-> (Tensor
      predGradient predLayout predDevice predDataType predShape,
    Tensor
      targetGradient
      targetLayout
      targetDevice
      targetDataType
      targetShape)
-> Generator generatorDevice
-> m (output, Generator generatorDevice)
forward MSELoss
MSELoss (Tensor predGradient predLayout predDevice predDataType predShape
prediction, Tensor
  targetGradient targetLayout targetDevice targetDataType targetShape
target) Generator generatorDevice
g = do
    output
loss <- Tensor predGradient predLayout predDevice predDataType predShape
prediction forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient' :: Gradient RequiresGradient)
       (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
       (dataType' :: DataType DType)
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)]).
(MonadThrow m, Catch (shape <+> shape')) =>
Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> m (Tensor
        (gradient <|> gradient')
        (layout <+> layout')
        (device <+> device')
        (dataType <+> dataType')
        ('Shape '[]))
`mseLoss` Tensor
  targetGradient targetLayout targetDevice targetDataType targetShape
target
    forall (f :: * -> *) a. Applicative f => a -> f a
pure (output
loss, Generator generatorDevice
g)