{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fplugin TypeLevel.Rewrite
-fplugin-opt=TypeLevel.Rewrite:Torch.GraduallyTyped.Unify.UnifyRightAssociativeL
-fplugin-opt=TypeLevel.Rewrite:Torch.GraduallyTyped.Unify.UnifyIdempotenceL2 #-}
module Torch.GraduallyTyped.NN.Transformer.Type where
import Control.Monad.Catch (MonadThrow)
import Data.Singletons.TH (SingKind (fromSing), genSingletons)
import GHC.Float (double2Int)
import GHC.Generics (Generic)
import GHC.TypeLits (Nat, Symbol)
import Torch.GraduallyTyped.DType (DType (..), DataType (..), SDType (..), SDataType (..))
import Torch.GraduallyTyped.Device (SDevice (..))
import Torch.GraduallyTyped.Layout (Layout (..), LayoutType (..), SLayout (..), SLayoutType (..))
import Torch.GraduallyTyped.NN.Class (HasForward (..), HasInitialize (..), HasStateDict (..), ModelSpec)
import Torch.GraduallyTyped.Prelude (Catch, forgetIsChecked, pattern (:|:))
import Torch.GraduallyTyped.Prelude.List (SList (SNil))
import Torch.GraduallyTyped.RequiresGradient (Gradient (..), RequiresGradient (..), SGradient (..), SRequiresGradient (..))
import Torch.GraduallyTyped.Scalar (Scalar)
import Torch.GraduallyTyped.Shape.Class (AddDimF, BroadcastShapesF, ReplaceDimF, sGetDimFromShape, type (!))
import Torch.GraduallyTyped.Shape.Type (By (..), Dim (..), Name (..), SBy (..), SDim (sDimSize), SName (..), SSelectDim (..), SShape (..), SSize (..), SelectDim (..), Shape (..), Size (..), pattern (:&:))
import Torch.GraduallyTyped.Tensor.Creation (sArangeNaturals, sFull, sOnes, sZeros)
import Torch.GraduallyTyped.Tensor.IndexingSlicingJoining (UnsqueezeF, cat, unsqueeze)
import Torch.GraduallyTyped.Tensor.MathOperations.Comparison ((==.))
import Torch.GraduallyTyped.Tensor.MathOperations.Pointwise (addScalar, logicalOr)
import Torch.GraduallyTyped.Tensor.Other (maskedFill, triu)
import Torch.GraduallyTyped.Tensor.Type (SGetDataType (sGetDataType), SGetDevice (..), SGetDim, SGetLayout (..), SGetShape (..), Tensor (..), TensorLike (sToTensor), TensorSpec (..), bool, sCheckedShape)
import Torch.GraduallyTyped.Unify (type (<+>), type (<|>))
import Torch.HList
data TransformerStyle
=
T5
|
ByT5
|
BART
|
MBART
|
Pegasus
|
BERT
|
RoBERTa
|
GPT2
deriving (Int -> TransformerStyle -> ShowS
[TransformerStyle] -> ShowS
TransformerStyle -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TransformerStyle] -> ShowS
$cshowList :: [TransformerStyle] -> ShowS
show :: TransformerStyle -> String
$cshow :: TransformerStyle -> String
showsPrec :: Int -> TransformerStyle -> ShowS
$cshowsPrec :: Int -> TransformerStyle -> ShowS
Show, TransformerStyle -> TransformerStyle -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: TransformerStyle -> TransformerStyle -> Bool
$c/= :: TransformerStyle -> TransformerStyle -> Bool
== :: TransformerStyle -> TransformerStyle -> Bool
$c== :: TransformerStyle -> TransformerStyle -> Bool
Eq)
genSingletons [''TransformerStyle]
data TransformerHead = WithoutHead | WithLMHead
genSingletons [''TransformerHead]
padded :: Integral n => n -> a -> [a] -> [a]
padded :: forall n a. Integral n => n -> a -> [a] -> [a]
padded n
n a
p [a]
xs =
let n' :: Int
n' = forall a b. (Integral a, Num b) => a -> b
fromIntegral n
n
diff :: Int
diff = Int
n' forall a. Num a => a -> a -> a
- forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
xs
in forall a. Int -> [a] -> [a]
take Int
n' [a]
xs forall a. [a] -> [a] -> [a]
++ forall a. Int -> a -> [a]
replicate Int
diff a
p
mkTransformerInput ::
forall batchDim seqDim device m output.
( MonadThrow m,
SGetDim batchDim,
SGetDim seqDim,
Catch
( 'Shape
'[ 'Dim ('Name "*") 'UncheckedSize,
'Dim ('Name "*") 'UncheckedSize
]
<+> 'Shape '[batchDim, seqDim]
),
output
~ Tensor
('Gradient 'WithoutGradient)
('Layout 'Dense)
device
('DataType 'Int64)
('Shape '[batchDim, seqDim])
) =>
Int ->
SDim batchDim ->
SDim seqDim ->
SDevice device ->
[[Int]] ->
m output
mkTransformerInput :: forall (batchDim :: Dim (Name Symbol) (Size Nat))
(seqDim :: Dim (Name Symbol) (Size Nat))
(device :: Device (DeviceType Nat)) (m :: * -> *) output.
(MonadThrow m, SGetDim batchDim, SGetDim seqDim,
Catch
('Shape
'[ 'Dim ('Name "*") 'UncheckedSize,
'Dim ('Name "*") 'UncheckedSize]
<+> 'Shape '[batchDim, seqDim]),
output
~ Tensor
('Gradient 'WithoutGradient)
('Layout 'Dense)
device
('DataType 'Int64)
('Shape '[batchDim, seqDim])) =>
Int
-> SDim batchDim
-> SDim seqDim
-> SDevice device
-> [[Int]]
-> m output
mkTransformerInput Int
padTokenId SDim batchDim
batchDim SDim seqDim
seqDim SDevice device
device [[Int]]
xs =
forall a (dType :: DType) (dims :: [Dim (Name Symbol) (Size Nat)])
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(m :: * -> *).
(TensorLike a dType dims, MonadThrow m) =>
SGradient gradient
-> SLayout layout
-> SDevice device
-> a
-> m (Tensor
gradient layout device ('DataType dType) ('Shape dims))
sToTensor SGradient ('Gradient 'WithoutGradient)
gradient SLayout ('Layout 'Dense)
layout SDevice device
device [[Int]]
paddedXs
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
(m :: * -> *) (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(SGetShape shape, MonadThrow m, Catch (shape <+> shape')) =>
SShape shape'
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
sCheckedShape (forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SList dims -> SShape ('Shape dims)
SShape forall a b. (a -> b) -> a -> b
$ SDim batchDim
batchDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: SDim seqDim
seqDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall a. SList '[]
SNil)
where
gradient :: SGradient ('Gradient 'WithoutGradient)
gradient = forall (requiresGradient :: RequiresGradient).
SRequiresGradient requiresGradient
-> SGradient ('Gradient requiresGradient)
SGradient SRequiresGradient 'WithoutGradient
SWithoutGradient
layout :: SLayout ('Layout 'Dense)
layout = forall (layoutType :: LayoutType).
SLayoutType layoutType -> SLayout ('Layout layoutType)
SLayout SLayoutType 'Dense
SDense
batchSize :: Integer
batchSize = forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall name size. Dim name size -> size
dimSize forall a b. (a -> b) -> a -> b
$ forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SDim batchDim
batchDim
seqSize :: Integer
seqSize = forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall name size. Dim name size -> size
dimSize forall a b. (a -> b) -> a -> b
$ forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SDim seqDim
seqDim
emptySeq :: [Int]
emptySeq = forall a. Int -> a -> [a]
replicate (forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
seqSize) Int
padTokenId
paddedXs :: [[Int]]
paddedXs = forall n a. Integral n => n -> a -> [a] -> [a]
padded Integer
batchSize [Int]
emptySeq (forall n a. Integral n => n -> a -> [a] -> [a]
padded Integer
seqSize Int
padTokenId forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [[Int]]
xs)
type MkPosC device shape seqDim seqName seqSize output =
( SGetDevice device,
SGetShape shape,
seqDim ~ (shape ! 1),
seqDim ~ 'Dim seqName seqSize,
output
~ Tensor
('Gradient 'WithoutGradient)
('Layout 'Dense)
device
('DataType 'Int64)
('Shape '[ 'Dim ('Name "*") seqSize])
)
mkPos ::
forall m gradient layout device dataType shape seqDim seqName seqSize output.
( MonadThrow m,
MkPosC device shape seqDim seqName seqSize output
) =>
Tensor gradient layout device dataType shape ->
m output
mkPos :: forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(seqDim :: Dim (Name Symbol) (Size Nat)) (seqName :: Name Symbol)
(seqSize :: Size Nat) output.
(MonadThrow m,
MkPosC device shape seqDim seqName seqSize output) =>
Tensor gradient layout device dataType shape -> m output
mkPos Tensor gradient layout device dataType shape
input = do
let device :: SDevice device
device = forall (device :: Device (DeviceType Nat))
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGetDevice device =>
Tensor gradient layout device dataType shape -> SDevice device
sGetDevice Tensor gradient layout device dataType shape
input
shape :: SShape shape
shape = forall (shape :: Shape [Dim (Name Symbol) (Size Nat)])
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType).
SGetShape shape =>
Tensor gradient layout device dataType shape -> SShape shape
sGetShape Tensor gradient layout device dataType shape
input
SDim ('Dim seqName seqSize)
seqDim <- forall (selectDim :: SelectDim (By Symbol Nat))
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(dim :: Dim (Name Symbol) (Size Nat)) (m :: * -> *).
(dim ~ GetDimF selectDim shape, MonadThrow m) =>
SSelectDim selectDim -> SShape shape -> m (SDim dim)
sGetDimFromShape (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim forall a b. (a -> b) -> a -> b
$ forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @1) SShape shape
shape
let seqSize :: SSize seqSize
seqSize = forall (name :: Name Symbol) (size :: Size Nat).
SDim ('Dim name size) -> SSize size
sDimSize SDim ('Dim seqName seqSize)
seqDim
forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType) (size :: Size Nat)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(MonadThrow m, shape ~ 'Shape '[ 'Dim ('Name "*") size]) =>
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> SSize size
-> m (Tensor gradient layout device dataType shape)
sArangeNaturals
(forall (requiresGradient :: RequiresGradient).
SRequiresGradient requiresGradient
-> SGradient ('Gradient requiresGradient)
SGradient SRequiresGradient 'WithoutGradient
SWithoutGradient)
(forall (layoutType :: LayoutType).
SLayoutType layoutType -> SLayout ('Layout layoutType)
SLayout SLayoutType 'Dense
SDense)
SDevice device
device
(forall (dType :: DType).
SDType dType -> SDataType ('DataType dType)
SDataType SDType 'Int64
SInt64)
SSize seqSize
seqSize
data MkAbsPos = MkAbsPos | MkAbsPosWithOffset {MkAbsPos -> Int
absPosOffset :: Int}
deriving stock (MkAbsPos -> MkAbsPos -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: MkAbsPos -> MkAbsPos -> Bool
$c/= :: MkAbsPos -> MkAbsPos -> Bool
== :: MkAbsPos -> MkAbsPos -> Bool
$c== :: MkAbsPos -> MkAbsPos -> Bool
Eq, Eq MkAbsPos
MkAbsPos -> MkAbsPos -> Bool
MkAbsPos -> MkAbsPos -> Ordering
MkAbsPos -> MkAbsPos -> MkAbsPos
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: MkAbsPos -> MkAbsPos -> MkAbsPos
$cmin :: MkAbsPos -> MkAbsPos -> MkAbsPos
max :: MkAbsPos -> MkAbsPos -> MkAbsPos
$cmax :: MkAbsPos -> MkAbsPos -> MkAbsPos
>= :: MkAbsPos -> MkAbsPos -> Bool
$c>= :: MkAbsPos -> MkAbsPos -> Bool
> :: MkAbsPos -> MkAbsPos -> Bool
$c> :: MkAbsPos -> MkAbsPos -> Bool
<= :: MkAbsPos -> MkAbsPos -> Bool
$c<= :: MkAbsPos -> MkAbsPos -> Bool
< :: MkAbsPos -> MkAbsPos -> Bool
$c< :: MkAbsPos -> MkAbsPos -> Bool
compare :: MkAbsPos -> MkAbsPos -> Ordering
$ccompare :: MkAbsPos -> MkAbsPos -> Ordering
Ord, Int -> MkAbsPos -> ShowS
[MkAbsPos] -> ShowS
MkAbsPos -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [MkAbsPos] -> ShowS
$cshowList :: [MkAbsPos] -> ShowS
show :: MkAbsPos -> String
$cshow :: MkAbsPos -> String
showsPrec :: Int -> MkAbsPos -> ShowS
$cshowsPrec :: Int -> MkAbsPos -> ShowS
Show, forall x. Rep MkAbsPos x -> MkAbsPos
forall x. MkAbsPos -> Rep MkAbsPos x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep MkAbsPos x -> MkAbsPos
$cfrom :: forall x. MkAbsPos -> Rep MkAbsPos x
Generic)
type instance ModelSpec MkAbsPos = MkAbsPos
instance HasInitialize MkAbsPos generatorDevice MkAbsPos generatorDevice where
initialize :: forall (m :: * -> *).
MonadThrow m =>
ModelSpec MkAbsPos
-> Generator generatorDevice
-> m (MkAbsPos, Generator generatorDevice)
initialize ModelSpec MkAbsPos
spec Generator generatorDevice
g = forall (f :: * -> *) a. Applicative f => a -> f a
pure (ModelSpec MkAbsPos
spec, Generator generatorDevice
g)
instance HasStateDict MkAbsPos where
fromStateDict :: forall (m :: * -> *).
(MonadIO m, MonadThrow m, MonadState StateDict m) =>
ModelSpec MkAbsPos -> StateDictKey -> m MkAbsPos
fromStateDict ModelSpec MkAbsPos
spec StateDictKey
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ModelSpec MkAbsPos
spec
toStateDict :: forall (m :: * -> *).
(MonadThrow m, MonadState StateDict m) =>
StateDictKey -> MkAbsPos -> m ()
toStateDict StateDictKey
_ MkAbsPos
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
instance
MkPosC device shape seqDim seqName seqSize output =>
HasForward
MkAbsPos
(Tensor gradient layout device dataType shape)
generatorDevice
(Tensor ('Gradient 'WithoutGradient) ('Layout 'Dense) device ('DataType 'Int64) ('Shape '[ 'Dim ('Name "*") seqSize]))
generatorDevice
where
forward :: forall (m :: * -> *).
MonadThrow m =>
MkAbsPos
-> Tensor gradient layout device dataType shape
-> Generator generatorDevice
-> m (Tensor
('Gradient 'WithoutGradient)
('Layout 'Dense)
device
('DataType 'Int64)
('Shape '[ 'Dim ('Name "*") seqSize]),
Generator generatorDevice)
forward MkAbsPos
MkAbsPos Tensor gradient layout device dataType shape
input Generator generatorDevice
g = do
Tensor
('Gradient 'WithoutGradient)
('Layout 'Dense)
device
('DataType 'Int64)
('Shape '[ 'Dim ('Name "*") seqSize])
pos <- forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(seqDim :: Dim (Name Symbol) (Size Nat)) (seqName :: Name Symbol)
(seqSize :: Size Nat) output.
(MonadThrow m,
MkPosC device shape seqDim seqName seqSize output) =>
Tensor gradient layout device dataType shape -> m output
mkPos Tensor gradient layout device dataType shape
input
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Tensor
('Gradient 'WithoutGradient)
('Layout 'Dense)
device
('DataType 'Int64)
('Shape '[ 'Dim ('Name "*") seqSize])
pos, Generator generatorDevice
g)
forward MkAbsPosWithOffset {Int
absPosOffset :: Int
absPosOffset :: MkAbsPos -> Int
..} Tensor gradient layout device dataType shape
input Generator generatorDevice
g = do
Tensor
('Gradient 'WithoutGradient)
('Layout 'Dense)
device
('DataType 'Int64)
('Shape '[ 'Dim ('Name "*") seqSize])
pos <- forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(seqDim :: Dim (Name Symbol) (Size Nat)) (seqName :: Name Symbol)
(seqSize :: Size Nat) output.
(MonadThrow m,
MkPosC device shape seqDim seqName seqSize output) =>
Tensor gradient layout device dataType shape -> m output
mkPos Tensor gradient layout device dataType shape
input
Tensor
('Gradient 'WithoutGradient)
('Layout 'Dense)
device
('DataType 'Int64)
('Shape '[ 'Dim ('Name "*") seqSize])
pos' <- Tensor
('Gradient 'WithoutGradient)
('Layout 'Dense)
device
('DataType 'Int64)
('Shape '[ 'Dim ('Name "*") seqSize])
pos forall other (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(Scalar other, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> other -> m (Tensor gradient layout device dataType shape)
`addScalar` Int
absPosOffset
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Tensor
('Gradient 'WithoutGradient)
('Layout 'Dense)
device
('DataType 'Int64)
('Shape '[ 'Dim ('Name "*") seqSize])
pos', Generator generatorDevice
g)
mkRelPos' :: Int -> Int -> Int -> Int -> [[Int]]
mkRelPos' :: Int -> Int -> Int -> Int -> [[Int]]
mkRelPos' Int
numBuckets Int
maxDistance Int
querySize Int
keySize =
let queryPos :: [Int]
queryPos = [Int
0, Int
1 .. Int
querySize forall a. Num a => a -> a -> a
- Int
1]
keyPos :: [Int]
keyPos = [Int
0, Int
1 .. Int
keySize forall a. Num a => a -> a -> a
- Int
1]
numBuckets' :: Int
numBuckets' = Int
numBuckets forall a. Integral a => a -> a -> a
`div` Int
2
maxExact :: Int
maxExact = Int
numBuckets' forall a. Integral a => a -> a -> a
`div` Int
2
in forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap
( \Int
qp ->
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap
( \Int
kp ->
let rawRelPos :: Int
rawRelPos = Int
kp forall a. Num a => a -> a -> a
- Int
qp
absRelPos :: Int
absRelPos = forall a. Num a => a -> a
abs Int
rawRelPos
relBucket :: Int
relBucket = if Int
rawRelPos forall a. Ord a => a -> a -> Bool
> Int
0 then Int
numBuckets' else Int
0
relBucket' :: Int
relBucket' =
let isSmall :: Bool
isSmall = Int
absRelPos forall a. Ord a => a -> a -> Bool
< Int
maxExact
relPosIfLarge :: Int
relPosIfLarge =
Int
maxExact
forall a. Num a => a -> a -> a
+ Double -> Int
double2Int
( forall a. Floating a => a -> a -> a
logBase
(forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
maxDistance forall a. Fractional a => a -> a -> a
/ forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
maxExact)
(forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
absRelPos forall a. Fractional a => a -> a -> a
/ forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
maxExact)
forall a. Num a => a -> a -> a
* forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
numBuckets' forall a. Num a => a -> a -> a
- Int
maxExact)
)
relPosIfLarge' :: Int
relPosIfLarge' = forall a. Ord a => a -> a -> a
min Int
relPosIfLarge (Int
numBuckets' forall a. Num a => a -> a -> a
- Int
1)
in if Bool
isSmall then Int
absRelPos else Int
relPosIfLarge'
in Int
relBucket forall a. Num a => a -> a -> a
+ Int
relBucket'
)
[Int]
keyPos
)
[Int]
queryPos
type MkRelPosC device shape seqDim seqName seqSize output =
( SGetDevice device,
SGetShape shape,
seqDim ~ (shape ! 1),
seqDim ~ 'Dim seqName seqSize,
Catch
( '[ 'Dim ('Name "*") 'UncheckedSize,
'Dim ('Name "*") 'UncheckedSize
]
<+> '[ 'Dim ('Name "*") seqSize, 'Dim ('Name "*") seqSize]
),
output
~ Tensor
('Gradient 'WithoutGradient)
('Layout 'Dense)
device
('DataType 'Int64)
('Shape '[ 'Dim ('Name "*") ('Size 1), 'Dim ('Name "*") seqSize, 'Dim ('Name "*") seqSize])
)
mkRelPos ::
forall m gradient layout device dataType shape relPosEncBucketDim seqDim seqName seqSize output.
( MonadThrow m,
MkRelPosC device shape seqDim seqName seqSize output
) =>
SDim relPosEncBucketDim ->
Int ->
Tensor gradient layout device dataType shape ->
m output
mkRelPos :: forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(relPosEncBucketDim :: Dim (Name Symbol) (Size Nat))
(seqDim :: Dim (Name Symbol) (Size Nat)) (seqName :: Name Symbol)
(seqSize :: Size Nat) output.
(MonadThrow m,
MkRelPosC device shape seqDim seqName seqSize output) =>
SDim relPosEncBucketDim
-> Int -> Tensor gradient layout device dataType shape -> m output
mkRelPos SDim relPosEncBucketDim
relPosEncBucketDim Int
maxDistance Tensor gradient layout device dataType shape
input = do
SDim ('Dim seqName seqSize)
seqDim <- forall (selectDim :: SelectDim (By Symbol Nat))
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(dim :: Dim (Name Symbol) (Size Nat)) (m :: * -> *).
(dim ~ GetDimF selectDim shape, MonadThrow m) =>
SSelectDim selectDim -> SShape shape -> m (SDim dim)
sGetDimFromShape (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim forall a b. (a -> b) -> a -> b
$ forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @1) forall a b. (a -> b) -> a -> b
$ forall (shape :: Shape [Dim (Name Symbol) (Size Nat)])
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType).
SGetShape shape =>
Tensor gradient layout device dataType shape -> SShape shape
sGetShape Tensor gradient layout device dataType shape
input
let seqSize :: Int
seqSize = forall a. Num a => Integer -> a
fromInteger forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ forall (name :: Name Symbol) (size :: Size Nat).
SDim ('Dim name size) -> SSize size
sDimSize SDim ('Dim seqName seqSize)
seqDim
forall a (dType :: DType) (dims :: [Dim (Name Symbol) (Size Nat)])
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(m :: * -> *).
(TensorLike a dType dims, MonadThrow m) =>
SGradient gradient
-> SLayout layout
-> SDevice device
-> a
-> m (Tensor
gradient layout device ('DataType dType) ('Shape dims))
sToTensor SGradient ('Gradient 'WithoutGradient)
gradient' SLayout ('Layout 'Dense)
layout' SDevice device
device [Int -> Int -> Int -> Int -> [[Int]]
mkRelPos' Int
relPosEncBucketSize Int
maxDistance Int
seqSize Int
seqSize]
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
(m :: * -> *) (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(SGetShape shape, MonadThrow m, Catch (shape <+> shape')) =>
SShape shape'
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
sCheckedShape (forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SList dims -> SShape ('Shape dims)
SShape forall a b. (a -> b) -> a -> b
$ forall (name1 :: Symbol). KnownSymbol name1 => SName ('Name name1)
SName @"*" forall (name :: Name Symbol) (size :: Size Nat).
SName name -> SSize size -> SDim ('Dim name size)
:&: forall (size1 :: Nat). KnownNat size1 => SSize ('Size size1)
SSize @1 forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall (name1 :: Symbol). KnownSymbol name1 => SName ('Name name1)
SName @"*" forall (name :: Name Symbol) (size :: Size Nat).
SName name -> SSize size -> SDim ('Dim name size)
:&: forall (name :: Name Symbol) (size :: Size Nat).
SDim ('Dim name size) -> SSize size
sDimSize SDim ('Dim seqName seqSize)
seqDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall (name1 :: Symbol). KnownSymbol name1 => SName ('Name name1)
SName @"*" forall (name :: Name Symbol) (size :: Size Nat).
SName name -> SSize size -> SDim ('Dim name size)
:&: forall (name :: Name Symbol) (size :: Size Nat).
SDim ('Dim name size) -> SSize size
sDimSize SDim ('Dim seqName seqSize)
seqDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall a. SList '[]
SNil)
where
gradient' :: SGradient ('Gradient 'WithoutGradient)
gradient' = forall (requiresGradient :: RequiresGradient).
SRequiresGradient requiresGradient
-> SGradient ('Gradient requiresGradient)
SGradient SRequiresGradient 'WithoutGradient
SWithoutGradient
layout' :: SLayout ('Layout 'Dense)
layout' = forall (layoutType :: LayoutType).
SLayoutType layoutType -> SLayout ('Layout layoutType)
SLayout SLayoutType 'Dense
SDense
device :: SDevice device
device = forall (device :: Device (DeviceType Nat))
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGetDevice device =>
Tensor gradient layout device dataType shape -> SDevice device
sGetDevice Tensor gradient layout device dataType shape
input
relPosEncBucketSize :: Int
relPosEncBucketSize = forall a. Num a => Integer -> a
fromInteger forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall name size. Dim name size -> size
dimSize forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SDim relPosEncBucketDim
relPosEncBucketDim
mkDecoderRelPos' :: Int -> Int -> Int -> Int -> [[Int]]
mkDecoderRelPos' :: Int -> Int -> Int -> Int -> [[Int]]
mkDecoderRelPos' Int
numBuckets Int
maxDistance Int
querySize Int
keySize =
let queryPos :: [Int]
queryPos = [Int
0, Int
1 .. Int
querySize forall a. Num a => a -> a -> a
- Int
1]
keyPos :: [Int]
keyPos = [Int
0, Int
1 .. Int
keySize forall a. Num a => a -> a -> a
- Int
1]
maxExact :: Int
maxExact = Int
numBuckets forall a. Integral a => a -> a -> a
`div` Int
2
in forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap
( \Int
qp ->
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap
( \Int
kp ->
let rawRelPos :: Int
rawRelPos = Int
kp forall a. Num a => a -> a -> a
- Int
qp
absRelPos :: Int
absRelPos = forall a. Num a => a -> a
negate forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Ord a => a -> a -> a
min Int
0 forall a b. (a -> b) -> a -> b
$ Int
rawRelPos
relBucket' :: Int
relBucket' =
let isSmall :: Bool
isSmall = Int
absRelPos forall a. Ord a => a -> a -> Bool
< Int
maxExact
relPosIfLarge :: Int
relPosIfLarge =
Int
maxExact
forall a. Num a => a -> a -> a
+ Double -> Int
double2Int
( forall a. Floating a => a -> a -> a
logBase
(forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
maxDistance forall a. Fractional a => a -> a -> a
/ forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
maxExact)
(forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
absRelPos forall a. Fractional a => a -> a -> a
/ forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
maxExact)
forall a. Num a => a -> a -> a
* forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
numBuckets forall a. Num a => a -> a -> a
- Int
maxExact)
)
relPosIfLarge' :: Int
relPosIfLarge' = forall a. Ord a => a -> a -> a
min Int
relPosIfLarge (Int
numBuckets forall a. Num a => a -> a -> a
- Int
1)
in if Bool
isSmall then Int
absRelPos else Int
relPosIfLarge'
in Int
relBucket'
)
[Int]
keyPos
)
[Int]
queryPos
mkDecoderRelPos ::
forall m gradient layout device dataType shape relPosEncBucketDim seqDim seqName seqSize output.
( MonadThrow m,
MkRelPosC device shape seqDim seqName seqSize output
) =>
SDim relPosEncBucketDim ->
Int ->
Tensor gradient layout device dataType shape ->
m output
mkDecoderRelPos :: forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(relPosEncBucketDim :: Dim (Name Symbol) (Size Nat))
(seqDim :: Dim (Name Symbol) (Size Nat)) (seqName :: Name Symbol)
(seqSize :: Size Nat) output.
(MonadThrow m,
MkRelPosC device shape seqDim seqName seqSize output) =>
SDim relPosEncBucketDim
-> Int -> Tensor gradient layout device dataType shape -> m output
mkDecoderRelPos SDim relPosEncBucketDim
relPosEncBucketDim Int
maxDistance Tensor gradient layout device dataType shape
input = do
SDim ('Dim seqName seqSize)
seqDim <- forall (selectDim :: SelectDim (By Symbol Nat))
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(dim :: Dim (Name Symbol) (Size Nat)) (m :: * -> *).
(dim ~ GetDimF selectDim shape, MonadThrow m) =>
SSelectDim selectDim -> SShape shape -> m (SDim dim)
sGetDimFromShape (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim forall a b. (a -> b) -> a -> b
$ forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @1) forall a b. (a -> b) -> a -> b
$ forall (shape :: Shape [Dim (Name Symbol) (Size Nat)])
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType).
SGetShape shape =>
Tensor gradient layout device dataType shape -> SShape shape
sGetShape Tensor gradient layout device dataType shape
input
let seqSize :: Int
seqSize = forall a. Num a => Integer -> a
fromInteger forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ forall (name :: Name Symbol) (size :: Size Nat).
SDim ('Dim name size) -> SSize size
sDimSize SDim ('Dim seqName seqSize)
seqDim
forall a (dType :: DType) (dims :: [Dim (Name Symbol) (Size Nat)])
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(m :: * -> *).
(TensorLike a dType dims, MonadThrow m) =>
SGradient gradient
-> SLayout layout
-> SDevice device
-> a
-> m (Tensor
gradient layout device ('DataType dType) ('Shape dims))
sToTensor SGradient ('Gradient 'WithoutGradient)
gradient' SLayout ('Layout 'Dense)
layout' SDevice device
device [Int -> Int -> Int -> Int -> [[Int]]
mkDecoderRelPos' Int
relPosEncBucketSize Int
maxDistance Int
seqSize Int
seqSize]
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
(m :: * -> *) (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(SGetShape shape, MonadThrow m, Catch (shape <+> shape')) =>
SShape shape'
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
sCheckedShape (forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SList dims -> SShape ('Shape dims)
SShape forall a b. (a -> b) -> a -> b
$ forall (name1 :: Symbol). KnownSymbol name1 => SName ('Name name1)
SName @"*" forall (name :: Name Symbol) (size :: Size Nat).
SName name -> SSize size -> SDim ('Dim name size)
:&: forall (size1 :: Nat). KnownNat size1 => SSize ('Size size1)
SSize @1 forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall (name1 :: Symbol). KnownSymbol name1 => SName ('Name name1)
SName @"*" forall (name :: Name Symbol) (size :: Size Nat).
SName name -> SSize size -> SDim ('Dim name size)
:&: forall (name :: Name Symbol) (size :: Size Nat).
SDim ('Dim name size) -> SSize size
sDimSize SDim ('Dim seqName seqSize)
seqDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall (name1 :: Symbol). KnownSymbol name1 => SName ('Name name1)
SName @"*" forall (name :: Name Symbol) (size :: Size Nat).
SName name -> SSize size -> SDim ('Dim name size)
:&: forall (name :: Name Symbol) (size :: Size Nat).
SDim ('Dim name size) -> SSize size
sDimSize SDim ('Dim seqName seqSize)
seqDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall a. SList '[]
SNil)
where
gradient' :: SGradient ('Gradient 'WithoutGradient)
gradient' = forall (requiresGradient :: RequiresGradient).
SRequiresGradient requiresGradient
-> SGradient ('Gradient requiresGradient)
SGradient SRequiresGradient 'WithoutGradient
SWithoutGradient
layout' :: SLayout ('Layout 'Dense)
layout' = forall (layoutType :: LayoutType).
SLayoutType layoutType -> SLayout ('Layout layoutType)
SLayout SLayoutType 'Dense
SDense
device :: SDevice device
device = forall (device :: Device (DeviceType Nat))
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGetDevice device =>
Tensor gradient layout device dataType shape -> SDevice device
sGetDevice Tensor gradient layout device dataType shape
input
relPosEncBucketSize :: Int
relPosEncBucketSize = forall a. Num a => Integer -> a
fromInteger forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall name size. Dim name size -> size
dimSize forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SDim relPosEncBucketDim
relPosEncBucketDim
data MkRelPos (relPosEncBucketDim :: Dim (Name Symbol) (Size Nat)) where
MkRelPos ::
forall relPosEncBucketDim.
{ forall (relPosEncBucketDim :: Dim (Name Symbol) (Size Nat)).
MkRelPos relPosEncBucketDim -> SDim relPosEncBucketDim
relPosEncBucketDim :: SDim relPosEncBucketDim,
forall (relPosEncBucketDim :: Dim (Name Symbol) (Size Nat)).
MkRelPos relPosEncBucketDim -> Int
relPosMaxDistance :: Int
} ->
MkRelPos relPosEncBucketDim
MkDecoderRelPos ::
forall relPosEncBucketDim.
{ forall (relPosEncBucketDim :: Dim (Name Symbol) (Size Nat)).
MkRelPos relPosEncBucketDim -> SDim relPosEncBucketDim
decoderRelPosEncBucketDim :: SDim relPosEncBucketDim,
forall (relPosEncBucketDim :: Dim (Name Symbol) (Size Nat)).
MkRelPos relPosEncBucketDim -> Int
decoderRelPosMaxDistance :: Int
} ->
MkRelPos relPosEncBucketDim
deriving stock (Int -> MkRelPos relPosEncBucketDim -> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (relPosEncBucketDim :: Dim (Name Symbol) (Size Nat)).
Int -> MkRelPos relPosEncBucketDim -> ShowS
forall (relPosEncBucketDim :: Dim (Name Symbol) (Size Nat)).
[MkRelPos relPosEncBucketDim] -> ShowS
forall (relPosEncBucketDim :: Dim (Name Symbol) (Size Nat)).
MkRelPos relPosEncBucketDim -> String
showList :: [MkRelPos relPosEncBucketDim] -> ShowS
$cshowList :: forall (relPosEncBucketDim :: Dim (Name Symbol) (Size Nat)).
[MkRelPos relPosEncBucketDim] -> ShowS
show :: MkRelPos relPosEncBucketDim -> String
$cshow :: forall (relPosEncBucketDim :: Dim (Name Symbol) (Size Nat)).
MkRelPos relPosEncBucketDim -> String
showsPrec :: Int -> MkRelPos relPosEncBucketDim -> ShowS
$cshowsPrec :: forall (relPosEncBucketDim :: Dim (Name Symbol) (Size Nat)).
Int -> MkRelPos relPosEncBucketDim -> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (relPosEncBucketDim :: Dim (Name Symbol) (Size Nat)) x.
Rep (MkRelPos relPosEncBucketDim) x -> MkRelPos relPosEncBucketDim
forall (relPosEncBucketDim :: Dim (Name Symbol) (Size Nat)) x.
MkRelPos relPosEncBucketDim -> Rep (MkRelPos relPosEncBucketDim) x
$cto :: forall (relPosEncBucketDim :: Dim (Name Symbol) (Size Nat)) x.
Rep (MkRelPos relPosEncBucketDim) x -> MkRelPos relPosEncBucketDim
$cfrom :: forall (relPosEncBucketDim :: Dim (Name Symbol) (Size Nat)) x.
MkRelPos relPosEncBucketDim -> Rep (MkRelPos relPosEncBucketDim) x
Generic)
type instance ModelSpec (MkRelPos relPosEncBucketDim) = MkRelPos relPosEncBucketDim
instance HasInitialize (MkRelPos relPosEncBucketDim) generatorDevice (MkRelPos relPosEncBucketDim) generatorDevice where
initialize :: forall (m :: * -> *).
MonadThrow m =>
ModelSpec (MkRelPos relPosEncBucketDim)
-> Generator generatorDevice
-> m (MkRelPos relPosEncBucketDim, Generator generatorDevice)
initialize ModelSpec (MkRelPos relPosEncBucketDim)
spec Generator generatorDevice
g = forall (f :: * -> *) a. Applicative f => a -> f a
pure (ModelSpec (MkRelPos relPosEncBucketDim)
spec, Generator generatorDevice
g)
instance HasStateDict (MkRelPos relPosEncBucketDim) where
fromStateDict :: forall (m :: * -> *).
(MonadIO m, MonadThrow m, MonadState StateDict m) =>
ModelSpec (MkRelPos relPosEncBucketDim)
-> StateDictKey -> m (MkRelPos relPosEncBucketDim)
fromStateDict ModelSpec (MkRelPos relPosEncBucketDim)
spec StateDictKey
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ModelSpec (MkRelPos relPosEncBucketDim)
spec
toStateDict :: forall (m :: * -> *).
(MonadThrow m, MonadState StateDict m) =>
StateDictKey -> MkRelPos relPosEncBucketDim -> m ()
toStateDict StateDictKey
_ MkRelPos relPosEncBucketDim
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
instance
MkRelPosC device shape seqDim seqName seqSize output =>
HasForward
(MkRelPos relPosEncBucketDim)
(Tensor gradient layout device dataType shape)
generatorDevice
( Tensor
('Gradient 'WithoutGradient)
('Layout 'Dense)
device
('DataType 'Int64)
('Shape '[ 'Dim ('Name "*") ('Size 1), 'Dim ('Name "*") seqSize, 'Dim ('Name "*") seqSize])
)
generatorDevice
where
forward :: forall (m :: * -> *).
MonadThrow m =>
MkRelPos relPosEncBucketDim
-> Tensor gradient layout device dataType shape
-> Generator generatorDevice
-> m (Tensor
('Gradient 'WithoutGradient)
('Layout 'Dense)
device
('DataType 'Int64)
('Shape
'[ 'Dim ('Name "*") ('Size 1), 'Dim ('Name "*") seqSize,
'Dim ('Name "*") seqSize]),
Generator generatorDevice)
forward MkRelPos {Int
SDim relPosEncBucketDim
relPosMaxDistance :: Int
relPosEncBucketDim :: SDim relPosEncBucketDim
relPosMaxDistance :: forall (relPosEncBucketDim :: Dim (Name Symbol) (Size Nat)).
MkRelPos relPosEncBucketDim -> Int
relPosEncBucketDim :: forall (relPosEncBucketDim :: Dim (Name Symbol) (Size Nat)).
MkRelPos relPosEncBucketDim -> SDim relPosEncBucketDim
..} Tensor gradient layout device dataType shape
input Generator generatorDevice
g = do
Tensor
('Gradient 'WithoutGradient)
('Layout 'Dense)
device
('DataType 'Int64)
('Shape
'[ 'Dim ('Name "*") ('Size 1), 'Dim ('Name "*") seqSize,
'Dim ('Name "*") seqSize])
relPos <- forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(relPosEncBucketDim :: Dim (Name Symbol) (Size Nat))
(seqDim :: Dim (Name Symbol) (Size Nat)) (seqName :: Name Symbol)
(seqSize :: Size Nat) output.
(MonadThrow m,
MkRelPosC device shape seqDim seqName seqSize output) =>
SDim relPosEncBucketDim
-> Int -> Tensor gradient layout device dataType shape -> m output
mkRelPos SDim relPosEncBucketDim
relPosEncBucketDim Int
relPosMaxDistance Tensor gradient layout device dataType shape
input
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Tensor
('Gradient 'WithoutGradient)
('Layout 'Dense)
device
('DataType 'Int64)
('Shape
'[ 'Dim ('Name "*") ('Size 1), 'Dim ('Name "*") seqSize,
'Dim ('Name "*") seqSize])
relPos, Generator generatorDevice
g)
forward MkDecoderRelPos {Int
SDim relPosEncBucketDim
decoderRelPosMaxDistance :: Int
decoderRelPosEncBucketDim :: SDim relPosEncBucketDim
decoderRelPosMaxDistance :: forall (relPosEncBucketDim :: Dim (Name Symbol) (Size Nat)).
MkRelPos relPosEncBucketDim -> Int
decoderRelPosEncBucketDim :: forall (relPosEncBucketDim :: Dim (Name Symbol) (Size Nat)).
MkRelPos relPosEncBucketDim -> SDim relPosEncBucketDim
..} Tensor gradient layout device dataType shape
input Generator generatorDevice
g = do
Tensor
('Gradient 'WithoutGradient)
('Layout 'Dense)
device
('DataType 'Int64)
('Shape
'[ 'Dim ('Name "*") ('Size 1), 'Dim ('Name "*") seqSize,
'Dim ('Name "*") seqSize])
decoderRelPos <- forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(relPosEncBucketDim :: Dim (Name Symbol) (Size Nat))
(seqDim :: Dim (Name Symbol) (Size Nat)) (seqName :: Name Symbol)
(seqSize :: Size Nat) output.
(MonadThrow m,
MkRelPosC device shape seqDim seqName seqSize output) =>
SDim relPosEncBucketDim
-> Int -> Tensor gradient layout device dataType shape -> m output
mkDecoderRelPos SDim relPosEncBucketDim
decoderRelPosEncBucketDim Int
decoderRelPosMaxDistance Tensor gradient layout device dataType shape
input
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Tensor
('Gradient 'WithoutGradient)
('Layout 'Dense)
device
('DataType 'Int64)
('Shape
'[ 'Dim ('Name "*") ('Size 1), 'Dim ('Name "*") seqSize,
'Dim ('Name "*") seqSize])
decoderRelPos, Generator generatorDevice
g)
type MkTransformerPaddingMaskC layout device dataType shape output =
( SGetDevice device,
Catch (dataType <+> 'DataType 'Int64),
Catch (BroadcastShapesF shape ('Shape '[])),
output
~ Tensor
('Gradient 'WithoutGradient)
(layout <+> 'Layout 'Dense)
device
('DataType 'Bool)
(BroadcastShapesF shape ('Shape '[]))
)
mkTransformerPaddingMask ::
forall m gradient layout device dataType shape output.
( MonadThrow m,
MkTransformerPaddingMaskC layout device dataType shape output
) =>
Int ->
Tensor gradient layout device dataType shape ->
m output
mkTransformerPaddingMask :: forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) output.
(MonadThrow m,
MkTransformerPaddingMaskC layout device dataType shape output) =>
Int -> Tensor gradient layout device dataType shape -> m output
mkTransformerPaddingMask Int
padTokenId Tensor gradient layout device dataType shape
input = do
let device :: SDevice device
device = forall (device :: Device (DeviceType Nat))
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGetDevice device =>
Tensor gradient layout device dataType shape -> SDevice device
sGetDevice Tensor gradient layout device dataType shape
input
Tensor
('Gradient 'WithoutGradient)
('Layout 'Dense)
device
('DataType 'Int64)
('Shape '[])
padToken <- forall a (dType :: DType) (dims :: [Dim (Name Symbol) (Size Nat)])
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(m :: * -> *).
(TensorLike a dType dims, MonadThrow m) =>
SGradient gradient
-> SLayout layout
-> SDevice device
-> a
-> m (Tensor
gradient layout device ('DataType dType) ('Shape dims))
sToTensor (forall (requiresGradient :: RequiresGradient).
SRequiresGradient requiresGradient
-> SGradient ('Gradient requiresGradient)
SGradient SRequiresGradient 'WithoutGradient
SWithoutGradient) (forall (layoutType :: LayoutType).
SLayoutType layoutType -> SLayout ('Layout layoutType)
SLayout SLayoutType 'Dense
SDense) SDevice device
device Int
padTokenId
Tensor gradient layout device dataType shape
input forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(gradient' :: Gradient RequiresGradient)
(layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
(dataType' :: DataType DType)
(shape' :: Shape [Dim (Name Symbol) (Size Nat)])
(shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, Catch (dataType <+> dataType'),
shape'' ~ BroadcastShapesF shape shape', Catch shape'') =>
Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> m (Tensor
('Gradient 'WithoutGradient)
(layout <+> layout')
(device <+> device')
('DataType 'Bool)
shape'')
==. Tensor
('Gradient 'WithoutGradient)
('Layout 'Dense)
device
('DataType 'Int64)
('Shape '[])
padToken
newtype MkTransformerPaddingMask = MkTransformerPaddingMask {MkTransformerPaddingMask -> Int
padTokenId :: Int}
deriving stock (Int -> MkTransformerPaddingMask -> ShowS
[MkTransformerPaddingMask] -> ShowS
MkTransformerPaddingMask -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [MkTransformerPaddingMask] -> ShowS
$cshowList :: [MkTransformerPaddingMask] -> ShowS
show :: MkTransformerPaddingMask -> String
$cshow :: MkTransformerPaddingMask -> String
showsPrec :: Int -> MkTransformerPaddingMask -> ShowS
$cshowsPrec :: Int -> MkTransformerPaddingMask -> ShowS
Show, forall x.
Rep MkTransformerPaddingMask x -> MkTransformerPaddingMask
forall x.
MkTransformerPaddingMask -> Rep MkTransformerPaddingMask x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x.
Rep MkTransformerPaddingMask x -> MkTransformerPaddingMask
$cfrom :: forall x.
MkTransformerPaddingMask -> Rep MkTransformerPaddingMask x
Generic)
type instance
ModelSpec MkTransformerPaddingMask =
MkTransformerPaddingMask
instance
HasInitialize
MkTransformerPaddingMask
generatorDevice
MkTransformerPaddingMask
generatorDevice
where
initialize :: forall (m :: * -> *).
MonadThrow m =>
ModelSpec MkTransformerPaddingMask
-> Generator generatorDevice
-> m (MkTransformerPaddingMask, Generator generatorDevice)
initialize ModelSpec MkTransformerPaddingMask
spec Generator generatorDevice
g = forall (f :: * -> *) a. Applicative f => a -> f a
pure (ModelSpec MkTransformerPaddingMask
spec, Generator generatorDevice
g)
instance HasStateDict MkTransformerPaddingMask where
fromStateDict :: forall (m :: * -> *).
(MonadIO m, MonadThrow m, MonadState StateDict m) =>
ModelSpec MkTransformerPaddingMask
-> StateDictKey -> m MkTransformerPaddingMask
fromStateDict ModelSpec MkTransformerPaddingMask
spec StateDictKey
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ModelSpec MkTransformerPaddingMask
spec
toStateDict :: forall (m :: * -> *).
(MonadThrow m, MonadState StateDict m) =>
StateDictKey -> MkTransformerPaddingMask -> m ()
toStateDict StateDictKey
_ MkTransformerPaddingMask
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
instance
MkTransformerPaddingMaskC layout device dataType shape output =>
HasForward
MkTransformerPaddingMask
(Tensor gradient layout device dataType shape)
generatorDevice
output
generatorDevice
where
forward :: forall (m :: * -> *).
MonadThrow m =>
MkTransformerPaddingMask
-> Tensor gradient layout device dataType shape
-> Generator generatorDevice
-> m (output, Generator generatorDevice)
forward MkTransformerPaddingMask {Int
padTokenId :: Int
padTokenId :: MkTransformerPaddingMask -> Int
..} Tensor gradient layout device dataType shape
input Generator generatorDevice
g = do
output
paddingMask <- forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) output.
(MonadThrow m,
MkTransformerPaddingMaskC layout device dataType shape output) =>
Int -> Tensor gradient layout device dataType shape -> m output
mkTransformerPaddingMask Int
padTokenId Tensor gradient layout device dataType shape
input
forall (f :: * -> *) a. Applicative f => a -> f a
pure (output
paddingMask, Generator generatorDevice
g)
type MkTransformerAttentionMaskC transformerDataType gradient layout device dataType shape seqDim output =
( SGetLayout layout,
SGetDevice device,
SGetShape shape,
seqDim ~ (shape ! 1),
Catch (gradient <+> 'Gradient 'WithoutGradient),
Catch (dataType <+> 'DataType 'Bool),
Catch (UnsqueezeF ('SelectDim ('ByIndex 1)) shape),
Catch
( BroadcastShapesF
(UnsqueezeF ('SelectDim ('ByIndex 1)) shape)
('Shape '[ 'Dim ('Name "*") ('Size 1), seqDim, seqDim])
),
output
~ Tensor
('Gradient 'WithoutGradient)
(layout <+> 'Layout 'Dense)
device
transformerDataType
( BroadcastShapesF
(UnsqueezeF ('SelectDim ('ByIndex 1)) shape)
('Shape '[ 'Dim ('Name "*") ('Size 1), seqDim, seqDim])
)
)
mkTransformerAttentionMask ::
forall m transformerDataType gradient layout device dataType shape seqDim output.
( MonadThrow m,
MkTransformerAttentionMaskC transformerDataType gradient layout device dataType shape seqDim output
) =>
SDataType transformerDataType ->
Double ->
Tensor gradient layout device dataType shape ->
m output
mkTransformerAttentionMask :: forall (m :: * -> *) (transformerDataType :: DataType DType)
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(seqDim :: Dim (Name Symbol) (Size Nat)) output.
(MonadThrow m,
MkTransformerAttentionMaskC
transformerDataType
gradient
layout
device
dataType
shape
seqDim
output) =>
SDataType transformerDataType
-> Double
-> Tensor gradient layout device dataType shape
-> m output
mkTransformerAttentionMask SDataType transformerDataType
transformerDataType Double
attentionMaskBias Tensor gradient layout device dataType shape
paddingMask = do
let pmLayout :: SLayout layout
pmLayout = forall (layout :: Layout LayoutType)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat)) (dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGetLayout layout =>
Tensor gradient layout device dataType shape -> SLayout layout
sGetLayout Tensor gradient layout device dataType shape
paddingMask
pmDevice :: SDevice device
pmDevice = forall (device :: Device (DeviceType Nat))
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGetDevice device =>
Tensor gradient layout device dataType shape -> SDevice device
sGetDevice Tensor gradient layout device dataType shape
paddingMask
pmShape :: SShape shape
pmShape = forall (shape :: Shape [Dim (Name Symbol) (Size Nat)])
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType).
SGetShape shape =>
Tensor gradient layout device dataType shape -> SShape shape
sGetShape Tensor gradient layout device dataType shape
paddingMask
SDim seqDim
pmSeqDim <- forall (selectDim :: SelectDim (By Symbol Nat))
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(dim :: Dim (Name Symbol) (Size Nat)) (m :: * -> *).
(dim ~ GetDimF selectDim shape, MonadThrow m) =>
SSelectDim selectDim -> SShape shape -> m (SDim dim)
sGetDimFromShape (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim forall a b. (a -> b) -> a -> b
$ forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @1) SShape shape
pmShape
Tensor
('Gradient 'WithoutGradient)
layout
device
transformerDataType
('Shape '[ 'Dim ('Name "*") ('Size 1), seqDim, seqDim])
emptyMask <-
forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
MonadThrow m =>
TensorSpec gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape)
sZeros forall a b. (a -> b) -> a -> b
$
forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> SShape shape
-> TensorSpec gradient layout device dataType shape
TensorSpec
(forall (requiresGradient :: RequiresGradient).
SRequiresGradient requiresGradient
-> SGradient ('Gradient requiresGradient)
SGradient SRequiresGradient 'WithoutGradient
SWithoutGradient)
SLayout layout
pmLayout
SDevice device
pmDevice
SDataType transformerDataType
transformerDataType
(forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SList dims -> SShape ('Shape dims)
SShape forall a b. (a -> b) -> a -> b
$ forall (name1 :: Symbol). KnownSymbol name1 => SName ('Name name1)
SName @"*" forall (name :: Name Symbol) (size :: Size Nat).
SName name -> SSize size -> SDim ('Dim name size)
:&: forall (size1 :: Nat). KnownNat size1 => SSize ('Size size1)
SSize @1 forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: SDim seqDim
pmSeqDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: SDim seqDim
pmSeqDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall a. SList '[]
SNil)
Tensor
gradient
layout
device
dataType
(UnsqueezeF ('SelectDim ('ByIndex 1)) shape)
paddingMask' <- forall (selectDim :: SelectDim (By Symbol Nat))
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(shape' ~ UnsqueezeF selectDim shape, Catch shape',
SingI selectDim, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
unsqueeze @('SelectDim ('ByIndex 1)) Tensor gradient layout device dataType shape
paddingMask
forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) value
(gradient' :: Gradient RequiresGradient)
(layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
(dataType' :: DataType DType)
(shape' :: Shape [Dim (Name Symbol) (Size Nat)])
(shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(Scalar value, MonadThrow m,
Catch (gradient <+> 'Gradient 'WithoutGradient),
Catch (dataType <+> 'DataType 'Bool),
shape'' ~ BroadcastShapesF shape shape', Catch shape'') =>
Tensor gradient layout device dataType shape
-> value
-> Tensor gradient' layout' device' dataType' shape'
-> m (Tensor
gradient'
(layout <+> (layout' <+> 'Layout 'Dense))
(device <+> device')
dataType'
shape'')
maskedFill Tensor
gradient
layout
device
dataType
(UnsqueezeF ('SelectDim ('ByIndex 1)) shape)
paddingMask' Double
attentionMaskBias Tensor
('Gradient 'WithoutGradient)
layout
device
transformerDataType
('Shape '[ 'Dim ('Name "*") ('Size 1), seqDim, seqDim])
emptyMask
data MkTransformerAttentionMask (dataType :: DataType DType) where
MkTransformerAttentionMask ::
forall dataType.
{ forall (dataType :: DataType DType).
MkTransformerAttentionMask dataType -> SDataType dataType
attentionMaskDataType :: SDataType dataType,
forall (dataType :: DataType DType).
MkTransformerAttentionMask dataType -> Double
attentionMaskBias :: Double
} ->
MkTransformerAttentionMask dataType
deriving stock (Int -> MkTransformerAttentionMask dataType -> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (dataType :: DataType DType).
Int -> MkTransformerAttentionMask dataType -> ShowS
forall (dataType :: DataType DType).
[MkTransformerAttentionMask dataType] -> ShowS
forall (dataType :: DataType DType).
MkTransformerAttentionMask dataType -> String
showList :: [MkTransformerAttentionMask dataType] -> ShowS
$cshowList :: forall (dataType :: DataType DType).
[MkTransformerAttentionMask dataType] -> ShowS
show :: MkTransformerAttentionMask dataType -> String
$cshow :: forall (dataType :: DataType DType).
MkTransformerAttentionMask dataType -> String
showsPrec :: Int -> MkTransformerAttentionMask dataType -> ShowS
$cshowsPrec :: forall (dataType :: DataType DType).
Int -> MkTransformerAttentionMask dataType -> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (dataType :: DataType DType) x.
Rep (MkTransformerAttentionMask dataType) x
-> MkTransformerAttentionMask dataType
forall (dataType :: DataType DType) x.
MkTransformerAttentionMask dataType
-> Rep (MkTransformerAttentionMask dataType) x
$cto :: forall (dataType :: DataType DType) x.
Rep (MkTransformerAttentionMask dataType) x
-> MkTransformerAttentionMask dataType
$cfrom :: forall (dataType :: DataType DType) x.
MkTransformerAttentionMask dataType
-> Rep (MkTransformerAttentionMask dataType) x
Generic)
type instance
ModelSpec (MkTransformerAttentionMask dataType) =
MkTransformerAttentionMask dataType
instance
HasInitialize
(MkTransformerAttentionMask dataType)
generatorDevice
(MkTransformerAttentionMask dataType)
generatorDevice
where
initialize :: forall (m :: * -> *).
MonadThrow m =>
ModelSpec (MkTransformerAttentionMask dataType)
-> Generator generatorDevice
-> m (MkTransformerAttentionMask dataType,
Generator generatorDevice)
initialize ModelSpec (MkTransformerAttentionMask dataType)
spec Generator generatorDevice
g = forall (f :: * -> *) a. Applicative f => a -> f a
pure (ModelSpec (MkTransformerAttentionMask dataType)
spec, Generator generatorDevice
g)
instance HasStateDict (MkTransformerAttentionMask dataType) where
fromStateDict :: forall (m :: * -> *).
(MonadIO m, MonadThrow m, MonadState StateDict m) =>
ModelSpec (MkTransformerAttentionMask dataType)
-> StateDictKey -> m (MkTransformerAttentionMask dataType)
fromStateDict ModelSpec (MkTransformerAttentionMask dataType)
spec StateDictKey
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ModelSpec (MkTransformerAttentionMask dataType)
spec
toStateDict :: forall (m :: * -> *).
(MonadThrow m, MonadState StateDict m) =>
StateDictKey -> MkTransformerAttentionMask dataType -> m ()
toStateDict StateDictKey
_ MkTransformerAttentionMask dataType
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
instance
MkTransformerAttentionMaskC dataType inputGradient inputLayout inputDevice inputDataType inputShape seqDim output =>
HasForward
(MkTransformerAttentionMask dataType)
(Tensor inputGradient inputLayout inputDevice inputDataType inputShape)
generatorDevice
output
generatorDevice
where
forward :: forall (m :: * -> *).
MonadThrow m =>
MkTransformerAttentionMask dataType
-> Tensor
inputGradient inputLayout inputDevice inputDataType inputShape
-> Generator generatorDevice
-> m (output, Generator generatorDevice)
forward MkTransformerAttentionMask {Double
SDataType dataType
attentionMaskBias :: Double
attentionMaskDataType :: SDataType dataType
attentionMaskBias :: forall (dataType :: DataType DType).
MkTransformerAttentionMask dataType -> Double
attentionMaskDataType :: forall (dataType :: DataType DType).
MkTransformerAttentionMask dataType -> SDataType dataType
..} Tensor
inputGradient inputLayout inputDevice inputDataType inputShape
input Generator generatorDevice
g = do
output
attentionMask <- forall (m :: * -> *) (transformerDataType :: DataType DType)
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(seqDim :: Dim (Name Symbol) (Size Nat)) output.
(MonadThrow m,
MkTransformerAttentionMaskC
transformerDataType
gradient
layout
device
dataType
shape
seqDim
output) =>
SDataType transformerDataType
-> Double
-> Tensor gradient layout device dataType shape
-> m output
mkTransformerAttentionMask SDataType dataType
attentionMaskDataType Double
attentionMaskBias Tensor
inputGradient inputLayout inputDevice inputDataType inputShape
input
forall (f :: * -> *) a. Applicative f => a -> f a
pure (output
attentionMask, Generator generatorDevice
g)
type MkTransformerDecoderAttentionMaskC transformerDataType layout device shape seqDim output =
( SGetLayout layout,
SGetDevice device,
SGetShape shape,
seqDim ~ (shape ! 1),
Catch seqDim,
Catch (UnsqueezeF ('SelectDim ('ByIndex 1)) shape),
Catch
( BroadcastShapesF
('Shape '[ 'Dim ('Name "*") ('Size 1), seqDim, seqDim])
(UnsqueezeF ('SelectDim ('ByIndex 1)) shape)
),
Catch
( BroadcastShapesF
( BroadcastShapesF
('Shape '[ 'Dim ('Name "*") ('Size 1), seqDim, seqDim])
(UnsqueezeF ('SelectDim ('ByIndex 1)) shape)
)
('Shape '[ 'Dim ('Name "*") ('Size 1), seqDim, seqDim])
),
output
~ Tensor
('Gradient 'WithoutGradient)
(layout <+> 'Layout 'Dense)
device
transformerDataType
( BroadcastShapesF
( BroadcastShapesF
('Shape '[ 'Dim ('Name "*") ('Size 1), seqDim, seqDim])
(UnsqueezeF ('SelectDim ('ByIndex 1)) shape)
)
('Shape '[ 'Dim ('Name "*") ('Size 1), seqDim, seqDim])
)
)
mkTransformerDecoderAttentionMask ::
forall m transformerDataType gradient layout device dataType shape seqDim output.
( MonadThrow m,
MkTransformerDecoderAttentionMaskC transformerDataType layout device shape seqDim output
) =>
SDataType transformerDataType ->
Double ->
Tensor gradient layout device dataType shape ->
m output
mkTransformerDecoderAttentionMask :: forall (m :: * -> *) (transformerDataType :: DataType DType)
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(seqDim :: Dim (Name Symbol) (Size Nat)) output.
(MonadThrow m,
MkTransformerDecoderAttentionMaskC
transformerDataType layout device shape seqDim output) =>
SDataType transformerDataType
-> Double
-> Tensor gradient layout device dataType shape
-> m output
mkTransformerDecoderAttentionMask SDataType transformerDataType
transformerDataType Double
attentionMaskBias Tensor gradient layout device dataType shape
paddingMask = do
let pmLayout :: SLayout layout
pmLayout = forall (layout :: Layout LayoutType)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat)) (dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGetLayout layout =>
Tensor gradient layout device dataType shape -> SLayout layout
sGetLayout Tensor gradient layout device dataType shape
paddingMask
pmDevice :: SDevice device
pmDevice = forall (device :: Device (DeviceType Nat))
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGetDevice device =>
Tensor gradient layout device dataType shape -> SDevice device
sGetDevice Tensor gradient layout device dataType shape
paddingMask
pmShape :: SShape shape
pmShape = forall (shape :: Shape [Dim (Name Symbol) (Size Nat)])
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType).
SGetShape shape =>
Tensor gradient layout device dataType shape -> SShape shape
sGetShape Tensor gradient layout device dataType shape
paddingMask
SDim seqDim
pmSeqDim <- forall (selectDim :: SelectDim (By Symbol Nat))
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(dim :: Dim (Name Symbol) (Size Nat)) (m :: * -> *).
(dim ~ GetDimF selectDim shape, MonadThrow m) =>
SSelectDim selectDim -> SShape shape -> m (SDim dim)
sGetDimFromShape (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim forall a b. (a -> b) -> a -> b
$ forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @1) SShape shape
pmShape
Tensor
('Gradient 'WithoutGradient)
layout
device
('DataType 'Bool)
('Shape '[seqDim, seqDim])
causalMask <-
forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
MonadThrow m =>
Tensor gradient layout device dataType shape
-> m (Tensor
('Gradient 'WithoutGradient) layout device ('DataType 'Bool) shape)
bool forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Int
-> Tensor gradient layout device dataType shape
-> Tensor gradient layout device dataType shape
triu Int
1
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
MonadThrow m =>
TensorSpec gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape)
sOnes
( forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> SShape shape
-> TensorSpec gradient layout device dataType shape
TensorSpec
(forall (requiresGradient :: RequiresGradient).
SRequiresGradient requiresGradient
-> SGradient ('Gradient requiresGradient)
SGradient SRequiresGradient 'WithoutGradient
SWithoutGradient)
SLayout layout
pmLayout
SDevice device
pmDevice
SDataType transformerDataType
transformerDataType
(forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SList dims -> SShape ('Shape dims)
SShape forall a b. (a -> b) -> a -> b
$ SDim seqDim
pmSeqDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: SDim seqDim
pmSeqDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall a. SList '[]
SNil)
)
Tensor
('Gradient 'WithoutGradient)
layout
device
('DataType 'Bool)
('Shape '[ 'Dim ('Name "*") ('Size 1), seqDim, seqDim])
causalMask' <- forall (selectDim :: SelectDim (By Symbol Nat))
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(shape' ~ UnsqueezeF selectDim shape, Catch shape',
SingI selectDim, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
unsqueeze @('SelectDim ('ByIndex 0)) Tensor
('Gradient 'WithoutGradient)
layout
device
('DataType 'Bool)
('Shape '[seqDim, seqDim])
causalMask
Tensor
('Gradient 'WithoutGradient)
layout
device
transformerDataType
('Shape '[ 'Dim ('Name "*") ('Size 1), seqDim, seqDim])
emptyMask <-
forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
MonadThrow m =>
TensorSpec gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape)
sZeros forall a b. (a -> b) -> a -> b
$
forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> SShape shape
-> TensorSpec gradient layout device dataType shape
TensorSpec
(forall (requiresGradient :: RequiresGradient).
SRequiresGradient requiresGradient
-> SGradient ('Gradient requiresGradient)
SGradient SRequiresGradient 'WithoutGradient
SWithoutGradient)
SLayout layout
pmLayout
SDevice device
pmDevice
SDataType transformerDataType
transformerDataType
(forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SList dims -> SShape ('Shape dims)
SShape forall a b. (a -> b) -> a -> b
$ forall (name1 :: Symbol). KnownSymbol name1 => SName ('Name name1)
SName @"*" forall (name :: Name Symbol) (size :: Size Nat).
SName name -> SSize size -> SDim ('Dim name size)
:&: forall (size1 :: Nat). KnownNat size1 => SSize ('Size size1)
SSize @1 forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: SDim seqDim
pmSeqDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: SDim seqDim
pmSeqDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall a. SList '[]
SNil)
Tensor
gradient
layout
device
dataType
(UnsqueezeF ('SelectDim ('ByIndex 1)) shape)
paddingMask' <- forall (selectDim :: SelectDim (By Symbol Nat))
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(shape' ~ UnsqueezeF selectDim shape, Catch shape',
SingI selectDim, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
unsqueeze @('SelectDim ('ByIndex 1)) Tensor gradient layout device dataType shape
paddingMask
Tensor
('Gradient 'WithoutGradient)
layout
device
('DataType 'Bool)
(BroadcastShapesF
('Shape '[ 'Dim ('Name "*") ('Size 1), seqDim, seqDim])
(UnsqueezeF ('SelectDim ('ByIndex 1)) shape))
booleanMask <- Tensor
('Gradient 'WithoutGradient)
layout
device
('DataType 'Bool)
('Shape '[ 'Dim ('Name "*") ('Size 1), seqDim, seqDim])
causalMask' forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(gradient' :: Gradient RequiresGradient)
(layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
(dataType' :: DataType DType)
(shape' :: Shape [Dim (Name Symbol) (Size Nat)])
(shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, shape'' ~ BroadcastShapesF shape shape',
Catch shape'') =>
Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> m (Tensor
('Gradient 'WithoutGradient)
(layout <+> layout')
(device <+> device')
('DataType 'Bool)
shape'')
`logicalOr` Tensor
gradient
layout
device
dataType
(UnsqueezeF ('SelectDim ('ByIndex 1)) shape)
paddingMask'
forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) value
(gradient' :: Gradient RequiresGradient)
(layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
(dataType' :: DataType DType)
(shape' :: Shape [Dim (Name Symbol) (Size Nat)])
(shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(Scalar value, MonadThrow m,
Catch (gradient <+> 'Gradient 'WithoutGradient),
Catch (dataType <+> 'DataType 'Bool),
shape'' ~ BroadcastShapesF shape shape', Catch shape'') =>
Tensor gradient layout device dataType shape
-> value
-> Tensor gradient' layout' device' dataType' shape'
-> m (Tensor
gradient'
(layout <+> (layout' <+> 'Layout 'Dense))
(device <+> device')
dataType'
shape'')
maskedFill
Tensor
('Gradient 'WithoutGradient)
layout
device
('DataType 'Bool)
(BroadcastShapesF
('Shape '[ 'Dim ('Name "*") ('Size 1), seqDim, seqDim])
(UnsqueezeF ('SelectDim ('ByIndex 1)) shape))
booleanMask
Double
attentionMaskBias
Tensor
('Gradient 'WithoutGradient)
layout
device
transformerDataType
('Shape '[ 'Dim ('Name "*") ('Size 1), seqDim, seqDim])
emptyMask
data MkTransformerDecoderAttentionMask (dataType :: DataType DType) where
MkTransformerDecoderAttentionMask ::
forall dataType.
{ forall (dataType :: DataType DType).
MkTransformerDecoderAttentionMask dataType -> SDataType dataType
decoderAttentionMaskDataType :: SDataType dataType,
forall (dataType :: DataType DType).
MkTransformerDecoderAttentionMask dataType -> Double
decoderAttentionMaskBias :: Double
} ->
MkTransformerDecoderAttentionMask dataType
deriving stock (Int -> MkTransformerDecoderAttentionMask dataType -> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (dataType :: DataType DType).
Int -> MkTransformerDecoderAttentionMask dataType -> ShowS
forall (dataType :: DataType DType).
[MkTransformerDecoderAttentionMask dataType] -> ShowS
forall (dataType :: DataType DType).
MkTransformerDecoderAttentionMask dataType -> String
showList :: [MkTransformerDecoderAttentionMask dataType] -> ShowS
$cshowList :: forall (dataType :: DataType DType).
[MkTransformerDecoderAttentionMask dataType] -> ShowS
show :: MkTransformerDecoderAttentionMask dataType -> String
$cshow :: forall (dataType :: DataType DType).
MkTransformerDecoderAttentionMask dataType -> String
showsPrec :: Int -> MkTransformerDecoderAttentionMask dataType -> ShowS
$cshowsPrec :: forall (dataType :: DataType DType).
Int -> MkTransformerDecoderAttentionMask dataType -> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (dataType :: DataType DType) x.
Rep (MkTransformerDecoderAttentionMask dataType) x
-> MkTransformerDecoderAttentionMask dataType
forall (dataType :: DataType DType) x.
MkTransformerDecoderAttentionMask dataType
-> Rep (MkTransformerDecoderAttentionMask dataType) x
$cto :: forall (dataType :: DataType DType) x.
Rep (MkTransformerDecoderAttentionMask dataType) x
-> MkTransformerDecoderAttentionMask dataType
$cfrom :: forall (dataType :: DataType DType) x.
MkTransformerDecoderAttentionMask dataType
-> Rep (MkTransformerDecoderAttentionMask dataType) x
Generic)
type instance
ModelSpec (MkTransformerDecoderAttentionMask dataType) =
MkTransformerDecoderAttentionMask dataType
instance
HasInitialize
(MkTransformerDecoderAttentionMask dataType)
generatorDevice
(MkTransformerDecoderAttentionMask dataType)
generatorDevice
where
initialize :: forall (m :: * -> *).
MonadThrow m =>
ModelSpec (MkTransformerDecoderAttentionMask dataType)
-> Generator generatorDevice
-> m (MkTransformerDecoderAttentionMask dataType,
Generator generatorDevice)
initialize ModelSpec (MkTransformerDecoderAttentionMask dataType)
spec Generator generatorDevice
g = forall (f :: * -> *) a. Applicative f => a -> f a
pure (ModelSpec (MkTransformerDecoderAttentionMask dataType)
spec, Generator generatorDevice
g)
instance HasStateDict (MkTransformerDecoderAttentionMask dataType) where
fromStateDict :: forall (m :: * -> *).
(MonadIO m, MonadThrow m, MonadState StateDict m) =>
ModelSpec (MkTransformerDecoderAttentionMask dataType)
-> StateDictKey -> m (MkTransformerDecoderAttentionMask dataType)
fromStateDict ModelSpec (MkTransformerDecoderAttentionMask dataType)
spec StateDictKey
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ModelSpec (MkTransformerDecoderAttentionMask dataType)
spec
toStateDict :: forall (m :: * -> *).
(MonadThrow m, MonadState StateDict m) =>
StateDictKey -> MkTransformerDecoderAttentionMask dataType -> m ()
toStateDict StateDictKey
_ MkTransformerDecoderAttentionMask dataType
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
instance
MkTransformerDecoderAttentionMaskC dataType decoderInputLayout decoderInputDevice decoderInputShape seqDim output =>
HasForward
(MkTransformerDecoderAttentionMask dataType)
(Tensor decoderInputGradient decoderInputLayout decoderInputDevice decoderInputDataType decoderInputShape)
generatorDevice
output
generatorDevice
where
forward :: forall (m :: * -> *).
MonadThrow m =>
MkTransformerDecoderAttentionMask dataType
-> Tensor
decoderInputGradient
decoderInputLayout
decoderInputDevice
decoderInputDataType
decoderInputShape
-> Generator generatorDevice
-> m (output, Generator generatorDevice)
forward MkTransformerDecoderAttentionMask {Double
SDataType dataType
decoderAttentionMaskBias :: Double
decoderAttentionMaskDataType :: SDataType dataType
decoderAttentionMaskBias :: forall (dataType :: DataType DType).
MkTransformerDecoderAttentionMask dataType -> Double
decoderAttentionMaskDataType :: forall (dataType :: DataType DType).
MkTransformerDecoderAttentionMask dataType -> SDataType dataType
..} Tensor
decoderInputGradient
decoderInputLayout
decoderInputDevice
decoderInputDataType
decoderInputShape
input Generator generatorDevice
g = do
output
decoderAttentionMask <- forall (m :: * -> *) (transformerDataType :: DataType DType)
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(seqDim :: Dim (Name Symbol) (Size Nat)) output.
(MonadThrow m,
MkTransformerDecoderAttentionMaskC
transformerDataType layout device shape seqDim output) =>
SDataType transformerDataType
-> Double
-> Tensor gradient layout device dataType shape
-> m output
mkTransformerDecoderAttentionMask SDataType dataType
decoderAttentionMaskDataType Double
decoderAttentionMaskBias Tensor
decoderInputGradient
decoderInputLayout
decoderInputDevice
decoderInputDataType
decoderInputShape
input
forall (f :: * -> *) a. Applicative f => a -> f a
pure (output
decoderAttentionMask, Generator generatorDevice
g)
type MkTransformerCrossAttentionMaskC transformerDataType decoderInputShape decoderInputSeqDim gradient layout device dataType shape seqDim output =
( SGetLayout layout,
SGetDevice device,
SGetShape shape,
seqDim ~ (shape ! 1),
SGetShape decoderInputShape,
decoderInputSeqDim ~ (decoderInputShape ! 1),
Catch (gradient <+> 'Gradient 'WithoutGradient),
Catch (dataType <+> 'DataType 'Bool),
Catch (UnsqueezeF ('SelectDim ('ByIndex 1)) shape),
Catch
( BroadcastShapesF
(UnsqueezeF ('SelectDim ('ByIndex 1)) shape)
( 'Shape
'[ 'Dim ('Name "*") ('Size 1), decoderInputSeqDim, seqDim]
)
),
output
~ Tensor
('Gradient 'WithoutGradient)
(layout <+> 'Layout 'Dense)
device
transformerDataType
( BroadcastShapesF
(UnsqueezeF ('SelectDim ('ByIndex 1)) shape)
('Shape '[ 'Dim ('Name "*") ('Size 1), decoderInputSeqDim, seqDim])
)
)
mkTransformerCrossAttentionMask ::
forall m transformerDataType decoderInputShape decoderInputSeqDim gradient layout device dataType shape seqDim output.
( MonadThrow m,
MkTransformerCrossAttentionMaskC transformerDataType decoderInputShape decoderInputSeqDim gradient layout device dataType shape seqDim output
) =>
SDataType transformerDataType ->
SShape decoderInputShape ->
Double ->
Tensor gradient layout device dataType shape ->
m output
mkTransformerCrossAttentionMask :: forall (m :: * -> *) (transformerDataType :: DataType DType)
(decoderInputShape :: Shape [Dim (Name Symbol) (Size Nat)])
(decoderInputSeqDim :: Dim (Name Symbol) (Size Nat))
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(seqDim :: Dim (Name Symbol) (Size Nat)) output.
(MonadThrow m,
MkTransformerCrossAttentionMaskC
transformerDataType
decoderInputShape
decoderInputSeqDim
gradient
layout
device
dataType
shape
seqDim
output) =>
SDataType transformerDataType
-> SShape decoderInputShape
-> Double
-> Tensor gradient layout device dataType shape
-> m output
mkTransformerCrossAttentionMask SDataType transformerDataType
transformerDataType SShape decoderInputShape
decoderInputShape Double
attentionMaskBias Tensor gradient layout device dataType shape
paddingMask = do
SDim decoderInputSeqDim
decoderInputSeqDim <- forall (selectDim :: SelectDim (By Symbol Nat))
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(dim :: Dim (Name Symbol) (Size Nat)) (m :: * -> *).
(dim ~ GetDimF selectDim shape, MonadThrow m) =>
SSelectDim selectDim -> SShape shape -> m (SDim dim)
sGetDimFromShape (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim forall a b. (a -> b) -> a -> b
$ forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @1) SShape decoderInputShape
decoderInputShape
let pmLayout :: SLayout layout
pmLayout = forall (layout :: Layout LayoutType)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat)) (dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGetLayout layout =>
Tensor gradient layout device dataType shape -> SLayout layout
sGetLayout Tensor gradient layout device dataType shape
paddingMask
pmDevice :: SDevice device
pmDevice = forall (device :: Device (DeviceType Nat))
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGetDevice device =>
Tensor gradient layout device dataType shape -> SDevice device
sGetDevice Tensor gradient layout device dataType shape
paddingMask
pmShape :: SShape shape
pmShape = forall (shape :: Shape [Dim (Name Symbol) (Size Nat)])
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType).
SGetShape shape =>
Tensor gradient layout device dataType shape -> SShape shape
sGetShape Tensor gradient layout device dataType shape
paddingMask
SDim seqDim
pmSeqDim <- forall (selectDim :: SelectDim (By Symbol Nat))
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(dim :: Dim (Name Symbol) (Size Nat)) (m :: * -> *).
(dim ~ GetDimF selectDim shape, MonadThrow m) =>
SSelectDim selectDim -> SShape shape -> m (SDim dim)
sGetDimFromShape (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim forall a b. (a -> b) -> a -> b
$ forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @1) SShape shape
pmShape
Tensor
('Gradient 'WithoutGradient)
layout
device
transformerDataType
('Shape '[ 'Dim ('Name "*") ('Size 1), decoderInputSeqDim, seqDim])
emptyMask <-
forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
MonadThrow m =>
TensorSpec gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape)
sZeros forall a b. (a -> b) -> a -> b
$
forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> SShape shape
-> TensorSpec gradient layout device dataType shape
TensorSpec
(forall (requiresGradient :: RequiresGradient).
SRequiresGradient requiresGradient
-> SGradient ('Gradient requiresGradient)
SGradient SRequiresGradient 'WithoutGradient
SWithoutGradient)
SLayout layout
pmLayout
SDevice device
pmDevice
SDataType transformerDataType
transformerDataType
(forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SList dims -> SShape ('Shape dims)
SShape forall a b. (a -> b) -> a -> b
$ forall (name1 :: Symbol). KnownSymbol name1 => SName ('Name name1)
SName @"*" forall (name :: Name Symbol) (size :: Size Nat).
SName name -> SSize size -> SDim ('Dim name size)
:&: forall (size1 :: Nat). KnownNat size1 => SSize ('Size size1)
SSize @1 forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: SDim decoderInputSeqDim
decoderInputSeqDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: SDim seqDim
pmSeqDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall a. SList '[]
SNil)
Tensor
gradient
layout
device
dataType
(UnsqueezeF ('SelectDim ('ByIndex 1)) shape)
paddingMask' <- forall (selectDim :: SelectDim (By Symbol Nat))
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(shape' ~ UnsqueezeF selectDim shape, Catch shape',
SingI selectDim, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
unsqueeze @('SelectDim ('ByIndex 1)) Tensor gradient layout device dataType shape
paddingMask
forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) value
(gradient' :: Gradient RequiresGradient)
(layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
(dataType' :: DataType DType)
(shape' :: Shape [Dim (Name Symbol) (Size Nat)])
(shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(Scalar value, MonadThrow m,
Catch (gradient <+> 'Gradient 'WithoutGradient),
Catch (dataType <+> 'DataType 'Bool),
shape'' ~ BroadcastShapesF shape shape', Catch shape'') =>
Tensor gradient layout device dataType shape
-> value
-> Tensor gradient' layout' device' dataType' shape'
-> m (Tensor
gradient'
(layout <+> (layout' <+> 'Layout 'Dense))
(device <+> device')
dataType'
shape'')
maskedFill Tensor
gradient
layout
device
dataType
(UnsqueezeF ('SelectDim ('ByIndex 1)) shape)
paddingMask' Double
attentionMaskBias Tensor
('Gradient 'WithoutGradient)
layout
device
transformerDataType
('Shape '[ 'Dim ('Name "*") ('Size 1), decoderInputSeqDim, seqDim])
emptyMask
data MkTransformerCrossAttentionMask (dataType :: DataType DType) where
MkTransformerCrossAttentionMask ::
forall dataType.
{ forall (dataType :: DataType DType).
MkTransformerCrossAttentionMask dataType -> SDataType dataType
crossAttentionMaskDataType :: SDataType dataType,
forall (dataType :: DataType DType).
MkTransformerCrossAttentionMask dataType -> Double
crossAttentionMaskBias :: Double
} ->
MkTransformerCrossAttentionMask dataType
deriving stock (Int -> MkTransformerCrossAttentionMask dataType -> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (dataType :: DataType DType).
Int -> MkTransformerCrossAttentionMask dataType -> ShowS
forall (dataType :: DataType DType).
[MkTransformerCrossAttentionMask dataType] -> ShowS
forall (dataType :: DataType DType).
MkTransformerCrossAttentionMask dataType -> String
showList :: [MkTransformerCrossAttentionMask dataType] -> ShowS
$cshowList :: forall (dataType :: DataType DType).
[MkTransformerCrossAttentionMask dataType] -> ShowS
show :: MkTransformerCrossAttentionMask dataType -> String
$cshow :: forall (dataType :: DataType DType).
MkTransformerCrossAttentionMask dataType -> String
showsPrec :: Int -> MkTransformerCrossAttentionMask dataType -> ShowS
$cshowsPrec :: forall (dataType :: DataType DType).
Int -> MkTransformerCrossAttentionMask dataType -> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (dataType :: DataType DType) x.
Rep (MkTransformerCrossAttentionMask dataType) x
-> MkTransformerCrossAttentionMask dataType
forall (dataType :: DataType DType) x.
MkTransformerCrossAttentionMask dataType
-> Rep (MkTransformerCrossAttentionMask dataType) x
$cto :: forall (dataType :: DataType DType) x.
Rep (MkTransformerCrossAttentionMask dataType) x
-> MkTransformerCrossAttentionMask dataType
$cfrom :: forall (dataType :: DataType DType) x.
MkTransformerCrossAttentionMask dataType
-> Rep (MkTransformerCrossAttentionMask dataType) x
Generic)
type instance
ModelSpec (MkTransformerCrossAttentionMask dataType) =
MkTransformerCrossAttentionMask dataType
instance
HasInitialize
(MkTransformerCrossAttentionMask dataType)
generatorDevice
(MkTransformerCrossAttentionMask dataType)
generatorDevice
where
initialize :: forall (m :: * -> *).
MonadThrow m =>
ModelSpec (MkTransformerCrossAttentionMask dataType)
-> Generator generatorDevice
-> m (MkTransformerCrossAttentionMask dataType,
Generator generatorDevice)
initialize ModelSpec (MkTransformerCrossAttentionMask dataType)
spec Generator generatorDevice
g = forall (f :: * -> *) a. Applicative f => a -> f a
pure (ModelSpec (MkTransformerCrossAttentionMask dataType)
spec, Generator generatorDevice
g)
instance HasStateDict (MkTransformerCrossAttentionMask dataType) where
fromStateDict :: forall (m :: * -> *).
(MonadIO m, MonadThrow m, MonadState StateDict m) =>
ModelSpec (MkTransformerCrossAttentionMask dataType)
-> StateDictKey -> m (MkTransformerCrossAttentionMask dataType)
fromStateDict ModelSpec (MkTransformerCrossAttentionMask dataType)
spec StateDictKey
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ModelSpec (MkTransformerCrossAttentionMask dataType)
spec
toStateDict :: forall (m :: * -> *).
(MonadThrow m, MonadState StateDict m) =>
StateDictKey -> MkTransformerCrossAttentionMask dataType -> m ()
toStateDict StateDictKey
_ MkTransformerCrossAttentionMask dataType
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
instance
MkTransformerCrossAttentionMaskC dataType decoderInputShape decoderInputSeqDim inputPaddingMaskGradient inputPaddingMaskLayout inputPaddingMaskDevice inputPaddingMaksDataType inputPaddingMaskShape seqDim output =>
HasForward
(MkTransformerCrossAttentionMask dataType)
( Tensor decoderInputGradient decoderInputLayout decoderInputDevice decoderInputDataType decoderInputShape,
Tensor inputPaddingMaskGradient inputPaddingMaskLayout inputPaddingMaskDevice inputPaddingMaksDataType inputPaddingMaskShape
)
generatorDevice
output
generatorDevice
where
forward :: forall (m :: * -> *).
MonadThrow m =>
MkTransformerCrossAttentionMask dataType
-> (Tensor
decoderInputGradient
decoderInputLayout
decoderInputDevice
decoderInputDataType
decoderInputShape,
Tensor
inputPaddingMaskGradient
inputPaddingMaskLayout
inputPaddingMaskDevice
inputPaddingMaksDataType
inputPaddingMaskShape)
-> Generator generatorDevice
-> m (output, Generator generatorDevice)
forward MkTransformerCrossAttentionMask {Double
SDataType dataType
crossAttentionMaskBias :: Double
crossAttentionMaskDataType :: SDataType dataType
crossAttentionMaskBias :: forall (dataType :: DataType DType).
MkTransformerCrossAttentionMask dataType -> Double
crossAttentionMaskDataType :: forall (dataType :: DataType DType).
MkTransformerCrossAttentionMask dataType -> SDataType dataType
..} (Tensor
decoderInputGradient
decoderInputLayout
decoderInputDevice
decoderInputDataType
decoderInputShape
decoderInput, Tensor
inputPaddingMaskGradient
inputPaddingMaskLayout
inputPaddingMaskDevice
inputPaddingMaksDataType
inputPaddingMaskShape
inputPaddingMask) Generator generatorDevice
g = do
let decoderInputShape :: SShape decoderInputShape
decoderInputShape = forall (shape :: Shape [Dim (Name Symbol) (Size Nat)])
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType).
SGetShape shape =>
Tensor gradient layout device dataType shape -> SShape shape
sGetShape Tensor
decoderInputGradient
decoderInputLayout
decoderInputDevice
decoderInputDataType
decoderInputShape
decoderInput
output
crossAttentionMask <- forall (m :: * -> *) (transformerDataType :: DataType DType)
(decoderInputShape :: Shape [Dim (Name Symbol) (Size Nat)])
(decoderInputSeqDim :: Dim (Name Symbol) (Size Nat))
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(seqDim :: Dim (Name Symbol) (Size Nat)) output.
(MonadThrow m,
MkTransformerCrossAttentionMaskC
transformerDataType
decoderInputShape
decoderInputSeqDim
gradient
layout
device
dataType
shape
seqDim
output) =>
SDataType transformerDataType
-> SShape decoderInputShape
-> Double
-> Tensor gradient layout device dataType shape
-> m output
mkTransformerCrossAttentionMask SDataType dataType
crossAttentionMaskDataType SShape decoderInputShape
decoderInputShape Double
crossAttentionMaskBias Tensor
inputPaddingMaskGradient
inputPaddingMaskLayout
inputPaddingMaskDevice
inputPaddingMaksDataType
inputPaddingMaskShape
inputPaddingMask
forall (f :: * -> *) a. Applicative f => a -> f a
pure (output
crossAttentionMask, Generator generatorDevice
g)
data ShiftRight fillValue where
ShiftRight ::
forall fillValue.
fillValue ->
ShiftRight fillValue
deriving stock (ShiftRight fillValue -> ShiftRight fillValue -> Bool
forall fillValue.
Eq fillValue =>
ShiftRight fillValue -> ShiftRight fillValue -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ShiftRight fillValue -> ShiftRight fillValue -> Bool
$c/= :: forall fillValue.
Eq fillValue =>
ShiftRight fillValue -> ShiftRight fillValue -> Bool
== :: ShiftRight fillValue -> ShiftRight fillValue -> Bool
$c== :: forall fillValue.
Eq fillValue =>
ShiftRight fillValue -> ShiftRight fillValue -> Bool
Eq, ShiftRight fillValue -> ShiftRight fillValue -> Bool
ShiftRight fillValue -> ShiftRight fillValue -> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {fillValue}. Ord fillValue => Eq (ShiftRight fillValue)
forall fillValue.
Ord fillValue =>
ShiftRight fillValue -> ShiftRight fillValue -> Bool
forall fillValue.
Ord fillValue =>
ShiftRight fillValue -> ShiftRight fillValue -> Ordering
forall fillValue.
Ord fillValue =>
ShiftRight fillValue
-> ShiftRight fillValue -> ShiftRight fillValue
min :: ShiftRight fillValue
-> ShiftRight fillValue -> ShiftRight fillValue
$cmin :: forall fillValue.
Ord fillValue =>
ShiftRight fillValue
-> ShiftRight fillValue -> ShiftRight fillValue
max :: ShiftRight fillValue
-> ShiftRight fillValue -> ShiftRight fillValue
$cmax :: forall fillValue.
Ord fillValue =>
ShiftRight fillValue
-> ShiftRight fillValue -> ShiftRight fillValue
>= :: ShiftRight fillValue -> ShiftRight fillValue -> Bool
$c>= :: forall fillValue.
Ord fillValue =>
ShiftRight fillValue -> ShiftRight fillValue -> Bool
> :: ShiftRight fillValue -> ShiftRight fillValue -> Bool
$c> :: forall fillValue.
Ord fillValue =>
ShiftRight fillValue -> ShiftRight fillValue -> Bool
<= :: ShiftRight fillValue -> ShiftRight fillValue -> Bool
$c<= :: forall fillValue.
Ord fillValue =>
ShiftRight fillValue -> ShiftRight fillValue -> Bool
< :: ShiftRight fillValue -> ShiftRight fillValue -> Bool
$c< :: forall fillValue.
Ord fillValue =>
ShiftRight fillValue -> ShiftRight fillValue -> Bool
compare :: ShiftRight fillValue -> ShiftRight fillValue -> Ordering
$ccompare :: forall fillValue.
Ord fillValue =>
ShiftRight fillValue -> ShiftRight fillValue -> Ordering
Ord, Int -> ShiftRight fillValue -> ShowS
forall fillValue.
Show fillValue =>
Int -> ShiftRight fillValue -> ShowS
forall fillValue. Show fillValue => [ShiftRight fillValue] -> ShowS
forall fillValue. Show fillValue => ShiftRight fillValue -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ShiftRight fillValue] -> ShowS
$cshowList :: forall fillValue. Show fillValue => [ShiftRight fillValue] -> ShowS
show :: ShiftRight fillValue -> String
$cshow :: forall fillValue. Show fillValue => ShiftRight fillValue -> String
showsPrec :: Int -> ShiftRight fillValue -> ShowS
$cshowsPrec :: forall fillValue.
Show fillValue =>
Int -> ShiftRight fillValue -> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall fillValue x.
Rep (ShiftRight fillValue) x -> ShiftRight fillValue
forall fillValue x.
ShiftRight fillValue -> Rep (ShiftRight fillValue) x
$cto :: forall fillValue x.
Rep (ShiftRight fillValue) x -> ShiftRight fillValue
$cfrom :: forall fillValue x.
ShiftRight fillValue -> Rep (ShiftRight fillValue) x
Generic)
type instance ModelSpec (ShiftRight fillValue) = ShiftRight fillValue
instance
HasInitialize
(ShiftRight fillValue)
generatorDevice
(ShiftRight fillValue)
generatorDevice
where
initialize :: forall (m :: * -> *).
MonadThrow m =>
ModelSpec (ShiftRight fillValue)
-> Generator generatorDevice
-> m (ShiftRight fillValue, Generator generatorDevice)
initialize ModelSpec (ShiftRight fillValue)
spec = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ModelSpec (ShiftRight fillValue)
spec,)
instance HasStateDict (ShiftRight fillValue) where
fromStateDict :: forall (m :: * -> *).
(MonadIO m, MonadThrow m, MonadState StateDict m) =>
ModelSpec (ShiftRight fillValue)
-> StateDictKey -> m (ShiftRight fillValue)
fromStateDict ModelSpec (ShiftRight fillValue)
spec StateDictKey
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ModelSpec (ShiftRight fillValue)
spec
toStateDict :: forall (m :: * -> *).
(MonadThrow m, MonadState StateDict m) =>
StateDictKey -> ShiftRight fillValue -> m ()
toStateDict StateDictKey
_ ShiftRight fillValue
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
instance
( input
~ Tensor
inputGradient
inputLayout
inputDevice
inputDataType
inputShape,
SGetLayout inputLayout,
SGetDevice inputDevice,
SGetDataType inputDataType,
SGetShape inputShape,
inputBatchDim ~ (inputShape ! 0),
inputSeqDim ~ (inputShape ! 1),
Scalar fillValue,
rightShiftedInput
~ Tensor
(inputGradient <|> 'Gradient 'WithoutGradient)
inputLayout
inputDevice
inputDataType
( ReplaceDimF
('SelectDim ('ByIndex 1))
(inputShape <+> 'Shape '[inputBatchDim, inputSeqDim])
(AddDimF inputSeqDim ('Dim ('Name "*") ('Size 1)))
)
) =>
HasForward (ShiftRight fillValue) input generator rightShiftedInput generator
where
forward :: forall (m :: * -> *).
MonadThrow m =>
ShiftRight fillValue
-> input
-> Generator generator
-> m (rightShiftedInput, Generator generator)
forward (ShiftRight fillValue
fillValue) input
input Generator generator
g = do
let inputLayout :: SLayout inputLayout
inputLayout = forall (layout :: Layout LayoutType)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat)) (dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGetLayout layout =>
Tensor gradient layout device dataType shape -> SLayout layout
sGetLayout input
input
inputDevice :: SDevice inputDevice
inputDevice = forall (device :: Device (DeviceType Nat))
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGetDevice device =>
Tensor gradient layout device dataType shape -> SDevice device
sGetDevice input
input
inputDataType :: SDataType inputDataType
inputDataType = forall (dataType :: DataType DType)
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGetDataType dataType =>
Tensor gradient layout device dataType shape -> SDataType dataType
sGetDataType input
input
inputShape :: SShape inputShape
inputShape = forall (shape :: Shape [Dim (Name Symbol) (Size Nat)])
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType).
SGetShape shape =>
Tensor gradient layout device dataType shape -> SShape shape
sGetShape input
input
SDim inputBatchDim
inputBatchDim <- forall (selectDim :: SelectDim (By Symbol Nat))
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(dim :: Dim (Name Symbol) (Size Nat)) (m :: * -> *).
(dim ~ GetDimF selectDim shape, MonadThrow m) =>
SSelectDim selectDim -> SShape shape -> m (SDim dim)
sGetDimFromShape (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim forall a b. (a -> b) -> a -> b
$ forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @0) SShape inputShape
inputShape
Tensor
('Gradient 'WithoutGradient)
inputLayout
inputDevice
inputDataType
('Shape '[inputBatchDim, 'Dim ('Name "*") ('Size 1)])
filler <-
forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]) input
(m :: * -> *).
(MonadThrow m, Scalar input) =>
TensorSpec gradient layout device dataType shape
-> input -> m (Tensor gradient layout device dataType shape)
sFull
( forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGradient gradient
-> SLayout layout
-> SDevice device
-> SDataType dataType
-> SShape shape
-> TensorSpec gradient layout device dataType shape
TensorSpec
(forall (requiresGradient :: RequiresGradient).
SRequiresGradient requiresGradient
-> SGradient ('Gradient requiresGradient)
SGradient SRequiresGradient 'WithoutGradient
SWithoutGradient)
SLayout inputLayout
inputLayout
SDevice inputDevice
inputDevice
SDataType inputDataType
inputDataType
(forall (dims :: [Dim (Name Symbol) (Size Nat)]).
SList dims -> SShape ('Shape dims)
SShape forall a b. (a -> b) -> a -> b
$ SDim inputBatchDim
inputBatchDim forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall (name1 :: Symbol). KnownSymbol name1 => SName ('Name name1)
SName @"*" forall (name :: Name Symbol) (size :: Size Nat).
SName name -> SSize size -> SDim ('Dim name size)
:&: forall (size1 :: Nat). KnownNat size1 => SSize ('Size size1)
SSize @1 forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall a. SList '[]
SNil)
)
fillValue
fillValue
rightShiftedInput
shifted <- forall (selectDim :: SelectDim (By Symbol Nat)) k (c :: k -> *)
(a :: k) (m :: * -> *).
(HasCat selectDim k c a, SingI selectDim, MonadThrow m) =>
c a -> m (CatF selectDim a c)
cat @('SelectDim ('ByIndex 1)) (Tensor
('Gradient 'WithoutGradient)
inputLayout
inputDevice
inputDataType
('Shape '[inputBatchDim, 'Dim ('Name "*") ('Size 1)])
filler forall x (xs :: [*]). x -> HList xs -> HList (x : xs)
:. input
input forall x (xs :: [*]). x -> HList xs -> HList (x : xs)
:. forall k. HList '[]
HNil)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (rightShiftedInput
shifted, Generator generator
g)