{-# LANGUAGE DataKinds #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module Torch.GraduallyTyped.NN.Functional.Sparse where
import GHC.Natural (Natural)
import GHC.TypeLits (Nat, Symbol, TypeError)
import System.IO.Unsafe (unsafePerformIO)
import Torch.GraduallyTyped.DType (DType (..), DataType (..))
import Torch.GraduallyTyped.Layout (LayoutType (..))
import Torch.GraduallyTyped.Prelude (Catch, Reverse)
import Torch.GraduallyTyped.Shape (Dim (..), Name, Shape (..), Size)
import Torch.GraduallyTyped.Tensor.Type (SGetLayout (..), Tensor)
import Torch.GraduallyTyped.Unify (type (<+>), type (<|>))
import Torch.Internal.Cast (cast5)
import qualified Torch.Internal.Managed.Native as ATen
import Type.Errors.Pretty (type (%), type (<>))
type EmbedDimsErrorMessage (embedDims :: [Dim (Name Symbol) (Size Nat)]) =
"Cannot apply the embedding."
% "The embedding weight tensor must have exactly two dimensions,"
% "but the following dimensions were found:"
% ""
% " " <> embedDims <> "."
% ""
type family EmbeddingF (weightShape :: Shape [Dim (Name Symbol) (Size Nat)]) (inputShape :: Shape [Dim (Name Symbol) (Size Nat)]) :: Shape [Dim (Name Symbol) (Size Nat)] where
EmbeddingF 'UncheckedShape _ = 'UncheckedShape
EmbeddingF _ 'UncheckedShape = 'UncheckedShape
EmbeddingF ('Shape '[_embedNumDim, embedDim]) ('Shape inputDims) = 'Shape (Reverse (embedDim ': Reverse inputDims))
EmbeddingF ('Shape embedDims) _ = TypeError (EmbedDimsErrorMessage embedDims)
embedding ::
forall gradient layout device dataType shape gradient' layout' device' dataType' shape'.
(SGetLayout layout, Catch (dataType' <+> 'DataType 'Int64)) =>
Maybe Natural ->
Bool ->
Tensor gradient layout device dataType shape ->
Tensor gradient' layout' device' dataType' shape' ->
Tensor
(gradient <|> gradient')
(layout <+> layout')
(device <+> device')
dataType
(EmbeddingF shape shape')
embedding :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(gradient' :: Gradient RequiresGradient)
(layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
(dataType' :: DataType DType)
(shape' :: Shape [Dim (Name Symbol) (Size Nat)]).
(SGetLayout layout, Catch (dataType' <+> 'DataType 'Int64)) =>
Maybe Nat
-> Bool
-> Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> Tensor
(gradient <|> gradient')
(layout <+> layout')
(device <+> device')
dataType
(EmbeddingF shape shape')
embedding Maybe Nat
paddingIdx Bool
scaleGradByFreq Tensor gradient layout device dataType shape
weight Tensor gradient' layout' device' dataType' shape'
input =
let isSparse :: Bool
isSparse = forall (layout :: Layout LayoutType)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat)) (dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGetLayout layout =>
Tensor gradient layout device dataType shape -> LayoutType
getLayoutType Tensor gradient layout device dataType shape
weight forall a. Eq a => a -> a -> Bool
== LayoutType
Sparse
Int
paddingIdx' :: Int = forall b a. b -> (a -> b) -> Maybe a -> b
maybe (-Int
1) forall a b. (Integral a, Num b) => a -> b
fromIntegral Maybe Nat
paddingIdx
in forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable x4 cx4, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> IO y
cast5 ForeignPtr Tensor
-> ForeignPtr Tensor
-> Int64
-> CBool
-> CBool
-> IO (ForeignPtr Tensor)
ATen.embedding_ttlbb Tensor gradient layout device dataType shape
weight Tensor gradient' layout' device' dataType' shape'
input Int
paddingIdx' Bool
scaleGradByFreq Bool
isSparse