{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE NoStarIsType #-}
{-# OPTIONS_GHC -fconstraint-solver-iterations=0 #-}
module Torch.Typed.NN.Transformer where
import Control.Monad
import Data.Proxy
import GHC.Generics
import GHC.TypeLits
import System.IO.Unsafe (unsafePerformIO)
import qualified Torch.DType as D
import qualified Torch.Device as D
import Torch.HList
import Torch.NN (HasForward (..))
import qualified Torch.NN as A
import Torch.Typed.Auxiliary
import Torch.Typed.Factories
import Torch.Typed.Functional hiding (linear, log)
import Torch.Typed.NN.Dropout
import Torch.Typed.NN.Linear
import Torch.Typed.NN.Normalization
import Torch.Typed.NN.Sparse
import Torch.Typed.Parameter
import Torch.Typed.Tensor
import Prelude hiding (cos, exp, sin)
residual :: (Tensor device dtype shape -> m (Tensor device dtype' shape'))
-> (Tensor
device
(DTypePromotionImpl dtype dtype' (CmpDType dtype dtype'))
(CheckBroadcast
shape
shape'
(ComputeBroadcast
(ReverseImpl shape '[]) (ReverseImpl shape' '[])))
-> m b)
-> Tensor device dtype shape
-> m b
residual Tensor device dtype shape -> m (Tensor device dtype' shape')
f Tensor
device
(DTypePromotionImpl dtype dtype' (CmpDType dtype dtype'))
(CheckBroadcast
shape
shape'
(ComputeBroadcast
(ReverseImpl shape '[]) (ReverseImpl shape' '[])))
-> m b
g Tensor device dtype shape
x = Tensor device dtype shape -> m (Tensor device dtype' shape')
f Tensor device dtype shape
x forall (m :: Type -> Type) a b. Monad m => m a -> (a -> m b) -> m b
>>= (\Tensor device dtype' shape'
x' -> Tensor
device
(DTypePromotionImpl dtype dtype' (CmpDType dtype dtype'))
(CheckBroadcast
shape
shape'
(ComputeBroadcast
(ReverseImpl shape '[]) (ReverseImpl shape' '[])))
-> m b
g (Tensor device dtype shape
x forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
(device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
shape'' ~ Broadcast shape shape',
BasicArithmeticDTypeIsValid device dtype,
BasicArithmeticDTypeIsValid device dtype',
BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
`add` Tensor device dtype' shape'
x'))
data
MultiheadAttentionSpec
(embedDim :: Nat)
(kEmbedDim :: Nat)
(vEmbedDim :: Nat)
(numHeads :: Nat)
(dtype :: D.DType)
(device :: (D.DeviceType, Nat))
where
MultiheadAttentionSpec ::
DropoutSpec ->
MultiheadAttentionSpec embedDim kEmbedDim vEmbedDim numHeads dtype device
deriving (Int
-> MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> ShowS
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Int
-> MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> ShowS
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
[MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device]
-> ShowS
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device]
-> ShowS
$cshowList :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
[MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device]
-> ShowS
show :: MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> String
$cshow :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> String
showsPrec :: Int
-> MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> ShowS
$cshowsPrec :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Int
-> MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> ShowS
Show, MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Bool
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Bool
$c/= :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Bool
== :: MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Bool
$c== :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Bool
Eq)
data
MultiheadAttention
(embedDim :: Nat)
(kEmbedDim :: Nat)
(vEmbedDim :: Nat)
(numHeads :: Nat)
(dtype :: D.DType)
(device :: (D.DeviceType, Nat))
where
MultiheadAttention ::
{
forall (embedDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat).
MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Linear embedDim embedDim dtype device
mhaQInProj :: Linear embedDim embedDim dtype device,
forall (embedDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat).
MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Linear kEmbedDim embedDim dtype device
mhaKInProj :: Linear kEmbedDim embedDim dtype device,
forall (embedDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat).
MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Linear vEmbedDim embedDim dtype device
mhaVInProj :: Linear vEmbedDim embedDim dtype device,
forall (embedDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat).
MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Linear embedDim embedDim dtype device
mhaOutProj :: Linear embedDim embedDim dtype device,
forall (embedDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat).
MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Dropout
mhaDropout :: Dropout
} ->
MultiheadAttention embedDim kEmbedDim vEmbedDim numHeads dtype device
deriving (Int
-> MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> ShowS
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Int
-> MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> ShowS
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
[MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device]
-> ShowS
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device]
-> ShowS
$cshowList :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
[MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device]
-> ShowS
show :: MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> String
$cshow :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> String
showsPrec :: Int
-> MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> ShowS
$cshowsPrec :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Int
-> MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> ShowS
Show, forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)) x.
Rep
(MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device)
x
-> MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)) x.
MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Rep
(MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device)
x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)) x.
Rep
(MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device)
x
-> MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
$cfrom :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)) x.
MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Rep
(MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device)
x
Generic, forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> HList
(Parameters
(MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device))
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> HList
(Parameters
(MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device))
-> MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
forall f.
(f -> HList (Parameters f))
-> (f -> HList (Parameters f) -> f) -> Parameterized f
replaceParameters :: MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> HList
(Parameters
(MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device))
-> MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
$creplaceParameters :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> HList
(Parameters
(MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device))
-> MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
flattenParameters :: MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> HList
(Parameters
(MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device))
$cflattenParameters :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> HList
(Parameters
(MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device))
Parameterized)
multiheadAttention ::
forall embedDim kEmbedDim vEmbedDim numHeads seqLen seqLen' batchSize headDim dtype device.
( 1 <= numHeads,
embedDim ~ (headDim * numHeads),
All KnownNat '[embedDim, kEmbedDim, vEmbedDim, numHeads, seqLen, seqLen', batchSize, headDim],
KnownDType dtype,
StandardFloatingPointDTypeValidation device dtype,
MatMulDTypeIsValid device dtype,
BasicArithmeticDTypeIsValid device dtype,
dtype ~ SumDType dtype,
SumDTypeIsValid device dtype,
KnownDevice device
) =>
MultiheadAttention embedDim kEmbedDim vEmbedDim numHeads dtype device ->
Bool ->
Maybe (Tensor device dtype '[batchSize, seqLen', seqLen]) ->
Maybe (Tensor device 'D.Bool '[batchSize, seqLen]) ->
Maybe (Tensor device dtype '[batchSize, seqLen', seqLen, headDim]) ->
Maybe (Tensor device dtype '[batchSize, seqLen', seqLen, headDim]) ->
Tensor device dtype '[batchSize, seqLen', embedDim] ->
Tensor device dtype '[batchSize, seqLen, kEmbedDim] ->
Tensor device dtype '[batchSize, seqLen, vEmbedDim] ->
IO
( Tensor device dtype '[batchSize, seqLen', embedDim],
Tensor device dtype '[batchSize, seqLen', seqLen]
)
multiheadAttention :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (seqLen :: Nat) (seqLen' :: Nat)
(batchSize :: Nat) (headDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
(1 <= numHeads, embedDim ~ (headDim * numHeads),
All
KnownNat
'[embedDim, kEmbedDim, vEmbedDim, numHeads, seqLen, seqLen',
batchSize, headDim],
KnownDType dtype,
StandardFloatingPointDTypeValidation device dtype,
MatMulDTypeIsValid device dtype,
BasicArithmeticDTypeIsValid device dtype, dtype ~ SumDType dtype,
SumDTypeIsValid device dtype, KnownDevice device) =>
MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Bool
-> Maybe (Tensor device dtype '[batchSize, seqLen', seqLen])
-> Maybe (Tensor device 'Bool '[batchSize, seqLen])
-> Maybe
(Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
-> Maybe
(Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
-> Tensor device dtype '[batchSize, seqLen', embedDim]
-> Tensor device dtype '[batchSize, seqLen, kEmbedDim]
-> Tensor device dtype '[batchSize, seqLen, vEmbedDim]
-> IO
(Tensor device dtype '[batchSize, seqLen', embedDim],
Tensor device dtype '[batchSize, seqLen', seqLen])
multiheadAttention MultiheadAttention {Linear embedDim embedDim dtype device
Linear kEmbedDim embedDim dtype device
Linear vEmbedDim embedDim dtype device
Dropout
mhaDropout :: Dropout
mhaOutProj :: Linear embedDim embedDim dtype device
mhaVInProj :: Linear vEmbedDim embedDim dtype device
mhaKInProj :: Linear kEmbedDim embedDim dtype device
mhaQInProj :: Linear embedDim embedDim dtype device
mhaDropout :: forall (embedDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat).
MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Dropout
mhaOutProj :: forall (embedDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat).
MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Linear embedDim embedDim dtype device
mhaVInProj :: forall (embedDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat).
MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Linear vEmbedDim embedDim dtype device
mhaKInProj :: forall (embedDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat).
MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Linear kEmbedDim embedDim dtype device
mhaQInProj :: forall (embedDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat).
MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Linear embedDim embedDim dtype device
..} Bool
train Maybe (Tensor device dtype '[batchSize, seqLen', seqLen])
attentionMask Maybe (Tensor device 'Bool '[batchSize, seqLen])
keyPaddingMask Maybe (Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
keyRelations Maybe (Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
valueRelations Tensor device dtype '[batchSize, seqLen', embedDim]
query Tensor device dtype '[batchSize, seqLen, kEmbedDim]
key Tensor device dtype '[batchSize, seqLen, vEmbedDim]
value = do
Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
weights <-
forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Dropout
-> Bool
-> Tensor device dtype shape
-> IO (Tensor device dtype shape)
dropoutForward Dropout
mhaDropout Bool
train
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (dim :: Nat) (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(KnownNat dim, DimOutOfBoundCheck shape dim, KnownDType dtype,
StandardFloatingPointDTypeValidation device dtype) =>
Tensor device dtype shape -> Tensor device dtype shape
softmax @3
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
-> Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
_maskKeyPaddings
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
-> Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
_maskAttention
forall a b. (a -> b) -> a -> b
$ Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
_attentionWeights
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
-> Tensor device dtype '[batchSize, seqLen', embedDim]
_attention Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
weights, Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
-> Tensor device dtype '[batchSize, seqLen', seqLen]
averageOverHeads Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
weights)
where
_attentionWeights :: Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
_attentionWeights =
let scaling :: Double
scaling = forall a. Floating a => a -> a
Prelude.sqrt forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ forall (n :: Nat). KnownNat n => Int
natValI @headDim :: Double
q :: Tensor device dtype '[batchSize, numHeads, seqLen', headDim]
q = forall (seqLen'' :: Nat).
KnownNat seqLen'' =>
Tensor device dtype '[batchSize, seqLen'', embedDim]
-> Tensor device dtype '[batchSize, numHeads, seqLen'', headDim]
reshape' forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
divScalar Double
scaling forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall f a b. HasForward f a b => f -> a -> b
forward Linear embedDim embedDim dtype device
mhaQInProj forall a b. (a -> b) -> a -> b
$ Tensor device dtype '[batchSize, seqLen', embedDim]
query
k :: Tensor device dtype '[batchSize, numHeads, seqLen, headDim]
k = forall (seqLen'' :: Nat).
KnownNat seqLen'' =>
Tensor device dtype '[batchSize, seqLen'', embedDim]
-> Tensor device dtype '[batchSize, numHeads, seqLen'', headDim]
reshape' forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall f a b. HasForward f a b => f -> a -> b
forward Linear kEmbedDim embedDim dtype device
mhaKInProj forall a b. (a -> b) -> a -> b
$ Tensor device dtype '[batchSize, seqLen, kEmbedDim]
key
weights :: Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
weights = forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ MatMul shape shape', MatMulDTypeIsValid device dtype) =>
Tensor device dtype shape
-> Tensor device dtype shape' -> Tensor device dtype shape''
matmul Tensor device dtype '[batchSize, numHeads, seqLen', headDim]
q (forall (n :: Nat) (m :: Nat) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat n, KnownNat m, shape' ~ Transpose shape n m) =>
Tensor device dtype shape -> Tensor device dtype shape'
transpose @2 @3 Tensor device dtype '[batchSize, numHeads, seqLen, headDim]
k)
weights' :: Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
weights' = case Maybe (Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
keyRelations of
Maybe (Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
Nothing -> Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
weights
Just Tensor device dtype '[batchSize, seqLen', seqLen, headDim]
kr -> Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
weights forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
(device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
shape'' ~ Broadcast shape shape',
BasicArithmeticDTypeIsValid device dtype,
BasicArithmeticDTypeIsValid device dtype',
BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
`add` forall (n :: Nat) (m :: Nat) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat n, KnownNat m, shape' ~ Transpose shape n m) =>
Tensor device dtype shape -> Tensor device dtype shape'
transpose @1 @2 ((forall (n :: Nat) (m :: Nat) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat n, KnownNat m, shape' ~ Transpose shape n m) =>
Tensor device dtype shape -> Tensor device dtype shape'
transpose @1 @2 Tensor device dtype '[batchSize, numHeads, seqLen', headDim]
q) forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ MatMul shape shape', MatMulDTypeIsValid device dtype) =>
Tensor device dtype shape
-> Tensor device dtype shape' -> Tensor device dtype shape''
`matmul` (forall (n :: Nat) (m :: Nat) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat n, KnownNat m, shape' ~ Transpose shape n m) =>
Tensor device dtype shape -> Tensor device dtype shape'
transpose @2 @3 Tensor device dtype '[batchSize, seqLen', seqLen, headDim]
kr))
in Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
weights'
_maskAttention :: Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
-> Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
_maskAttention Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
attentionWeights =
case Maybe (Tensor device dtype '[batchSize, seqLen', seqLen])
attentionMask of
Maybe (Tensor device dtype '[batchSize, seqLen', seqLen])
Nothing -> Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
attentionWeights
Just Tensor device dtype '[batchSize, seqLen', seqLen]
am -> Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
attentionWeights forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
(device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
shape'' ~ Broadcast shape shape',
BasicArithmeticDTypeIsValid device dtype,
BasicArithmeticDTypeIsValid device dtype',
BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
`add` forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ Unsqueeze shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
unsqueeze @1 Tensor device dtype '[batchSize, seqLen', seqLen]
am
_maskKeyPaddings :: Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
-> Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
_maskKeyPaddings Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
attentionWeights =
case Maybe (Tensor device 'Bool '[batchSize, seqLen])
keyPaddingMask of
Maybe (Tensor device 'Bool '[batchSize, seqLen])
Nothing -> Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
attentionWeights
Just Tensor device 'Bool '[batchSize, seqLen]
kpm ->
let keyPaddingMask' :: Tensor
device
'Bool
(UnsqueezeCheck
'[batchSize, 1, seqLen]
2
(UnsqueezeImpl '[batchSize, 1, seqLen] 2))
keyPaddingMask' = forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ Unsqueeze shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
unsqueeze @2 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ Unsqueeze shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
unsqueeze @1 forall a b. (a -> b) -> a -> b
$ Tensor device 'Bool '[batchSize, seqLen]
kpm
in forall a (shape :: [Nat]) (shape' :: [Nat]) (shape'' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(Scalar a, shape'' ~ Broadcast shape shape') =>
Tensor device 'Bool shape'
-> a -> Tensor device dtype shape -> Tensor device dtype shape''
maskedFill Tensor
device
'Bool
(UnsqueezeCheck
'[batchSize, 1, seqLen]
2
(UnsqueezeImpl '[batchSize, 1, seqLen] 2))
keyPaddingMask' (-Double
1 forall a. Fractional a => a -> a -> a
/ Double
0 :: Double) Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
attentionWeights
_attention :: Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
-> Tensor device dtype '[batchSize, seqLen', embedDim]
_attention Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
attentionWeights =
let v :: Tensor device dtype '[batchSize, numHeads, seqLen, headDim]
v = forall (seqLen'' :: Nat).
KnownNat seqLen'' =>
Tensor device dtype '[batchSize, seqLen'', embedDim]
-> Tensor device dtype '[batchSize, numHeads, seqLen'', headDim]
reshape' forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall f a b. HasForward f a b => f -> a -> b
forward Linear vEmbedDim embedDim dtype device
mhaVInProj forall a b. (a -> b) -> a -> b
$ Tensor device dtype '[batchSize, seqLen, vEmbedDim]
value
attention :: Tensor
device
dtype
(SetValue
(SetValue
(CheckMatMul
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[batchSize, numHeads, seqLen, headDim]
(ComputeMatMul
(ReverseImpl
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[])
'[headDim, seqLen, numHeads, batchSize]))
1
(GetValue
(CheckMatMul
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[batchSize, numHeads, seqLen, headDim]
(ComputeMatMul
(ReverseImpl
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[])
'[headDim, seqLen, numHeads, batchSize]))
2))
2
(GetValue
(CheckMatMul
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[batchSize, numHeads, seqLen, headDim]
(ComputeMatMul
(ReverseImpl
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[])
'[headDim, seqLen, numHeads, batchSize]))
1))
attention = forall (n :: Nat) (m :: Nat) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat n, KnownNat m, shape' ~ Transpose shape n m) =>
Tensor device dtype shape -> Tensor device dtype shape'
transpose @1 @2 forall a b. (a -> b) -> a -> b
$ forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ MatMul shape shape', MatMulDTypeIsValid device dtype) =>
Tensor device dtype shape
-> Tensor device dtype shape' -> Tensor device dtype shape''
matmul Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
attentionWeights Tensor device dtype '[batchSize, numHeads, seqLen, headDim]
v
attention' :: Tensor
device
dtype
(SetValue
(SetValue
(CheckMatMul
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[batchSize, numHeads, seqLen, headDim]
(ComputeMatMul
(ReverseImpl
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[])
'[headDim, seqLen, numHeads, batchSize]))
1
(GetValue
(CheckMatMul
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[batchSize, numHeads, seqLen, headDim]
(ComputeMatMul
(ReverseImpl
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[])
'[headDim, seqLen, numHeads, batchSize]))
2))
2
(GetValue
(CheckMatMul
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[batchSize, numHeads, seqLen, headDim]
(ComputeMatMul
(ReverseImpl
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[])
'[headDim, seqLen, numHeads, batchSize]))
1))
attention' = case Maybe (Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
valueRelations of
Maybe (Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
Nothing -> Tensor
device
dtype
(SetValue
(SetValue
(CheckMatMul
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[batchSize, numHeads, seqLen, headDim]
(ComputeMatMul
(ReverseImpl
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[])
'[headDim, seqLen, numHeads, batchSize]))
1
(GetValue
(CheckMatMul
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[batchSize, numHeads, seqLen, headDim]
(ComputeMatMul
(ReverseImpl
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[])
'[headDim, seqLen, numHeads, batchSize]))
2))
2
(GetValue
(CheckMatMul
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[batchSize, numHeads, seqLen, headDim]
(ComputeMatMul
(ReverseImpl
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[])
'[headDim, seqLen, numHeads, batchSize]))
1))
attention
Just Tensor device dtype '[batchSize, seqLen', seqLen, headDim]
vr -> Tensor
device
dtype
(SetValue
(SetValue
(CheckMatMul
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[batchSize, numHeads, seqLen, headDim]
(ComputeMatMul
(ReverseImpl
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[])
'[headDim, seqLen, numHeads, batchSize]))
1
(GetValue
(CheckMatMul
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[batchSize, numHeads, seqLen, headDim]
(ComputeMatMul
(ReverseImpl
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[])
'[headDim, seqLen, numHeads, batchSize]))
2))
2
(GetValue
(CheckMatMul
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[batchSize, numHeads, seqLen, headDim]
(ComputeMatMul
(ReverseImpl
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[])
'[headDim, seqLen, numHeads, batchSize]))
1))
attention forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
(device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
shape'' ~ Broadcast shape shape',
BasicArithmeticDTypeIsValid device dtype,
BasicArithmeticDTypeIsValid device dtype',
BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
`add` (forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ MatMul shape shape', MatMulDTypeIsValid device dtype) =>
Tensor device dtype shape
-> Tensor device dtype shape' -> Tensor device dtype shape''
matmul (forall (n :: Nat) (m :: Nat) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat n, KnownNat m, shape' ~ Transpose shape n m) =>
Tensor device dtype shape -> Tensor device dtype shape'
transpose @1 @2 Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
attentionWeights) Tensor device dtype '[batchSize, seqLen', seqLen, headDim]
vr)
in forall f a b. HasForward f a b => f -> a -> b
forward Linear embedDim embedDim dtype device
mhaOutProj forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (shape' :: [Nat]) (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(KnownShape shape', Numel shape ~ Numel shape') =>
Tensor device dtype shape -> Tensor device dtype shape'
reshape @'[batchSize, seqLen', embedDim] forall a b. (a -> b) -> a -> b
$ Tensor
device
dtype
(SetValue
(SetValue
(CheckMatMul
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[batchSize, numHeads, seqLen, headDim]
(ComputeMatMul
(ReverseImpl
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[])
'[headDim, seqLen, numHeads, batchSize]))
1
(GetValue
(CheckMatMul
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[batchSize, numHeads, seqLen, headDim]
(ComputeMatMul
(ReverseImpl
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[])
'[headDim, seqLen, numHeads, batchSize]))
2))
2
(GetValue
(CheckMatMul
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[batchSize, numHeads, seqLen, headDim]
(ComputeMatMul
(ReverseImpl
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[])
'[headDim, seqLen, numHeads, batchSize]))
1))
attention'
averageOverHeads :: Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
-> Tensor device dtype '[batchSize, seqLen', seqLen]
averageOverHeads =
let numHeads' :: Int
numHeads' = forall (n :: Nat). KnownNat n => Int
natValI @numHeads
in forall a (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
divScalar Int
numHeads' forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (d :: Nat) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(KnownNat d, shape' ~ DropValue shape d,
SumDTypeIsValid device dtype, dtype' ~ SumDType dtype) =>
Tensor device dtype shape -> Tensor device dtype' shape'
sumDim @1
reshape' ::
forall seqLen''.
KnownNat seqLen'' =>
Tensor device dtype '[batchSize, seqLen'', embedDim] ->
Tensor device dtype '[batchSize, numHeads, seqLen'', headDim]
reshape' :: forall (seqLen'' :: Nat).
KnownNat seqLen'' =>
Tensor device dtype '[batchSize, seqLen'', embedDim]
-> Tensor device dtype '[batchSize, numHeads, seqLen'', headDim]
reshape' = forall (n :: Nat) (m :: Nat) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat n, KnownNat m, shape' ~ Transpose shape n m) =>
Tensor device dtype shape -> Tensor device dtype shape'
transpose @1 @2 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (shape' :: [Nat]) (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(KnownShape shape', Numel shape ~ Numel shape') =>
Tensor device dtype shape -> Tensor device dtype shape'
reshape @'[batchSize, seqLen'', numHeads, headDim]
instance
( All KnownNat '[embedDim, kEmbedDim, vEmbedDim, numHeads],
KnownDType dtype,
KnownDevice device,
RandDTypeIsValid device dtype
) =>
A.Randomizable
(MultiheadAttentionSpec embedDim kEmbedDim vEmbedDim numHeads dtype device)
(MultiheadAttention embedDim kEmbedDim vEmbedDim numHeads dtype device)
where
sample :: MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> IO
(MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device)
sample (MultiheadAttentionSpec DropoutSpec
mhaDropoutSpec) =
forall (embedDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat).
Linear embedDim embedDim dtype device
-> Linear kEmbedDim embedDim dtype device
-> Linear vEmbedDim embedDim dtype device
-> Linear embedDim embedDim dtype device
-> Dropout
-> MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
MultiheadAttention
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> forall spec f. Randomizable spec f => spec -> IO f
A.sample forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
LinearSpec inputFeatures outputFeatures dtype device
LinearSpec
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> forall spec f. Randomizable spec f => spec -> IO f
A.sample forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
LinearSpec inputFeatures outputFeatures dtype device
LinearSpec
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> forall spec f. Randomizable spec f => spec -> IO f
A.sample forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
LinearSpec inputFeatures outputFeatures dtype device
LinearSpec
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> forall spec f. Randomizable spec f => spec -> IO f
A.sample forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
LinearSpec inputFeatures outputFeatures dtype device
LinearSpec
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> forall spec f. Randomizable spec f => spec -> IO f
A.sample DropoutSpec
mhaDropoutSpec
data
TransformerMLPSpec
(embedDim :: Nat)
(ffnDim :: Nat)
(dtype :: D.DType)
(device :: (D.DeviceType, Nat))
where
TransformerMLPSpec ::
forall embedDim ffnDim dtype device.
{
forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerMLPSpec embedDim ffnDim dtype device -> DropoutSpec
dropout0Spec :: DropoutSpec,
forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerMLPSpec embedDim ffnDim dtype device -> DropoutSpec
dropout1Spec :: DropoutSpec,
forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerMLPSpec embedDim ffnDim dtype device -> Double
epsSpec :: Double
} ->
TransformerMLPSpec embedDim ffnDim dtype device
deriving (Int -> TransformerMLPSpec embedDim ffnDim dtype device -> ShowS
forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
Int -> TransformerMLPSpec embedDim ffnDim dtype device -> ShowS
forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
[TransformerMLPSpec embedDim ffnDim dtype device] -> ShowS
forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerMLPSpec embedDim ffnDim dtype device -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TransformerMLPSpec embedDim ffnDim dtype device] -> ShowS
$cshowList :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
[TransformerMLPSpec embedDim ffnDim dtype device] -> ShowS
show :: TransformerMLPSpec embedDim ffnDim dtype device -> String
$cshow :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerMLPSpec embedDim ffnDim dtype device -> String
showsPrec :: Int -> TransformerMLPSpec embedDim ffnDim dtype device -> ShowS
$cshowsPrec :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
Int -> TransformerMLPSpec embedDim ffnDim dtype device -> ShowS
Show, TransformerMLPSpec embedDim ffnDim dtype device
-> TransformerMLPSpec embedDim ffnDim dtype device -> Bool
forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerMLPSpec embedDim ffnDim dtype device
-> TransformerMLPSpec embedDim ffnDim dtype device -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: TransformerMLPSpec embedDim ffnDim dtype device
-> TransformerMLPSpec embedDim ffnDim dtype device -> Bool
$c/= :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerMLPSpec embedDim ffnDim dtype device
-> TransformerMLPSpec embedDim ffnDim dtype device -> Bool
== :: TransformerMLPSpec embedDim ffnDim dtype device
-> TransformerMLPSpec embedDim ffnDim dtype device -> Bool
$c== :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerMLPSpec embedDim ffnDim dtype device
-> TransformerMLPSpec embedDim ffnDim dtype device -> Bool
Eq)
data
TransformerMLP
(embedDim :: Nat)
(ffnDim :: Nat)
(dtype :: D.DType)
(device :: (D.DeviceType, Nat))
where
TransformerMLP ::
forall embedDim ffnDim dtype device.
{
forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerMLP embedDim ffnDim dtype device
-> Linear embedDim ffnDim dtype device
linear0 :: Linear embedDim ffnDim dtype device,
forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerMLP embedDim ffnDim dtype device
-> Linear ffnDim embedDim dtype device
linear1 :: Linear ffnDim embedDim dtype device,
forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerMLP embedDim ffnDim dtype device -> Dropout
dropout0 :: Dropout,
forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerMLP embedDim ffnDim dtype device -> Dropout
dropout1 :: Dropout,
forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerMLP embedDim ffnDim dtype device
-> LayerNorm '[embedDim] dtype device
ln :: LayerNorm '[embedDim] dtype device
} ->
TransformerMLP embedDim ffnDim dtype device
deriving (Int -> TransformerMLP embedDim ffnDim dtype device -> ShowS
forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
Int -> TransformerMLP embedDim ffnDim dtype device -> ShowS
forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
[TransformerMLP embedDim ffnDim dtype device] -> ShowS
forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerMLP embedDim ffnDim dtype device -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TransformerMLP embedDim ffnDim dtype device] -> ShowS
$cshowList :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
[TransformerMLP embedDim ffnDim dtype device] -> ShowS
show :: TransformerMLP embedDim ffnDim dtype device -> String
$cshow :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerMLP embedDim ffnDim dtype device -> String
showsPrec :: Int -> TransformerMLP embedDim ffnDim dtype device -> ShowS
$cshowsPrec :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
Int -> TransformerMLP embedDim ffnDim dtype device -> ShowS
Show, forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)) x.
Rep (TransformerMLP embedDim ffnDim dtype device) x
-> TransformerMLP embedDim ffnDim dtype device
forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)) x.
TransformerMLP embedDim ffnDim dtype device
-> Rep (TransformerMLP embedDim ffnDim dtype device) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)) x.
Rep (TransformerMLP embedDim ffnDim dtype device) x
-> TransformerMLP embedDim ffnDim dtype device
$cfrom :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)) x.
TransformerMLP embedDim ffnDim dtype device
-> Rep (TransformerMLP embedDim ffnDim dtype device) x
Generic, forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerMLP embedDim ffnDim dtype device
-> HList (Parameters (TransformerMLP embedDim ffnDim dtype device))
forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerMLP embedDim ffnDim dtype device
-> HList (Parameters (TransformerMLP embedDim ffnDim dtype device))
-> TransformerMLP embedDim ffnDim dtype device
forall f.
(f -> HList (Parameters f))
-> (f -> HList (Parameters f) -> f) -> Parameterized f
replaceParameters :: TransformerMLP embedDim ffnDim dtype device
-> HList (Parameters (TransformerMLP embedDim ffnDim dtype device))
-> TransformerMLP embedDim ffnDim dtype device
$creplaceParameters :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerMLP embedDim ffnDim dtype device
-> HList (Parameters (TransformerMLP embedDim ffnDim dtype device))
-> TransformerMLP embedDim ffnDim dtype device
flattenParameters :: TransformerMLP embedDim ffnDim dtype device
-> HList (Parameters (TransformerMLP embedDim ffnDim dtype device))
$cflattenParameters :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerMLP embedDim ffnDim dtype device
-> HList (Parameters (TransformerMLP embedDim ffnDim dtype device))
Parameterized)
transformerMLP ::
forall embedDim ffnDim seqLen batchSize dtype device.
( BasicArithmeticDTypeIsValid device dtype,
StandardFloatingPointDTypeValidation device dtype,
KnownNat embedDim,
IsSuffixOf '[embedDim] '[seqLen, batchSize, embedDim]
) =>
TransformerMLP embedDim ffnDim dtype device ->
Bool ->
Tensor device dtype '[seqLen, batchSize, embedDim] ->
IO (Tensor device dtype '[seqLen, batchSize, embedDim])
transformerMLP :: forall (embedDim :: Nat) (ffnDim :: Nat) (seqLen :: Nat)
(batchSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
(BasicArithmeticDTypeIsValid device dtype,
StandardFloatingPointDTypeValidation device dtype,
KnownNat embedDim,
IsSuffixOf '[embedDim] '[seqLen, batchSize, embedDim]) =>
TransformerMLP embedDim ffnDim dtype device
-> Bool
-> Tensor device dtype '[seqLen, batchSize, embedDim]
-> IO (Tensor device dtype '[seqLen, batchSize, embedDim])
transformerMLP TransformerMLP {LayerNorm '[embedDim] dtype device
Linear embedDim ffnDim dtype device
Linear ffnDim embedDim dtype device
Dropout
ln :: LayerNorm '[embedDim] dtype device
dropout1 :: Dropout
dropout0 :: Dropout
linear1 :: Linear ffnDim embedDim dtype device
linear0 :: Linear embedDim ffnDim dtype device
ln :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerMLP embedDim ffnDim dtype device
-> LayerNorm '[embedDim] dtype device
dropout1 :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerMLP embedDim ffnDim dtype device -> Dropout
dropout0 :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerMLP embedDim ffnDim dtype device -> Dropout
linear1 :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerMLP embedDim ffnDim dtype device
-> Linear ffnDim embedDim dtype device
linear0 :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerMLP embedDim ffnDim dtype device
-> Linear embedDim ffnDim dtype device
..} Bool
train Tensor device dtype '[seqLen, batchSize, embedDim]
input =
forall {device :: (DeviceType, Nat)} {dtype :: DType}
{dtype' :: DType} {m :: Type -> Type} {shape :: [Nat]}
{shape' :: [Nat]} {b}.
(BasicArithmeticDTypeIsValid device dtype,
BasicArithmeticDTypeIsValid device dtype',
BasicArithmeticDTypeIsValid
device (DTypePromotionImpl dtype dtype' (CmpDType dtype dtype')),
Monad m) =>
(Tensor device dtype shape -> m (Tensor device dtype' shape'))
-> (Tensor
device
(DTypePromotionImpl dtype dtype' (CmpDType dtype dtype'))
(CheckBroadcast
shape
shape'
(ComputeBroadcast
(ReverseImpl shape '[]) (ReverseImpl shape' '[])))
-> m b)
-> Tensor device dtype shape
-> m b
residual Tensor device dtype '[seqLen, batchSize, embedDim]
-> IO
(Tensor
device
dtype
(CheckBroadcast
(CheckMatMul
'[seqLen, batchSize, ffnDim]
'[ffnDim, embedDim]
(ComputeMatMul
(ReverseImpl '[seqLen, batchSize, ffnDim] '[])
'[embedDim, ffnDim]))
(CheckMatMul
'[seqLen, batchSize, ffnDim]
'[ffnDim, embedDim]
(ComputeMatMul
(ReverseImpl '[seqLen, batchSize, ffnDim] '[])
'[embedDim, ffnDim]))
(ComputeBroadcast
(ReverseImpl
(CheckMatMul
'[seqLen, batchSize, ffnDim]
'[ffnDim, embedDim]
(ComputeMatMul
(ReverseImpl '[seqLen, batchSize, ffnDim] '[])
'[embedDim, ffnDim]))
'[])
(ReverseImpl
(CheckMatMul
'[seqLen, batchSize, ffnDim]
'[ffnDim, embedDim]
(ComputeMatMul
(ReverseImpl '[seqLen, batchSize, ffnDim] '[])
'[embedDim, ffnDim]))
'[]))))
f (forall (f :: Type -> Type) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall f a b. HasForward f a b => f -> a -> b
forward LayerNorm '[embedDim] dtype device
ln) Tensor device dtype '[seqLen, batchSize, embedDim]
input
where
f :: Tensor device dtype '[seqLen, batchSize, embedDim]
-> IO
(Tensor
device
dtype
(CheckBroadcast
(CheckMatMul
'[seqLen, batchSize, ffnDim]
'[ffnDim, embedDim]
(ComputeMatMul
(ReverseImpl '[seqLen, batchSize, ffnDim] '[])
'[embedDim, ffnDim]))
(CheckMatMul
'[seqLen, batchSize, ffnDim]
'[ffnDim, embedDim]
(ComputeMatMul
(ReverseImpl '[seqLen, batchSize, ffnDim] '[])
'[embedDim, ffnDim]))
(ComputeBroadcast
(ReverseImpl
(CheckMatMul
'[seqLen, batchSize, ffnDim]
'[ffnDim, embedDim]
(ComputeMatMul
(ReverseImpl '[seqLen, batchSize, ffnDim] '[])
'[embedDim, ffnDim]))
'[])
(ReverseImpl
(CheckMatMul
'[seqLen, batchSize, ffnDim]
'[ffnDim, embedDim]
(ComputeMatMul
(ReverseImpl '[seqLen, batchSize, ffnDim] '[])
'[embedDim, ffnDim]))
'[]))))
f Tensor device dtype '[seqLen, batchSize, embedDim]
x =
forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Dropout
-> Bool
-> Tensor device dtype shape
-> IO (Tensor device dtype shape)
dropoutForward Dropout
dropout1 Bool
train
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall f a b. HasForward f a b => f -> a -> b
forward Linear ffnDim embedDim dtype device
linear1
forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Dropout
-> Bool
-> Tensor device dtype shape
-> IO (Tensor device dtype shape)
dropoutForward Dropout
dropout0 Bool
train
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)) t.
(StandardFloatingPointDTypeValidation device dtype,
IsUnnamed t device dtype shape) =>
t -> t
relu
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall f a b. HasForward f a b => f -> a -> b
forward Linear embedDim ffnDim dtype device
linear0
forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Tensor device dtype '[seqLen, batchSize, embedDim]
x
instance
( All KnownNat '[embedDim, ffnDim],
KnownDType dtype,
KnownDevice device,
RandDTypeIsValid device dtype
) =>
A.Randomizable
(TransformerMLPSpec embedDim ffnDim dtype device)
(TransformerMLP embedDim ffnDim dtype device)
where
sample :: TransformerMLPSpec embedDim ffnDim dtype device
-> IO (TransformerMLP embedDim ffnDim dtype device)
sample TransformerMLPSpec {Double
DropoutSpec
epsSpec :: Double
dropout1Spec :: DropoutSpec
dropout0Spec :: DropoutSpec
epsSpec :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerMLPSpec embedDim ffnDim dtype device -> Double
dropout1Spec :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerMLPSpec embedDim ffnDim dtype device -> DropoutSpec
dropout0Spec :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerMLPSpec embedDim ffnDim dtype device -> DropoutSpec
..} =
forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
Linear embedDim ffnDim dtype device
-> Linear ffnDim embedDim dtype device
-> Dropout
-> Dropout
-> LayerNorm '[embedDim] dtype device
-> TransformerMLP embedDim ffnDim dtype device
TransformerMLP
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> forall spec f. Randomizable spec f => spec -> IO f
A.sample forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
LinearSpec inputFeatures outputFeatures dtype device
LinearSpec
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> forall spec f. Randomizable spec f => spec -> IO f
A.sample forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
LinearSpec inputFeatures outputFeatures dtype device
LinearSpec
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> forall spec f. Randomizable spec f => spec -> IO f
A.sample DropoutSpec
dropout0Spec
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> forall spec f. Randomizable spec f => spec -> IO f
A.sample DropoutSpec
dropout1Spec
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> forall spec f. Randomizable spec f => spec -> IO f
A.sample (forall (normalizedShape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Double -> LayerNormSpec normalizedShape dtype device
LayerNormSpec Double
epsSpec)
data
TransformerLayerSpec
(embedDim :: Nat)
(kEmbedDim :: Nat)
(vEmbedDim :: Nat)
(numHeads :: Nat)
(ffnDim :: Nat)
(dtype :: D.DType)
(device :: (D.DeviceType, Nat))
where
TransformerLayerSpec ::
forall embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device.
{ forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device
mhaSpec :: MultiheadAttentionSpec embedDim kEmbedDim vEmbedDim numHeads dtype device,
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> DropoutSpec
attnDropoutSpec :: DropoutSpec,
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> Double
epsSpec' :: Double,
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> TransformerMLPSpec embedDim ffnDim dtype device
mlpSpec :: TransformerMLPSpec embedDim ffnDim dtype device
} ->
TransformerLayerSpec embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
deriving (Int
-> TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> ShowS
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
Int
-> TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> ShowS
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
[TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device]
-> ShowS
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device]
-> ShowS
$cshowList :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
[TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device]
-> ShowS
show :: TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> String
$cshow :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> String
showsPrec :: Int
-> TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> ShowS
$cshowsPrec :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
Int
-> TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> ShowS
Show, TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> Bool
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> Bool
$c/= :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> Bool
== :: TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> Bool
$c== :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> Bool
Eq)
data
TransformerLayer
(embedDim :: Nat)
(kEmbedDim :: Nat)
(vEmbedDim :: Nat)
(numHeads :: Nat)
(ffnDim :: Nat)
(dtype :: D.DType)
(device :: (D.DeviceType, Nat))
where
TransformerLayer ::
forall embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device.
{
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
transformerLayer_mha :: MultiheadAttention embedDim kEmbedDim vEmbedDim numHeads dtype device,
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> Dropout
transformerLayer_attnDropout :: Dropout,
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> LayerNorm '[embedDim] dtype device
transformerLayer_ln :: LayerNorm '[embedDim] dtype device,
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> TransformerMLP embedDim ffnDim dtype device
transformerLayer_mlp :: TransformerMLP embedDim ffnDim dtype device
} ->
TransformerLayer embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
deriving (Int
-> TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> ShowS
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
Int
-> TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> ShowS
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
[TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device]
-> ShowS
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device]
-> ShowS
$cshowList :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
[TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device]
-> ShowS
show :: TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> String
$cshow :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> String
showsPrec :: Int
-> TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> ShowS
$cshowsPrec :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
Int
-> TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> ShowS
Show, forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)) x.
Rep
(TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
x
-> TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)) x.
TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> Rep
(TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)) x.
Rep
(TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
x
-> TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
$cfrom :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)) x.
TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> Rep
(TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
x
Generic, forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> HList
(Parameters
(TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device))
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> HList
(Parameters
(TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device))
-> TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
forall f.
(f -> HList (Parameters f))
-> (f -> HList (Parameters f) -> f) -> Parameterized f
replaceParameters :: TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> HList
(Parameters
(TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device))
-> TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
$creplaceParameters :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> HList
(Parameters
(TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device))
-> TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
flattenParameters :: TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> HList
(Parameters
(TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device))
$cflattenParameters :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> HList
(Parameters
(TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device))
Parameterized)
transformerLayer ::
forall (numHeads :: Nat) (ffnDim :: Nat) (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat) (headDim :: Nat) (seqLen :: Nat) (seqLen' :: Nat) (batchSize :: Nat) dtype device.
( 1 <= numHeads,
embedDim ~ (headDim * numHeads),
All KnownNat '[embedDim, kEmbedDim, vEmbedDim, numHeads, seqLen, seqLen', batchSize, headDim],
IsSuffixOf '[embedDim] '[batchSize, seqLen', embedDim],
KnownDType dtype,
dtype ~ SumDType dtype,
StandardFloatingPointDTypeValidation device dtype,
MatMulDTypeIsValid device dtype,
BasicArithmeticDTypeIsValid device dtype,
SumDTypeIsValid device dtype,
KnownDevice device
) =>
TransformerLayer embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device ->
Bool ->
Maybe (Tensor device dtype '[batchSize, seqLen', seqLen]) ->
Maybe (Tensor device 'D.Bool '[batchSize, seqLen]) ->
Maybe (Tensor device dtype '[batchSize, seqLen', seqLen, headDim]) ->
Maybe (Tensor device dtype '[batchSize, seqLen', seqLen, headDim]) ->
Tensor device dtype '[batchSize, seqLen', embedDim] ->
Tensor device dtype '[batchSize, seqLen, kEmbedDim] ->
Tensor device dtype '[batchSize, seqLen, vEmbedDim] ->
IO (Tensor device dtype '[batchSize, seqLen', embedDim])
transformerLayer :: forall (numHeads :: Nat) (ffnDim :: Nat) (embedDim :: Nat)
(kEmbedDim :: Nat) (vEmbedDim :: Nat) (headDim :: Nat)
(seqLen :: Nat) (seqLen' :: Nat) (batchSize :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
(1 <= numHeads, embedDim ~ (headDim * numHeads),
All
KnownNat
'[embedDim, kEmbedDim, vEmbedDim, numHeads, seqLen, seqLen',
batchSize, headDim],
IsSuffixOf '[embedDim] '[batchSize, seqLen', embedDim],
KnownDType dtype, dtype ~ SumDType dtype,
StandardFloatingPointDTypeValidation device dtype,
MatMulDTypeIsValid device dtype,
BasicArithmeticDTypeIsValid device dtype,
SumDTypeIsValid device dtype, KnownDevice device) =>
TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> Bool
-> Maybe (Tensor device dtype '[batchSize, seqLen', seqLen])
-> Maybe (Tensor device 'Bool '[batchSize, seqLen])
-> Maybe
(Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
-> Maybe
(Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
-> Tensor device dtype '[batchSize, seqLen', embedDim]
-> Tensor device dtype '[batchSize, seqLen, kEmbedDim]
-> Tensor device dtype '[batchSize, seqLen, vEmbedDim]
-> IO (Tensor device dtype '[batchSize, seqLen', embedDim])
transformerLayer TransformerLayer {LayerNorm '[embedDim] dtype device
Dropout
TransformerMLP embedDim ffnDim dtype device
MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
transformerLayer_mlp :: TransformerMLP embedDim ffnDim dtype device
transformerLayer_ln :: LayerNorm '[embedDim] dtype device
transformerLayer_attnDropout :: Dropout
transformerLayer_mha :: MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
transformerLayer_mlp :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> TransformerMLP embedDim ffnDim dtype device
transformerLayer_ln :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> LayerNorm '[embedDim] dtype device
transformerLayer_attnDropout :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> Dropout
transformerLayer_mha :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
..} Bool
train Maybe (Tensor device dtype '[batchSize, seqLen', seqLen])
attentionMask Maybe (Tensor device 'Bool '[batchSize, seqLen])
keyPaddingMask Maybe (Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
keyRelations Maybe (Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
valueRelations Tensor device dtype '[batchSize, seqLen', embedDim]
query Tensor device dtype '[batchSize, seqLen, kEmbedDim]
key Tensor device dtype '[batchSize, seqLen, vEmbedDim]
value =
let f :: Tensor device dtype '[batchSize, seqLen', embedDim]
-> IO (Tensor device dtype '[batchSize, seqLen', embedDim])
f Tensor device dtype '[batchSize, seqLen', embedDim]
query' =
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (seqLen :: Nat) (seqLen' :: Nat)
(batchSize :: Nat) (headDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
(1 <= numHeads, embedDim ~ (headDim * numHeads),
All
KnownNat
'[embedDim, kEmbedDim, vEmbedDim, numHeads, seqLen, seqLen',
batchSize, headDim],
KnownDType dtype,
StandardFloatingPointDTypeValidation device dtype,
MatMulDTypeIsValid device dtype,
BasicArithmeticDTypeIsValid device dtype, dtype ~ SumDType dtype,
SumDTypeIsValid device dtype, KnownDevice device) =>
MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Bool
-> Maybe (Tensor device dtype '[batchSize, seqLen', seqLen])
-> Maybe (Tensor device 'Bool '[batchSize, seqLen])
-> Maybe
(Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
-> Maybe
(Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
-> Tensor device dtype '[batchSize, seqLen', embedDim]
-> Tensor device dtype '[batchSize, seqLen, kEmbedDim]
-> Tensor device dtype '[batchSize, seqLen, vEmbedDim]
-> IO
(Tensor device dtype '[batchSize, seqLen', embedDim],
Tensor device dtype '[batchSize, seqLen', seqLen])
multiheadAttention MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
transformerLayer_mha Bool
train Maybe (Tensor device dtype '[batchSize, seqLen', seqLen])
attentionMask Maybe (Tensor device 'Bool '[batchSize, seqLen])
keyPaddingMask Maybe (Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
keyRelations Maybe (Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
valueRelations Tensor device dtype '[batchSize, seqLen', embedDim]
query' Tensor device dtype '[batchSize, seqLen, kEmbedDim]
key Tensor device dtype '[batchSize, seqLen, vEmbedDim]
value
forall (m :: Type -> Type) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Dropout
-> Bool
-> Tensor device dtype shape
-> IO (Tensor device dtype shape)
dropoutForward Dropout
transformerLayer_attnDropout Bool
train forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst
in forall {device :: (DeviceType, Nat)} {dtype :: DType}
{dtype' :: DType} {m :: Type -> Type} {shape :: [Nat]}
{shape' :: [Nat]} {b}.
(BasicArithmeticDTypeIsValid device dtype,
BasicArithmeticDTypeIsValid device dtype',
BasicArithmeticDTypeIsValid
device (DTypePromotionImpl dtype dtype' (CmpDType dtype dtype')),
Monad m) =>
(Tensor device dtype shape -> m (Tensor device dtype' shape'))
-> (Tensor
device
(DTypePromotionImpl dtype dtype' (CmpDType dtype dtype'))
(CheckBroadcast
shape
shape'
(ComputeBroadcast
(ReverseImpl shape '[]) (ReverseImpl shape' '[])))
-> m b)
-> Tensor device dtype shape
-> m b
residual Tensor device dtype '[batchSize, seqLen', embedDim]
-> IO (Tensor device dtype '[batchSize, seqLen', embedDim])
f (forall (f :: Type -> Type) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall f a b. HasForward f a b => f -> a -> b
forward LayerNorm '[embedDim] dtype device
transformerLayer_ln) Tensor device dtype '[batchSize, seqLen', embedDim]
query forall (m :: Type -> Type) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (embedDim :: Nat) (ffnDim :: Nat) (seqLen :: Nat)
(batchSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
(BasicArithmeticDTypeIsValid device dtype,
StandardFloatingPointDTypeValidation device dtype,
KnownNat embedDim,
IsSuffixOf '[embedDim] '[seqLen, batchSize, embedDim]) =>
TransformerMLP embedDim ffnDim dtype device
-> Bool
-> Tensor device dtype '[seqLen, batchSize, embedDim]
-> IO (Tensor device dtype '[seqLen, batchSize, embedDim])
transformerMLP TransformerMLP embedDim ffnDim dtype device
transformerLayer_mlp Bool
train
instance
( All KnownNat '[embedDim, kEmbedDim, vEmbedDim, numHeads, ffnDim],
KnownDType dtype,
KnownDevice device,
RandDTypeIsValid device dtype
) =>
A.Randomizable
(TransformerLayerSpec embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
(TransformerLayer embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
where
sample :: TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> IO
(TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
sample TransformerLayerSpec {Double
DropoutSpec
TransformerMLPSpec embedDim ffnDim dtype device
MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device
mlpSpec :: TransformerMLPSpec embedDim ffnDim dtype device
epsSpec' :: Double
attnDropoutSpec :: DropoutSpec
mhaSpec :: MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device
mlpSpec :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> TransformerMLPSpec embedDim ffnDim dtype device
epsSpec' :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> Double
attnDropoutSpec :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> DropoutSpec
mhaSpec :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device
..} =
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Dropout
-> LayerNorm '[embedDim] dtype device
-> TransformerMLP embedDim ffnDim dtype device
-> TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
TransformerLayer
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> forall spec f. Randomizable spec f => spec -> IO f
A.sample MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device
mhaSpec
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> forall spec f. Randomizable spec f => spec -> IO f
A.sample DropoutSpec
attnDropoutSpec
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> forall spec f. Randomizable spec f => spec -> IO f
A.sample (forall (normalizedShape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Double -> LayerNormSpec normalizedShape dtype device
LayerNormSpec Double
epsSpec')
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> forall spec f. Randomizable spec f => spec -> IO f
A.sample TransformerMLPSpec embedDim ffnDim dtype device
mlpSpec
data
TransformerLMSpec
(numAttnLayers :: Nat)
(numHeads :: Nat)
(ffnDim :: Nat)
(paddingIdx :: Nat)
(numEmbeds :: Nat)
(embedDim :: Nat)
(dtype :: D.DType)
(device :: (D.DeviceType, Nat))
where
TransformerLMSpec ::
forall numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device.
{
forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
(paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLMSpec
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> DropoutSpec
lmDropoutSpec :: DropoutSpec,
forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
(paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLMSpec
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> TransformerLayerSpec
embedDim embedDim embedDim numHeads ffnDim dtype device
lmLayerSpec :: TransformerLayerSpec embedDim embedDim embedDim numHeads ffnDim dtype device
} ->
TransformerLMSpec numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device
deriving (Int
-> TransformerLMSpec
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> ShowS
forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
(paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
Int
-> TransformerLMSpec
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> ShowS
forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
(paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
[TransformerLMSpec
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device]
-> ShowS
forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
(paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLMSpec
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TransformerLMSpec
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device]
-> ShowS
$cshowList :: forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
(paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
[TransformerLMSpec
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device]
-> ShowS
show :: TransformerLMSpec
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> String
$cshow :: forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
(paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLMSpec
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> String
showsPrec :: Int
-> TransformerLMSpec
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> ShowS
$cshowsPrec :: forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
(paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
Int
-> TransformerLMSpec
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> ShowS
Show, TransformerLMSpec
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> TransformerLMSpec
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> Bool
forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
(paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLMSpec
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> TransformerLMSpec
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: TransformerLMSpec
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> TransformerLMSpec
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> Bool
$c/= :: forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
(paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLMSpec
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> TransformerLMSpec
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> Bool
== :: TransformerLMSpec
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> TransformerLMSpec
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> Bool
$c== :: forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
(paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLMSpec
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> TransformerLMSpec
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> Bool
Eq)
data
TransformerLM
(numAttnLayers :: Nat)
(numHeads :: Nat)
(ffnDim :: Nat)
(paddingIdx :: Nat)
(numEmbeds :: Nat)
(embedDim :: Nat)
(dtype :: D.DType)
(device :: (D.DeviceType, Nat))
where
TransformerLM ::
forall numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device.
{
forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
(paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> Embedding
('Just paddingIdx) numEmbeds embedDim 'Learned dtype device
tEmbedding :: Embedding ('Just paddingIdx) numEmbeds embedDim 'Learned dtype device,
forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
(paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> Embedding 'Nothing 2048 embedDim 'Constant dtype device
tPosEmbedding :: Embedding 'Nothing 2048 embedDim 'Constant dtype device,
forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
(paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> Dropout
tDropout :: Dropout,
forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
(paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> HList
(HReplicateR
numAttnLayers
(TransformerLayer
embedDim embedDim embedDim numHeads ffnDim dtype device))
tLayers :: HList (HReplicateR numAttnLayers (TransformerLayer embedDim embedDim embedDim numHeads ffnDim dtype device)),
forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
(paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> Linear embedDim numEmbeds dtype device
tProj :: Linear embedDim numEmbeds dtype device
} ->
TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device
deriving (forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
(paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)) x.
Rep
(TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device)
x
-> TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
(paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)) x.
TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> Rep
(TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device)
x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
(paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)) x.
Rep
(TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device)
x
-> TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
$cfrom :: forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
(paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)) x.
TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> Rep
(TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device)
x
Generic)
deriving instance
( Show
( HList
( HReplicateR
numAttnLayers
( TransformerLayer
embedDim
embedDim
embedDim
numHeads
ffnDim
dtype
device
)
)
)
) =>
Show (TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device)
instance
( layers
~ ( HReplicateR
numAttnLayers
( TransformerLayer
embedDim
embedDim
embedDim
numHeads
ffnDim
dtype
device
)
),
Parameterized
( HList
layers
),
HAppendFD
(Parameters (HList layers))
'[ Parameter device dtype '[numEmbeds, embedDim],
Parameter device dtype '[numEmbeds]
]
( Parameters (HList layers)
++ '[ Parameter device dtype '[numEmbeds, embedDim],
Parameter device dtype '[numEmbeds]
]
)
) =>
Parameterized (TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device)
data
FoldLayers
(batchSize :: Nat)
(seqLen :: Nat)
(dtype :: D.DType)
(device :: (D.DeviceType, Nat)) = FoldLayers
{
forall (batchSize :: Nat) (seqLen :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
FoldLayers batchSize seqLen dtype device -> Bool
flTrain :: Bool,
forall (batchSize :: Nat) (seqLen :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
FoldLayers batchSize seqLen dtype device
-> Maybe (Tensor device dtype '[batchSize, seqLen, seqLen])
flAttentionMask :: Maybe (Tensor device dtype '[batchSize, seqLen, seqLen]),
forall (batchSize :: Nat) (seqLen :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
FoldLayers batchSize seqLen dtype device
-> Maybe (Tensor device 'Bool '[batchSize, seqLen])
flKeyPaddingMask :: Maybe (Tensor device 'D.Bool '[batchSize, seqLen])
}
instance
( 1 <= numHeads,
embedDim ~ (headDim * numHeads),
All KnownNat '[embedDim, numHeads, seqLen, batchSize, headDim],
IsSuffixOf '[embedDim] '[batchSize, seqLen, embedDim],
KnownDType dtype,
StandardFloatingPointDTypeValidation device dtype,
MatMulDTypeIsValid device dtype,
BasicArithmeticDTypeIsValid device dtype,
dtype ~ SumDType dtype,
SumDTypeIsValid device dtype,
KnownDevice device
) =>
Apply'
(FoldLayers batchSize seqLen dtype device)
( TransformerLayer embedDim embedDim embedDim numHeads ffnDim dtype device,
IO (Tensor device dtype '[batchSize, seqLen, embedDim])
)
(IO (Tensor device dtype '[batchSize, seqLen, embedDim]))
where
apply' :: FoldLayers batchSize seqLen dtype device
-> (TransformerLayer
embedDim embedDim embedDim numHeads ffnDim dtype device,
IO (Tensor device dtype '[batchSize, seqLen, embedDim]))
-> IO (Tensor device dtype '[batchSize, seqLen, embedDim])
apply' FoldLayers {Bool
Maybe (Tensor device dtype '[batchSize, seqLen, seqLen])
Maybe (Tensor device 'Bool '[batchSize, seqLen])
flKeyPaddingMask :: Maybe (Tensor device 'Bool '[batchSize, seqLen])
flAttentionMask :: Maybe (Tensor device dtype '[batchSize, seqLen, seqLen])
flTrain :: Bool
flKeyPaddingMask :: forall (batchSize :: Nat) (seqLen :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
FoldLayers batchSize seqLen dtype device
-> Maybe (Tensor device 'Bool '[batchSize, seqLen])
flAttentionMask :: forall (batchSize :: Nat) (seqLen :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
FoldLayers batchSize seqLen dtype device
-> Maybe (Tensor device dtype '[batchSize, seqLen, seqLen])
flTrain :: forall (batchSize :: Nat) (seqLen :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
FoldLayers batchSize seqLen dtype device -> Bool
..} (TransformerLayer
embedDim embedDim embedDim numHeads ffnDim dtype device
layer, IO (Tensor device dtype '[batchSize, seqLen, embedDim])
mx) = IO (Tensor device dtype '[batchSize, seqLen, embedDim])
mx forall (m :: Type -> Type) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Tensor device dtype '[batchSize, seqLen, embedDim]
x -> forall (numHeads :: Nat) (ffnDim :: Nat) (embedDim :: Nat)
(kEmbedDim :: Nat) (vEmbedDim :: Nat) (headDim :: Nat)
(seqLen :: Nat) (seqLen' :: Nat) (batchSize :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
(1 <= numHeads, embedDim ~ (headDim * numHeads),
All
KnownNat
'[embedDim, kEmbedDim, vEmbedDim, numHeads, seqLen, seqLen',
batchSize, headDim],
IsSuffixOf '[embedDim] '[batchSize, seqLen', embedDim],
KnownDType dtype, dtype ~ SumDType dtype,
StandardFloatingPointDTypeValidation device dtype,
MatMulDTypeIsValid device dtype,
BasicArithmeticDTypeIsValid device dtype,
SumDTypeIsValid device dtype, KnownDevice device) =>
TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> Bool
-> Maybe (Tensor device dtype '[batchSize, seqLen', seqLen])
-> Maybe (Tensor device 'Bool '[batchSize, seqLen])
-> Maybe
(Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
-> Maybe
(Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
-> Tensor device dtype '[batchSize, seqLen', embedDim]
-> Tensor device dtype '[batchSize, seqLen, kEmbedDim]
-> Tensor device dtype '[batchSize, seqLen, vEmbedDim]
-> IO (Tensor device dtype '[batchSize, seqLen', embedDim])
transformerLayer TransformerLayer
embedDim embedDim embedDim numHeads ffnDim dtype device
layer Bool
flTrain Maybe (Tensor device dtype '[batchSize, seqLen, seqLen])
flAttentionMask Maybe (Tensor device 'Bool '[batchSize, seqLen])
flKeyPaddingMask forall a. Maybe a
Nothing forall a. Maybe a
Nothing Tensor device dtype '[batchSize, seqLen, embedDim]
x Tensor device dtype '[batchSize, seqLen, embedDim]
x Tensor device dtype '[batchSize, seqLen, embedDim]
x
transformerLM ::
forall
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
seqLen
batchSize
dtype
device.
( All KnownNat '[paddingIdx, embedDim, seqLen, batchSize],
paddingIdx + 1 <= numEmbeds,
1 <= seqLen,
HFoldrM
IO
(FoldLayers batchSize seqLen dtype device)
(Tensor device dtype '[batchSize, seqLen, embedDim])
(HReplicateR numAttnLayers (TransformerLayer embedDim embedDim embedDim numHeads ffnDim dtype device))
(Tensor device dtype '[batchSize, seqLen, embedDim]),
BasicArithmeticDTypeIsValid device dtype,
ComparisonDTypeIsValid device dtype,
ComparisonDTypeIsValid device 'D.Int64,
KnownDType dtype,
KnownDevice device
) =>
TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device ->
Bool ->
Tensor device 'D.Int64 '[batchSize, seqLen] ->
IO (Tensor device dtype '[batchSize, seqLen, numEmbeds])
transformerLM :: forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
(paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
(seqLen :: Nat) (batchSize :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
(All KnownNat '[paddingIdx, embedDim, seqLen, batchSize],
(paddingIdx + 1) <= numEmbeds, 1 <= seqLen,
HFoldrM
IO
(FoldLayers batchSize seqLen dtype device)
(Tensor device dtype '[batchSize, seqLen, embedDim])
(HReplicateR
numAttnLayers
(TransformerLayer
embedDim embedDim embedDim numHeads ffnDim dtype device))
(Tensor device dtype '[batchSize, seqLen, embedDim]),
BasicArithmeticDTypeIsValid device dtype,
ComparisonDTypeIsValid device dtype,
ComparisonDTypeIsValid device 'Int64, KnownDType dtype,
KnownDevice device) =>
TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> Bool
-> Tensor device 'Int64 '[batchSize, seqLen]
-> IO (Tensor device dtype '[batchSize, seqLen, numEmbeds])
transformerLM TransformerLM {HList
(HReplicateR
numAttnLayers
(TransformerLayer
embedDim embedDim embedDim numHeads ffnDim dtype device))
Embedding 'Nothing 2048 embedDim 'Constant dtype device
Embedding
('Just paddingIdx) numEmbeds embedDim 'Learned dtype device
Linear embedDim numEmbeds dtype device
Dropout
tProj :: Linear embedDim numEmbeds dtype device
tLayers :: HList
(HReplicateR
numAttnLayers
(TransformerLayer
embedDim embedDim embedDim numHeads ffnDim dtype device))
tDropout :: Dropout
tPosEmbedding :: Embedding 'Nothing 2048 embedDim 'Constant dtype device
tEmbedding :: Embedding
('Just paddingIdx) numEmbeds embedDim 'Learned dtype device
tProj :: forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
(paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> Linear embedDim numEmbeds dtype device
tLayers :: forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
(paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> HList
(HReplicateR
numAttnLayers
(TransformerLayer
embedDim embedDim embedDim numHeads ffnDim dtype device))
tDropout :: forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
(paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> Dropout
tPosEmbedding :: forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
(paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> Embedding 'Nothing 2048 embedDim 'Constant dtype device
tEmbedding :: forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
(paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> Embedding
('Just paddingIdx) numEmbeds embedDim 'Learned dtype device
..} Bool
train Tensor device 'Int64 '[batchSize, seqLen]
xTokens = do
let x :: Tensor device dtype '[batchSize, seqLen, embedDim]
x = forall (paddingIdx :: Maybe Nat) (shape :: [Nat])
(numEmbeds :: Nat) (embedSize :: Nat)
(embeddingType :: EmbeddingType) (dtype :: DType)
(device :: (DeviceType, Nat)) (shape' :: [Nat]).
(KnownMaybeNat paddingIdx, PaddingIdxCheck paddingIdx numEmbeds,
shape' ~ Reverse (embedSize : Reverse shape)) =>
Embedding paddingIdx numEmbeds embedSize embeddingType dtype device
-> Tensor device 'Int64 shape -> Tensor device dtype shape'
embed Embedding
('Just paddingIdx) numEmbeds embedDim 'Learned dtype device
tEmbedding Tensor device 'Int64 '[batchSize, seqLen]
xTokens
positions :: Tensor device dtype '[batchSize, seqLen, embedDim]
positions =
forall (shape' :: [Nat]) (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(KnownShape shape', shape' ~ Broadcast shape shape') =>
Bool -> Tensor device dtype shape -> Tensor device dtype shape'
expand @'[batchSize, seqLen, embedDim] Bool
True
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (paddingIdx :: Maybe Nat) (shape :: [Nat])
(numEmbeds :: Nat) (embedSize :: Nat)
(embeddingType :: EmbeddingType) (dtype :: DType)
(device :: (DeviceType, Nat)) (shape' :: [Nat]).
(KnownMaybeNat paddingIdx, PaddingIdxCheck paddingIdx numEmbeds,
shape' ~ Reverse (embedSize : Reverse shape)) =>
Embedding paddingIdx numEmbeds embedSize embeddingType dtype device
-> Tensor device 'Int64 shape -> Tensor device dtype shape'
embed Embedding 'Nothing 2048 embedDim 'Constant dtype device
tPosEmbedding
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (dtype' :: DType) (dtype :: DType)
(device :: (DeviceType, Nat)) (shape :: [Nat]) t t'.
(KnownDType dtype', IsUnnamed t device dtype shape, Unnamed t',
t' ~ ReplaceDType'' t dtype') =>
t -> t'
Torch.Typed.Tensor.toDType @D.Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (steps :: Nat) (device :: (DeviceType, Nat)) start end.
(Scalar start, Scalar end, KnownNat steps,
TensorOptions '[steps] 'Float device) =>
start -> end -> Tensor device 'Float '[steps]
linspace @seqLen (Int
0 :: Int)
forall a b. (a -> b) -> a -> b
$ forall (n :: Nat). KnownNat n => Int
natValI @(seqLen - 1)
Tensor device dtype '[batchSize, seqLen, embedDim]
x' <- forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Dropout
-> Bool
-> Tensor device dtype shape
-> IO (Tensor device dtype shape)
dropoutForward Dropout
tDropout Bool
train (Tensor device dtype '[batchSize, seqLen, embedDim]
x forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
(device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
shape'' ~ Broadcast shape shape',
BasicArithmeticDTypeIsValid device dtype,
BasicArithmeticDTypeIsValid device dtype',
BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
`add` Tensor device dtype '[batchSize, seqLen, embedDim]
positions)
let attentionMask :: Tensor device 'Bool '[1, seqLen, seqLen]
attentionMask =
forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ Unsqueeze shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
unsqueeze @0
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (dtype' :: DType) (dtype :: DType)
(device :: (DeviceType, Nat)) (shape :: [Nat]) t t'.
(KnownDType dtype', IsUnnamed t device dtype shape, Unnamed t',
t' ~ ReplaceDType'' t dtype') =>
t -> t'
Torch.Typed.Tensor.toDType @D.Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(shape ~ MatrixOrMatrixBatch shape) =>
Int -> Tensor device dtype shape -> Tensor device dtype shape
triu Int
1
forall a b. (a -> b) -> a -> b
$ forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
Tensor device dtype shape
ones @'[seqLen, seqLen] @D.Int8 @device
attentionMask' :: Maybe (Tensor device dtype '[batchSize, seqLen, seqLen])
attentionMask' =
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a (shape :: [Nat]) (shape' :: [Nat]) (shape'' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(Scalar a, shape'' ~ Broadcast shape shape') =>
Tensor device 'Bool shape'
-> a -> Tensor device dtype shape -> Tensor device dtype shape''
maskedFill Tensor device 'Bool '[1, seqLen, seqLen]
attentionMask (-Double
1 forall a. Fractional a => a -> a -> a
/ Double
0 :: Double) forall a b. (a -> b) -> a -> b
$
forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
Tensor device dtype shape
zeros @'[batchSize, seqLen, seqLen] @dtype @device
let keyPaddingMask :: Maybe (Tensor device 'Bool '[batchSize, seqLen])
keyPaddingMask = forall (f :: Type -> Type) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Tensor device 'Int64 '[batchSize, seqLen]
xTokens forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
ComparisonDTypeIsValid device dtype,
ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
==. (forall a. Num a => Integer -> a
fromInteger forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @paddingIdx :: Tensor device 'D.Int64 '[])
Tensor device dtype '[batchSize, seqLen, embedDim]
y <- forall {k} {k1} (m :: k -> Type) f acc (xs :: [k1]) (res :: k).
HFoldrM m f acc xs res =>
f -> acc -> HList xs -> m res
hfoldrM (forall (batchSize :: Nat) (seqLen :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
Bool
-> Maybe (Tensor device dtype '[batchSize, seqLen, seqLen])
-> Maybe (Tensor device 'Bool '[batchSize, seqLen])
-> FoldLayers batchSize seqLen dtype device
FoldLayers Bool
train Maybe (Tensor device dtype '[batchSize, seqLen, seqLen])
attentionMask' Maybe (Tensor device 'Bool '[batchSize, seqLen])
keyPaddingMask) Tensor device dtype '[batchSize, seqLen, embedDim]
x' HList
(HReplicateR
numAttnLayers
(TransformerLayer
embedDim embedDim embedDim numHeads ffnDim dtype device))
tLayers
forall (m :: Type -> Type) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall f a b. HasForward f a b => f -> a -> b
forward Linear embedDim numEmbeds dtype device
tProj Tensor device dtype '[batchSize, seqLen, embedDim]
y
instance
( All KnownNat '[paddingIdx, embedDim, seqLen, batchSize],
paddingIdx + 1 <= numEmbeds,
1 <= seqLen,
HFoldrM
IO
(FoldLayers batchSize seqLen dtype device)
(Tensor device dtype '[batchSize, seqLen, embedDim])
(HReplicateR numAttnLayers (TransformerLayer embedDim embedDim embedDim numHeads ffnDim dtype device))
(Tensor device dtype '[batchSize, seqLen, embedDim]),
BasicArithmeticDTypeIsValid device dtype,
ComparisonDTypeIsValid device dtype,
ComparisonDTypeIsValid device 'D.Int64,
KnownDType dtype,
KnownDevice device
) =>
HasForward (TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device) (Tensor device 'D.Int64 '[batchSize, seqLen]) (Tensor device dtype '[batchSize, seqLen, numEmbeds])
where
forward :: TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> Tensor device 'Int64 '[batchSize, seqLen]
-> Tensor device dtype '[batchSize, seqLen, numEmbeds]
forward TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
model Tensor device 'Int64 '[batchSize, seqLen]
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
(paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
(seqLen :: Nat) (batchSize :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
(All KnownNat '[paddingIdx, embedDim, seqLen, batchSize],
(paddingIdx + 1) <= numEmbeds, 1 <= seqLen,
HFoldrM
IO
(FoldLayers batchSize seqLen dtype device)
(Tensor device dtype '[batchSize, seqLen, embedDim])
(HReplicateR
numAttnLayers
(TransformerLayer
embedDim embedDim embedDim numHeads ffnDim dtype device))
(Tensor device dtype '[batchSize, seqLen, embedDim]),
BasicArithmeticDTypeIsValid device dtype,
ComparisonDTypeIsValid device dtype,
ComparisonDTypeIsValid device 'Int64, KnownDType dtype,
KnownDevice device) =>
TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> Bool
-> Tensor device 'Int64 '[batchSize, seqLen]
-> IO (Tensor device dtype '[batchSize, seqLen, numEmbeds])
transformerLM TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
model Bool
False Tensor device 'Int64 '[batchSize, seqLen]
input
forwardStoch :: TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> Tensor device 'Int64 '[batchSize, seqLen]
-> IO (Tensor device dtype '[batchSize, seqLen, numEmbeds])
forwardStoch TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
model Tensor device 'Int64 '[batchSize, seqLen]
input = forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
(paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
(seqLen :: Nat) (batchSize :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
(All KnownNat '[paddingIdx, embedDim, seqLen, batchSize],
(paddingIdx + 1) <= numEmbeds, 1 <= seqLen,
HFoldrM
IO
(FoldLayers batchSize seqLen dtype device)
(Tensor device dtype '[batchSize, seqLen, embedDim])
(HReplicateR
numAttnLayers
(TransformerLayer
embedDim embedDim embedDim numHeads ffnDim dtype device))
(Tensor device dtype '[batchSize, seqLen, embedDim]),
BasicArithmeticDTypeIsValid device dtype,
ComparisonDTypeIsValid device dtype,
ComparisonDTypeIsValid device 'Int64, KnownDType dtype,
KnownDevice device) =>
TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> Bool
-> Tensor device 'Int64 '[batchSize, seqLen]
-> IO (Tensor device dtype '[batchSize, seqLen, numEmbeds])
transformerLM TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
model Bool
True Tensor device 'Int64 '[batchSize, seqLen]
input
sinusoidal ::
forall numEmbeds embedDim device.
( All KnownNat '[numEmbeds, embedDim],
1 <= numEmbeds,
1 <= Div embedDim 2,
(Div embedDim 2 * 2) ~ embedDim,
StandardFloatingPointDTypeValidation device 'D.Float,
BasicArithmeticDTypeIsValid device 'D.Float,
KnownDevice device
) =>
Tensor device 'D.Float '[numEmbeds, embedDim]
sinusoidal :: forall (numEmbeds :: Nat) (embedDim :: Nat)
(device :: (DeviceType, Nat)).
(All KnownNat '[numEmbeds, embedDim], 1 <= numEmbeds,
1 <= Div embedDim 2, (Div embedDim 2 * 2) ~ embedDim,
StandardFloatingPointDTypeValidation device 'Float,
BasicArithmeticDTypeIsValid device 'Float, KnownDevice device) =>
Tensor device 'Float '[numEmbeds, embedDim]
sinusoidal =
let positions :: Tensor device 'Float '[numEmbeds, 1]
positions =
forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ Unsqueeze shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
unsqueeze @1
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (steps :: Nat) (device :: (DeviceType, Nat)) start end.
(Scalar start, Scalar end, KnownNat steps,
TensorOptions '[steps] 'Float device) =>
start -> end -> Tensor device 'Float '[steps]
linspace @numEmbeds (Int
0 :: Int)
forall a b. (a -> b) -> a -> b
$ forall (n :: Nat). KnownNat n => Int
natValI @(numEmbeds - 1)
scalingFactors :: Tensor device 'Float '[Div embedDim 2]
scalingFactors =
forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
mulScalar (- forall a. Floating a => a -> a
log (Double
10000 :: Double) forall a. Fractional a => a -> a -> a
/ (forall a. Num a => Integer -> a
fromInteger forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @(Div embedDim 2)))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (steps :: Nat) (device :: (DeviceType, Nat)) start end.
(Scalar start, Scalar end, KnownNat steps,
TensorOptions '[steps] 'Float device) =>
start -> end -> Tensor device 'Float '[steps]
linspace @(Div embedDim 2) (Int
0 :: Int)
forall a b. (a -> b) -> a -> b
$ forall (n :: Nat). KnownNat n => Int
natValI @((Div embedDim 2) - 1)
radians :: Tensor device 'Float '[numEmbeds, Div embedDim 2]
radians = forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
(device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
shape'' ~ Broadcast shape shape',
BasicArithmeticDTypeIsValid device dtype,
BasicArithmeticDTypeIsValid device dtype',
BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
mul Tensor device 'Float '[numEmbeds, 1]
positions Tensor device 'Float '[Div embedDim 2]
scalingFactors
weights :: Tensor device 'Float '[numEmbeds, Div embedDim 2, 2]
weights = forall {k} (dim :: Nat) (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)) (tensors :: [k]).
(KnownNat dim, '(shape, dtype, device) ~ Stack dim tensors,
Castable (HList tensors) [ATenTensor]) =>
HList tensors -> Tensor device dtype shape
stack @2 (forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)) t.
(StandardFloatingPointDTypeValidation device dtype,
IsUnnamed t device dtype shape) =>
t -> t
sin Tensor device 'Float '[numEmbeds, Div embedDim 2]
radians forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
:. forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)) t.
(StandardFloatingPointDTypeValidation device dtype,
IsUnnamed t device dtype shape) =>
t -> t
cos Tensor device 'Float '[numEmbeds, Div embedDim 2]
radians forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
:. forall k. HList '[]
HNil)
in forall (shape' :: [Nat]) (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(KnownShape shape', Numel shape ~ Numel shape') =>
Tensor device dtype shape -> Tensor device dtype shape'
reshape Tensor device 'Float '[numEmbeds, Div embedDim 2, 2]
weights
instance
( paddingIdx <= numEmbeds,
1 <= numEmbeds - paddingIdx,
1 <= Div embedDim 2,
(((numEmbeds - paddingIdx) - 1) + (1 + paddingIdx)) ~ numEmbeds,
(Div embedDim 2 * 2) ~ embedDim,
All KnownNat '[ffnDim, paddingIdx, numEmbeds, embedDim],
HReplicate numAttnLayers (TransformerLayerSpec embedDim embedDim embedDim numHeads ffnDim dtype device),
A.Randomizable
(HList (HReplicateR numAttnLayers (TransformerLayerSpec embedDim embedDim embedDim numHeads ffnDim dtype device)))
(HList (HReplicateR numAttnLayers (TransformerLayer embedDim embedDim embedDim numHeads ffnDim dtype device))),
KnownDType dtype,
RandDTypeIsValid device dtype,
StandardFloatingPointDTypeValidation device 'D.Float,
BasicArithmeticDTypeIsValid device 'D.Float,
KnownDevice device
) =>
A.Randomizable
(TransformerLMSpec numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device)
(TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device)
where
sample :: TransformerLMSpec
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> IO
(TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device)
sample TransformerLMSpec {DropoutSpec
TransformerLayerSpec
embedDim embedDim embedDim numHeads ffnDim dtype device
lmLayerSpec :: TransformerLayerSpec
embedDim embedDim embedDim numHeads ffnDim dtype device
lmDropoutSpec :: DropoutSpec
lmLayerSpec :: forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
(paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLMSpec
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> TransformerLayerSpec
embedDim embedDim embedDim numHeads ffnDim dtype device
lmDropoutSpec :: forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
(paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLMSpec
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> DropoutSpec
..} =
forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
(paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
Embedding
('Just paddingIdx) numEmbeds embedDim 'Learned dtype device
-> Embedding 'Nothing 2048 embedDim 'Constant dtype device
-> Dropout
-> HList
(HReplicateR
numAttnLayers
(TransformerLayer
embedDim embedDim embedDim numHeads ffnDim dtype device))
-> Linear embedDim numEmbeds dtype device
-> TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
TransformerLM
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> forall spec f. Randomizable spec f => spec -> IO f
A.sample (forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
(embedSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
EmbeddingSpec paddingIdx numEmbeds embedSize 'Learned dtype device
LearnedEmbeddingWithRandomInitSpec @('Just paddingIdx))
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> forall spec f. Randomizable spec f => spec -> IO f
A.sample (forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
(embedSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Tensor device dtype '[numEmbeds, embedSize]
-> EmbeddingSpec
paddingIdx numEmbeds embedSize 'Constant dtype device
ConstEmbeddingSpec @'Nothing (forall (dtype' :: DType) (dtype :: DType)
(device :: (DeviceType, Nat)) (shape :: [Nat]) t t'.
(KnownDType dtype', IsUnnamed t device dtype shape, Unnamed t',
t' ~ ReplaceDType'' t dtype') =>
t -> t'
Torch.Typed.Tensor.toDType forall (numEmbeds :: Nat) (embedDim :: Nat)
(device :: (DeviceType, Nat)).
(All KnownNat '[numEmbeds, embedDim], 1 <= numEmbeds,
1 <= Div embedDim 2, (Div embedDim 2 * 2) ~ embedDim,
StandardFloatingPointDTypeValidation device 'Float,
BasicArithmeticDTypeIsValid device 'Float, KnownDevice device) =>
Tensor device 'Float '[numEmbeds, embedDim]
sinusoidal))
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> forall spec f. Randomizable spec f => spec -> IO f
A.sample DropoutSpec
lmDropoutSpec
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> forall spec f. Randomizable spec f => spec -> IO f
A.sample (forall (n :: Nat) e. HReplicate n e => e -> HList (HReplicateR n e)
hreplicate @numAttnLayers TransformerLayerSpec
embedDim embedDim embedDim numHeads ffnDim dtype device
lmLayerSpec)
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> forall spec f. Randomizable spec f => spec -> IO f
A.sample forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
LinearSpec inputFeatures outputFeatures dtype device
LinearSpec