{-# LANGUAGE DataKinds #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

module Torch.GraduallyTyped.NN.Functional.Sparse where

import GHC.Natural (Natural)
import GHC.TypeLits (Nat, Symbol, TypeError)
import System.IO.Unsafe (unsafePerformIO)
import Torch.GraduallyTyped.DType (DType (..), DataType (..))
import Torch.GraduallyTyped.Layout (LayoutType (..))
import Torch.GraduallyTyped.Prelude (Catch, Reverse)
import Torch.GraduallyTyped.Shape (Dim (..), Name, Shape (..), Size)
import Torch.GraduallyTyped.Tensor.Type (SGetLayout (..), Tensor)
import Torch.GraduallyTyped.Unify (type (<+>), type (<|>))
import Torch.Internal.Cast (cast5)
import qualified Torch.Internal.Managed.Native as ATen
import Type.Errors.Pretty (type (%), type (<>))

type EmbedDimsErrorMessage (embedDims :: [Dim (Name Symbol) (Size Nat)]) =
  "Cannot apply the embedding."
    % "The embedding weight tensor must have exactly two dimensions,"
    % "but the following dimensions were found:"
    % ""
    % "    " <> embedDims <> "."
    % ""

type family EmbeddingF (weightShape :: Shape [Dim (Name Symbol) (Size Nat)]) (inputShape :: Shape [Dim (Name Symbol) (Size Nat)]) :: Shape [Dim (Name Symbol) (Size Nat)] where
  EmbeddingF 'UncheckedShape _ = 'UncheckedShape
  EmbeddingF _ 'UncheckedShape = 'UncheckedShape
  EmbeddingF ('Shape '[_embedNumDim, embedDim]) ('Shape inputDims) = 'Shape (Reverse (embedDim ': Reverse inputDims))
  EmbeddingF ('Shape embedDims) _ = TypeError (EmbedDimsErrorMessage embedDims)

embedding ::
  forall gradient layout device dataType shape gradient' layout' device' dataType' shape'.
  (SGetLayout layout, Catch (dataType' <+> 'DataType 'Int64)) =>
  -- | padding index
  Maybe Natural ->
  -- | whether or not to scale gradients by the inverse of frequency of the words in the mini-batch
  Bool ->
  -- | weight
  Tensor gradient layout device dataType shape ->
  -- | input
  Tensor gradient' layout' device' dataType' shape' ->
  -- | output
  Tensor
    (gradient <|> gradient')
    (layout <+> layout')
    (device <+> device')
    dataType
    (EmbeddingF shape shape')
embedding :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient' :: Gradient RequiresGradient)
       (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
       (dataType' :: DataType DType)
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)]).
(SGetLayout layout, Catch (dataType' <+> 'DataType 'Int64)) =>
Maybe Nat
-> Bool
-> Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> Tensor
     (gradient <|> gradient')
     (layout <+> layout')
     (device <+> device')
     dataType
     (EmbeddingF shape shape')
embedding Maybe Nat
paddingIdx Bool
scaleGradByFreq Tensor gradient layout device dataType shape
weight Tensor gradient' layout' device' dataType' shape'
input =
  let isSparse :: Bool
isSparse = forall (layout :: Layout LayoutType)
       (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGetLayout layout =>
Tensor gradient layout device dataType shape -> LayoutType
getLayoutType Tensor gradient layout device dataType shape
weight forall a. Eq a => a -> a -> Bool
== LayoutType
Sparse
      Int
paddingIdx' :: Int = forall b a. b -> (a -> b) -> Maybe a -> b
maybe (-Int
1) forall a b. (Integral a, Num b) => a -> b
fromIntegral Maybe Nat
paddingIdx
   in forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> IO y
cast5 ForeignPtr Tensor
-> ForeignPtr Tensor
-> Int64
-> CBool
-> CBool
-> IO (ForeignPtr Tensor)
ATen.embedding_ttlbb Tensor gradient layout device dataType shape
weight Tensor gradient' layout' device' dataType' shape'
input Int
paddingIdx' Bool
scaleGradByFreq Bool
isSparse