{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fplugin TypeLevel.Rewrite
                -fplugin-opt=TypeLevel.Rewrite:Torch.GraduallyTyped.Unify.UnifyRightAssociativeL
                -fplugin-opt=TypeLevel.Rewrite:Torch.GraduallyTyped.Unify.UnifyIdempotenceL2 #-}

module Torch.GraduallyTyped.NN.Transformer.Type where

import Control.Monad.Catch (MonadThrow)
import Data.Singletons.TH (SingKind (fromSing), genSingletons)
import GHC.Float (double2Int)
import GHC.Generics (Generic)
import GHC.TypeLits (Nat, Symbol)
import Torch.GraduallyTyped.DType (DType (..), DataType (..), SDType (..), SDataType (..))
import Torch.GraduallyTyped.Device (SDevice (..))
import Torch.GraduallyTyped.Layout (Layout (..), LayoutType (..), SLayout (..), SLayoutType (..))
import Torch.GraduallyTyped.NN.Class (HasForward (..), HasInitialize (..), HasStateDict (..), ModelSpec)
import Torch.GraduallyTyped.Prelude (Catch, forgetIsChecked, pattern (:|:))
import Torch.GraduallyTyped.Prelude.List (SList (SNil))
import Torch.GraduallyTyped.RequiresGradient (Gradient (..), RequiresGradient (..), SGradient (..), SRequiresGradient (..))
import Torch.GraduallyTyped.Scalar (Scalar)
import Torch.GraduallyTyped.Shape.Class (AddDimF, BroadcastShapesF, ReplaceDimF, sGetDimFromShape, type (!))
import Torch.GraduallyTyped.Shape.Type (By (..), Dim (..), Name (..), SBy (..), SDim (sDimSize), SName (..), SSelectDim (..), SShape (..), SSize (..), SelectDim (..), Shape (..), Size (..), pattern (:&:))
import Torch.GraduallyTyped.Tensor.Creation (sArangeNaturals, sFull, sOnes, sZeros)
import Torch.GraduallyTyped.Tensor.IndexingSlicingJoining (UnsqueezeF, cat, unsqueeze)
import Torch.GraduallyTyped.Tensor.MathOperations.Comparison ((==.))
import Torch.GraduallyTyped.Tensor.MathOperations.Pointwise (addScalar, logicalOr)
import Torch.GraduallyTyped.Tensor.Other (maskedFill, triu)
import Torch.GraduallyTyped.Tensor.Type (SGetDataType (sGetDataType), SGetDevice (..), SGetDim, SGetLayout (..), SGetShape (..), Tensor (..), TensorLike (sToTensor), TensorSpec (..), bool, sCheckedShape)
import Torch.GraduallyTyped.Unify (type (<+>), type (<|>))
import Torch.HList

-- | A data type representing the style of a transformer.
-- Every supported transformer has a constructor of this type.
data TransformerStyle
  = -- | @T5@ transformer style, see https://ai.googleblog.com/2020/02/exploring-transfer-learning-with-t5.html
    T5
  | -- | @ByT5@ transformer style, see https://arxiv.org/abs/2105.13626
    ByT5
  | -- | @BART@ transformer style, see https://arxiv.org/abs/1910.13461
    BART
  | -- | @MBART@ transformer style, see https://arxiv.org/abs/2001.08210
    MBART
  | -- | @Pegasus@ transformer style, see https://ai.googleblog.com/2020/06/pegasus-state-of-art-model-for.html
    Pegasus
  | -- | @BERT@ transformer style, see https://arxiv.org/abs/1810.04805
    BERT
  | -- | @RoBERTa@ transformer style, see https://arxiv.org/abs/1907.11692
    RoBERTa
  | -- | @GPT2@ transformer style, see https://openai.com/blog/better-language-models/
    GPT2
  deriving (Int -> TransformerStyle -> ShowS
[TransformerStyle] -> ShowS
TransformerStyle -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TransformerStyle] -> ShowS
$cshowList :: [TransformerStyle] -> ShowS
show :: TransformerStyle -> String
$cshow :: TransformerStyle -> String
showsPrec :: Int -> TransformerStyle -> ShowS
$cshowsPrec :: Int -> TransformerStyle -> ShowS
Show, TransformerStyle -> TransformerStyle -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: TransformerStyle -> TransformerStyle -> Bool
$c/= :: TransformerStyle -> TransformerStyle -> Bool
== :: TransformerStyle -> TransformerStyle -> Bool
$c== :: TransformerStyle -> TransformerStyle -> Bool
Eq)

genSingletons [''TransformerStyle]

-- | A data type representing the type of head used in a transformer.
data TransformerHead = WithoutHead | WithLMHead

genSingletons [''TransformerHead]

padded :: Integral n => n -> a -> [a] -> [a]
padded :: forall n a. Integral n => n -> a -> [a] -> [a]
padded n
n a
p [a]
xs =
  let n' :: Int
n' = forall a b. (Integral a, Num b) => a -> b
fromIntegral n
n
      diff :: Int
diff = Int
n' forall a. Num a => a -> a -> a
- forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
xs
   in forall a. Int -> [a] -> [a]
take Int
n' [a]
xs forall a. [a] -> [a] -> [a]
++ forall a. Int -> a -> [a]
replicate Int
diff a
p

-- | Converts a doubly-nested list of input ids to a batched input tensor.
-- The outer list is over batches, the inner list over sequences.
-- The batch size is inferred from the length of the outer list.
-- The sequence length is inferred from the length of the inner list.
-- The input ids are padded to the maximum sequence length.
-- The output tensor is truncated to the maximum sequence length.
mkTransformerInput ::
  forall batchDim seqDim device m output.
  ( MonadThrow m,
    SGetDim batchDim,
    SGetDim seqDim,
    Catch
      ( 'Shape
          '[ 'Dim ('Name "*") 'UncheckedSize,
             'Dim ('Name "*") 'UncheckedSize
           ]
          <+> 'Shape '[batchDim, seqDim]
      ),
    output
      ~ Tensor
          ('Gradient 'WithoutGradient)
          ('Layout 'Dense)
          device
          ('DataType 'Int64)
          ('Shape '[batchDim, seqDim])
  ) =>
  -- | padding token id
  Int ->
  -- | batch dimension singleton
  SDim batchDim ->
  -- | sequence dimension singleton
  SDim seqDim ->
  -- | device for the tensor
  SDevice device ->
  -- | batch of input ids
  [[Int]] ->
  -- | input tensor
  m output
mkTransformerInput :: forall (batchDim :: Dim (Name Symbol) (Size Nat))
       (seqDim :: Dim (Name Symbol) (Size Nat))
       (device :: Device (DeviceType Nat)) (m :: * -> *) output.
(MonadThrow m, SGetDim batchDim, SGetDim seqDim,
 Catch
   ('Shape
      '[ 'Dim ('Name "*") 'UncheckedSize,
         'Dim ('Name "*") 'UncheckedSize]
    <+> 'Shape '[batchDim, seqDim]),
 output
 ~ Tensor
     ('Gradient 'WithoutGradient)
     ('Layout 'Dense)
     device
     ('DataType 'Int64)
     ('Shape '[batchDim, seqDim])) =>
Int
-> SDim batchDim
-> SDim seqDim
-> SDevice device
-> [[Int]]
-> m output
mkTransformerInput Int
padTokenId SDim batchDim
batchDim SDim seqDim
seqDim SDevice device
device [[Int]]
xs =
  forall a (dType :: DType) (dims :: [Dim (Name Symbol) (Size Nat)])
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (m :: * -> *).
(TensorLike a dType dims, MonadThrow m) =>
SGradient gradient
-> SLayout layout
-> SDevice device
-> a
-> m (Tensor
        gradient layout device ('DataType dType) ('Shape dims))
sToTensor SGradient ('Gradient 'WithoutGradient)
gradient SLayout ('Layout 'Dense)
layout SDevice device
device [[Int]]
paddedXs
    forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (m :: * -> *) (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(SGetShape shape, MonadThrow m, Catch (shape <+> shape')) =>
SShape shape'
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
sCheckedShape (forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SList dims -> SShape ('Shape dims)
SShape forall a b. (a -> b) -> a -> b
$ SDim batchDim
batchDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: SDim seqDim
seqDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall a. SList '[]
SNil)
  where
    gradient :: SGradient ('Gradient 'WithoutGradient)
gradient = forall (requiresGradient :: RequiresGradient).
SRequiresGradient requiresGradient
-> SGradient ('Gradient requiresGradient)
SGradient SRequiresGradient 'WithoutGradient
SWithoutGradient
    layout :: SLayout ('Layout 'Dense)
layout = forall (layoutType :: LayoutType).
SLayoutType layoutType -> SLayout ('Layout layoutType)
SLayout SLayoutType 'Dense
SDense
    batchSize :: Integer
batchSize = forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall name size. Dim name size -> size
dimSize forall a b. (a -> b) -> a -> b
$ forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SDim batchDim
batchDim
    seqSize :: Integer
seqSize = forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall name size. Dim name size -> size
dimSize forall a b. (a -> b) -> a -> b
$ forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SDim seqDim
seqDim
    emptySeq :: [Int]
emptySeq = forall a. Int -> a -> [a]
replicate (forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
seqSize) Int
padTokenId
    paddedXs :: [[Int]]
paddedXs = forall n a. Integral n => n -> a -> [a] -> [a]
padded Integer
batchSize [Int]
emptySeq (forall n a. Integral n => n -> a -> [a] -> [a]
padded Integer
seqSize Int
padTokenId forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [[Int]]
xs)

type MkPosC device shape seqDim seqName seqSize output =
  ( SGetDevice device,
    SGetShape shape,
    seqDim ~ (shape ! 1),
    seqDim ~ 'Dim seqName seqSize,
    output
      ~ Tensor
          ('Gradient 'WithoutGradient)
          ('Layout 'Dense)
          device
          ('DataType 'Int64)
          ('Shape '[ 'Dim ('Name "*") seqSize])
  )

-- | Computes absolute positions of the input tokens.
-- Given an input tensor of shape @[batchDim, Dim seqName seqSize]@,
-- returns a tensor of shape @[Dim "*" seqSize]@.
mkPos ::
  forall m gradient layout device dataType shape seqDim seqName seqSize output.
  ( MonadThrow m,
    MkPosC device shape seqDim seqName seqSize output
  ) =>
  -- | input tensor
  Tensor gradient layout device dataType shape ->
  -- | positions of the input tokens
  m output
mkPos :: forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (seqDim :: Dim (Name Symbol) (Size Nat)) (seqName :: Name Symbol)
       (seqSize :: Size Nat) output.
(MonadThrow m,
 MkPosC device shape seqDim seqName seqSize output) =>
Tensor gradient layout device dataType shape -> m output
mkPos Tensor gradient layout device dataType shape
input = do
  let device :: SDevice device
device = forall (device :: Device (DeviceType Nat))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGetDevice device =>
Tensor gradient layout device dataType shape -> SDevice device
sGetDevice Tensor gradient layout device dataType shape
input
      shape :: SShape shape
shape = forall (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType).
SGetShape shape =>
Tensor gradient layout device dataType shape -> SShape shape
sGetShape Tensor gradient layout device dataType shape
input
  SDim ('Dim seqName seqSize)
seqDim <- forall (selectDim :: SelectDim (By Symbol Nat))
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (dim :: Dim (Name Symbol) (Size Nat)) (m :: * -> *).
(dim ~ GetDimF selectDim shape, MonadThrow m) =>
SSelectDim selectDim -> SShape shape -> m (SDim dim)
sGetDimFromShape (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim forall a b. (a -> b) -> a -> b
$ forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @1) SShape shape
shape
  let seqSize :: SSize seqSize
seqSize = forall (name :: Name Symbol) (size :: Size Nat).
SDim ('Dim name size) -> SSize size
sDimSize SDim ('Dim seqName seqSize)
seqDim
  forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType) (size :: Size Nat)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(MonadThrow m, shape ~ 'Shape '[ 'Dim ('Name "*") size]) =>
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> SSize size
-> m (Tensor gradient layout device dataType shape)
sArangeNaturals
    (forall (requiresGradient :: RequiresGradient).
SRequiresGradient requiresGradient
-> SGradient ('Gradient requiresGradient)
SGradient SRequiresGradient 'WithoutGradient
SWithoutGradient)
    (forall (layoutType :: LayoutType).
SLayoutType layoutType -> SLayout ('Layout layoutType)
SLayout SLayoutType 'Dense
SDense)
    SDevice device
device
    (forall (dType :: DType).
SDType dType -> SDataType ('DataType dType)
SDataType SDType 'Int64
SInt64)
    SSize seqSize
seqSize

data MkAbsPos = MkAbsPos | MkAbsPosWithOffset {MkAbsPos -> Int
absPosOffset :: Int}
  deriving stock (MkAbsPos -> MkAbsPos -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: MkAbsPos -> MkAbsPos -> Bool
$c/= :: MkAbsPos -> MkAbsPos -> Bool
== :: MkAbsPos -> MkAbsPos -> Bool
$c== :: MkAbsPos -> MkAbsPos -> Bool
Eq, Eq MkAbsPos
MkAbsPos -> MkAbsPos -> Bool
MkAbsPos -> MkAbsPos -> Ordering
MkAbsPos -> MkAbsPos -> MkAbsPos
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: MkAbsPos -> MkAbsPos -> MkAbsPos
$cmin :: MkAbsPos -> MkAbsPos -> MkAbsPos
max :: MkAbsPos -> MkAbsPos -> MkAbsPos
$cmax :: MkAbsPos -> MkAbsPos -> MkAbsPos
>= :: MkAbsPos -> MkAbsPos -> Bool
$c>= :: MkAbsPos -> MkAbsPos -> Bool
> :: MkAbsPos -> MkAbsPos -> Bool
$c> :: MkAbsPos -> MkAbsPos -> Bool
<= :: MkAbsPos -> MkAbsPos -> Bool
$c<= :: MkAbsPos -> MkAbsPos -> Bool
< :: MkAbsPos -> MkAbsPos -> Bool
$c< :: MkAbsPos -> MkAbsPos -> Bool
compare :: MkAbsPos -> MkAbsPos -> Ordering
$ccompare :: MkAbsPos -> MkAbsPos -> Ordering
Ord, Int -> MkAbsPos -> ShowS
[MkAbsPos] -> ShowS
MkAbsPos -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [MkAbsPos] -> ShowS
$cshowList :: [MkAbsPos] -> ShowS
show :: MkAbsPos -> String
$cshow :: MkAbsPos -> String
showsPrec :: Int -> MkAbsPos -> ShowS
$cshowsPrec :: Int -> MkAbsPos -> ShowS
Show, forall x. Rep MkAbsPos x -> MkAbsPos
forall x. MkAbsPos -> Rep MkAbsPos x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep MkAbsPos x -> MkAbsPos
$cfrom :: forall x. MkAbsPos -> Rep MkAbsPos x
Generic)

type instance ModelSpec MkAbsPos = MkAbsPos

instance HasInitialize MkAbsPos generatorDevice MkAbsPos generatorDevice where
  initialize :: forall (m :: * -> *).
MonadThrow m =>
ModelSpec MkAbsPos
-> Generator generatorDevice
-> m (MkAbsPos, Generator generatorDevice)
initialize ModelSpec MkAbsPos
spec Generator generatorDevice
g = forall (f :: * -> *) a. Applicative f => a -> f a
pure (ModelSpec MkAbsPos
spec, Generator generatorDevice
g)

instance HasStateDict MkAbsPos where
  fromStateDict :: forall (m :: * -> *).
(MonadIO m, MonadThrow m, MonadState StateDict m) =>
ModelSpec MkAbsPos -> StateDictKey -> m MkAbsPos
fromStateDict ModelSpec MkAbsPos
spec StateDictKey
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ModelSpec MkAbsPos
spec
  toStateDict :: forall (m :: * -> *).
(MonadThrow m, MonadState StateDict m) =>
StateDictKey -> MkAbsPos -> m ()
toStateDict StateDictKey
_ MkAbsPos
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

instance
  MkPosC device shape seqDim seqName seqSize output =>
  HasForward
    MkAbsPos
    (Tensor gradient layout device dataType shape)
    generatorDevice
    (Tensor ('Gradient 'WithoutGradient) ('Layout 'Dense) device ('DataType 'Int64) ('Shape '[ 'Dim ('Name "*") seqSize]))
    generatorDevice
  where
  forward :: forall (m :: * -> *).
MonadThrow m =>
MkAbsPos
-> Tensor gradient layout device dataType shape
-> Generator generatorDevice
-> m (Tensor
        ('Gradient 'WithoutGradient)
        ('Layout 'Dense)
        device
        ('DataType 'Int64)
        ('Shape '[ 'Dim ('Name "*") seqSize]),
      Generator generatorDevice)
forward MkAbsPos
MkAbsPos Tensor gradient layout device dataType shape
input Generator generatorDevice
g = do
    Tensor
  ('Gradient 'WithoutGradient)
  ('Layout 'Dense)
  device
  ('DataType 'Int64)
  ('Shape '[ 'Dim ('Name "*") seqSize])
pos <- forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (seqDim :: Dim (Name Symbol) (Size Nat)) (seqName :: Name Symbol)
       (seqSize :: Size Nat) output.
(MonadThrow m,
 MkPosC device shape seqDim seqName seqSize output) =>
Tensor gradient layout device dataType shape -> m output
mkPos Tensor gradient layout device dataType shape
input
    forall (f :: * -> *) a. Applicative f => a -> f a
pure (Tensor
  ('Gradient 'WithoutGradient)
  ('Layout 'Dense)
  device
  ('DataType 'Int64)
  ('Shape '[ 'Dim ('Name "*") seqSize])
pos, Generator generatorDevice
g)
  forward MkAbsPosWithOffset {Int
absPosOffset :: Int
absPosOffset :: MkAbsPos -> Int
..} Tensor gradient layout device dataType shape
input Generator generatorDevice
g = do
    Tensor
  ('Gradient 'WithoutGradient)
  ('Layout 'Dense)
  device
  ('DataType 'Int64)
  ('Shape '[ 'Dim ('Name "*") seqSize])
pos <- forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (seqDim :: Dim (Name Symbol) (Size Nat)) (seqName :: Name Symbol)
       (seqSize :: Size Nat) output.
(MonadThrow m,
 MkPosC device shape seqDim seqName seqSize output) =>
Tensor gradient layout device dataType shape -> m output
mkPos Tensor gradient layout device dataType shape
input
    Tensor
  ('Gradient 'WithoutGradient)
  ('Layout 'Dense)
  device
  ('DataType 'Int64)
  ('Shape '[ 'Dim ('Name "*") seqSize])
pos' <- Tensor
  ('Gradient 'WithoutGradient)
  ('Layout 'Dense)
  device
  ('DataType 'Int64)
  ('Shape '[ 'Dim ('Name "*") seqSize])
pos forall other (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(Scalar other, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> other -> m (Tensor gradient layout device dataType shape)
`addScalar` Int
absPosOffset
    forall (f :: * -> *) a. Applicative f => a -> f a
pure (Tensor
  ('Gradient 'WithoutGradient)
  ('Layout 'Dense)
  device
  ('DataType 'Int64)
  ('Shape '[ 'Dim ('Name "*") seqSize])
pos', Generator generatorDevice
g)

-- | Computes relative positions of the input tokens to the encoder.
--
-- >>> mkRelPos' 32 128 21 17
-- [[0,17,18,19,20,21,22,23,24,24,24,24,25,25,25,25,26],[1,0,17,18,19,20,21,22,23,24,24,24,24,25,25,25,25],[2,1,0,17,18,19,20,21,22,23,24,24,24,24,25,25,25],[3,2,1,0,17,18,19,20,21,22,23,24,24,24,24,25,25],[4,3,2,1,0,17,18,19,20,21,22,23,24,24,24,24,25],[5,4,3,2,1,0,17,18,19,20,21,22,23,24,24,24,24],[6,5,4,3,2,1,0,17,18,19,20,21,22,23,24,24,24],[7,6,5,4,3,2,1,0,17,18,19,20,21,22,23,24,24],[8,7,6,5,4,3,2,1,0,17,18,19,20,21,22,23,24],[8,8,7,6,5,4,3,2,1,0,17,18,19,20,21,22,23],[8,8,8,7,6,5,4,3,2,1,0,17,18,19,20,21,22],[8,8,8,8,7,6,5,4,3,2,1,0,17,18,19,20,21],[9,8,8,8,8,7,6,5,4,3,2,1,0,17,18,19,20],[9,9,8,8,8,8,7,6,5,4,3,2,1,0,17,18,19],[9,9,9,8,8,8,8,7,6,5,4,3,2,1,0,17,18],[9,9,9,9,8,8,8,8,7,6,5,4,3,2,1,0,17],[10,9,9,9,9,8,8,8,8,7,6,5,4,3,2,1,0],[10,10,9,9,9,9,8,8,8,8,7,6,5,4,3,2,1],[10,10,10,9,9,9,9,8,8,8,8,7,6,5,4,3,2],[10,10,10,10,9,9,9,9,8,8,8,8,7,6,5,4,3],[10,10,10,10,10,9,9,9,9,8,8,8,8,7,6,5,4]]
mkRelPos' :: Int -> Int -> Int -> Int -> [[Int]]
mkRelPos' :: Int -> Int -> Int -> Int -> [[Int]]
mkRelPos' Int
numBuckets Int
maxDistance Int
querySize Int
keySize =
  let queryPos :: [Int]
queryPos = [Int
0, Int
1 .. Int
querySize forall a. Num a => a -> a -> a
- Int
1]
      keyPos :: [Int]
keyPos = [Int
0, Int
1 .. Int
keySize forall a. Num a => a -> a -> a
- Int
1]
      numBuckets' :: Int
numBuckets' = Int
numBuckets forall a. Integral a => a -> a -> a
`div` Int
2
      maxExact :: Int
maxExact = Int
numBuckets' forall a. Integral a => a -> a -> a
`div` Int
2
   in forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap
        ( \Int
qp ->
            forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap
              ( \Int
kp ->
                  let rawRelPos :: Int
rawRelPos = Int
kp forall a. Num a => a -> a -> a
- Int
qp
                      absRelPos :: Int
absRelPos = forall a. Num a => a -> a
abs Int
rawRelPos
                      relBucket :: Int
relBucket = if Int
rawRelPos forall a. Ord a => a -> a -> Bool
> Int
0 then Int
numBuckets' else Int
0
                      relBucket' :: Int
relBucket' =
                        let isSmall :: Bool
isSmall = Int
absRelPos forall a. Ord a => a -> a -> Bool
< Int
maxExact
                            relPosIfLarge :: Int
relPosIfLarge =
                              Int
maxExact
                                forall a. Num a => a -> a -> a
+ Double -> Int
double2Int
                                  ( forall a. Floating a => a -> a -> a
logBase
                                      (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
maxDistance forall a. Fractional a => a -> a -> a
/ forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
maxExact)
                                      (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
absRelPos forall a. Fractional a => a -> a -> a
/ forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
maxExact)
                                      forall a. Num a => a -> a -> a
* forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
numBuckets' forall a. Num a => a -> a -> a
- Int
maxExact)
                                  )
                            relPosIfLarge' :: Int
relPosIfLarge' = forall a. Ord a => a -> a -> a
min Int
relPosIfLarge (Int
numBuckets' forall a. Num a => a -> a -> a
- Int
1)
                         in if Bool
isSmall then Int
absRelPos else Int
relPosIfLarge'
                   in Int
relBucket forall a. Num a => a -> a -> a
+ Int
relBucket'
              )
              [Int]
keyPos
        )
        [Int]
queryPos

type MkRelPosC device shape seqDim seqName seqSize output =
  ( SGetDevice device,
    SGetShape shape,
    seqDim ~ (shape ! 1),
    seqDim ~ 'Dim seqName seqSize,
    Catch
      ( '[ 'Dim ('Name "*") 'UncheckedSize,
           'Dim ('Name "*") 'UncheckedSize
         ]
          <+> '[ 'Dim ('Name "*") seqSize, 'Dim ('Name "*") seqSize]
      ),
    output
      ~ Tensor
          ('Gradient 'WithoutGradient)
          ('Layout 'Dense)
          device
          ('DataType 'Int64)
          ('Shape '[ 'Dim ('Name "*") ('Size 1), 'Dim ('Name "*") seqSize, 'Dim ('Name "*") seqSize])
  )

-- | Computes relative positions of the input tokens to the encoder.
-- Given an input tensor of shape @[batchDim, Dim seqName seqSize]@,
-- returns a tensor of shape @[1, Dim "*" seqSize, Dim "*" seqSize]@.
mkRelPos ::
  forall m gradient layout device dataType shape relPosEncBucketDim seqDim seqName seqSize output.
  ( MonadThrow m,
    MkRelPosC device shape seqDim seqName seqSize output
  ) =>
  -- | bucket dimension
  SDim relPosEncBucketDim ->
  -- | maximum distance
  Int ->
  -- | input tensor
  Tensor gradient layout device dataType shape ->
  -- | relative positions of the input tokens
  m output
mkRelPos :: forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (relPosEncBucketDim :: Dim (Name Symbol) (Size Nat))
       (seqDim :: Dim (Name Symbol) (Size Nat)) (seqName :: Name Symbol)
       (seqSize :: Size Nat) output.
(MonadThrow m,
 MkRelPosC device shape seqDim seqName seqSize output) =>
SDim relPosEncBucketDim
-> Int -> Tensor gradient layout device dataType shape -> m output
mkRelPos SDim relPosEncBucketDim
relPosEncBucketDim Int
maxDistance Tensor gradient layout device dataType shape
input = do
  SDim ('Dim seqName seqSize)
seqDim <- forall (selectDim :: SelectDim (By Symbol Nat))
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (dim :: Dim (Name Symbol) (Size Nat)) (m :: * -> *).
(dim ~ GetDimF selectDim shape, MonadThrow m) =>
SSelectDim selectDim -> SShape shape -> m (SDim dim)
sGetDimFromShape (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim forall a b. (a -> b) -> a -> b
$ forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @1) forall a b. (a -> b) -> a -> b
$ forall (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType).
SGetShape shape =>
Tensor gradient layout device dataType shape -> SShape shape
sGetShape Tensor gradient layout device dataType shape
input
  let seqSize :: Int
seqSize = forall a. Num a => Integer -> a
fromInteger forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ forall (name :: Name Symbol) (size :: Size Nat).
SDim ('Dim name size) -> SSize size
sDimSize SDim ('Dim seqName seqSize)
seqDim
  forall a (dType :: DType) (dims :: [Dim (Name Symbol) (Size Nat)])
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (m :: * -> *).
(TensorLike a dType dims, MonadThrow m) =>
SGradient gradient
-> SLayout layout
-> SDevice device
-> a
-> m (Tensor
        gradient layout device ('DataType dType) ('Shape dims))
sToTensor SGradient ('Gradient 'WithoutGradient)
gradient' SLayout ('Layout 'Dense)
layout' SDevice device
device [Int -> Int -> Int -> Int -> [[Int]]
mkRelPos' Int
relPosEncBucketSize Int
maxDistance Int
seqSize Int
seqSize]
    forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (m :: * -> *) (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(SGetShape shape, MonadThrow m, Catch (shape <+> shape')) =>
SShape shape'
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
sCheckedShape (forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SList dims -> SShape ('Shape dims)
SShape forall a b. (a -> b) -> a -> b
$ forall (name1 :: Symbol). KnownSymbol name1 => SName ('Name name1)
SName @"*" forall (name :: Name Symbol) (size :: Size Nat).
SName name -> SSize size -> SDim ('Dim name size)
:&: forall (size1 :: Nat). KnownNat size1 => SSize ('Size size1)
SSize @1 forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall (name1 :: Symbol). KnownSymbol name1 => SName ('Name name1)
SName @"*" forall (name :: Name Symbol) (size :: Size Nat).
SName name -> SSize size -> SDim ('Dim name size)
:&: forall (name :: Name Symbol) (size :: Size Nat).
SDim ('Dim name size) -> SSize size
sDimSize SDim ('Dim seqName seqSize)
seqDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall (name1 :: Symbol). KnownSymbol name1 => SName ('Name name1)
SName @"*" forall (name :: Name Symbol) (size :: Size Nat).
SName name -> SSize size -> SDim ('Dim name size)
:&: forall (name :: Name Symbol) (size :: Size Nat).
SDim ('Dim name size) -> SSize size
sDimSize SDim ('Dim seqName seqSize)
seqDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall a. SList '[]
SNil)
  where
    gradient' :: SGradient ('Gradient 'WithoutGradient)
gradient' = forall (requiresGradient :: RequiresGradient).
SRequiresGradient requiresGradient
-> SGradient ('Gradient requiresGradient)
SGradient SRequiresGradient 'WithoutGradient
SWithoutGradient
    layout' :: SLayout ('Layout 'Dense)
layout' = forall (layoutType :: LayoutType).
SLayoutType layoutType -> SLayout ('Layout layoutType)
SLayout SLayoutType 'Dense
SDense
    device :: SDevice device
device = forall (device :: Device (DeviceType Nat))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGetDevice device =>
Tensor gradient layout device dataType shape -> SDevice device
sGetDevice Tensor gradient layout device dataType shape
input
    relPosEncBucketSize :: Int
relPosEncBucketSize = forall a. Num a => Integer -> a
fromInteger forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall name size. Dim name size -> size
dimSize forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SDim relPosEncBucketDim
relPosEncBucketDim

-- | Computes relative positions of the input tokens to the decoder.
--
-- >>> mkDecoderRelPos' 32 128 21 17
-- [[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],[1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],[2,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],[3,2,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0],[4,3,2,1,0,0,0,0,0,0,0,0,0,0,0,0,0],[5,4,3,2,1,0,0,0,0,0,0,0,0,0,0,0,0],[6,5,4,3,2,1,0,0,0,0,0,0,0,0,0,0,0],[7,6,5,4,3,2,1,0,0,0,0,0,0,0,0,0,0],[8,7,6,5,4,3,2,1,0,0,0,0,0,0,0,0,0],[9,8,7,6,5,4,3,2,1,0,0,0,0,0,0,0,0],[10,9,8,7,6,5,4,3,2,1,0,0,0,0,0,0,0],[11,10,9,8,7,6,5,4,3,2,1,0,0,0,0,0,0],[12,11,10,9,8,7,6,5,4,3,2,1,0,0,0,0,0],[13,12,11,10,9,8,7,6,5,4,3,2,1,0,0,0,0],[14,13,12,11,10,9,8,7,6,5,4,3,2,1,0,0,0],[15,14,13,12,11,10,9,8,7,6,5,4,3,2,1,0,0],[16,15,14,13,12,11,10,9,8,7,6,5,4,3,2,1,0],[16,16,15,14,13,12,11,10,9,8,7,6,5,4,3,2,1],[16,16,16,15,14,13,12,11,10,9,8,7,6,5,4,3,2],[17,16,16,16,15,14,13,12,11,10,9,8,7,6,5,4,3],[17,17,16,16,16,15,14,13,12,11,10,9,8,7,6,5,4]]
mkDecoderRelPos' :: Int -> Int -> Int -> Int -> [[Int]]
mkDecoderRelPos' :: Int -> Int -> Int -> Int -> [[Int]]
mkDecoderRelPos' Int
numBuckets Int
maxDistance Int
querySize Int
keySize =
  let queryPos :: [Int]
queryPos = [Int
0, Int
1 .. Int
querySize forall a. Num a => a -> a -> a
- Int
1]
      keyPos :: [Int]
keyPos = [Int
0, Int
1 .. Int
keySize forall a. Num a => a -> a -> a
- Int
1]
      maxExact :: Int
maxExact = Int
numBuckets forall a. Integral a => a -> a -> a
`div` Int
2
   in forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap
        ( \Int
qp ->
            forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap
              ( \Int
kp ->
                  let rawRelPos :: Int
rawRelPos = Int
kp forall a. Num a => a -> a -> a
- Int
qp
                      absRelPos :: Int
absRelPos = forall a. Num a => a -> a
negate forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Ord a => a -> a -> a
min Int
0 forall a b. (a -> b) -> a -> b
$ Int
rawRelPos
                      relBucket' :: Int
relBucket' =
                        let isSmall :: Bool
isSmall = Int
absRelPos forall a. Ord a => a -> a -> Bool
< Int
maxExact
                            relPosIfLarge :: Int
relPosIfLarge =
                              Int
maxExact
                                forall a. Num a => a -> a -> a
+ Double -> Int
double2Int
                                  ( forall a. Floating a => a -> a -> a
logBase
                                      (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
maxDistance forall a. Fractional a => a -> a -> a
/ forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
maxExact)
                                      (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
absRelPos forall a. Fractional a => a -> a -> a
/ forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
maxExact)
                                      forall a. Num a => a -> a -> a
* forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
numBuckets forall a. Num a => a -> a -> a
- Int
maxExact)
                                  )
                            relPosIfLarge' :: Int
relPosIfLarge' = forall a. Ord a => a -> a -> a
min Int
relPosIfLarge (Int
numBuckets forall a. Num a => a -> a -> a
- Int
1)
                         in if Bool
isSmall then Int
absRelPos else Int
relPosIfLarge'
                   in Int
relBucket'
              )
              [Int]
keyPos
        )
        [Int]
queryPos

-- | Computes relative positions of the input tokens to the decoder.
-- Given an input tensor of shape @[batchDim, Dim seqName seqSize]@,
-- returns a tensor of shape @[1, Dim "*" seqSize, Dim "*" seqSize]@.
mkDecoderRelPos ::
  forall m gradient layout device dataType shape relPosEncBucketDim seqDim seqName seqSize output.
  ( MonadThrow m,
    MkRelPosC device shape seqDim seqName seqSize output
  ) =>
  -- | bucket dimension
  SDim relPosEncBucketDim ->
  -- | maximum distance
  Int ->
  -- | decoder input tensor
  Tensor gradient layout device dataType shape ->
  -- | relative positions of the input tokens
  m output
mkDecoderRelPos :: forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (relPosEncBucketDim :: Dim (Name Symbol) (Size Nat))
       (seqDim :: Dim (Name Symbol) (Size Nat)) (seqName :: Name Symbol)
       (seqSize :: Size Nat) output.
(MonadThrow m,
 MkRelPosC device shape seqDim seqName seqSize output) =>
SDim relPosEncBucketDim
-> Int -> Tensor gradient layout device dataType shape -> m output
mkDecoderRelPos SDim relPosEncBucketDim
relPosEncBucketDim Int
maxDistance Tensor gradient layout device dataType shape
input = do
  SDim ('Dim seqName seqSize)
seqDim <- forall (selectDim :: SelectDim (By Symbol Nat))
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (dim :: Dim (Name Symbol) (Size Nat)) (m :: * -> *).
(dim ~ GetDimF selectDim shape, MonadThrow m) =>
SSelectDim selectDim -> SShape shape -> m (SDim dim)
sGetDimFromShape (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim forall a b. (a -> b) -> a -> b
$ forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @1) forall a b. (a -> b) -> a -> b
$ forall (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType).
SGetShape shape =>
Tensor gradient layout device dataType shape -> SShape shape
sGetShape Tensor gradient layout device dataType shape
input
  let seqSize :: Int
seqSize = forall a. Num a => Integer -> a
fromInteger forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ forall (name :: Name Symbol) (size :: Size Nat).
SDim ('Dim name size) -> SSize size
sDimSize SDim ('Dim seqName seqSize)
seqDim
  forall a (dType :: DType) (dims :: [Dim (Name Symbol) (Size Nat)])
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (m :: * -> *).
(TensorLike a dType dims, MonadThrow m) =>
SGradient gradient
-> SLayout layout
-> SDevice device
-> a
-> m (Tensor
        gradient layout device ('DataType dType) ('Shape dims))
sToTensor SGradient ('Gradient 'WithoutGradient)
gradient' SLayout ('Layout 'Dense)
layout' SDevice device
device [Int -> Int -> Int -> Int -> [[Int]]
mkDecoderRelPos' Int
relPosEncBucketSize Int
maxDistance Int
seqSize Int
seqSize]
    forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (m :: * -> *) (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(SGetShape shape, MonadThrow m, Catch (shape <+> shape')) =>
SShape shape'
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
sCheckedShape (forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SList dims -> SShape ('Shape dims)
SShape forall a b. (a -> b) -> a -> b
$ forall (name1 :: Symbol). KnownSymbol name1 => SName ('Name name1)
SName @"*" forall (name :: Name Symbol) (size :: Size Nat).
SName name -> SSize size -> SDim ('Dim name size)
:&: forall (size1 :: Nat). KnownNat size1 => SSize ('Size size1)
SSize @1 forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall (name1 :: Symbol). KnownSymbol name1 => SName ('Name name1)
SName @"*" forall (name :: Name Symbol) (size :: Size Nat).
SName name -> SSize size -> SDim ('Dim name size)
:&: forall (name :: Name Symbol) (size :: Size Nat).
SDim ('Dim name size) -> SSize size
sDimSize SDim ('Dim seqName seqSize)
seqDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall (name1 :: Symbol). KnownSymbol name1 => SName ('Name name1)
SName @"*" forall (name :: Name Symbol) (size :: Size Nat).
SName name -> SSize size -> SDim ('Dim name size)
:&: forall (name :: Name Symbol) (size :: Size Nat).
SDim ('Dim name size) -> SSize size
sDimSize SDim ('Dim seqName seqSize)
seqDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall a. SList '[]
SNil)
  where
    gradient' :: SGradient ('Gradient 'WithoutGradient)
gradient' = forall (requiresGradient :: RequiresGradient).
SRequiresGradient requiresGradient
-> SGradient ('Gradient requiresGradient)
SGradient SRequiresGradient 'WithoutGradient
SWithoutGradient
    layout' :: SLayout ('Layout 'Dense)
layout' = forall (layoutType :: LayoutType).
SLayoutType layoutType -> SLayout ('Layout layoutType)
SLayout SLayoutType 'Dense
SDense
    device :: SDevice device
device = forall (device :: Device (DeviceType Nat))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGetDevice device =>
Tensor gradient layout device dataType shape -> SDevice device
sGetDevice Tensor gradient layout device dataType shape
input
    relPosEncBucketSize :: Int
relPosEncBucketSize = forall a. Num a => Integer -> a
fromInteger forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall name size. Dim name size -> size
dimSize forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SDim relPosEncBucketDim
relPosEncBucketDim

data MkRelPos (relPosEncBucketDim :: Dim (Name Symbol) (Size Nat)) where
  MkRelPos ::
    forall relPosEncBucketDim.
    { forall (relPosEncBucketDim :: Dim (Name Symbol) (Size Nat)).
MkRelPos relPosEncBucketDim -> SDim relPosEncBucketDim
relPosEncBucketDim :: SDim relPosEncBucketDim,
      forall (relPosEncBucketDim :: Dim (Name Symbol) (Size Nat)).
MkRelPos relPosEncBucketDim -> Int
relPosMaxDistance :: Int
    } ->
    MkRelPos relPosEncBucketDim
  MkDecoderRelPos ::
    forall relPosEncBucketDim.
    { forall (relPosEncBucketDim :: Dim (Name Symbol) (Size Nat)).
MkRelPos relPosEncBucketDim -> SDim relPosEncBucketDim
decoderRelPosEncBucketDim :: SDim relPosEncBucketDim,
      forall (relPosEncBucketDim :: Dim (Name Symbol) (Size Nat)).
MkRelPos relPosEncBucketDim -> Int
decoderRelPosMaxDistance :: Int
    } ->
    MkRelPos relPosEncBucketDim
  deriving stock (Int -> MkRelPos relPosEncBucketDim -> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (relPosEncBucketDim :: Dim (Name Symbol) (Size Nat)).
Int -> MkRelPos relPosEncBucketDim -> ShowS
forall (relPosEncBucketDim :: Dim (Name Symbol) (Size Nat)).
[MkRelPos relPosEncBucketDim] -> ShowS
forall (relPosEncBucketDim :: Dim (Name Symbol) (Size Nat)).
MkRelPos relPosEncBucketDim -> String
showList :: [MkRelPos relPosEncBucketDim] -> ShowS
$cshowList :: forall (relPosEncBucketDim :: Dim (Name Symbol) (Size Nat)).
[MkRelPos relPosEncBucketDim] -> ShowS
show :: MkRelPos relPosEncBucketDim -> String
$cshow :: forall (relPosEncBucketDim :: Dim (Name Symbol) (Size Nat)).
MkRelPos relPosEncBucketDim -> String
showsPrec :: Int -> MkRelPos relPosEncBucketDim -> ShowS
$cshowsPrec :: forall (relPosEncBucketDim :: Dim (Name Symbol) (Size Nat)).
Int -> MkRelPos relPosEncBucketDim -> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (relPosEncBucketDim :: Dim (Name Symbol) (Size Nat)) x.
Rep (MkRelPos relPosEncBucketDim) x -> MkRelPos relPosEncBucketDim
forall (relPosEncBucketDim :: Dim (Name Symbol) (Size Nat)) x.
MkRelPos relPosEncBucketDim -> Rep (MkRelPos relPosEncBucketDim) x
$cto :: forall (relPosEncBucketDim :: Dim (Name Symbol) (Size Nat)) x.
Rep (MkRelPos relPosEncBucketDim) x -> MkRelPos relPosEncBucketDim
$cfrom :: forall (relPosEncBucketDim :: Dim (Name Symbol) (Size Nat)) x.
MkRelPos relPosEncBucketDim -> Rep (MkRelPos relPosEncBucketDim) x
Generic)

type instance ModelSpec (MkRelPos relPosEncBucketDim) = MkRelPos relPosEncBucketDim

instance HasInitialize (MkRelPos relPosEncBucketDim) generatorDevice (MkRelPos relPosEncBucketDim) generatorDevice where
  initialize :: forall (m :: * -> *).
MonadThrow m =>
ModelSpec (MkRelPos relPosEncBucketDim)
-> Generator generatorDevice
-> m (MkRelPos relPosEncBucketDim, Generator generatorDevice)
initialize ModelSpec (MkRelPos relPosEncBucketDim)
spec Generator generatorDevice
g = forall (f :: * -> *) a. Applicative f => a -> f a
pure (ModelSpec (MkRelPos relPosEncBucketDim)
spec, Generator generatorDevice
g)

instance HasStateDict (MkRelPos relPosEncBucketDim) where
  fromStateDict :: forall (m :: * -> *).
(MonadIO m, MonadThrow m, MonadState StateDict m) =>
ModelSpec (MkRelPos relPosEncBucketDim)
-> StateDictKey -> m (MkRelPos relPosEncBucketDim)
fromStateDict ModelSpec (MkRelPos relPosEncBucketDim)
spec StateDictKey
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ModelSpec (MkRelPos relPosEncBucketDim)
spec
  toStateDict :: forall (m :: * -> *).
(MonadThrow m, MonadState StateDict m) =>
StateDictKey -> MkRelPos relPosEncBucketDim -> m ()
toStateDict StateDictKey
_ MkRelPos relPosEncBucketDim
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

instance
  MkRelPosC device shape seqDim seqName seqSize output =>
  HasForward
    (MkRelPos relPosEncBucketDim)
    (Tensor gradient layout device dataType shape)
    generatorDevice
    ( Tensor
        ('Gradient 'WithoutGradient)
        ('Layout 'Dense)
        device
        ('DataType 'Int64)
        ('Shape '[ 'Dim ('Name "*") ('Size 1), 'Dim ('Name "*") seqSize, 'Dim ('Name "*") seqSize])
    )
    generatorDevice
  where
  forward :: forall (m :: * -> *).
MonadThrow m =>
MkRelPos relPosEncBucketDim
-> Tensor gradient layout device dataType shape
-> Generator generatorDevice
-> m (Tensor
        ('Gradient 'WithoutGradient)
        ('Layout 'Dense)
        device
        ('DataType 'Int64)
        ('Shape
           '[ 'Dim ('Name "*") ('Size 1), 'Dim ('Name "*") seqSize,
              'Dim ('Name "*") seqSize]),
      Generator generatorDevice)
forward MkRelPos {Int
SDim relPosEncBucketDim
relPosMaxDistance :: Int
relPosEncBucketDim :: SDim relPosEncBucketDim
relPosMaxDistance :: forall (relPosEncBucketDim :: Dim (Name Symbol) (Size Nat)).
MkRelPos relPosEncBucketDim -> Int
relPosEncBucketDim :: forall (relPosEncBucketDim :: Dim (Name Symbol) (Size Nat)).
MkRelPos relPosEncBucketDim -> SDim relPosEncBucketDim
..} Tensor gradient layout device dataType shape
input Generator generatorDevice
g = do
    Tensor
  ('Gradient 'WithoutGradient)
  ('Layout 'Dense)
  device
  ('DataType 'Int64)
  ('Shape
     '[ 'Dim ('Name "*") ('Size 1), 'Dim ('Name "*") seqSize,
        'Dim ('Name "*") seqSize])
relPos <- forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (relPosEncBucketDim :: Dim (Name Symbol) (Size Nat))
       (seqDim :: Dim (Name Symbol) (Size Nat)) (seqName :: Name Symbol)
       (seqSize :: Size Nat) output.
(MonadThrow m,
 MkRelPosC device shape seqDim seqName seqSize output) =>
SDim relPosEncBucketDim
-> Int -> Tensor gradient layout device dataType shape -> m output
mkRelPos SDim relPosEncBucketDim
relPosEncBucketDim Int
relPosMaxDistance Tensor gradient layout device dataType shape
input
    forall (f :: * -> *) a. Applicative f => a -> f a
pure (Tensor
  ('Gradient 'WithoutGradient)
  ('Layout 'Dense)
  device
  ('DataType 'Int64)
  ('Shape
     '[ 'Dim ('Name "*") ('Size 1), 'Dim ('Name "*") seqSize,
        'Dim ('Name "*") seqSize])
relPos, Generator generatorDevice
g)
  forward MkDecoderRelPos {Int
SDim relPosEncBucketDim
decoderRelPosMaxDistance :: Int
decoderRelPosEncBucketDim :: SDim relPosEncBucketDim
decoderRelPosMaxDistance :: forall (relPosEncBucketDim :: Dim (Name Symbol) (Size Nat)).
MkRelPos relPosEncBucketDim -> Int
decoderRelPosEncBucketDim :: forall (relPosEncBucketDim :: Dim (Name Symbol) (Size Nat)).
MkRelPos relPosEncBucketDim -> SDim relPosEncBucketDim
..} Tensor gradient layout device dataType shape
input Generator generatorDevice
g = do
    Tensor
  ('Gradient 'WithoutGradient)
  ('Layout 'Dense)
  device
  ('DataType 'Int64)
  ('Shape
     '[ 'Dim ('Name "*") ('Size 1), 'Dim ('Name "*") seqSize,
        'Dim ('Name "*") seqSize])
decoderRelPos <- forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (relPosEncBucketDim :: Dim (Name Symbol) (Size Nat))
       (seqDim :: Dim (Name Symbol) (Size Nat)) (seqName :: Name Symbol)
       (seqSize :: Size Nat) output.
(MonadThrow m,
 MkRelPosC device shape seqDim seqName seqSize output) =>
SDim relPosEncBucketDim
-> Int -> Tensor gradient layout device dataType shape -> m output
mkDecoderRelPos SDim relPosEncBucketDim
decoderRelPosEncBucketDim Int
decoderRelPosMaxDistance Tensor gradient layout device dataType shape
input
    forall (f :: * -> *) a. Applicative f => a -> f a
pure (Tensor
  ('Gradient 'WithoutGradient)
  ('Layout 'Dense)
  device
  ('DataType 'Int64)
  ('Shape
     '[ 'Dim ('Name "*") ('Size 1), 'Dim ('Name "*") seqSize,
        'Dim ('Name "*") seqSize])
decoderRelPos, Generator generatorDevice
g)

type MkTransformerPaddingMaskC layout device dataType shape output =
  ( SGetDevice device,
    Catch (dataType <+> 'DataType 'Int64),
    Catch (BroadcastShapesF shape ('Shape '[])),
    output
      ~ Tensor
          ('Gradient 'WithoutGradient)
          (layout <+> 'Layout 'Dense)
          device
          ('DataType 'Bool)
          (BroadcastShapesF shape ('Shape '[]))
  )

-- | Computes the padding mask for a transformer.
-- Given an input tensor of shape @[batchDim, Dim seqName seqSize]@,
-- returns a tensor of shape @[batchDim, Dim "*" seqSize]@.
mkTransformerPaddingMask ::
  forall m gradient layout device dataType shape output.
  ( MonadThrow m,
    MkTransformerPaddingMaskC layout device dataType shape output
  ) =>
  -- | padding token id
  Int ->
  -- | input tensor
  Tensor gradient layout device dataType shape ->
  -- | padding mask
  m output
mkTransformerPaddingMask :: forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) output.
(MonadThrow m,
 MkTransformerPaddingMaskC layout device dataType shape output) =>
Int -> Tensor gradient layout device dataType shape -> m output
mkTransformerPaddingMask Int
padTokenId Tensor gradient layout device dataType shape
input = do
  let device :: SDevice device
device = forall (device :: Device (DeviceType Nat))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGetDevice device =>
Tensor gradient layout device dataType shape -> SDevice device
sGetDevice Tensor gradient layout device dataType shape
input
  Tensor
  ('Gradient 'WithoutGradient)
  ('Layout 'Dense)
  device
  ('DataType 'Int64)
  ('Shape '[])
padToken <- forall a (dType :: DType) (dims :: [Dim (Name Symbol) (Size Nat)])
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (m :: * -> *).
(TensorLike a dType dims, MonadThrow m) =>
SGradient gradient
-> SLayout layout
-> SDevice device
-> a
-> m (Tensor
        gradient layout device ('DataType dType) ('Shape dims))
sToTensor (forall (requiresGradient :: RequiresGradient).
SRequiresGradient requiresGradient
-> SGradient ('Gradient requiresGradient)
SGradient SRequiresGradient 'WithoutGradient
SWithoutGradient) (forall (layoutType :: LayoutType).
SLayoutType layoutType -> SLayout ('Layout layoutType)
SLayout SLayoutType 'Dense
SDense) SDevice device
device Int
padTokenId
  Tensor gradient layout device dataType shape
input forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient' :: Gradient RequiresGradient)
       (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
       (dataType' :: DataType DType)
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, Catch (dataType <+> dataType'),
 shape'' ~ BroadcastShapesF shape shape', Catch shape'') =>
Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> m (Tensor
        ('Gradient 'WithoutGradient)
        (layout <+> layout')
        (device <+> device')
        ('DataType 'Bool)
        shape'')
==. Tensor
  ('Gradient 'WithoutGradient)
  ('Layout 'Dense)
  device
  ('DataType 'Int64)
  ('Shape '[])
padToken

newtype MkTransformerPaddingMask = MkTransformerPaddingMask {MkTransformerPaddingMask -> Int
padTokenId :: Int}
  deriving stock (Int -> MkTransformerPaddingMask -> ShowS
[MkTransformerPaddingMask] -> ShowS
MkTransformerPaddingMask -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [MkTransformerPaddingMask] -> ShowS
$cshowList :: [MkTransformerPaddingMask] -> ShowS
show :: MkTransformerPaddingMask -> String
$cshow :: MkTransformerPaddingMask -> String
showsPrec :: Int -> MkTransformerPaddingMask -> ShowS
$cshowsPrec :: Int -> MkTransformerPaddingMask -> ShowS
Show, forall x.
Rep MkTransformerPaddingMask x -> MkTransformerPaddingMask
forall x.
MkTransformerPaddingMask -> Rep MkTransformerPaddingMask x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x.
Rep MkTransformerPaddingMask x -> MkTransformerPaddingMask
$cfrom :: forall x.
MkTransformerPaddingMask -> Rep MkTransformerPaddingMask x
Generic)

type instance
  ModelSpec MkTransformerPaddingMask =
    MkTransformerPaddingMask

instance
  HasInitialize
    MkTransformerPaddingMask
    generatorDevice
    MkTransformerPaddingMask
    generatorDevice
  where
  initialize :: forall (m :: * -> *).
MonadThrow m =>
ModelSpec MkTransformerPaddingMask
-> Generator generatorDevice
-> m (MkTransformerPaddingMask, Generator generatorDevice)
initialize ModelSpec MkTransformerPaddingMask
spec Generator generatorDevice
g = forall (f :: * -> *) a. Applicative f => a -> f a
pure (ModelSpec MkTransformerPaddingMask
spec, Generator generatorDevice
g)

instance HasStateDict MkTransformerPaddingMask where
  fromStateDict :: forall (m :: * -> *).
(MonadIO m, MonadThrow m, MonadState StateDict m) =>
ModelSpec MkTransformerPaddingMask
-> StateDictKey -> m MkTransformerPaddingMask
fromStateDict ModelSpec MkTransformerPaddingMask
spec StateDictKey
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ModelSpec MkTransformerPaddingMask
spec
  toStateDict :: forall (m :: * -> *).
(MonadThrow m, MonadState StateDict m) =>
StateDictKey -> MkTransformerPaddingMask -> m ()
toStateDict StateDictKey
_ MkTransformerPaddingMask
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

instance
  MkTransformerPaddingMaskC layout device dataType shape output =>
  HasForward
    MkTransformerPaddingMask
    (Tensor gradient layout device dataType shape)
    generatorDevice
    output
    generatorDevice
  where
  forward :: forall (m :: * -> *).
MonadThrow m =>
MkTransformerPaddingMask
-> Tensor gradient layout device dataType shape
-> Generator generatorDevice
-> m (output, Generator generatorDevice)
forward MkTransformerPaddingMask {Int
padTokenId :: Int
padTokenId :: MkTransformerPaddingMask -> Int
..} Tensor gradient layout device dataType shape
input Generator generatorDevice
g = do
    output
paddingMask <- forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) output.
(MonadThrow m,
 MkTransformerPaddingMaskC layout device dataType shape output) =>
Int -> Tensor gradient layout device dataType shape -> m output
mkTransformerPaddingMask Int
padTokenId Tensor gradient layout device dataType shape
input
    forall (f :: * -> *) a. Applicative f => a -> f a
pure (output
paddingMask, Generator generatorDevice
g)

type MkTransformerAttentionMaskC transformerDataType gradient layout device dataType shape seqDim output =
  ( SGetLayout layout,
    SGetDevice device,
    SGetShape shape,
    seqDim ~ (shape ! 1),
    Catch (gradient <+> 'Gradient 'WithoutGradient),
    Catch (dataType <+> 'DataType 'Bool),
    Catch (UnsqueezeF ('SelectDim ('ByIndex 1)) shape),
    Catch
      ( BroadcastShapesF
          (UnsqueezeF ('SelectDim ('ByIndex 1)) shape)
          ('Shape '[ 'Dim ('Name "*") ('Size 1), seqDim, seqDim])
      ),
    output
      ~ Tensor
          ('Gradient 'WithoutGradient)
          (layout <+> 'Layout 'Dense)
          device
          transformerDataType
          ( BroadcastShapesF
              (UnsqueezeF ('SelectDim ('ByIndex 1)) shape)
              ('Shape '[ 'Dim ('Name "*") ('Size 1), seqDim, seqDim])
          )
  )

-- | Creates a bidirectional attention mask for a transformer.
-- Given a padding mask of shape @[batchDim, seqDim]@,
-- returns a tensor of shape @[batchDim, seqDim, seqDim]@.
mkTransformerAttentionMask ::
  forall m transformerDataType gradient layout device dataType shape seqDim output.
  ( MonadThrow m,
    MkTransformerAttentionMaskC transformerDataType gradient layout device dataType shape seqDim output
  ) =>
  -- | data type singleton of the transformer
  SDataType transformerDataType ->
  -- | attention mask bias (typically a large negative number)
  Double ->
  -- | encoder padding mask
  Tensor gradient layout device dataType shape ->
  m output
mkTransformerAttentionMask :: forall (m :: * -> *) (transformerDataType :: DataType DType)
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (seqDim :: Dim (Name Symbol) (Size Nat)) output.
(MonadThrow m,
 MkTransformerAttentionMaskC
   transformerDataType
   gradient
   layout
   device
   dataType
   shape
   seqDim
   output) =>
SDataType transformerDataType
-> Double
-> Tensor gradient layout device dataType shape
-> m output
mkTransformerAttentionMask SDataType transformerDataType
transformerDataType Double
attentionMaskBias Tensor gradient layout device dataType shape
paddingMask = do
  let pmLayout :: SLayout layout
pmLayout = forall (layout :: Layout LayoutType)
       (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGetLayout layout =>
Tensor gradient layout device dataType shape -> SLayout layout
sGetLayout Tensor gradient layout device dataType shape
paddingMask
      pmDevice :: SDevice device
pmDevice = forall (device :: Device (DeviceType Nat))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGetDevice device =>
Tensor gradient layout device dataType shape -> SDevice device
sGetDevice Tensor gradient layout device dataType shape
paddingMask
      pmShape :: SShape shape
pmShape = forall (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType).
SGetShape shape =>
Tensor gradient layout device dataType shape -> SShape shape
sGetShape Tensor gradient layout device dataType shape
paddingMask
  SDim seqDim
pmSeqDim <- forall (selectDim :: SelectDim (By Symbol Nat))
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (dim :: Dim (Name Symbol) (Size Nat)) (m :: * -> *).
(dim ~ GetDimF selectDim shape, MonadThrow m) =>
SSelectDim selectDim -> SShape shape -> m (SDim dim)
sGetDimFromShape (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim forall a b. (a -> b) -> a -> b
$ forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @1) SShape shape
pmShape
  Tensor
  ('Gradient 'WithoutGradient)
  layout
  device
  transformerDataType
  ('Shape '[ 'Dim ('Name "*") ('Size 1), seqDim, seqDim])
emptyMask <-
    forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
MonadThrow m =>
TensorSpec gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape)
sZeros forall a b. (a -> b) -> a -> b
$
      forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> SShape shape
-> TensorSpec gradient layout device dataType shape
TensorSpec
        (forall (requiresGradient :: RequiresGradient).
SRequiresGradient requiresGradient
-> SGradient ('Gradient requiresGradient)
SGradient SRequiresGradient 'WithoutGradient
SWithoutGradient)
        SLayout layout
pmLayout
        SDevice device
pmDevice
        SDataType transformerDataType
transformerDataType
        (forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SList dims -> SShape ('Shape dims)
SShape forall a b. (a -> b) -> a -> b
$ forall (name1 :: Symbol). KnownSymbol name1 => SName ('Name name1)
SName @"*" forall (name :: Name Symbol) (size :: Size Nat).
SName name -> SSize size -> SDim ('Dim name size)
:&: forall (size1 :: Nat). KnownNat size1 => SSize ('Size size1)
SSize @1 forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: SDim seqDim
pmSeqDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: SDim seqDim
pmSeqDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall a. SList '[]
SNil)
  Tensor
  gradient
  layout
  device
  dataType
  (UnsqueezeF ('SelectDim ('ByIndex 1)) shape)
paddingMask' <- forall (selectDim :: SelectDim (By Symbol Nat))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(shape' ~ UnsqueezeF selectDim shape, Catch shape',
 SingI selectDim, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
unsqueeze @('SelectDim ('ByIndex 1)) Tensor gradient layout device dataType shape
paddingMask
  forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) value
       (gradient' :: Gradient RequiresGradient)
       (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
       (dataType' :: DataType DType)
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(Scalar value, MonadThrow m,
 Catch (gradient <+> 'Gradient 'WithoutGradient),
 Catch (dataType <+> 'DataType 'Bool),
 shape'' ~ BroadcastShapesF shape shape', Catch shape'') =>
Tensor gradient layout device dataType shape
-> value
-> Tensor gradient' layout' device' dataType' shape'
-> m (Tensor
        gradient'
        (layout <+> (layout' <+> 'Layout 'Dense))
        (device <+> device')
        dataType'
        shape'')
maskedFill Tensor
  gradient
  layout
  device
  dataType
  (UnsqueezeF ('SelectDim ('ByIndex 1)) shape)
paddingMask' Double
attentionMaskBias Tensor
  ('Gradient 'WithoutGradient)
  layout
  device
  transformerDataType
  ('Shape '[ 'Dim ('Name "*") ('Size 1), seqDim, seqDim])
emptyMask

data MkTransformerAttentionMask (dataType :: DataType DType) where
  MkTransformerAttentionMask ::
    forall dataType.
    { forall (dataType :: DataType DType).
MkTransformerAttentionMask dataType -> SDataType dataType
attentionMaskDataType :: SDataType dataType,
      forall (dataType :: DataType DType).
MkTransformerAttentionMask dataType -> Double
attentionMaskBias :: Double
    } ->
    MkTransformerAttentionMask dataType
  deriving stock (Int -> MkTransformerAttentionMask dataType -> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (dataType :: DataType DType).
Int -> MkTransformerAttentionMask dataType -> ShowS
forall (dataType :: DataType DType).
[MkTransformerAttentionMask dataType] -> ShowS
forall (dataType :: DataType DType).
MkTransformerAttentionMask dataType -> String
showList :: [MkTransformerAttentionMask dataType] -> ShowS
$cshowList :: forall (dataType :: DataType DType).
[MkTransformerAttentionMask dataType] -> ShowS
show :: MkTransformerAttentionMask dataType -> String
$cshow :: forall (dataType :: DataType DType).
MkTransformerAttentionMask dataType -> String
showsPrec :: Int -> MkTransformerAttentionMask dataType -> ShowS
$cshowsPrec :: forall (dataType :: DataType DType).
Int -> MkTransformerAttentionMask dataType -> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (dataType :: DataType DType) x.
Rep (MkTransformerAttentionMask dataType) x
-> MkTransformerAttentionMask dataType
forall (dataType :: DataType DType) x.
MkTransformerAttentionMask dataType
-> Rep (MkTransformerAttentionMask dataType) x
$cto :: forall (dataType :: DataType DType) x.
Rep (MkTransformerAttentionMask dataType) x
-> MkTransformerAttentionMask dataType
$cfrom :: forall (dataType :: DataType DType) x.
MkTransformerAttentionMask dataType
-> Rep (MkTransformerAttentionMask dataType) x
Generic)

type instance
  ModelSpec (MkTransformerAttentionMask dataType) =
    MkTransformerAttentionMask dataType

instance
  HasInitialize
    (MkTransformerAttentionMask dataType)
    generatorDevice
    (MkTransformerAttentionMask dataType)
    generatorDevice
  where
  initialize :: forall (m :: * -> *).
MonadThrow m =>
ModelSpec (MkTransformerAttentionMask dataType)
-> Generator generatorDevice
-> m (MkTransformerAttentionMask dataType,
      Generator generatorDevice)
initialize ModelSpec (MkTransformerAttentionMask dataType)
spec Generator generatorDevice
g = forall (f :: * -> *) a. Applicative f => a -> f a
pure (ModelSpec (MkTransformerAttentionMask dataType)
spec, Generator generatorDevice
g)

instance HasStateDict (MkTransformerAttentionMask dataType) where
  fromStateDict :: forall (m :: * -> *).
(MonadIO m, MonadThrow m, MonadState StateDict m) =>
ModelSpec (MkTransformerAttentionMask dataType)
-> StateDictKey -> m (MkTransformerAttentionMask dataType)
fromStateDict ModelSpec (MkTransformerAttentionMask dataType)
spec StateDictKey
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ModelSpec (MkTransformerAttentionMask dataType)
spec
  toStateDict :: forall (m :: * -> *).
(MonadThrow m, MonadState StateDict m) =>
StateDictKey -> MkTransformerAttentionMask dataType -> m ()
toStateDict StateDictKey
_ MkTransformerAttentionMask dataType
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

instance
  MkTransformerAttentionMaskC dataType inputGradient inputLayout inputDevice inputDataType inputShape seqDim output =>
  HasForward
    (MkTransformerAttentionMask dataType)
    (Tensor inputGradient inputLayout inputDevice inputDataType inputShape)
    generatorDevice
    output
    generatorDevice
  where
  forward :: forall (m :: * -> *).
MonadThrow m =>
MkTransformerAttentionMask dataType
-> Tensor
     inputGradient inputLayout inputDevice inputDataType inputShape
-> Generator generatorDevice
-> m (output, Generator generatorDevice)
forward MkTransformerAttentionMask {Double
SDataType dataType
attentionMaskBias :: Double
attentionMaskDataType :: SDataType dataType
attentionMaskBias :: forall (dataType :: DataType DType).
MkTransformerAttentionMask dataType -> Double
attentionMaskDataType :: forall (dataType :: DataType DType).
MkTransformerAttentionMask dataType -> SDataType dataType
..} Tensor
  inputGradient inputLayout inputDevice inputDataType inputShape
input Generator generatorDevice
g = do
    output
attentionMask <- forall (m :: * -> *) (transformerDataType :: DataType DType)
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (seqDim :: Dim (Name Symbol) (Size Nat)) output.
(MonadThrow m,
 MkTransformerAttentionMaskC
   transformerDataType
   gradient
   layout
   device
   dataType
   shape
   seqDim
   output) =>
SDataType transformerDataType
-> Double
-> Tensor gradient layout device dataType shape
-> m output
mkTransformerAttentionMask SDataType dataType
attentionMaskDataType Double
attentionMaskBias Tensor
  inputGradient inputLayout inputDevice inputDataType inputShape
input
    forall (f :: * -> *) a. Applicative f => a -> f a
pure (output
attentionMask, Generator generatorDevice
g)

type MkTransformerDecoderAttentionMaskC transformerDataType layout device shape seqDim output =
  ( SGetLayout layout,
    SGetDevice device,
    SGetShape shape,
    seqDim ~ (shape ! 1),
    Catch seqDim,
    Catch (UnsqueezeF ('SelectDim ('ByIndex 1)) shape),
    Catch
      ( BroadcastShapesF
          ('Shape '[ 'Dim ('Name "*") ('Size 1), seqDim, seqDim])
          (UnsqueezeF ('SelectDim ('ByIndex 1)) shape)
      ),
    Catch
      ( BroadcastShapesF
          ( BroadcastShapesF
              ('Shape '[ 'Dim ('Name "*") ('Size 1), seqDim, seqDim])
              (UnsqueezeF ('SelectDim ('ByIndex 1)) shape)
          )
          ('Shape '[ 'Dim ('Name "*") ('Size 1), seqDim, seqDim])
      ),
    output
      ~ Tensor
          ('Gradient 'WithoutGradient)
          (layout <+> 'Layout 'Dense)
          device
          transformerDataType
          ( BroadcastShapesF
              ( BroadcastShapesF
                  ('Shape '[ 'Dim ('Name "*") ('Size 1), seqDim, seqDim])
                  (UnsqueezeF ('SelectDim ('ByIndex 1)) shape)
              )
              ('Shape '[ 'Dim ('Name "*") ('Size 1), seqDim, seqDim])
          )
  )

-- | Creates a causal attention mask for a transformer decoder.
-- Given a padding mask of shape @[batchDim, seqDim]@,
-- returns a tensor of shape @[batchDim, seqDim, seqDim]@.
mkTransformerDecoderAttentionMask ::
  forall m transformerDataType gradient layout device dataType shape seqDim output.
  ( MonadThrow m,
    MkTransformerDecoderAttentionMaskC transformerDataType layout device shape seqDim output
  ) =>
  -- | data type singleton of the transformer
  SDataType transformerDataType ->
  -- | attention mask bias (typically a large negative number)
  Double ->
  -- | decoder padding mask
  Tensor gradient layout device dataType shape ->
  m output
mkTransformerDecoderAttentionMask :: forall (m :: * -> *) (transformerDataType :: DataType DType)
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (seqDim :: Dim (Name Symbol) (Size Nat)) output.
(MonadThrow m,
 MkTransformerDecoderAttentionMaskC
   transformerDataType layout device shape seqDim output) =>
SDataType transformerDataType
-> Double
-> Tensor gradient layout device dataType shape
-> m output
mkTransformerDecoderAttentionMask SDataType transformerDataType
transformerDataType Double
attentionMaskBias Tensor gradient layout device dataType shape
paddingMask = do
  let pmLayout :: SLayout layout
pmLayout = forall (layout :: Layout LayoutType)
       (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGetLayout layout =>
Tensor gradient layout device dataType shape -> SLayout layout
sGetLayout Tensor gradient layout device dataType shape
paddingMask
      pmDevice :: SDevice device
pmDevice = forall (device :: Device (DeviceType Nat))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGetDevice device =>
Tensor gradient layout device dataType shape -> SDevice device
sGetDevice Tensor gradient layout device dataType shape
paddingMask
      pmShape :: SShape shape
pmShape = forall (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType).
SGetShape shape =>
Tensor gradient layout device dataType shape -> SShape shape
sGetShape Tensor gradient layout device dataType shape
paddingMask
  SDim seqDim
pmSeqDim <- forall (selectDim :: SelectDim (By Symbol Nat))
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (dim :: Dim (Name Symbol) (Size Nat)) (m :: * -> *).
(dim ~ GetDimF selectDim shape, MonadThrow m) =>
SSelectDim selectDim -> SShape shape -> m (SDim dim)
sGetDimFromShape (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim forall a b. (a -> b) -> a -> b
$ forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @1) SShape shape
pmShape
  Tensor
  ('Gradient 'WithoutGradient)
  layout
  device
  ('DataType 'Bool)
  ('Shape '[seqDim, seqDim])
causalMask <-
    forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
MonadThrow m =>
Tensor gradient layout device dataType shape
-> m (Tensor
        ('Gradient 'WithoutGradient) layout device ('DataType 'Bool) shape)
bool forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Int
-> Tensor gradient layout device dataType shape
-> Tensor gradient layout device dataType shape
triu Int
1
      forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
MonadThrow m =>
TensorSpec gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape)
sOnes
        ( forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> SShape shape
-> TensorSpec gradient layout device dataType shape
TensorSpec
            (forall (requiresGradient :: RequiresGradient).
SRequiresGradient requiresGradient
-> SGradient ('Gradient requiresGradient)
SGradient SRequiresGradient 'WithoutGradient
SWithoutGradient)
            SLayout layout
pmLayout
            SDevice device
pmDevice
            SDataType transformerDataType
transformerDataType
            (forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SList dims -> SShape ('Shape dims)
SShape forall a b. (a -> b) -> a -> b
$ SDim seqDim
pmSeqDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: SDim seqDim
pmSeqDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall a. SList '[]
SNil)
        )
  Tensor
  ('Gradient 'WithoutGradient)
  layout
  device
  ('DataType 'Bool)
  ('Shape '[ 'Dim ('Name "*") ('Size 1), seqDim, seqDim])
causalMask' <- forall (selectDim :: SelectDim (By Symbol Nat))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(shape' ~ UnsqueezeF selectDim shape, Catch shape',
 SingI selectDim, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
unsqueeze @('SelectDim ('ByIndex 0)) Tensor
  ('Gradient 'WithoutGradient)
  layout
  device
  ('DataType 'Bool)
  ('Shape '[seqDim, seqDim])
causalMask
  Tensor
  ('Gradient 'WithoutGradient)
  layout
  device
  transformerDataType
  ('Shape '[ 'Dim ('Name "*") ('Size 1), seqDim, seqDim])
emptyMask <-
    forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
MonadThrow m =>
TensorSpec gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape)
sZeros forall a b. (a -> b) -> a -> b
$
      forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> SShape shape
-> TensorSpec gradient layout device dataType shape
TensorSpec
        (forall (requiresGradient :: RequiresGradient).
SRequiresGradient requiresGradient
-> SGradient ('Gradient requiresGradient)
SGradient SRequiresGradient 'WithoutGradient
SWithoutGradient)
        SLayout layout
pmLayout
        SDevice device
pmDevice
        SDataType transformerDataType
transformerDataType
        (forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SList dims -> SShape ('Shape dims)
SShape forall a b. (a -> b) -> a -> b
$ forall (name1 :: Symbol). KnownSymbol name1 => SName ('Name name1)
SName @"*" forall (name :: Name Symbol) (size :: Size Nat).
SName name -> SSize size -> SDim ('Dim name size)
:&: forall (size1 :: Nat). KnownNat size1 => SSize ('Size size1)
SSize @1 forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: SDim seqDim
pmSeqDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: SDim seqDim
pmSeqDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall a. SList '[]
SNil)
  Tensor
  gradient
  layout
  device
  dataType
  (UnsqueezeF ('SelectDim ('ByIndex 1)) shape)
paddingMask' <- forall (selectDim :: SelectDim (By Symbol Nat))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(shape' ~ UnsqueezeF selectDim shape, Catch shape',
 SingI selectDim, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
unsqueeze @('SelectDim ('ByIndex 1)) Tensor gradient layout device dataType shape
paddingMask
  Tensor
  ('Gradient 'WithoutGradient)
  layout
  device
  ('DataType 'Bool)
  (BroadcastShapesF
     ('Shape '[ 'Dim ('Name "*") ('Size 1), seqDim, seqDim])
     (UnsqueezeF ('SelectDim ('ByIndex 1)) shape))
booleanMask <- Tensor
  ('Gradient 'WithoutGradient)
  layout
  device
  ('DataType 'Bool)
  ('Shape '[ 'Dim ('Name "*") ('Size 1), seqDim, seqDim])
causalMask' forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient' :: Gradient RequiresGradient)
       (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
       (dataType' :: DataType DType)
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, shape'' ~ BroadcastShapesF shape shape',
 Catch shape'') =>
Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> m (Tensor
        ('Gradient 'WithoutGradient)
        (layout <+> layout')
        (device <+> device')
        ('DataType 'Bool)
        shape'')
`logicalOr` Tensor
  gradient
  layout
  device
  dataType
  (UnsqueezeF ('SelectDim ('ByIndex 1)) shape)
paddingMask'
  forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) value
       (gradient' :: Gradient RequiresGradient)
       (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
       (dataType' :: DataType DType)
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(Scalar value, MonadThrow m,
 Catch (gradient <+> 'Gradient 'WithoutGradient),
 Catch (dataType <+> 'DataType 'Bool),
 shape'' ~ BroadcastShapesF shape shape', Catch shape'') =>
Tensor gradient layout device dataType shape
-> value
-> Tensor gradient' layout' device' dataType' shape'
-> m (Tensor
        gradient'
        (layout <+> (layout' <+> 'Layout 'Dense))
        (device <+> device')
        dataType'
        shape'')
maskedFill
    Tensor
  ('Gradient 'WithoutGradient)
  layout
  device
  ('DataType 'Bool)
  (BroadcastShapesF
     ('Shape '[ 'Dim ('Name "*") ('Size 1), seqDim, seqDim])
     (UnsqueezeF ('SelectDim ('ByIndex 1)) shape))
booleanMask
    Double
attentionMaskBias
    Tensor
  ('Gradient 'WithoutGradient)
  layout
  device
  transformerDataType
  ('Shape '[ 'Dim ('Name "*") ('Size 1), seqDim, seqDim])
emptyMask

data MkTransformerDecoderAttentionMask (dataType :: DataType DType) where
  MkTransformerDecoderAttentionMask ::
    forall dataType.
    { forall (dataType :: DataType DType).
MkTransformerDecoderAttentionMask dataType -> SDataType dataType
decoderAttentionMaskDataType :: SDataType dataType,
      forall (dataType :: DataType DType).
MkTransformerDecoderAttentionMask dataType -> Double
decoderAttentionMaskBias :: Double
    } ->
    MkTransformerDecoderAttentionMask dataType
  deriving stock (Int -> MkTransformerDecoderAttentionMask dataType -> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (dataType :: DataType DType).
Int -> MkTransformerDecoderAttentionMask dataType -> ShowS
forall (dataType :: DataType DType).
[MkTransformerDecoderAttentionMask dataType] -> ShowS
forall (dataType :: DataType DType).
MkTransformerDecoderAttentionMask dataType -> String
showList :: [MkTransformerDecoderAttentionMask dataType] -> ShowS
$cshowList :: forall (dataType :: DataType DType).
[MkTransformerDecoderAttentionMask dataType] -> ShowS
show :: MkTransformerDecoderAttentionMask dataType -> String
$cshow :: forall (dataType :: DataType DType).
MkTransformerDecoderAttentionMask dataType -> String
showsPrec :: Int -> MkTransformerDecoderAttentionMask dataType -> ShowS
$cshowsPrec :: forall (dataType :: DataType DType).
Int -> MkTransformerDecoderAttentionMask dataType -> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (dataType :: DataType DType) x.
Rep (MkTransformerDecoderAttentionMask dataType) x
-> MkTransformerDecoderAttentionMask dataType
forall (dataType :: DataType DType) x.
MkTransformerDecoderAttentionMask dataType
-> Rep (MkTransformerDecoderAttentionMask dataType) x
$cto :: forall (dataType :: DataType DType) x.
Rep (MkTransformerDecoderAttentionMask dataType) x
-> MkTransformerDecoderAttentionMask dataType
$cfrom :: forall (dataType :: DataType DType) x.
MkTransformerDecoderAttentionMask dataType
-> Rep (MkTransformerDecoderAttentionMask dataType) x
Generic)

type instance
  ModelSpec (MkTransformerDecoderAttentionMask dataType) =
    MkTransformerDecoderAttentionMask dataType

instance
  HasInitialize
    (MkTransformerDecoderAttentionMask dataType)
    generatorDevice
    (MkTransformerDecoderAttentionMask dataType)
    generatorDevice
  where
  initialize :: forall (m :: * -> *).
MonadThrow m =>
ModelSpec (MkTransformerDecoderAttentionMask dataType)
-> Generator generatorDevice
-> m (MkTransformerDecoderAttentionMask dataType,
      Generator generatorDevice)
initialize ModelSpec (MkTransformerDecoderAttentionMask dataType)
spec Generator generatorDevice
g = forall (f :: * -> *) a. Applicative f => a -> f a
pure (ModelSpec (MkTransformerDecoderAttentionMask dataType)
spec, Generator generatorDevice
g)

instance HasStateDict (MkTransformerDecoderAttentionMask dataType) where
  fromStateDict :: forall (m :: * -> *).
(MonadIO m, MonadThrow m, MonadState StateDict m) =>
ModelSpec (MkTransformerDecoderAttentionMask dataType)
-> StateDictKey -> m (MkTransformerDecoderAttentionMask dataType)
fromStateDict ModelSpec (MkTransformerDecoderAttentionMask dataType)
spec StateDictKey
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ModelSpec (MkTransformerDecoderAttentionMask dataType)
spec
  toStateDict :: forall (m :: * -> *).
(MonadThrow m, MonadState StateDict m) =>
StateDictKey -> MkTransformerDecoderAttentionMask dataType -> m ()
toStateDict StateDictKey
_ MkTransformerDecoderAttentionMask dataType
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

instance
  MkTransformerDecoderAttentionMaskC dataType decoderInputLayout decoderInputDevice decoderInputShape seqDim output =>
  HasForward
    (MkTransformerDecoderAttentionMask dataType)
    (Tensor decoderInputGradient decoderInputLayout decoderInputDevice decoderInputDataType decoderInputShape)
    generatorDevice
    output
    generatorDevice
  where
  forward :: forall (m :: * -> *).
MonadThrow m =>
MkTransformerDecoderAttentionMask dataType
-> Tensor
     decoderInputGradient
     decoderInputLayout
     decoderInputDevice
     decoderInputDataType
     decoderInputShape
-> Generator generatorDevice
-> m (output, Generator generatorDevice)
forward MkTransformerDecoderAttentionMask {Double
SDataType dataType
decoderAttentionMaskBias :: Double
decoderAttentionMaskDataType :: SDataType dataType
decoderAttentionMaskBias :: forall (dataType :: DataType DType).
MkTransformerDecoderAttentionMask dataType -> Double
decoderAttentionMaskDataType :: forall (dataType :: DataType DType).
MkTransformerDecoderAttentionMask dataType -> SDataType dataType
..} Tensor
  decoderInputGradient
  decoderInputLayout
  decoderInputDevice
  decoderInputDataType
  decoderInputShape
input Generator generatorDevice
g = do
    output
decoderAttentionMask <- forall (m :: * -> *) (transformerDataType :: DataType DType)
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (seqDim :: Dim (Name Symbol) (Size Nat)) output.
(MonadThrow m,
 MkTransformerDecoderAttentionMaskC
   transformerDataType layout device shape seqDim output) =>
SDataType transformerDataType
-> Double
-> Tensor gradient layout device dataType shape
-> m output
mkTransformerDecoderAttentionMask SDataType dataType
decoderAttentionMaskDataType Double
decoderAttentionMaskBias Tensor
  decoderInputGradient
  decoderInputLayout
  decoderInputDevice
  decoderInputDataType
  decoderInputShape
input
    forall (f :: * -> *) a. Applicative f => a -> f a
pure (output
decoderAttentionMask, Generator generatorDevice
g)

type MkTransformerCrossAttentionMaskC transformerDataType decoderInputShape decoderInputSeqDim gradient layout device dataType shape seqDim output =
  ( SGetLayout layout,
    SGetDevice device,
    SGetShape shape,
    seqDim ~ (shape ! 1),
    SGetShape decoderInputShape,
    decoderInputSeqDim ~ (decoderInputShape ! 1),
    Catch (gradient <+> 'Gradient 'WithoutGradient),
    Catch (dataType <+> 'DataType 'Bool),
    Catch (UnsqueezeF ('SelectDim ('ByIndex 1)) shape),
    Catch
      ( BroadcastShapesF
          (UnsqueezeF ('SelectDim ('ByIndex 1)) shape)
          ( 'Shape
              '[ 'Dim ('Name "*") ('Size 1), decoderInputSeqDim, seqDim]
          )
      ),
    output
      ~ Tensor
          ('Gradient 'WithoutGradient)
          (layout <+> 'Layout 'Dense)
          device
          transformerDataType
          ( BroadcastShapesF
              (UnsqueezeF ('SelectDim ('ByIndex 1)) shape)
              ('Shape '[ 'Dim ('Name "*") ('Size 1), decoderInputSeqDim, seqDim])
          )
  )

-- | Creates a cross-attention mask for an encoder-decoder transformer.
-- Given an encoder padding mask of shape @[batchDim, seqDim]@,
-- and the shape @[batchDim, decoderSeqDim]@ of the decoder's input,
-- returns a tensor of shape @[batchDim, decoderSeqDim, seqDim]@.
mkTransformerCrossAttentionMask ::
  forall m transformerDataType decoderInputShape decoderInputSeqDim gradient layout device dataType shape seqDim output.
  ( MonadThrow m,
    MkTransformerCrossAttentionMaskC transformerDataType decoderInputShape decoderInputSeqDim gradient layout device dataType shape seqDim output
  ) =>
  -- | data type singleton of the transformer
  SDataType transformerDataType ->
  -- | decoder input shape
  SShape decoderInputShape ->
  -- | attention mask bias (typically a large negative number)
  Double ->
  -- | encoder padding mask
  Tensor gradient layout device dataType shape ->
  m output
mkTransformerCrossAttentionMask :: forall (m :: * -> *) (transformerDataType :: DataType DType)
       (decoderInputShape :: Shape [Dim (Name Symbol) (Size Nat)])
       (decoderInputSeqDim :: Dim (Name Symbol) (Size Nat))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (seqDim :: Dim (Name Symbol) (Size Nat)) output.
(MonadThrow m,
 MkTransformerCrossAttentionMaskC
   transformerDataType
   decoderInputShape
   decoderInputSeqDim
   gradient
   layout
   device
   dataType
   shape
   seqDim
   output) =>
SDataType transformerDataType
-> SShape decoderInputShape
-> Double
-> Tensor gradient layout device dataType shape
-> m output
mkTransformerCrossAttentionMask SDataType transformerDataType
transformerDataType SShape decoderInputShape
decoderInputShape Double
attentionMaskBias Tensor gradient layout device dataType shape
paddingMask = do
  SDim decoderInputSeqDim
decoderInputSeqDim <- forall (selectDim :: SelectDim (By Symbol Nat))
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (dim :: Dim (Name Symbol) (Size Nat)) (m :: * -> *).
(dim ~ GetDimF selectDim shape, MonadThrow m) =>
SSelectDim selectDim -> SShape shape -> m (SDim dim)
sGetDimFromShape (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim forall a b. (a -> b) -> a -> b
$ forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @1) SShape decoderInputShape
decoderInputShape
  let pmLayout :: SLayout layout
pmLayout = forall (layout :: Layout LayoutType)
       (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGetLayout layout =>
Tensor gradient layout device dataType shape -> SLayout layout
sGetLayout Tensor gradient layout device dataType shape
paddingMask
      pmDevice :: SDevice device
pmDevice = forall (device :: Device (DeviceType Nat))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGetDevice device =>
Tensor gradient layout device dataType shape -> SDevice device
sGetDevice Tensor gradient layout device dataType shape
paddingMask
      pmShape :: SShape shape
pmShape = forall (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType).
SGetShape shape =>
Tensor gradient layout device dataType shape -> SShape shape
sGetShape Tensor gradient layout device dataType shape
paddingMask
  SDim seqDim
pmSeqDim <- forall (selectDim :: SelectDim (By Symbol Nat))
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (dim :: Dim (Name Symbol) (Size Nat)) (m :: * -> *).
(dim ~ GetDimF selectDim shape, MonadThrow m) =>
SSelectDim selectDim -> SShape shape -> m (SDim dim)
sGetDimFromShape (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim forall a b. (a -> b) -> a -> b
$ forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @1) SShape shape
pmShape
  Tensor
  ('Gradient 'WithoutGradient)
  layout
  device
  transformerDataType
  ('Shape '[ 'Dim ('Name "*") ('Size 1), decoderInputSeqDim, seqDim])
emptyMask <-
    forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
MonadThrow m =>
TensorSpec gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape)
sZeros forall a b. (a -> b) -> a -> b
$
      forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> SShape shape
-> TensorSpec gradient layout device dataType shape
TensorSpec
        (forall (requiresGradient :: RequiresGradient).
SRequiresGradient requiresGradient
-> SGradient ('Gradient requiresGradient)
SGradient SRequiresGradient 'WithoutGradient
SWithoutGradient)
        SLayout layout
pmLayout
        SDevice device
pmDevice
        SDataType transformerDataType
transformerDataType
        (forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SList dims -> SShape ('Shape dims)
SShape forall a b. (a -> b) -> a -> b
$ forall (name1 :: Symbol). KnownSymbol name1 => SName ('Name name1)
SName @"*" forall (name :: Name Symbol) (size :: Size Nat).
SName name -> SSize size -> SDim ('Dim name size)
:&: forall (size1 :: Nat). KnownNat size1 => SSize ('Size size1)
SSize @1 forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: SDim decoderInputSeqDim
decoderInputSeqDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: SDim seqDim
pmSeqDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall a. SList '[]
SNil)
  Tensor
  gradient
  layout
  device
  dataType
  (UnsqueezeF ('SelectDim ('ByIndex 1)) shape)
paddingMask' <- forall (selectDim :: SelectDim (By Symbol Nat))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(shape' ~ UnsqueezeF selectDim shape, Catch shape',
 SingI selectDim, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
unsqueeze @('SelectDim ('ByIndex 1)) Tensor gradient layout device dataType shape
paddingMask
  forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) value
       (gradient' :: Gradient RequiresGradient)
       (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
       (dataType' :: DataType DType)
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(Scalar value, MonadThrow m,
 Catch (gradient <+> 'Gradient 'WithoutGradient),
 Catch (dataType <+> 'DataType 'Bool),
 shape'' ~ BroadcastShapesF shape shape', Catch shape'') =>
Tensor gradient layout device dataType shape
-> value
-> Tensor gradient' layout' device' dataType' shape'
-> m (Tensor
        gradient'
        (layout <+> (layout' <+> 'Layout 'Dense))
        (device <+> device')
        dataType'
        shape'')
maskedFill Tensor
  gradient
  layout
  device
  dataType
  (UnsqueezeF ('SelectDim ('ByIndex 1)) shape)
paddingMask' Double
attentionMaskBias Tensor
  ('Gradient 'WithoutGradient)
  layout
  device
  transformerDataType
  ('Shape '[ 'Dim ('Name "*") ('Size 1), decoderInputSeqDim, seqDim])
emptyMask

data MkTransformerCrossAttentionMask (dataType :: DataType DType) where
  MkTransformerCrossAttentionMask ::
    forall dataType.
    { forall (dataType :: DataType DType).
MkTransformerCrossAttentionMask dataType -> SDataType dataType
crossAttentionMaskDataType :: SDataType dataType,
      forall (dataType :: DataType DType).
MkTransformerCrossAttentionMask dataType -> Double
crossAttentionMaskBias :: Double
    } ->
    MkTransformerCrossAttentionMask dataType
  deriving stock (Int -> MkTransformerCrossAttentionMask dataType -> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (dataType :: DataType DType).
Int -> MkTransformerCrossAttentionMask dataType -> ShowS
forall (dataType :: DataType DType).
[MkTransformerCrossAttentionMask dataType] -> ShowS
forall (dataType :: DataType DType).
MkTransformerCrossAttentionMask dataType -> String
showList :: [MkTransformerCrossAttentionMask dataType] -> ShowS
$cshowList :: forall (dataType :: DataType DType).
[MkTransformerCrossAttentionMask dataType] -> ShowS
show :: MkTransformerCrossAttentionMask dataType -> String
$cshow :: forall (dataType :: DataType DType).
MkTransformerCrossAttentionMask dataType -> String
showsPrec :: Int -> MkTransformerCrossAttentionMask dataType -> ShowS
$cshowsPrec :: forall (dataType :: DataType DType).
Int -> MkTransformerCrossAttentionMask dataType -> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (dataType :: DataType DType) x.
Rep (MkTransformerCrossAttentionMask dataType) x
-> MkTransformerCrossAttentionMask dataType
forall (dataType :: DataType DType) x.
MkTransformerCrossAttentionMask dataType
-> Rep (MkTransformerCrossAttentionMask dataType) x
$cto :: forall (dataType :: DataType DType) x.
Rep (MkTransformerCrossAttentionMask dataType) x
-> MkTransformerCrossAttentionMask dataType
$cfrom :: forall (dataType :: DataType DType) x.
MkTransformerCrossAttentionMask dataType
-> Rep (MkTransformerCrossAttentionMask dataType) x
Generic)

type instance
  ModelSpec (MkTransformerCrossAttentionMask dataType) =
    MkTransformerCrossAttentionMask dataType

instance
  HasInitialize
    (MkTransformerCrossAttentionMask dataType)
    generatorDevice
    (MkTransformerCrossAttentionMask dataType)
    generatorDevice
  where
  initialize :: forall (m :: * -> *).
MonadThrow m =>
ModelSpec (MkTransformerCrossAttentionMask dataType)
-> Generator generatorDevice
-> m (MkTransformerCrossAttentionMask dataType,
      Generator generatorDevice)
initialize ModelSpec (MkTransformerCrossAttentionMask dataType)
spec Generator generatorDevice
g = forall (f :: * -> *) a. Applicative f => a -> f a
pure (ModelSpec (MkTransformerCrossAttentionMask dataType)
spec, Generator generatorDevice
g)

instance HasStateDict (MkTransformerCrossAttentionMask dataType) where
  fromStateDict :: forall (m :: * -> *).
(MonadIO m, MonadThrow m, MonadState StateDict m) =>
ModelSpec (MkTransformerCrossAttentionMask dataType)
-> StateDictKey -> m (MkTransformerCrossAttentionMask dataType)
fromStateDict ModelSpec (MkTransformerCrossAttentionMask dataType)
spec StateDictKey
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ModelSpec (MkTransformerCrossAttentionMask dataType)
spec
  toStateDict :: forall (m :: * -> *).
(MonadThrow m, MonadState StateDict m) =>
StateDictKey -> MkTransformerCrossAttentionMask dataType -> m ()
toStateDict StateDictKey
_ MkTransformerCrossAttentionMask dataType
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

instance
  MkTransformerCrossAttentionMaskC dataType decoderInputShape decoderInputSeqDim inputPaddingMaskGradient inputPaddingMaskLayout inputPaddingMaskDevice inputPaddingMaksDataType inputPaddingMaskShape seqDim output =>
  HasForward
    (MkTransformerCrossAttentionMask dataType)
    ( Tensor decoderInputGradient decoderInputLayout decoderInputDevice decoderInputDataType decoderInputShape,
      Tensor inputPaddingMaskGradient inputPaddingMaskLayout inputPaddingMaskDevice inputPaddingMaksDataType inputPaddingMaskShape
    )
    generatorDevice
    output
    generatorDevice
  where
  forward :: forall (m :: * -> *).
MonadThrow m =>
MkTransformerCrossAttentionMask dataType
-> (Tensor
      decoderInputGradient
      decoderInputLayout
      decoderInputDevice
      decoderInputDataType
      decoderInputShape,
    Tensor
      inputPaddingMaskGradient
      inputPaddingMaskLayout
      inputPaddingMaskDevice
      inputPaddingMaksDataType
      inputPaddingMaskShape)
-> Generator generatorDevice
-> m (output, Generator generatorDevice)
forward MkTransformerCrossAttentionMask {Double
SDataType dataType
crossAttentionMaskBias :: Double
crossAttentionMaskDataType :: SDataType dataType
crossAttentionMaskBias :: forall (dataType :: DataType DType).
MkTransformerCrossAttentionMask dataType -> Double
crossAttentionMaskDataType :: forall (dataType :: DataType DType).
MkTransformerCrossAttentionMask dataType -> SDataType dataType
..} (Tensor
  decoderInputGradient
  decoderInputLayout
  decoderInputDevice
  decoderInputDataType
  decoderInputShape
decoderInput, Tensor
  inputPaddingMaskGradient
  inputPaddingMaskLayout
  inputPaddingMaskDevice
  inputPaddingMaksDataType
  inputPaddingMaskShape
inputPaddingMask) Generator generatorDevice
g = do
    let decoderInputShape :: SShape decoderInputShape
decoderInputShape = forall (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType).
SGetShape shape =>
Tensor gradient layout device dataType shape -> SShape shape
sGetShape Tensor
  decoderInputGradient
  decoderInputLayout
  decoderInputDevice
  decoderInputDataType
  decoderInputShape
decoderInput
    output
crossAttentionMask <- forall (m :: * -> *) (transformerDataType :: DataType DType)
       (decoderInputShape :: Shape [Dim (Name Symbol) (Size Nat)])
       (decoderInputSeqDim :: Dim (Name Symbol) (Size Nat))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (seqDim :: Dim (Name Symbol) (Size Nat)) output.
(MonadThrow m,
 MkTransformerCrossAttentionMaskC
   transformerDataType
   decoderInputShape
   decoderInputSeqDim
   gradient
   layout
   device
   dataType
   shape
   seqDim
   output) =>
SDataType transformerDataType
-> SShape decoderInputShape
-> Double
-> Tensor gradient layout device dataType shape
-> m output
mkTransformerCrossAttentionMask SDataType dataType
crossAttentionMaskDataType SShape decoderInputShape
decoderInputShape Double
crossAttentionMaskBias Tensor
  inputPaddingMaskGradient
  inputPaddingMaskLayout
  inputPaddingMaskDevice
  inputPaddingMaksDataType
  inputPaddingMaskShape
inputPaddingMask
    forall (f :: * -> *) a. Applicative f => a -> f a
pure (output
crossAttentionMask, Generator generatorDevice
g)

data ShiftRight fillValue where
  ShiftRight ::
    forall fillValue.
    -- | fill value for shift right
    fillValue ->
    ShiftRight fillValue
  deriving stock (ShiftRight fillValue -> ShiftRight fillValue -> Bool
forall fillValue.
Eq fillValue =>
ShiftRight fillValue -> ShiftRight fillValue -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ShiftRight fillValue -> ShiftRight fillValue -> Bool
$c/= :: forall fillValue.
Eq fillValue =>
ShiftRight fillValue -> ShiftRight fillValue -> Bool
== :: ShiftRight fillValue -> ShiftRight fillValue -> Bool
$c== :: forall fillValue.
Eq fillValue =>
ShiftRight fillValue -> ShiftRight fillValue -> Bool
Eq, ShiftRight fillValue -> ShiftRight fillValue -> Bool
ShiftRight fillValue -> ShiftRight fillValue -> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {fillValue}. Ord fillValue => Eq (ShiftRight fillValue)
forall fillValue.
Ord fillValue =>
ShiftRight fillValue -> ShiftRight fillValue -> Bool
forall fillValue.
Ord fillValue =>
ShiftRight fillValue -> ShiftRight fillValue -> Ordering
forall fillValue.
Ord fillValue =>
ShiftRight fillValue
-> ShiftRight fillValue -> ShiftRight fillValue
min :: ShiftRight fillValue
-> ShiftRight fillValue -> ShiftRight fillValue
$cmin :: forall fillValue.
Ord fillValue =>
ShiftRight fillValue
-> ShiftRight fillValue -> ShiftRight fillValue
max :: ShiftRight fillValue
-> ShiftRight fillValue -> ShiftRight fillValue
$cmax :: forall fillValue.
Ord fillValue =>
ShiftRight fillValue
-> ShiftRight fillValue -> ShiftRight fillValue
>= :: ShiftRight fillValue -> ShiftRight fillValue -> Bool
$c>= :: forall fillValue.
Ord fillValue =>
ShiftRight fillValue -> ShiftRight fillValue -> Bool
> :: ShiftRight fillValue -> ShiftRight fillValue -> Bool
$c> :: forall fillValue.
Ord fillValue =>
ShiftRight fillValue -> ShiftRight fillValue -> Bool
<= :: ShiftRight fillValue -> ShiftRight fillValue -> Bool
$c<= :: forall fillValue.
Ord fillValue =>
ShiftRight fillValue -> ShiftRight fillValue -> Bool
< :: ShiftRight fillValue -> ShiftRight fillValue -> Bool
$c< :: forall fillValue.
Ord fillValue =>
ShiftRight fillValue -> ShiftRight fillValue -> Bool
compare :: ShiftRight fillValue -> ShiftRight fillValue -> Ordering
$ccompare :: forall fillValue.
Ord fillValue =>
ShiftRight fillValue -> ShiftRight fillValue -> Ordering
Ord, Int -> ShiftRight fillValue -> ShowS
forall fillValue.
Show fillValue =>
Int -> ShiftRight fillValue -> ShowS
forall fillValue. Show fillValue => [ShiftRight fillValue] -> ShowS
forall fillValue. Show fillValue => ShiftRight fillValue -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ShiftRight fillValue] -> ShowS
$cshowList :: forall fillValue. Show fillValue => [ShiftRight fillValue] -> ShowS
show :: ShiftRight fillValue -> String
$cshow :: forall fillValue. Show fillValue => ShiftRight fillValue -> String
showsPrec :: Int -> ShiftRight fillValue -> ShowS
$cshowsPrec :: forall fillValue.
Show fillValue =>
Int -> ShiftRight fillValue -> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall fillValue x.
Rep (ShiftRight fillValue) x -> ShiftRight fillValue
forall fillValue x.
ShiftRight fillValue -> Rep (ShiftRight fillValue) x
$cto :: forall fillValue x.
Rep (ShiftRight fillValue) x -> ShiftRight fillValue
$cfrom :: forall fillValue x.
ShiftRight fillValue -> Rep (ShiftRight fillValue) x
Generic)

type instance ModelSpec (ShiftRight fillValue) = ShiftRight fillValue

instance
  HasInitialize
    (ShiftRight fillValue)
    generatorDevice
    (ShiftRight fillValue)
    generatorDevice
  where
  initialize :: forall (m :: * -> *).
MonadThrow m =>
ModelSpec (ShiftRight fillValue)
-> Generator generatorDevice
-> m (ShiftRight fillValue, Generator generatorDevice)
initialize ModelSpec (ShiftRight fillValue)
spec = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ModelSpec (ShiftRight fillValue)
spec,)

instance HasStateDict (ShiftRight fillValue) where
  fromStateDict :: forall (m :: * -> *).
(MonadIO m, MonadThrow m, MonadState StateDict m) =>
ModelSpec (ShiftRight fillValue)
-> StateDictKey -> m (ShiftRight fillValue)
fromStateDict ModelSpec (ShiftRight fillValue)
spec StateDictKey
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ModelSpec (ShiftRight fillValue)
spec
  toStateDict :: forall (m :: * -> *).
(MonadThrow m, MonadState StateDict m) =>
StateDictKey -> ShiftRight fillValue -> m ()
toStateDict StateDictKey
_ ShiftRight fillValue
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

instance
  ( input
      ~ Tensor
          inputGradient
          inputLayout
          inputDevice
          inputDataType
          inputShape,
    SGetLayout inputLayout,
    SGetDevice inputDevice,
    SGetDataType inputDataType,
    SGetShape inputShape,
    inputBatchDim ~ (inputShape ! 0),
    inputSeqDim ~ (inputShape ! 1),
    Scalar fillValue,
    rightShiftedInput
      ~ Tensor
          (inputGradient <|> 'Gradient 'WithoutGradient)
          inputLayout
          inputDevice
          inputDataType
          ( ReplaceDimF
              ('SelectDim ('ByIndex 1))
              (inputShape <+> 'Shape '[inputBatchDim, inputSeqDim])
              (AddDimF inputSeqDim ('Dim ('Name "*") ('Size 1)))
          )
  ) =>
  HasForward (ShiftRight fillValue) input generator rightShiftedInput generator
  where
  forward :: forall (m :: * -> *).
MonadThrow m =>
ShiftRight fillValue
-> input
-> Generator generator
-> m (rightShiftedInput, Generator generator)
forward (ShiftRight fillValue
fillValue) input
input Generator generator
g = do
    let inputLayout :: SLayout inputLayout
inputLayout = forall (layout :: Layout LayoutType)
       (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGetLayout layout =>
Tensor gradient layout device dataType shape -> SLayout layout
sGetLayout input
input
        inputDevice :: SDevice inputDevice
inputDevice = forall (device :: Device (DeviceType Nat))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGetDevice device =>
Tensor gradient layout device dataType shape -> SDevice device
sGetDevice input
input
        inputDataType :: SDataType inputDataType
inputDataType = forall (dataType :: DataType DType)
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGetDataType dataType =>
Tensor gradient layout device dataType shape -> SDataType dataType
sGetDataType input
input
        inputShape :: SShape inputShape
inputShape = forall (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType).
SGetShape shape =>
Tensor gradient layout device dataType shape -> SShape shape
sGetShape input
input
    SDim inputBatchDim
inputBatchDim <- forall (selectDim :: SelectDim (By Symbol Nat))
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (dim :: Dim (Name Symbol) (Size Nat)) (m :: * -> *).
(dim ~ GetDimF selectDim shape, MonadThrow m) =>
SSelectDim selectDim -> SShape shape -> m (SDim dim)
sGetDimFromShape (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim forall a b. (a -> b) -> a -> b
$ forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @0) SShape inputShape
inputShape
    Tensor
  ('Gradient 'WithoutGradient)
  inputLayout
  inputDevice
  inputDataType
  ('Shape '[inputBatchDim, 'Dim ('Name "*") ('Size 1)])
filler <-
      forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) input
       (m :: * -> *).
(MonadThrow m, Scalar input) =>
TensorSpec gradient layout device dataType shape
-> input -> m (Tensor gradient layout device dataType shape)
sFull
        ( forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> SShape shape
-> TensorSpec gradient layout device dataType shape
TensorSpec
            (forall (requiresGradient :: RequiresGradient).
SRequiresGradient requiresGradient
-> SGradient ('Gradient requiresGradient)
SGradient SRequiresGradient 'WithoutGradient
SWithoutGradient)
            SLayout inputLayout
inputLayout
            SDevice inputDevice
inputDevice
            SDataType inputDataType
inputDataType
            (forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SList dims -> SShape ('Shape dims)
SShape forall a b. (a -> b) -> a -> b
$ SDim inputBatchDim
inputBatchDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall (name1 :: Symbol). KnownSymbol name1 => SName ('Name name1)
SName @"*" forall (name :: Name Symbol) (size :: Size Nat).
SName name -> SSize size -> SDim ('Dim name size)
:&: forall (size1 :: Nat). KnownNat size1 => SSize ('Size size1)
SSize @1 forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall a. SList '[]
SNil)
        )
        fillValue
fillValue
    rightShiftedInput
shifted <- forall (selectDim :: SelectDim (By Symbol Nat)) k (c :: k -> *)
       (a :: k) (m :: * -> *).
(HasCat selectDim k c a, SingI selectDim, MonadThrow m) =>
c a -> m (CatF selectDim a c)
cat @('SelectDim ('ByIndex 1)) (Tensor
  ('Gradient 'WithoutGradient)
  inputLayout
  inputDevice
  inputDataType
  ('Shape '[inputBatchDim, 'Dim ('Name "*") ('Size 1)])
filler forall x (xs :: [*]). x -> HList xs -> HList (x : xs)
:. input
input forall x (xs :: [*]). x -> HList xs -> HList (x : xs)
:. forall k. HList '[]
HNil)
    forall (f :: * -> *) a. Applicative f => a -> f a
pure (rightShiftedInput
shifted, Generator generator
g)