{-# LANGUAGE DataKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Torch.GraduallyTyped.NN.Functional.Loss where
import Control.Monad.Catch (MonadThrow)
import Torch.GraduallyTyped.Prelude (Catch)
import Torch.GraduallyTyped.Shape.Type (Shape (..))
import Torch.GraduallyTyped.Tensor.Type (Tensor)
import Torch.GraduallyTyped.Unify (type (<+>), type (<|>))
import qualified Torch.Internal.Cast as ATen
import Torch.Internal.GC (unsafeThrowableIO)
import qualified Torch.Internal.Managed.Native as ATen
mseLoss ::
forall m gradient layout device dataType shape gradient' layout' device' dataType' shape'.
(MonadThrow m, Catch (shape <+> shape')) =>
Tensor gradient layout device dataType shape ->
Tensor gradient' layout' device' dataType' shape' ->
m
( Tensor
(gradient <|> gradient')
(layout <+> layout')
(device <+> device')
(dataType <+> dataType')
('Shape '[])
)
Tensor gradient layout device dataType shape
prediction mseLoss :: forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(gradient' :: Gradient RequiresGradient)
(layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
(dataType' :: DataType DType)
(shape' :: Shape [Dim (Name Symbol) (Size Nat)]).
(MonadThrow m, Catch (shape <+> shape')) =>
Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> m (Tensor
(gradient <|> gradient')
(layout <+> layout')
(device <+> device')
(dataType <+> dataType')
('Shape '[]))
`mseLoss` Tensor gradient' layout' device' dataType' shape'
target =
forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$
forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3
ForeignPtr Tensor
-> ForeignPtr Tensor -> Int64 -> IO (ForeignPtr Tensor)
ATen.mse_loss_ttl
Tensor gradient layout device dataType shape
prediction
Tensor gradient' layout' device' dataType' shape'
target
(Int
1 :: Int)