{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeInType #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

module Torch.Typed.NN.Sparse where

import Data.Proxy
import GHC.Generics
import GHC.TypeLits
import qualified Torch.DType as D
import qualified Torch.Device as D
import Torch.HList
import Torch.NN (HasForward (..), Randomizable (..))
import Torch.Typed.Auxiliary
import Torch.Typed.Factories
import Torch.Typed.Functional
import Torch.Typed.Parameter
import Torch.Typed.Tensor

data EmbeddingType = Constant | Learned deriving (Int -> EmbeddingType -> ShowS
[EmbeddingType] -> ShowS
EmbeddingType -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [EmbeddingType] -> ShowS
$cshowList :: [EmbeddingType] -> ShowS
show :: EmbeddingType -> String
$cshow :: EmbeddingType -> String
showsPrec :: Int -> EmbeddingType -> ShowS
$cshowsPrec :: Int -> EmbeddingType -> ShowS
Show, forall x. Rep EmbeddingType x -> EmbeddingType
forall x. EmbeddingType -> Rep EmbeddingType x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep EmbeddingType x -> EmbeddingType
$cfrom :: forall x. EmbeddingType -> Rep EmbeddingType x
Generic)

data
  EmbeddingSpec
    (paddingIdx :: Maybe Nat)
    (numEmbeds :: Nat)
    (embedSize :: Nat)
    (embeddingType :: EmbeddingType)
    (dtype :: D.DType)
    (device :: (D.DeviceType, Nat))
  where
  ConstEmbeddingSpec ::
    forall paddingIdx numEmbeds embedSize dtype device.
    Tensor device dtype '[numEmbeds, embedSize] ->
    EmbeddingSpec paddingIdx numEmbeds embedSize 'Constant dtype device
  LearnedEmbeddingWithRandomInitSpec ::
    forall paddingIdx numEmbeds embedSize dtype device.
    EmbeddingSpec
      paddingIdx
      numEmbeds
      embedSize
      'Learned
      dtype
      device
  LearnedEmbeddingWithCustomInitSpec ::
    forall paddingIdx numEmbeds embedSize dtype device.
    Tensor device dtype '[numEmbeds, embedSize] ->
    EmbeddingSpec paddingIdx numEmbeds embedSize 'Learned dtype device

deriving instance Show (EmbeddingSpec paddingIdx numEmbeds embedSize embeddingType dtype device)

data
  Embedding
    (paddingIdx :: Maybe Nat)
    (numEmbeds :: Nat)
    (embedSize :: Nat)
    (embeddingType :: EmbeddingType)
    (dtype :: D.DType)
    (device :: (D.DeviceType, Nat))
  where
  ConstEmbedding ::
    forall paddingIdx numEmbeds embedSize dtype device.
    --  . (PaddingIdxCheck paddingIdx numEmbeds)
    {forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
       (embedSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Embedding paddingIdx numEmbeds embedSize 'Constant dtype device
-> Tensor device dtype '[numEmbeds, embedSize]
constEmbedWeights :: Tensor device dtype '[numEmbeds, embedSize]} ->
    Embedding
      paddingIdx
      numEmbeds
      embedSize
      'Constant
      dtype
      device
  LearnedEmbedding ::
    forall paddingIdx numEmbeds embedSize dtype device.
    --  . (PaddingIdxCheck paddingIdx numEmbeds)
    {forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
       (embedSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Embedding paddingIdx numEmbeds embedSize 'Learned dtype device
-> Parameter device dtype '[numEmbeds, embedSize]
learnedEmbedWeights :: Parameter device dtype '[numEmbeds, embedSize]} ->
    Embedding
      paddingIdx
      numEmbeds
      embedSize
      'Learned
      dtype
      device

deriving instance Show (Embedding paddingIdx numEmbeds embedSize embeddingType dtype device)

instance Generic (Embedding paddingIdx numEmbeds embedSize 'Constant dtype device) where
  type
    Rep (Embedding paddingIdx numEmbeds embedSize 'Constant dtype device) =
      Rec0 (Tensor device dtype '[numEmbeds, embedSize])
  from :: forall x.
Embedding paddingIdx numEmbeds embedSize 'Constant dtype device
-> Rep
     (Embedding paddingIdx numEmbeds embedSize 'Constant dtype device) x
from (ConstEmbedding {Tensor device dtype '[numEmbeds, embedSize]
constEmbedWeights :: Tensor device dtype '[numEmbeds, embedSize]
constEmbedWeights :: forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
       (embedSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Embedding paddingIdx numEmbeds embedSize 'Constant dtype device
-> Tensor device dtype '[numEmbeds, embedSize]
..}) = forall k i c (p :: k). c -> K1 i c p
K1 Tensor device dtype '[numEmbeds, embedSize]
constEmbedWeights
  to :: forall x.
Rep
  (Embedding paddingIdx numEmbeds embedSize 'Constant dtype device) x
-> Embedding paddingIdx numEmbeds embedSize 'Constant dtype device
to = forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
       (embedSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Tensor device dtype '[numEmbeds, embedSize]
-> Embedding paddingIdx numEmbeds embedSize 'Constant dtype device
ConstEmbedding forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k i c (p :: k). K1 i c p -> c
unK1

instance Generic (Embedding paddingIdx numEmbeds embedSize 'Learned dtype device) where
  type
    Rep (Embedding paddingIdx numEmbeds embedSize 'Learned dtype device) =
      Rec0 (Parameter device dtype '[numEmbeds, embedSize])
  from :: forall x.
Embedding paddingIdx numEmbeds embedSize 'Learned dtype device
-> Rep
     (Embedding paddingIdx numEmbeds embedSize 'Learned dtype device) x
from (LearnedEmbedding {Parameter device dtype '[numEmbeds, embedSize]
learnedEmbedWeights :: Parameter device dtype '[numEmbeds, embedSize]
learnedEmbedWeights :: forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
       (embedSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Embedding paddingIdx numEmbeds embedSize 'Learned dtype device
-> Parameter device dtype '[numEmbeds, embedSize]
..}) = forall k i c (p :: k). c -> K1 i c p
K1 Parameter device dtype '[numEmbeds, embedSize]
learnedEmbedWeights
  to :: forall x.
Rep
  (Embedding paddingIdx numEmbeds embedSize 'Learned dtype device) x
-> Embedding paddingIdx numEmbeds embedSize 'Learned dtype device
to = forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
       (embedSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Parameter device dtype '[numEmbeds, embedSize]
-> Embedding paddingIdx numEmbeds embedSize 'Learned dtype device
LearnedEmbedding forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k i c (p :: k). K1 i c p -> c
unK1

instance Parameterized (Embedding paddingIdx numEmbeds embedSize 'Constant dtype device)

instance Parameterized (Embedding paddingIdx numEmbeds embedSize 'Learned dtype device)

embed ::
  forall paddingIdx shape numEmbeds embedSize embeddingType dtype device shape'.
  ( KnownMaybeNat paddingIdx,
    PaddingIdxCheck paddingIdx numEmbeds,
    shape' ~ Reverse (embedSize ': (Reverse shape))
  ) =>
  Embedding paddingIdx numEmbeds embedSize embeddingType dtype device ->
  Tensor device 'D.Int64 shape ->
  Tensor device dtype shape'
embed :: forall (paddingIdx :: Maybe Nat) (shape :: [Nat])
       (numEmbeds :: Nat) (embedSize :: Nat)
       (embeddingType :: EmbeddingType) (dtype :: DType)
       (device :: (DeviceType, Nat)) (shape' :: [Nat]).
(KnownMaybeNat paddingIdx, PaddingIdxCheck paddingIdx numEmbeds,
 shape' ~ Reverse (embedSize : Reverse shape)) =>
Embedding paddingIdx numEmbeds embedSize embeddingType dtype device
-> Tensor device 'Int64 shape -> Tensor device dtype shape'
embed ConstEmbedding {Tensor device dtype '[numEmbeds, embedSize]
constEmbedWeights :: Tensor device dtype '[numEmbeds, embedSize]
constEmbedWeights :: forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
       (embedSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Embedding paddingIdx numEmbeds embedSize 'Constant dtype device
-> Tensor device dtype '[numEmbeds, embedSize]
..} Tensor device 'Int64 shape
input =
  forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
       (embedDim :: Nat) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(KnownMaybeNat paddingIdx, PaddingIdxCheck paddingIdx numEmbeds) =>
Bool
-> Bool
-> Tensor device dtype '[numEmbeds, embedDim]
-> Tensor device 'Int64 shape
-> Tensor device dtype (Reverse (embedDim : Reverse shape))
embedding @paddingIdx
    Bool
False
    Bool
False
    Tensor device dtype '[numEmbeds, embedSize]
constEmbedWeights
    Tensor device 'Int64 shape
input
embed LearnedEmbedding {Parameter device dtype '[numEmbeds, embedSize]
learnedEmbedWeights :: Parameter device dtype '[numEmbeds, embedSize]
learnedEmbedWeights :: forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
       (embedSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Embedding paddingIdx numEmbeds embedSize 'Learned dtype device
-> Parameter device dtype '[numEmbeds, embedSize]
..} Tensor device 'Int64 shape
input =
  forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
       (embedDim :: Nat) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(KnownMaybeNat paddingIdx, PaddingIdxCheck paddingIdx numEmbeds) =>
Bool
-> Bool
-> Tensor device dtype '[numEmbeds, embedDim]
-> Tensor device 'Int64 shape
-> Tensor device dtype (Reverse (embedDim : Reverse shape))
embedding @paddingIdx
    Bool
False
    Bool
False
    (forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Parameter device dtype shape -> Tensor device dtype shape
toDependent Parameter device dtype '[numEmbeds, embedSize]
learnedEmbedWeights)
    Tensor device 'Int64 shape
input

instance
  ( KnownMaybeNat paddingIdx,
    PaddingIdxCheck paddingIdx numEmbeds,
    shape' ~ Reverse (embedSize ': (Reverse shape))
  ) =>
  HasForward (Embedding paddingIdx numEmbeds embedSize embeddingType dtype device) (Tensor device 'D.Int64 shape) (Tensor device dtype shape')
  where
  forward :: Embedding paddingIdx numEmbeds embedSize embeddingType dtype device
-> Tensor device 'Int64 shape -> Tensor device dtype shape'
forward = forall (paddingIdx :: Maybe Nat) (shape :: [Nat])
       (numEmbeds :: Nat) (embedSize :: Nat)
       (embeddingType :: EmbeddingType) (dtype :: DType)
       (device :: (DeviceType, Nat)) (shape' :: [Nat]).
(KnownMaybeNat paddingIdx, PaddingIdxCheck paddingIdx numEmbeds,
 shape' ~ Reverse (embedSize : Reverse shape)) =>
Embedding paddingIdx numEmbeds embedSize embeddingType dtype device
-> Tensor device 'Int64 shape -> Tensor device dtype shape'
embed
  forwardStoch :: Embedding paddingIdx numEmbeds embedSize embeddingType dtype device
-> Tensor device 'Int64 shape -> IO (Tensor device dtype shape')
forwardStoch = (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
.) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall f a b. HasForward f a b => f -> a -> b
forward

instance
  Randomizable
    (EmbeddingSpec paddingIdx numEmbeds embedSize 'Constant dtype device)
    (Embedding paddingIdx numEmbeds embedSize 'Constant dtype device)
  where
  sample :: EmbeddingSpec paddingIdx numEmbeds embedSize 'Constant dtype device
-> IO
     (Embedding paddingIdx numEmbeds embedSize 'Constant dtype device)
sample (ConstEmbeddingSpec Tensor device dtype '[numEmbeds, embedSize]
tensor) = forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
       (embedSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Tensor device dtype '[numEmbeds, embedSize]
-> Embedding paddingIdx numEmbeds embedSize 'Constant dtype device
ConstEmbedding Tensor device dtype '[numEmbeds, embedSize]
tensor)

instance
  ( KnownNat numEmbeds,
    KnownNat embedSize,
    KnownDType dtype,
    KnownDevice device,
    RandDTypeIsValid device dtype
  ) =>
  Randomizable
    (EmbeddingSpec 'Nothing numEmbeds embedSize 'Learned dtype device)
    (Embedding 'Nothing numEmbeds embedSize 'Learned dtype device)
  where
  sample :: EmbeddingSpec 'Nothing numEmbeds embedSize 'Learned dtype device
-> IO
     (Embedding 'Nothing numEmbeds embedSize 'Learned dtype device)
sample EmbeddingSpec 'Nothing numEmbeds embedSize 'Learned dtype device
LearnedEmbeddingWithRandomInitSpec = forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
       (embedSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Parameter device dtype '[numEmbeds, embedSize]
-> Embedding paddingIdx numEmbeds embedSize 'Learned dtype device
LearnedEmbedding forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(TensorOptions shape dtype device,
 RandDTypeIsValid device dtype) =>
IO (Tensor device dtype shape)
randn)
  sample (LearnedEmbeddingWithCustomInitSpec Tensor device dtype '[numEmbeds, embedSize]
tensor) = forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
       (embedSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Parameter device dtype '[numEmbeds, embedSize]
-> Embedding paddingIdx numEmbeds embedSize 'Learned dtype device
LearnedEmbedding forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (forall (f :: * -> *) a. Applicative f => a -> f a
pure Tensor device dtype '[numEmbeds, embedSize]
tensor))

instance
  ( paddingIdx <= numEmbeds,
    1 <= numEmbeds - paddingIdx,
    (((numEmbeds - paddingIdx) - 1) + (1 + paddingIdx)) ~ numEmbeds,
    KnownNat paddingIdx,
    KnownNat numEmbeds,
    KnownNat embedSize,
    KnownDType dtype,
    KnownDevice device,
    RandDTypeIsValid device dtype
  ) =>
  Randomizable
    (EmbeddingSpec ('Just paddingIdx) numEmbeds embedSize 'Learned dtype device)
    (Embedding ('Just paddingIdx) numEmbeds embedSize 'Learned dtype device)
  where
  sample :: EmbeddingSpec
  ('Just paddingIdx) numEmbeds embedSize 'Learned dtype device
-> IO
     (Embedding
        ('Just paddingIdx) numEmbeds embedSize 'Learned dtype device)
sample EmbeddingSpec
  ('Just paddingIdx) numEmbeds embedSize 'Learned dtype device
LearnedEmbeddingWithRandomInitSpec =
    let mask :: Tensor device 'Bool '[numEmbeds, embedSize]
mask =
          forall {k} (dim :: Nat) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)) (tensors :: [k]).
(KnownNat dim, '(shape, dtype, device) ~ Cat dim tensors,
 Castable (HList tensors) [ATenTensor]) =>
HList tensors -> Tensor device dtype shape
cat @0
            ( forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
Tensor device dtype shape
zeros @'[paddingIdx, embedSize] @'D.Bool @device
                forall x (xs :: [*]). x -> HList xs -> HList (x : xs)
:. forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
Tensor device dtype shape
ones @'[1, embedSize] @'D.Bool @device
                forall x (xs :: [*]). x -> HList xs -> HList (x : xs)
:. forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
Tensor device dtype shape
zeros @'[numEmbeds - paddingIdx - 1, embedSize] @'D.Bool @device
                forall x (xs :: [*]). x -> HList xs -> HList (x : xs)
:. forall k. HList '[]
HNil
            )
     in forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
       (embedSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Parameter device dtype '[numEmbeds, embedSize]
-> Embedding paddingIdx numEmbeds embedSize 'Learned dtype device
LearnedEmbedding forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (forall a (shape :: [Nat]) (shape' :: [Nat]) (shape'' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(Scalar a, shape'' ~ Broadcast shape shape') =>
Tensor device 'Bool shape'
-> a -> Tensor device dtype shape -> Tensor device dtype shape''
maskedFill Tensor device 'Bool '[numEmbeds, embedSize]
mask (Int
0 :: Int) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(TensorOptions shape dtype device,
 RandDTypeIsValid device dtype) =>
IO (Tensor device dtype shape)
randn @'[numEmbeds, embedSize] @dtype @device)))
  sample (LearnedEmbeddingWithCustomInitSpec Tensor device dtype '[numEmbeds, embedSize]
tensor) = forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
       (embedSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Parameter device dtype '[numEmbeds, embedSize]
-> Embedding paddingIdx numEmbeds embedSize 'Learned dtype device
LearnedEmbedding forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (forall (f :: * -> *) a. Applicative f => a -> f a
pure Tensor device dtype '[numEmbeds, embedSize]
tensor))