{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeInType #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module Torch.Typed.NN.Sparse where
import Data.Proxy
import GHC.Generics
import GHC.TypeLits
import qualified Torch.DType as D
import qualified Torch.Device as D
import Torch.HList
import Torch.NN (HasForward (..), Randomizable (..))
import Torch.Typed.Auxiliary
import Torch.Typed.Factories
import Torch.Typed.Functional
import Torch.Typed.Parameter
import Torch.Typed.Tensor
data EmbeddingType = Constant | Learned deriving (Int -> EmbeddingType -> ShowS
[EmbeddingType] -> ShowS
EmbeddingType -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [EmbeddingType] -> ShowS
$cshowList :: [EmbeddingType] -> ShowS
show :: EmbeddingType -> String
$cshow :: EmbeddingType -> String
showsPrec :: Int -> EmbeddingType -> ShowS
$cshowsPrec :: Int -> EmbeddingType -> ShowS
Show, forall x. Rep EmbeddingType x -> EmbeddingType
forall x. EmbeddingType -> Rep EmbeddingType x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep EmbeddingType x -> EmbeddingType
$cfrom :: forall x. EmbeddingType -> Rep EmbeddingType x
Generic)
data
EmbeddingSpec
(paddingIdx :: Maybe Nat)
(numEmbeds :: Nat)
(embedSize :: Nat)
(embeddingType :: EmbeddingType)
(dtype :: D.DType)
(device :: (D.DeviceType, Nat))
where
ConstEmbeddingSpec ::
forall paddingIdx numEmbeds embedSize dtype device.
Tensor device dtype '[numEmbeds, embedSize] ->
EmbeddingSpec paddingIdx numEmbeds embedSize 'Constant dtype device
LearnedEmbeddingWithRandomInitSpec ::
forall paddingIdx numEmbeds embedSize dtype device.
EmbeddingSpec
paddingIdx
numEmbeds
embedSize
'Learned
dtype
device
LearnedEmbeddingWithCustomInitSpec ::
forall paddingIdx numEmbeds embedSize dtype device.
Tensor device dtype '[numEmbeds, embedSize] ->
EmbeddingSpec paddingIdx numEmbeds embedSize 'Learned dtype device
deriving instance Show (EmbeddingSpec paddingIdx numEmbeds embedSize embeddingType dtype device)
data
Embedding
(paddingIdx :: Maybe Nat)
(numEmbeds :: Nat)
(embedSize :: Nat)
(embeddingType :: EmbeddingType)
(dtype :: D.DType)
(device :: (D.DeviceType, Nat))
where
ConstEmbedding ::
forall paddingIdx numEmbeds embedSize dtype device.
{forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
(embedSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Embedding paddingIdx numEmbeds embedSize 'Constant dtype device
-> Tensor device dtype '[numEmbeds, embedSize]
constEmbedWeights :: Tensor device dtype '[numEmbeds, embedSize]} ->
Embedding
paddingIdx
numEmbeds
embedSize
'Constant
dtype
device
LearnedEmbedding ::
forall paddingIdx numEmbeds embedSize dtype device.
{forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
(embedSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Embedding paddingIdx numEmbeds embedSize 'Learned dtype device
-> Parameter device dtype '[numEmbeds, embedSize]
learnedEmbedWeights :: Parameter device dtype '[numEmbeds, embedSize]} ->
Embedding
paddingIdx
numEmbeds
embedSize
'Learned
dtype
device
deriving instance Show (Embedding paddingIdx numEmbeds embedSize embeddingType dtype device)
instance Generic (Embedding paddingIdx numEmbeds embedSize 'Constant dtype device) where
type
Rep (Embedding paddingIdx numEmbeds embedSize 'Constant dtype device) =
Rec0 (Tensor device dtype '[numEmbeds, embedSize])
from :: forall x.
Embedding paddingIdx numEmbeds embedSize 'Constant dtype device
-> Rep
(Embedding paddingIdx numEmbeds embedSize 'Constant dtype device) x
from (ConstEmbedding {Tensor device dtype '[numEmbeds, embedSize]
constEmbedWeights :: Tensor device dtype '[numEmbeds, embedSize]
constEmbedWeights :: forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
(embedSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Embedding paddingIdx numEmbeds embedSize 'Constant dtype device
-> Tensor device dtype '[numEmbeds, embedSize]
..}) = forall k i c (p :: k). c -> K1 i c p
K1 Tensor device dtype '[numEmbeds, embedSize]
constEmbedWeights
to :: forall x.
Rep
(Embedding paddingIdx numEmbeds embedSize 'Constant dtype device) x
-> Embedding paddingIdx numEmbeds embedSize 'Constant dtype device
to = forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
(embedSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Tensor device dtype '[numEmbeds, embedSize]
-> Embedding paddingIdx numEmbeds embedSize 'Constant dtype device
ConstEmbedding forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k i c (p :: k). K1 i c p -> c
unK1
instance Generic (Embedding paddingIdx numEmbeds embedSize 'Learned dtype device) where
type
Rep (Embedding paddingIdx numEmbeds embedSize 'Learned dtype device) =
Rec0 (Parameter device dtype '[numEmbeds, embedSize])
from :: forall x.
Embedding paddingIdx numEmbeds embedSize 'Learned dtype device
-> Rep
(Embedding paddingIdx numEmbeds embedSize 'Learned dtype device) x
from (LearnedEmbedding {Parameter device dtype '[numEmbeds, embedSize]
learnedEmbedWeights :: Parameter device dtype '[numEmbeds, embedSize]
learnedEmbedWeights :: forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
(embedSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Embedding paddingIdx numEmbeds embedSize 'Learned dtype device
-> Parameter device dtype '[numEmbeds, embedSize]
..}) = forall k i c (p :: k). c -> K1 i c p
K1 Parameter device dtype '[numEmbeds, embedSize]
learnedEmbedWeights
to :: forall x.
Rep
(Embedding paddingIdx numEmbeds embedSize 'Learned dtype device) x
-> Embedding paddingIdx numEmbeds embedSize 'Learned dtype device
to = forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
(embedSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Parameter device dtype '[numEmbeds, embedSize]
-> Embedding paddingIdx numEmbeds embedSize 'Learned dtype device
LearnedEmbedding forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k i c (p :: k). K1 i c p -> c
unK1
instance Parameterized (Embedding paddingIdx numEmbeds embedSize 'Constant dtype device)
instance Parameterized (Embedding paddingIdx numEmbeds embedSize 'Learned dtype device)
embed ::
forall paddingIdx shape numEmbeds embedSize embeddingType dtype device shape'.
( KnownMaybeNat paddingIdx,
PaddingIdxCheck paddingIdx numEmbeds,
shape' ~ Reverse (embedSize ': (Reverse shape))
) =>
Embedding paddingIdx numEmbeds embedSize embeddingType dtype device ->
Tensor device 'D.Int64 shape ->
Tensor device dtype shape'
embed :: forall (paddingIdx :: Maybe Nat) (shape :: [Nat])
(numEmbeds :: Nat) (embedSize :: Nat)
(embeddingType :: EmbeddingType) (dtype :: DType)
(device :: (DeviceType, Nat)) (shape' :: [Nat]).
(KnownMaybeNat paddingIdx, PaddingIdxCheck paddingIdx numEmbeds,
shape' ~ Reverse (embedSize : Reverse shape)) =>
Embedding paddingIdx numEmbeds embedSize embeddingType dtype device
-> Tensor device 'Int64 shape -> Tensor device dtype shape'
embed ConstEmbedding {Tensor device dtype '[numEmbeds, embedSize]
constEmbedWeights :: Tensor device dtype '[numEmbeds, embedSize]
constEmbedWeights :: forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
(embedSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Embedding paddingIdx numEmbeds embedSize 'Constant dtype device
-> Tensor device dtype '[numEmbeds, embedSize]
..} Tensor device 'Int64 shape
input =
forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
(embedDim :: Nat) (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(KnownMaybeNat paddingIdx, PaddingIdxCheck paddingIdx numEmbeds) =>
Bool
-> Bool
-> Tensor device dtype '[numEmbeds, embedDim]
-> Tensor device 'Int64 shape
-> Tensor device dtype (Reverse (embedDim : Reverse shape))
embedding @paddingIdx
Bool
False
Bool
False
Tensor device dtype '[numEmbeds, embedSize]
constEmbedWeights
Tensor device 'Int64 shape
input
embed LearnedEmbedding {Parameter device dtype '[numEmbeds, embedSize]
learnedEmbedWeights :: Parameter device dtype '[numEmbeds, embedSize]
learnedEmbedWeights :: forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
(embedSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Embedding paddingIdx numEmbeds embedSize 'Learned dtype device
-> Parameter device dtype '[numEmbeds, embedSize]
..} Tensor device 'Int64 shape
input =
forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
(embedDim :: Nat) (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(KnownMaybeNat paddingIdx, PaddingIdxCheck paddingIdx numEmbeds) =>
Bool
-> Bool
-> Tensor device dtype '[numEmbeds, embedDim]
-> Tensor device 'Int64 shape
-> Tensor device dtype (Reverse (embedDim : Reverse shape))
embedding @paddingIdx
Bool
False
Bool
False
(forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Parameter device dtype shape -> Tensor device dtype shape
toDependent Parameter device dtype '[numEmbeds, embedSize]
learnedEmbedWeights)
Tensor device 'Int64 shape
input
instance
( KnownMaybeNat paddingIdx,
PaddingIdxCheck paddingIdx numEmbeds,
shape' ~ Reverse (embedSize ': (Reverse shape))
) =>
HasForward (Embedding paddingIdx numEmbeds embedSize embeddingType dtype device) (Tensor device 'D.Int64 shape) (Tensor device dtype shape')
where
forward :: Embedding paddingIdx numEmbeds embedSize embeddingType dtype device
-> Tensor device 'Int64 shape -> Tensor device dtype shape'
forward = forall (paddingIdx :: Maybe Nat) (shape :: [Nat])
(numEmbeds :: Nat) (embedSize :: Nat)
(embeddingType :: EmbeddingType) (dtype :: DType)
(device :: (DeviceType, Nat)) (shape' :: [Nat]).
(KnownMaybeNat paddingIdx, PaddingIdxCheck paddingIdx numEmbeds,
shape' ~ Reverse (embedSize : Reverse shape)) =>
Embedding paddingIdx numEmbeds embedSize embeddingType dtype device
-> Tensor device 'Int64 shape -> Tensor device dtype shape'
embed
forwardStoch :: Embedding paddingIdx numEmbeds embedSize embeddingType dtype device
-> Tensor device 'Int64 shape -> IO (Tensor device dtype shape')
forwardStoch = (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
.) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall f a b. HasForward f a b => f -> a -> b
forward
instance
Randomizable
(EmbeddingSpec paddingIdx numEmbeds embedSize 'Constant dtype device)
(Embedding paddingIdx numEmbeds embedSize 'Constant dtype device)
where
sample :: EmbeddingSpec paddingIdx numEmbeds embedSize 'Constant dtype device
-> IO
(Embedding paddingIdx numEmbeds embedSize 'Constant dtype device)
sample (ConstEmbeddingSpec Tensor device dtype '[numEmbeds, embedSize]
tensor) = forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
(embedSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Tensor device dtype '[numEmbeds, embedSize]
-> Embedding paddingIdx numEmbeds embedSize 'Constant dtype device
ConstEmbedding Tensor device dtype '[numEmbeds, embedSize]
tensor)
instance
( KnownNat numEmbeds,
KnownNat embedSize,
KnownDType dtype,
KnownDevice device,
RandDTypeIsValid device dtype
) =>
Randomizable
(EmbeddingSpec 'Nothing numEmbeds embedSize 'Learned dtype device)
(Embedding 'Nothing numEmbeds embedSize 'Learned dtype device)
where
sample :: EmbeddingSpec 'Nothing numEmbeds embedSize 'Learned dtype device
-> IO
(Embedding 'Nothing numEmbeds embedSize 'Learned dtype device)
sample EmbeddingSpec 'Nothing numEmbeds embedSize 'Learned dtype device
LearnedEmbeddingWithRandomInitSpec = forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
(embedSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Parameter device dtype '[numEmbeds, embedSize]
-> Embedding paddingIdx numEmbeds embedSize 'Learned dtype device
LearnedEmbedding forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(TensorOptions shape dtype device,
RandDTypeIsValid device dtype) =>
IO (Tensor device dtype shape)
randn)
sample (LearnedEmbeddingWithCustomInitSpec Tensor device dtype '[numEmbeds, embedSize]
tensor) = forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
(embedSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Parameter device dtype '[numEmbeds, embedSize]
-> Embedding paddingIdx numEmbeds embedSize 'Learned dtype device
LearnedEmbedding forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (forall (f :: * -> *) a. Applicative f => a -> f a
pure Tensor device dtype '[numEmbeds, embedSize]
tensor))
instance
( paddingIdx <= numEmbeds,
1 <= numEmbeds - paddingIdx,
(((numEmbeds - paddingIdx) - 1) + (1 + paddingIdx)) ~ numEmbeds,
KnownNat paddingIdx,
KnownNat numEmbeds,
KnownNat embedSize,
KnownDType dtype,
KnownDevice device,
RandDTypeIsValid device dtype
) =>
Randomizable
(EmbeddingSpec ('Just paddingIdx) numEmbeds embedSize 'Learned dtype device)
(Embedding ('Just paddingIdx) numEmbeds embedSize 'Learned dtype device)
where
sample :: EmbeddingSpec
('Just paddingIdx) numEmbeds embedSize 'Learned dtype device
-> IO
(Embedding
('Just paddingIdx) numEmbeds embedSize 'Learned dtype device)
sample EmbeddingSpec
('Just paddingIdx) numEmbeds embedSize 'Learned dtype device
LearnedEmbeddingWithRandomInitSpec =
let mask :: Tensor device 'Bool '[numEmbeds, embedSize]
mask =
forall {k} (dim :: Nat) (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)) (tensors :: [k]).
(KnownNat dim, '(shape, dtype, device) ~ Cat dim tensors,
Castable (HList tensors) [ATenTensor]) =>
HList tensors -> Tensor device dtype shape
cat @0
( forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
Tensor device dtype shape
zeros @'[paddingIdx, embedSize] @'D.Bool @device
forall x (xs :: [*]). x -> HList xs -> HList (x : xs)
:. forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
Tensor device dtype shape
ones @'[1, embedSize] @'D.Bool @device
forall x (xs :: [*]). x -> HList xs -> HList (x : xs)
:. forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
Tensor device dtype shape
zeros @'[numEmbeds - paddingIdx - 1, embedSize] @'D.Bool @device
forall x (xs :: [*]). x -> HList xs -> HList (x : xs)
:. forall k. HList '[]
HNil
)
in forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
(embedSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Parameter device dtype '[numEmbeds, embedSize]
-> Embedding paddingIdx numEmbeds embedSize 'Learned dtype device
LearnedEmbedding forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (forall a (shape :: [Nat]) (shape' :: [Nat]) (shape'' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(Scalar a, shape'' ~ Broadcast shape shape') =>
Tensor device 'Bool shape'
-> a -> Tensor device dtype shape -> Tensor device dtype shape''
maskedFill Tensor device 'Bool '[numEmbeds, embedSize]
mask (Int
0 :: Int) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(TensorOptions shape dtype device,
RandDTypeIsValid device dtype) =>
IO (Tensor device dtype shape)
randn @'[numEmbeds, embedSize] @dtype @device)))
sample (LearnedEmbeddingWithCustomInitSpec Tensor device dtype '[numEmbeds, embedSize]
tensor) = forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
(embedSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Parameter device dtype '[numEmbeds, embedSize]
-> Embedding paddingIdx numEmbeds embedSize 'Learned dtype device
LearnedEmbedding forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (forall (f :: * -> *) a. Applicative f => a -> f a
pure Tensor device dtype '[numEmbeds, embedSize]
tensor))