{-# LANGUAGE DataKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}

module Torch.GraduallyTyped.NN.Functional.Loss where

import Control.Monad.Catch (MonadThrow)
import Torch.GraduallyTyped.Prelude (Catch)
import Torch.GraduallyTyped.Shape.Type (Shape (..))
import Torch.GraduallyTyped.Tensor.Type (Tensor)
import Torch.GraduallyTyped.Unify (type (<+>), type (<|>))
import qualified Torch.Internal.Cast as ATen
import Torch.Internal.GC (unsafeThrowableIO)
import qualified Torch.Internal.Managed.Native as ATen

-- | Compute the mean squared error between two tensors.
mseLoss ::
  forall m gradient layout device dataType shape gradient' layout' device' dataType' shape'.
  (MonadThrow m, Catch (shape <+> shape')) =>
  -- | prediction tensor
  Tensor gradient layout device dataType shape ->
  -- | target tensor
  Tensor gradient' layout' device' dataType' shape' ->
  -- | output tensor
  m
    ( Tensor
        (gradient <|> gradient')
        (layout <+> layout')
        (device <+> device')
        (dataType <+> dataType')
        ('Shape '[])
    )
Tensor gradient layout device dataType shape
prediction mseLoss :: forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient' :: Gradient RequiresGradient)
       (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
       (dataType' :: DataType DType)
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)]).
(MonadThrow m, Catch (shape <+> shape')) =>
Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> m (Tensor
        (gradient <|> gradient')
        (layout <+> layout')
        (device <+> device')
        (dataType <+> dataType')
        ('Shape '[]))
`mseLoss` Tensor gradient' layout' device' dataType' shape'
target =
  forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$
    forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3
      ForeignPtr Tensor
-> ForeignPtr Tensor -> Int64 -> IO (ForeignPtr Tensor)
ATen.mse_loss_ttl
      Tensor gradient layout device dataType shape
prediction
      Tensor gradient' layout' device' dataType' shape'
target
      (Int
1 :: Int) -- reduce mean