{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE NoStarIsType #-}
{-# OPTIONS_GHC -fconstraint-solver-iterations=0 #-}

module Torch.Typed.NN.Transformer where

import Control.Monad
import Data.Proxy
import GHC.Generics
import GHC.TypeLits
import System.IO.Unsafe (unsafePerformIO)
import qualified Torch.DType as D
import qualified Torch.Device as D
import Torch.HList
import Torch.NN (HasForward (..))
import qualified Torch.NN as A
import Torch.Typed.Auxiliary
import Torch.Typed.Factories
import Torch.Typed.Functional hiding (linear, log)
import Torch.Typed.NN.Dropout
import Torch.Typed.NN.Linear
import Torch.Typed.NN.Normalization
import Torch.Typed.NN.Sparse
import Torch.Typed.Parameter
import Torch.Typed.Tensor
import Prelude hiding (cos, exp, sin)

residual :: (Tensor device dtype shape -> m (Tensor device dtype' shape'))
-> (Tensor
      device
      (DTypePromotionImpl dtype dtype' (CmpDType dtype dtype'))
      (CheckBroadcast
         shape
         shape'
         (ComputeBroadcast
            (ReverseImpl shape '[]) (ReverseImpl shape' '[])))
    -> m b)
-> Tensor device dtype shape
-> m b
residual Tensor device dtype shape -> m (Tensor device dtype' shape')
f Tensor
  device
  (DTypePromotionImpl dtype dtype' (CmpDType dtype dtype'))
  (CheckBroadcast
     shape
     shape'
     (ComputeBroadcast
        (ReverseImpl shape '[]) (ReverseImpl shape' '[])))
-> m b
g Tensor device dtype shape
x = Tensor device dtype shape -> m (Tensor device dtype' shape')
f Tensor device dtype shape
x forall (m :: Type -> Type) a b. Monad m => m a -> (a -> m b) -> m b
>>= (\Tensor device dtype' shape'
x' -> Tensor
  device
  (DTypePromotionImpl dtype dtype' (CmpDType dtype dtype'))
  (CheckBroadcast
     shape
     shape'
     (ComputeBroadcast
        (ReverseImpl shape '[]) (ReverseImpl shape' '[])))
-> m b
g (Tensor device dtype shape
x forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
       (device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
 shape'' ~ Broadcast shape shape',
 BasicArithmeticDTypeIsValid device dtype,
 BasicArithmeticDTypeIsValid device dtype',
 BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
`add` Tensor device dtype' shape'
x'))

--------------------------------------------------------------------------------
-- Relation-Aware Multi-Headed Attention Layer
--------------------------------------------------------------------------------

data
  MultiheadAttentionSpec
    (embedDim :: Nat)
    (kEmbedDim :: Nat)
    (vEmbedDim :: Nat)
    (numHeads :: Nat)
    (dtype :: D.DType)
    (device :: (D.DeviceType, Nat))
  where
  MultiheadAttentionSpec ::
    -- | spec for dropout
    DropoutSpec ->
    MultiheadAttentionSpec embedDim kEmbedDim vEmbedDim numHeads dtype device
  deriving (Int
-> MultiheadAttentionSpec
     embedDim kEmbedDim vEmbedDim numHeads dtype device
-> ShowS
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Int
-> MultiheadAttentionSpec
     embedDim kEmbedDim vEmbedDim numHeads dtype device
-> ShowS
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
[MultiheadAttentionSpec
   embedDim kEmbedDim vEmbedDim numHeads dtype device]
-> ShowS
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
MultiheadAttentionSpec
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [MultiheadAttentionSpec
   embedDim kEmbedDim vEmbedDim numHeads dtype device]
-> ShowS
$cshowList :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
[MultiheadAttentionSpec
   embedDim kEmbedDim vEmbedDim numHeads dtype device]
-> ShowS
show :: MultiheadAttentionSpec
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> String
$cshow :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
MultiheadAttentionSpec
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> String
showsPrec :: Int
-> MultiheadAttentionSpec
     embedDim kEmbedDim vEmbedDim numHeads dtype device
-> ShowS
$cshowsPrec :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Int
-> MultiheadAttentionSpec
     embedDim kEmbedDim vEmbedDim numHeads dtype device
-> ShowS
Show, MultiheadAttentionSpec
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> MultiheadAttentionSpec
     embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Bool
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
MultiheadAttentionSpec
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> MultiheadAttentionSpec
     embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: MultiheadAttentionSpec
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> MultiheadAttentionSpec
     embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Bool
$c/= :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
MultiheadAttentionSpec
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> MultiheadAttentionSpec
     embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Bool
== :: MultiheadAttentionSpec
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> MultiheadAttentionSpec
     embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Bool
$c== :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
MultiheadAttentionSpec
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> MultiheadAttentionSpec
     embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Bool
Eq)

data
  MultiheadAttention
    (embedDim :: Nat)
    (kEmbedDim :: Nat)
    (vEmbedDim :: Nat)
    (numHeads :: Nat)
    (dtype :: D.DType)
    (device :: (D.DeviceType, Nat))
  where
  MultiheadAttention ::
    { -- | in-projection for query
      forall (embedDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat).
MultiheadAttention
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Linear embedDim embedDim dtype device
mhaQInProj :: Linear embedDim embedDim dtype device,
      -- | in-projection for key
      forall (embedDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat).
MultiheadAttention
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Linear kEmbedDim embedDim dtype device
mhaKInProj :: Linear kEmbedDim embedDim dtype device,
      -- | in-projection for value
      forall (embedDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat).
MultiheadAttention
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Linear vEmbedDim embedDim dtype device
mhaVInProj :: Linear vEmbedDim embedDim dtype device,
      -- | out-projection
      forall (embedDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat).
MultiheadAttention
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Linear embedDim embedDim dtype device
mhaOutProj :: Linear embedDim embedDim dtype device,
      -- | dropout
      forall (embedDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat).
MultiheadAttention
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Dropout
mhaDropout :: Dropout
    } ->
    MultiheadAttention embedDim kEmbedDim vEmbedDim numHeads dtype device
  deriving (Int
-> MultiheadAttention
     embedDim kEmbedDim vEmbedDim numHeads dtype device
-> ShowS
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Int
-> MultiheadAttention
     embedDim kEmbedDim vEmbedDim numHeads dtype device
-> ShowS
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
[MultiheadAttention
   embedDim kEmbedDim vEmbedDim numHeads dtype device]
-> ShowS
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
MultiheadAttention
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [MultiheadAttention
   embedDim kEmbedDim vEmbedDim numHeads dtype device]
-> ShowS
$cshowList :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
[MultiheadAttention
   embedDim kEmbedDim vEmbedDim numHeads dtype device]
-> ShowS
show :: MultiheadAttention
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> String
$cshow :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
MultiheadAttention
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> String
showsPrec :: Int
-> MultiheadAttention
     embedDim kEmbedDim vEmbedDim numHeads dtype device
-> ShowS
$cshowsPrec :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Int
-> MultiheadAttention
     embedDim kEmbedDim vEmbedDim numHeads dtype device
-> ShowS
Show, forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)) x.
Rep
  (MultiheadAttention
     embedDim kEmbedDim vEmbedDim numHeads dtype device)
  x
-> MultiheadAttention
     embedDim kEmbedDim vEmbedDim numHeads dtype device
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)) x.
MultiheadAttention
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Rep
     (MultiheadAttention
        embedDim kEmbedDim vEmbedDim numHeads dtype device)
     x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)) x.
Rep
  (MultiheadAttention
     embedDim kEmbedDim vEmbedDim numHeads dtype device)
  x
-> MultiheadAttention
     embedDim kEmbedDim vEmbedDim numHeads dtype device
$cfrom :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)) x.
MultiheadAttention
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Rep
     (MultiheadAttention
        embedDim kEmbedDim vEmbedDim numHeads dtype device)
     x
Generic, forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
MultiheadAttention
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> HList
     (Parameters
        (MultiheadAttention
           embedDim kEmbedDim vEmbedDim numHeads dtype device))
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
MultiheadAttention
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> HList
     (Parameters
        (MultiheadAttention
           embedDim kEmbedDim vEmbedDim numHeads dtype device))
-> MultiheadAttention
     embedDim kEmbedDim vEmbedDim numHeads dtype device
forall f.
(f -> HList (Parameters f))
-> (f -> HList (Parameters f) -> f) -> Parameterized f
replaceParameters :: MultiheadAttention
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> HList
     (Parameters
        (MultiheadAttention
           embedDim kEmbedDim vEmbedDim numHeads dtype device))
-> MultiheadAttention
     embedDim kEmbedDim vEmbedDim numHeads dtype device
$creplaceParameters :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
MultiheadAttention
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> HList
     (Parameters
        (MultiheadAttention
           embedDim kEmbedDim vEmbedDim numHeads dtype device))
-> MultiheadAttention
     embedDim kEmbedDim vEmbedDim numHeads dtype device
flattenParameters :: MultiheadAttention
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> HList
     (Parameters
        (MultiheadAttention
           embedDim kEmbedDim vEmbedDim numHeads dtype device))
$cflattenParameters :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
MultiheadAttention
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> HList
     (Parameters
        (MultiheadAttention
           embedDim kEmbedDim vEmbedDim numHeads dtype device))
Parameterized)

multiheadAttention ::
  forall embedDim kEmbedDim vEmbedDim numHeads seqLen seqLen' batchSize headDim dtype device.
  ( 1 <= numHeads,
    embedDim ~ (headDim * numHeads),
    All KnownNat '[embedDim, kEmbedDim, vEmbedDim, numHeads, seqLen, seqLen', batchSize, headDim],
    KnownDType dtype,
    StandardFloatingPointDTypeValidation device dtype,
    MatMulDTypeIsValid device dtype,
    BasicArithmeticDTypeIsValid device dtype,
    dtype ~ SumDType dtype,
    SumDTypeIsValid device dtype,
    KnownDevice device
  ) =>
  -- | multi-head attention model ADT
  MultiheadAttention embedDim kEmbedDim vEmbedDim numHeads dtype device ->
  -- | switch between training mode and evaluation mode (turns random dropout on and off)
  Bool ->
  -- | optional attention mask
  Maybe (Tensor device dtype '[batchSize, seqLen', seqLen]) ->
  -- | optional key padding mask
  Maybe (Tensor device 'D.Bool '[batchSize, seqLen]) ->
  -- | optional key relations
  Maybe (Tensor device dtype '[batchSize, seqLen', seqLen, headDim]) ->
  -- | optional value relations
  Maybe (Tensor device dtype '[batchSize, seqLen', seqLen, headDim]) ->
  -- | query representation
  Tensor device dtype '[batchSize, seqLen', embedDim] ->
  -- | key representation
  Tensor device dtype '[batchSize, seqLen, kEmbedDim] ->
  -- | value representation
  Tensor device dtype '[batchSize, seqLen, vEmbedDim] ->
  -- | attention and attention averaged over heads
  IO
    ( Tensor device dtype '[batchSize, seqLen', embedDim],
      Tensor device dtype '[batchSize, seqLen', seqLen]
    )
multiheadAttention :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (seqLen :: Nat) (seqLen' :: Nat)
       (batchSize :: Nat) (headDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(1 <= numHeads, embedDim ~ (headDim * numHeads),
 All
   KnownNat
   '[embedDim, kEmbedDim, vEmbedDim, numHeads, seqLen, seqLen',
     batchSize, headDim],
 KnownDType dtype,
 StandardFloatingPointDTypeValidation device dtype,
 MatMulDTypeIsValid device dtype,
 BasicArithmeticDTypeIsValid device dtype, dtype ~ SumDType dtype,
 SumDTypeIsValid device dtype, KnownDevice device) =>
MultiheadAttention
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Bool
-> Maybe (Tensor device dtype '[batchSize, seqLen', seqLen])
-> Maybe (Tensor device 'Bool '[batchSize, seqLen])
-> Maybe
     (Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
-> Maybe
     (Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
-> Tensor device dtype '[batchSize, seqLen', embedDim]
-> Tensor device dtype '[batchSize, seqLen, kEmbedDim]
-> Tensor device dtype '[batchSize, seqLen, vEmbedDim]
-> IO
     (Tensor device dtype '[batchSize, seqLen', embedDim],
      Tensor device dtype '[batchSize, seqLen', seqLen])
multiheadAttention MultiheadAttention {Linear embedDim embedDim dtype device
Linear kEmbedDim embedDim dtype device
Linear vEmbedDim embedDim dtype device
Dropout
mhaDropout :: Dropout
mhaOutProj :: Linear embedDim embedDim dtype device
mhaVInProj :: Linear vEmbedDim embedDim dtype device
mhaKInProj :: Linear kEmbedDim embedDim dtype device
mhaQInProj :: Linear embedDim embedDim dtype device
mhaDropout :: forall (embedDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat).
MultiheadAttention
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Dropout
mhaOutProj :: forall (embedDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat).
MultiheadAttention
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Linear embedDim embedDim dtype device
mhaVInProj :: forall (embedDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat).
MultiheadAttention
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Linear vEmbedDim embedDim dtype device
mhaKInProj :: forall (embedDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat).
MultiheadAttention
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Linear kEmbedDim embedDim dtype device
mhaQInProj :: forall (embedDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat).
MultiheadAttention
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Linear embedDim embedDim dtype device
..} Bool
train Maybe (Tensor device dtype '[batchSize, seqLen', seqLen])
attentionMask Maybe (Tensor device 'Bool '[batchSize, seqLen])
keyPaddingMask Maybe (Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
keyRelations Maybe (Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
valueRelations Tensor device dtype '[batchSize, seqLen', embedDim]
query Tensor device dtype '[batchSize, seqLen, kEmbedDim]
key Tensor device dtype '[batchSize, seqLen, vEmbedDim]
value = do
  Tensor
  device
  dtype
  (CheckMatMul
     '[batchSize, numHeads, seqLen', headDim]
     '[batchSize, numHeads, headDim, seqLen]
     (ComputeMatMul
        '[headDim, seqLen', numHeads, batchSize]
        (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
weights <-
    forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Dropout
-> Bool
-> Tensor device dtype shape
-> IO (Tensor device dtype shape)
dropoutForward Dropout
mhaDropout Bool
train
      forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (dim :: Nat) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(KnownNat dim, DimOutOfBoundCheck shape dim, KnownDType dtype,
 StandardFloatingPointDTypeValidation device dtype) =>
Tensor device dtype shape -> Tensor device dtype shape
softmax @3
      forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor
  device
  dtype
  (CheckMatMul
     '[batchSize, numHeads, seqLen', headDim]
     '[batchSize, numHeads, headDim, seqLen]
     (ComputeMatMul
        '[headDim, seqLen', numHeads, batchSize]
        (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
-> Tensor
     device
     dtype
     (CheckMatMul
        '[batchSize, numHeads, seqLen', headDim]
        '[batchSize, numHeads, headDim, seqLen]
        (ComputeMatMul
           '[headDim, seqLen', numHeads, batchSize]
           (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
_maskKeyPaddings
      forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor
  device
  dtype
  (CheckMatMul
     '[batchSize, numHeads, seqLen', headDim]
     '[batchSize, numHeads, headDim, seqLen]
     (ComputeMatMul
        '[headDim, seqLen', numHeads, batchSize]
        (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
-> Tensor
     device
     dtype
     (CheckMatMul
        '[batchSize, numHeads, seqLen', headDim]
        '[batchSize, numHeads, headDim, seqLen]
        (ComputeMatMul
           '[headDim, seqLen', numHeads, batchSize]
           (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
_maskAttention
      forall a b. (a -> b) -> a -> b
$ Tensor
  device
  dtype
  (CheckMatMul
     '[batchSize, numHeads, seqLen', headDim]
     '[batchSize, numHeads, headDim, seqLen]
     (ComputeMatMul
        '[headDim, seqLen', numHeads, batchSize]
        (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
_attentionWeights
  forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (Tensor
  device
  dtype
  (CheckMatMul
     '[batchSize, numHeads, seqLen', headDim]
     '[batchSize, numHeads, headDim, seqLen]
     (ComputeMatMul
        '[headDim, seqLen', numHeads, batchSize]
        (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
-> Tensor device dtype '[batchSize, seqLen', embedDim]
_attention Tensor
  device
  dtype
  (CheckMatMul
     '[batchSize, numHeads, seqLen', headDim]
     '[batchSize, numHeads, headDim, seqLen]
     (ComputeMatMul
        '[headDim, seqLen', numHeads, batchSize]
        (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
weights, Tensor
  device
  dtype
  (CheckMatMul
     '[batchSize, numHeads, seqLen', headDim]
     '[batchSize, numHeads, headDim, seqLen]
     (ComputeMatMul
        '[headDim, seqLen', numHeads, batchSize]
        (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
-> Tensor device dtype '[batchSize, seqLen', seqLen]
averageOverHeads Tensor
  device
  dtype
  (CheckMatMul
     '[batchSize, numHeads, seqLen', headDim]
     '[batchSize, numHeads, headDim, seqLen]
     (ComputeMatMul
        '[headDim, seqLen', numHeads, batchSize]
        (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
weights)
  where
    _attentionWeights :: Tensor
  device
  dtype
  (CheckMatMul
     '[batchSize, numHeads, seqLen', headDim]
     '[batchSize, numHeads, headDim, seqLen]
     (ComputeMatMul
        '[headDim, seqLen', numHeads, batchSize]
        (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
_attentionWeights =
      let scaling :: Double
scaling = forall a. Floating a => a -> a
Prelude.sqrt forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ forall (n :: Nat). KnownNat n => Int
natValI @headDim :: Double
          q :: Tensor device dtype '[batchSize, numHeads, seqLen', headDim]
q = forall (seqLen'' :: Nat).
KnownNat seqLen'' =>
Tensor device dtype '[batchSize, seqLen'', embedDim]
-> Tensor device dtype '[batchSize, numHeads, seqLen'', headDim]
reshape' forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
divScalar Double
scaling forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall f a b. HasForward f a b => f -> a -> b
forward Linear embedDim embedDim dtype device
mhaQInProj forall a b. (a -> b) -> a -> b
$ Tensor device dtype '[batchSize, seqLen', embedDim]
query
          k :: Tensor device dtype '[batchSize, numHeads, seqLen, headDim]
k = forall (seqLen'' :: Nat).
KnownNat seqLen'' =>
Tensor device dtype '[batchSize, seqLen'', embedDim]
-> Tensor device dtype '[batchSize, numHeads, seqLen'', headDim]
reshape' forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall f a b. HasForward f a b => f -> a -> b
forward Linear kEmbedDim embedDim dtype device
mhaKInProj forall a b. (a -> b) -> a -> b
$ Tensor device dtype '[batchSize, seqLen, kEmbedDim]
key
          weights :: Tensor
  device
  dtype
  (CheckMatMul
     '[batchSize, numHeads, seqLen', headDim]
     '[batchSize, numHeads, headDim, seqLen]
     (ComputeMatMul
        '[headDim, seqLen', numHeads, batchSize]
        (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
weights = forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ MatMul shape shape', MatMulDTypeIsValid device dtype) =>
Tensor device dtype shape
-> Tensor device dtype shape' -> Tensor device dtype shape''
matmul Tensor device dtype '[batchSize, numHeads, seqLen', headDim]
q (forall (n :: Nat) (m :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat n, KnownNat m, shape' ~ Transpose shape n m) =>
Tensor device dtype shape -> Tensor device dtype shape'
transpose @2 @3 Tensor device dtype '[batchSize, numHeads, seqLen, headDim]
k)
          weights' :: Tensor
  device
  dtype
  (CheckMatMul
     '[batchSize, numHeads, seqLen', headDim]
     '[batchSize, numHeads, headDim, seqLen]
     (ComputeMatMul
        '[headDim, seqLen', numHeads, batchSize]
        (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
weights' = case Maybe (Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
keyRelations of
            Maybe (Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
Nothing -> Tensor
  device
  dtype
  (CheckMatMul
     '[batchSize, numHeads, seqLen', headDim]
     '[batchSize, numHeads, headDim, seqLen]
     (ComputeMatMul
        '[headDim, seqLen', numHeads, batchSize]
        (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
weights
            Just Tensor device dtype '[batchSize, seqLen', seqLen, headDim]
kr -> Tensor
  device
  dtype
  (CheckMatMul
     '[batchSize, numHeads, seqLen', headDim]
     '[batchSize, numHeads, headDim, seqLen]
     (ComputeMatMul
        '[headDim, seqLen', numHeads, batchSize]
        (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
weights forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
       (device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
 shape'' ~ Broadcast shape shape',
 BasicArithmeticDTypeIsValid device dtype,
 BasicArithmeticDTypeIsValid device dtype',
 BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
`add` forall (n :: Nat) (m :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat n, KnownNat m, shape' ~ Transpose shape n m) =>
Tensor device dtype shape -> Tensor device dtype shape'
transpose @1 @2 ((forall (n :: Nat) (m :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat n, KnownNat m, shape' ~ Transpose shape n m) =>
Tensor device dtype shape -> Tensor device dtype shape'
transpose @1 @2 Tensor device dtype '[batchSize, numHeads, seqLen', headDim]
q) forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ MatMul shape shape', MatMulDTypeIsValid device dtype) =>
Tensor device dtype shape
-> Tensor device dtype shape' -> Tensor device dtype shape''
`matmul` (forall (n :: Nat) (m :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat n, KnownNat m, shape' ~ Transpose shape n m) =>
Tensor device dtype shape -> Tensor device dtype shape'
transpose @2 @3 Tensor device dtype '[batchSize, seqLen', seqLen, headDim]
kr))
       in Tensor
  device
  dtype
  (CheckMatMul
     '[batchSize, numHeads, seqLen', headDim]
     '[batchSize, numHeads, headDim, seqLen]
     (ComputeMatMul
        '[headDim, seqLen', numHeads, batchSize]
        (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
weights'
    _maskAttention :: Tensor
  device
  dtype
  (CheckMatMul
     '[batchSize, numHeads, seqLen', headDim]
     '[batchSize, numHeads, headDim, seqLen]
     (ComputeMatMul
        '[headDim, seqLen', numHeads, batchSize]
        (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
-> Tensor
     device
     dtype
     (CheckMatMul
        '[batchSize, numHeads, seqLen', headDim]
        '[batchSize, numHeads, headDim, seqLen]
        (ComputeMatMul
           '[headDim, seqLen', numHeads, batchSize]
           (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
_maskAttention Tensor
  device
  dtype
  (CheckMatMul
     '[batchSize, numHeads, seqLen', headDim]
     '[batchSize, numHeads, headDim, seqLen]
     (ComputeMatMul
        '[headDim, seqLen', numHeads, batchSize]
        (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
attentionWeights =
      case Maybe (Tensor device dtype '[batchSize, seqLen', seqLen])
attentionMask of
        Maybe (Tensor device dtype '[batchSize, seqLen', seqLen])
Nothing -> Tensor
  device
  dtype
  (CheckMatMul
     '[batchSize, numHeads, seqLen', headDim]
     '[batchSize, numHeads, headDim, seqLen]
     (ComputeMatMul
        '[headDim, seqLen', numHeads, batchSize]
        (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
attentionWeights
        Just Tensor device dtype '[batchSize, seqLen', seqLen]
am -> Tensor
  device
  dtype
  (CheckMatMul
     '[batchSize, numHeads, seqLen', headDim]
     '[batchSize, numHeads, headDim, seqLen]
     (ComputeMatMul
        '[headDim, seqLen', numHeads, batchSize]
        (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
attentionWeights forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
       (device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
 shape'' ~ Broadcast shape shape',
 BasicArithmeticDTypeIsValid device dtype,
 BasicArithmeticDTypeIsValid device dtype',
 BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
`add` forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ Unsqueeze shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
unsqueeze @1 Tensor device dtype '[batchSize, seqLen', seqLen]
am
    _maskKeyPaddings :: Tensor
  device
  dtype
  (CheckMatMul
     '[batchSize, numHeads, seqLen', headDim]
     '[batchSize, numHeads, headDim, seqLen]
     (ComputeMatMul
        '[headDim, seqLen', numHeads, batchSize]
        (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
-> Tensor
     device
     dtype
     (CheckMatMul
        '[batchSize, numHeads, seqLen', headDim]
        '[batchSize, numHeads, headDim, seqLen]
        (ComputeMatMul
           '[headDim, seqLen', numHeads, batchSize]
           (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
_maskKeyPaddings Tensor
  device
  dtype
  (CheckMatMul
     '[batchSize, numHeads, seqLen', headDim]
     '[batchSize, numHeads, headDim, seqLen]
     (ComputeMatMul
        '[headDim, seqLen', numHeads, batchSize]
        (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
attentionWeights =
      case Maybe (Tensor device 'Bool '[batchSize, seqLen])
keyPaddingMask of
        Maybe (Tensor device 'Bool '[batchSize, seqLen])
Nothing -> Tensor
  device
  dtype
  (CheckMatMul
     '[batchSize, numHeads, seqLen', headDim]
     '[batchSize, numHeads, headDim, seqLen]
     (ComputeMatMul
        '[headDim, seqLen', numHeads, batchSize]
        (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
attentionWeights
        Just Tensor device 'Bool '[batchSize, seqLen]
kpm ->
          let keyPaddingMask' :: Tensor
  device
  'Bool
  (UnsqueezeCheck
     '[batchSize, 1, seqLen]
     2
     (UnsqueezeImpl '[batchSize, 1, seqLen] 2))
keyPaddingMask' = forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ Unsqueeze shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
unsqueeze @2 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ Unsqueeze shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
unsqueeze @1 forall a b. (a -> b) -> a -> b
$ Tensor device 'Bool '[batchSize, seqLen]
kpm
           in forall a (shape :: [Nat]) (shape' :: [Nat]) (shape'' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(Scalar a, shape'' ~ Broadcast shape shape') =>
Tensor device 'Bool shape'
-> a -> Tensor device dtype shape -> Tensor device dtype shape''
maskedFill Tensor
  device
  'Bool
  (UnsqueezeCheck
     '[batchSize, 1, seqLen]
     2
     (UnsqueezeImpl '[batchSize, 1, seqLen] 2))
keyPaddingMask' (-Double
1 forall a. Fractional a => a -> a -> a
/ Double
0 :: Double) Tensor
  device
  dtype
  (CheckMatMul
     '[batchSize, numHeads, seqLen', headDim]
     '[batchSize, numHeads, headDim, seqLen]
     (ComputeMatMul
        '[headDim, seqLen', numHeads, batchSize]
        (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
attentionWeights
    _attention :: Tensor
  device
  dtype
  (CheckMatMul
     '[batchSize, numHeads, seqLen', headDim]
     '[batchSize, numHeads, headDim, seqLen]
     (ComputeMatMul
        '[headDim, seqLen', numHeads, batchSize]
        (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
-> Tensor device dtype '[batchSize, seqLen', embedDim]
_attention Tensor
  device
  dtype
  (CheckMatMul
     '[batchSize, numHeads, seqLen', headDim]
     '[batchSize, numHeads, headDim, seqLen]
     (ComputeMatMul
        '[headDim, seqLen', numHeads, batchSize]
        (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
attentionWeights =
      let v :: Tensor device dtype '[batchSize, numHeads, seqLen, headDim]
v = forall (seqLen'' :: Nat).
KnownNat seqLen'' =>
Tensor device dtype '[batchSize, seqLen'', embedDim]
-> Tensor device dtype '[batchSize, numHeads, seqLen'', headDim]
reshape' forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall f a b. HasForward f a b => f -> a -> b
forward Linear vEmbedDim embedDim dtype device
mhaVInProj forall a b. (a -> b) -> a -> b
$ Tensor device dtype '[batchSize, seqLen, vEmbedDim]
value
          attention :: Tensor
  device
  dtype
  (SetValue
     (SetValue
        (CheckMatMul
           (CheckMatMul
              '[batchSize, numHeads, seqLen', headDim]
              '[batchSize, numHeads, headDim, seqLen]
              (ComputeMatMul
                 '[headDim, seqLen', numHeads, batchSize]
                 (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
           '[batchSize, numHeads, seqLen, headDim]
           (ComputeMatMul
              (ReverseImpl
                 (CheckMatMul
                    '[batchSize, numHeads, seqLen', headDim]
                    '[batchSize, numHeads, headDim, seqLen]
                    (ComputeMatMul
                       '[headDim, seqLen', numHeads, batchSize]
                       (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                 '[])
              '[headDim, seqLen, numHeads, batchSize]))
        1
        (GetValue
           (CheckMatMul
              (CheckMatMul
                 '[batchSize, numHeads, seqLen', headDim]
                 '[batchSize, numHeads, headDim, seqLen]
                 (ComputeMatMul
                    '[headDim, seqLen', numHeads, batchSize]
                    (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
              '[batchSize, numHeads, seqLen, headDim]
              (ComputeMatMul
                 (ReverseImpl
                    (CheckMatMul
                       '[batchSize, numHeads, seqLen', headDim]
                       '[batchSize, numHeads, headDim, seqLen]
                       (ComputeMatMul
                          '[headDim, seqLen', numHeads, batchSize]
                          (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                    '[])
                 '[headDim, seqLen, numHeads, batchSize]))
           2))
     2
     (GetValue
        (CheckMatMul
           (CheckMatMul
              '[batchSize, numHeads, seqLen', headDim]
              '[batchSize, numHeads, headDim, seqLen]
              (ComputeMatMul
                 '[headDim, seqLen', numHeads, batchSize]
                 (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
           '[batchSize, numHeads, seqLen, headDim]
           (ComputeMatMul
              (ReverseImpl
                 (CheckMatMul
                    '[batchSize, numHeads, seqLen', headDim]
                    '[batchSize, numHeads, headDim, seqLen]
                    (ComputeMatMul
                       '[headDim, seqLen', numHeads, batchSize]
                       (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                 '[])
              '[headDim, seqLen, numHeads, batchSize]))
        1))
attention = forall (n :: Nat) (m :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat n, KnownNat m, shape' ~ Transpose shape n m) =>
Tensor device dtype shape -> Tensor device dtype shape'
transpose @1 @2 forall a b. (a -> b) -> a -> b
$ forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ MatMul shape shape', MatMulDTypeIsValid device dtype) =>
Tensor device dtype shape
-> Tensor device dtype shape' -> Tensor device dtype shape''
matmul Tensor
  device
  dtype
  (CheckMatMul
     '[batchSize, numHeads, seqLen', headDim]
     '[batchSize, numHeads, headDim, seqLen]
     (ComputeMatMul
        '[headDim, seqLen', numHeads, batchSize]
        (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
attentionWeights Tensor device dtype '[batchSize, numHeads, seqLen, headDim]
v
          attention' :: Tensor
  device
  dtype
  (SetValue
     (SetValue
        (CheckMatMul
           (CheckMatMul
              '[batchSize, numHeads, seqLen', headDim]
              '[batchSize, numHeads, headDim, seqLen]
              (ComputeMatMul
                 '[headDim, seqLen', numHeads, batchSize]
                 (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
           '[batchSize, numHeads, seqLen, headDim]
           (ComputeMatMul
              (ReverseImpl
                 (CheckMatMul
                    '[batchSize, numHeads, seqLen', headDim]
                    '[batchSize, numHeads, headDim, seqLen]
                    (ComputeMatMul
                       '[headDim, seqLen', numHeads, batchSize]
                       (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                 '[])
              '[headDim, seqLen, numHeads, batchSize]))
        1
        (GetValue
           (CheckMatMul
              (CheckMatMul
                 '[batchSize, numHeads, seqLen', headDim]
                 '[batchSize, numHeads, headDim, seqLen]
                 (ComputeMatMul
                    '[headDim, seqLen', numHeads, batchSize]
                    (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
              '[batchSize, numHeads, seqLen, headDim]
              (ComputeMatMul
                 (ReverseImpl
                    (CheckMatMul
                       '[batchSize, numHeads, seqLen', headDim]
                       '[batchSize, numHeads, headDim, seqLen]
                       (ComputeMatMul
                          '[headDim, seqLen', numHeads, batchSize]
                          (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                    '[])
                 '[headDim, seqLen, numHeads, batchSize]))
           2))
     2
     (GetValue
        (CheckMatMul
           (CheckMatMul
              '[batchSize, numHeads, seqLen', headDim]
              '[batchSize, numHeads, headDim, seqLen]
              (ComputeMatMul
                 '[headDim, seqLen', numHeads, batchSize]
                 (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
           '[batchSize, numHeads, seqLen, headDim]
           (ComputeMatMul
              (ReverseImpl
                 (CheckMatMul
                    '[batchSize, numHeads, seqLen', headDim]
                    '[batchSize, numHeads, headDim, seqLen]
                    (ComputeMatMul
                       '[headDim, seqLen', numHeads, batchSize]
                       (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                 '[])
              '[headDim, seqLen, numHeads, batchSize]))
        1))
attention' = case Maybe (Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
valueRelations of
            Maybe (Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
Nothing -> Tensor
  device
  dtype
  (SetValue
     (SetValue
        (CheckMatMul
           (CheckMatMul
              '[batchSize, numHeads, seqLen', headDim]
              '[batchSize, numHeads, headDim, seqLen]
              (ComputeMatMul
                 '[headDim, seqLen', numHeads, batchSize]
                 (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
           '[batchSize, numHeads, seqLen, headDim]
           (ComputeMatMul
              (ReverseImpl
                 (CheckMatMul
                    '[batchSize, numHeads, seqLen', headDim]
                    '[batchSize, numHeads, headDim, seqLen]
                    (ComputeMatMul
                       '[headDim, seqLen', numHeads, batchSize]
                       (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                 '[])
              '[headDim, seqLen, numHeads, batchSize]))
        1
        (GetValue
           (CheckMatMul
              (CheckMatMul
                 '[batchSize, numHeads, seqLen', headDim]
                 '[batchSize, numHeads, headDim, seqLen]
                 (ComputeMatMul
                    '[headDim, seqLen', numHeads, batchSize]
                    (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
              '[batchSize, numHeads, seqLen, headDim]
              (ComputeMatMul
                 (ReverseImpl
                    (CheckMatMul
                       '[batchSize, numHeads, seqLen', headDim]
                       '[batchSize, numHeads, headDim, seqLen]
                       (ComputeMatMul
                          '[headDim, seqLen', numHeads, batchSize]
                          (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                    '[])
                 '[headDim, seqLen, numHeads, batchSize]))
           2))
     2
     (GetValue
        (CheckMatMul
           (CheckMatMul
              '[batchSize, numHeads, seqLen', headDim]
              '[batchSize, numHeads, headDim, seqLen]
              (ComputeMatMul
                 '[headDim, seqLen', numHeads, batchSize]
                 (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
           '[batchSize, numHeads, seqLen, headDim]
           (ComputeMatMul
              (ReverseImpl
                 (CheckMatMul
                    '[batchSize, numHeads, seqLen', headDim]
                    '[batchSize, numHeads, headDim, seqLen]
                    (ComputeMatMul
                       '[headDim, seqLen', numHeads, batchSize]
                       (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                 '[])
              '[headDim, seqLen, numHeads, batchSize]))
        1))
attention
            Just Tensor device dtype '[batchSize, seqLen', seqLen, headDim]
vr -> Tensor
  device
  dtype
  (SetValue
     (SetValue
        (CheckMatMul
           (CheckMatMul
              '[batchSize, numHeads, seqLen', headDim]
              '[batchSize, numHeads, headDim, seqLen]
              (ComputeMatMul
                 '[headDim, seqLen', numHeads, batchSize]
                 (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
           '[batchSize, numHeads, seqLen, headDim]
           (ComputeMatMul
              (ReverseImpl
                 (CheckMatMul
                    '[batchSize, numHeads, seqLen', headDim]
                    '[batchSize, numHeads, headDim, seqLen]
                    (ComputeMatMul
                       '[headDim, seqLen', numHeads, batchSize]
                       (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                 '[])
              '[headDim, seqLen, numHeads, batchSize]))
        1
        (GetValue
           (CheckMatMul
              (CheckMatMul
                 '[batchSize, numHeads, seqLen', headDim]
                 '[batchSize, numHeads, headDim, seqLen]
                 (ComputeMatMul
                    '[headDim, seqLen', numHeads, batchSize]
                    (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
              '[batchSize, numHeads, seqLen, headDim]
              (ComputeMatMul
                 (ReverseImpl
                    (CheckMatMul
                       '[batchSize, numHeads, seqLen', headDim]
                       '[batchSize, numHeads, headDim, seqLen]
                       (ComputeMatMul
                          '[headDim, seqLen', numHeads, batchSize]
                          (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                    '[])
                 '[headDim, seqLen, numHeads, batchSize]))
           2))
     2
     (GetValue
        (CheckMatMul
           (CheckMatMul
              '[batchSize, numHeads, seqLen', headDim]
              '[batchSize, numHeads, headDim, seqLen]
              (ComputeMatMul
                 '[headDim, seqLen', numHeads, batchSize]
                 (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
           '[batchSize, numHeads, seqLen, headDim]
           (ComputeMatMul
              (ReverseImpl
                 (CheckMatMul
                    '[batchSize, numHeads, seqLen', headDim]
                    '[batchSize, numHeads, headDim, seqLen]
                    (ComputeMatMul
                       '[headDim, seqLen', numHeads, batchSize]
                       (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                 '[])
              '[headDim, seqLen, numHeads, batchSize]))
        1))
attention forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
       (device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
 shape'' ~ Broadcast shape shape',
 BasicArithmeticDTypeIsValid device dtype,
 BasicArithmeticDTypeIsValid device dtype',
 BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
`add` (forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ MatMul shape shape', MatMulDTypeIsValid device dtype) =>
Tensor device dtype shape
-> Tensor device dtype shape' -> Tensor device dtype shape''
matmul (forall (n :: Nat) (m :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat n, KnownNat m, shape' ~ Transpose shape n m) =>
Tensor device dtype shape -> Tensor device dtype shape'
transpose @1 @2 Tensor
  device
  dtype
  (CheckMatMul
     '[batchSize, numHeads, seqLen', headDim]
     '[batchSize, numHeads, headDim, seqLen]
     (ComputeMatMul
        '[headDim, seqLen', numHeads, batchSize]
        (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
attentionWeights) Tensor device dtype '[batchSize, seqLen', seqLen, headDim]
vr)
       in forall f a b. HasForward f a b => f -> a -> b
forward Linear embedDim embedDim dtype device
mhaOutProj forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (shape' :: [Nat]) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(KnownShape shape', Numel shape ~ Numel shape') =>
Tensor device dtype shape -> Tensor device dtype shape'
reshape @'[batchSize, seqLen', embedDim] forall a b. (a -> b) -> a -> b
$ Tensor
  device
  dtype
  (SetValue
     (SetValue
        (CheckMatMul
           (CheckMatMul
              '[batchSize, numHeads, seqLen', headDim]
              '[batchSize, numHeads, headDim, seqLen]
              (ComputeMatMul
                 '[headDim, seqLen', numHeads, batchSize]
                 (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
           '[batchSize, numHeads, seqLen, headDim]
           (ComputeMatMul
              (ReverseImpl
                 (CheckMatMul
                    '[batchSize, numHeads, seqLen', headDim]
                    '[batchSize, numHeads, headDim, seqLen]
                    (ComputeMatMul
                       '[headDim, seqLen', numHeads, batchSize]
                       (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                 '[])
              '[headDim, seqLen, numHeads, batchSize]))
        1
        (GetValue
           (CheckMatMul
              (CheckMatMul
                 '[batchSize, numHeads, seqLen', headDim]
                 '[batchSize, numHeads, headDim, seqLen]
                 (ComputeMatMul
                    '[headDim, seqLen', numHeads, batchSize]
                    (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
              '[batchSize, numHeads, seqLen, headDim]
              (ComputeMatMul
                 (ReverseImpl
                    (CheckMatMul
                       '[batchSize, numHeads, seqLen', headDim]
                       '[batchSize, numHeads, headDim, seqLen]
                       (ComputeMatMul
                          '[headDim, seqLen', numHeads, batchSize]
                          (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                    '[])
                 '[headDim, seqLen, numHeads, batchSize]))
           2))
     2
     (GetValue
        (CheckMatMul
           (CheckMatMul
              '[batchSize, numHeads, seqLen', headDim]
              '[batchSize, numHeads, headDim, seqLen]
              (ComputeMatMul
                 '[headDim, seqLen', numHeads, batchSize]
                 (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
           '[batchSize, numHeads, seqLen, headDim]
           (ComputeMatMul
              (ReverseImpl
                 (CheckMatMul
                    '[batchSize, numHeads, seqLen', headDim]
                    '[batchSize, numHeads, headDim, seqLen]
                    (ComputeMatMul
                       '[headDim, seqLen', numHeads, batchSize]
                       (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                 '[])
              '[headDim, seqLen, numHeads, batchSize]))
        1))
attention'
    averageOverHeads :: Tensor
  device
  dtype
  (CheckMatMul
     '[batchSize, numHeads, seqLen', headDim]
     '[batchSize, numHeads, headDim, seqLen]
     (ComputeMatMul
        '[headDim, seqLen', numHeads, batchSize]
        (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
-> Tensor device dtype '[batchSize, seqLen', seqLen]
averageOverHeads =
      let numHeads' :: Int
numHeads' = forall (n :: Nat). KnownNat n => Int
natValI @numHeads
       in forall a (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
divScalar Int
numHeads' forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (d :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(KnownNat d, shape' ~ DropValue shape d,
 SumDTypeIsValid device dtype, dtype' ~ SumDType dtype) =>
Tensor device dtype shape -> Tensor device dtype' shape'
sumDim @1
    reshape' ::
      forall seqLen''.
      KnownNat seqLen'' =>
      Tensor device dtype '[batchSize, seqLen'', embedDim] ->
      Tensor device dtype '[batchSize, numHeads, seqLen'', headDim]
    reshape' :: forall (seqLen'' :: Nat).
KnownNat seqLen'' =>
Tensor device dtype '[batchSize, seqLen'', embedDim]
-> Tensor device dtype '[batchSize, numHeads, seqLen'', headDim]
reshape' = forall (n :: Nat) (m :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat n, KnownNat m, shape' ~ Transpose shape n m) =>
Tensor device dtype shape -> Tensor device dtype shape'
transpose @1 @2 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (shape' :: [Nat]) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(KnownShape shape', Numel shape ~ Numel shape') =>
Tensor device dtype shape -> Tensor device dtype shape'
reshape @'[batchSize, seqLen'', numHeads, headDim]

instance
  ( All KnownNat '[embedDim, kEmbedDim, vEmbedDim, numHeads],
    KnownDType dtype,
    KnownDevice device,
    RandDTypeIsValid device dtype
  ) =>
  A.Randomizable
    (MultiheadAttentionSpec embedDim kEmbedDim vEmbedDim numHeads dtype device)
    (MultiheadAttention embedDim kEmbedDim vEmbedDim numHeads dtype device)
  where
  sample :: MultiheadAttentionSpec
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> IO
     (MultiheadAttention
        embedDim kEmbedDim vEmbedDim numHeads dtype device)
sample (MultiheadAttentionSpec DropoutSpec
mhaDropoutSpec) =
    forall (embedDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat).
Linear embedDim embedDim dtype device
-> Linear kEmbedDim embedDim dtype device
-> Linear vEmbedDim embedDim dtype device
-> Linear embedDim embedDim dtype device
-> Dropout
-> MultiheadAttention
     embedDim kEmbedDim vEmbedDim numHeads dtype device
MultiheadAttention
      forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> forall spec f. Randomizable spec f => spec -> IO f
A.sample forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
LinearSpec inputFeatures outputFeatures dtype device
LinearSpec
      forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> forall spec f. Randomizable spec f => spec -> IO f
A.sample forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
LinearSpec inputFeatures outputFeatures dtype device
LinearSpec
      forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> forall spec f. Randomizable spec f => spec -> IO f
A.sample forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
LinearSpec inputFeatures outputFeatures dtype device
LinearSpec
      forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> forall spec f. Randomizable spec f => spec -> IO f
A.sample forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
LinearSpec inputFeatures outputFeatures dtype device
LinearSpec
      forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> forall spec f. Randomizable spec f => spec -> IO f
A.sample DropoutSpec
mhaDropoutSpec

--------------------------------------------------------------------------------
-- Transformer MLP Layer
--------------------------------------------------------------------------------

data
  TransformerMLPSpec
    (embedDim :: Nat)
    (ffnDim :: Nat)
    (dtype :: D.DType)
    (device :: (D.DeviceType, Nat))
  where
  TransformerMLPSpec ::
    forall embedDim ffnDim dtype device.
    { -- | spec for relu dropout
      forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerMLPSpec embedDim ffnDim dtype device -> DropoutSpec
dropout0Spec :: DropoutSpec,
      -- | spec for other dropout
      forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerMLPSpec embedDim ffnDim dtype device -> DropoutSpec
dropout1Spec :: DropoutSpec,
      -- | epsilon for layer norm
      forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerMLPSpec embedDim ffnDim dtype device -> Double
epsSpec :: Double
    } ->
    TransformerMLPSpec embedDim ffnDim dtype device
  deriving (Int -> TransformerMLPSpec embedDim ffnDim dtype device -> ShowS
forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Int -> TransformerMLPSpec embedDim ffnDim dtype device -> ShowS
forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
[TransformerMLPSpec embedDim ffnDim dtype device] -> ShowS
forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerMLPSpec embedDim ffnDim dtype device -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TransformerMLPSpec embedDim ffnDim dtype device] -> ShowS
$cshowList :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
[TransformerMLPSpec embedDim ffnDim dtype device] -> ShowS
show :: TransformerMLPSpec embedDim ffnDim dtype device -> String
$cshow :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerMLPSpec embedDim ffnDim dtype device -> String
showsPrec :: Int -> TransformerMLPSpec embedDim ffnDim dtype device -> ShowS
$cshowsPrec :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Int -> TransformerMLPSpec embedDim ffnDim dtype device -> ShowS
Show, TransformerMLPSpec embedDim ffnDim dtype device
-> TransformerMLPSpec embedDim ffnDim dtype device -> Bool
forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerMLPSpec embedDim ffnDim dtype device
-> TransformerMLPSpec embedDim ffnDim dtype device -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: TransformerMLPSpec embedDim ffnDim dtype device
-> TransformerMLPSpec embedDim ffnDim dtype device -> Bool
$c/= :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerMLPSpec embedDim ffnDim dtype device
-> TransformerMLPSpec embedDim ffnDim dtype device -> Bool
== :: TransformerMLPSpec embedDim ffnDim dtype device
-> TransformerMLPSpec embedDim ffnDim dtype device -> Bool
$c== :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerMLPSpec embedDim ffnDim dtype device
-> TransformerMLPSpec embedDim ffnDim dtype device -> Bool
Eq)

data
  TransformerMLP
    (embedDim :: Nat)
    (ffnDim :: Nat)
    (dtype :: D.DType)
    (device :: (D.DeviceType, Nat))
  where
  TransformerMLP ::
    forall embedDim ffnDim dtype device.
    { -- | first fully connected layer
      forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerMLP embedDim ffnDim dtype device
-> Linear embedDim ffnDim dtype device
linear0 :: Linear embedDim ffnDim dtype device,
      -- | second fully connected layer
      forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerMLP embedDim ffnDim dtype device
-> Linear ffnDim embedDim dtype device
linear1 :: Linear ffnDim embedDim dtype device,
      -- | relu dropout
      forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerMLP embedDim ffnDim dtype device -> Dropout
dropout0 :: Dropout,
      -- | other dropout
      forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerMLP embedDim ffnDim dtype device -> Dropout
dropout1 :: Dropout,
      -- | layer norm
      forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerMLP embedDim ffnDim dtype device
-> LayerNorm '[embedDim] dtype device
ln :: LayerNorm '[embedDim] dtype device
    } ->
    TransformerMLP embedDim ffnDim dtype device
  deriving (Int -> TransformerMLP embedDim ffnDim dtype device -> ShowS
forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Int -> TransformerMLP embedDim ffnDim dtype device -> ShowS
forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
[TransformerMLP embedDim ffnDim dtype device] -> ShowS
forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerMLP embedDim ffnDim dtype device -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TransformerMLP embedDim ffnDim dtype device] -> ShowS
$cshowList :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
[TransformerMLP embedDim ffnDim dtype device] -> ShowS
show :: TransformerMLP embedDim ffnDim dtype device -> String
$cshow :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerMLP embedDim ffnDim dtype device -> String
showsPrec :: Int -> TransformerMLP embedDim ffnDim dtype device -> ShowS
$cshowsPrec :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Int -> TransformerMLP embedDim ffnDim dtype device -> ShowS
Show, forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)) x.
Rep (TransformerMLP embedDim ffnDim dtype device) x
-> TransformerMLP embedDim ffnDim dtype device
forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)) x.
TransformerMLP embedDim ffnDim dtype device
-> Rep (TransformerMLP embedDim ffnDim dtype device) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)) x.
Rep (TransformerMLP embedDim ffnDim dtype device) x
-> TransformerMLP embedDim ffnDim dtype device
$cfrom :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)) x.
TransformerMLP embedDim ffnDim dtype device
-> Rep (TransformerMLP embedDim ffnDim dtype device) x
Generic, forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerMLP embedDim ffnDim dtype device
-> HList (Parameters (TransformerMLP embedDim ffnDim dtype device))
forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerMLP embedDim ffnDim dtype device
-> HList (Parameters (TransformerMLP embedDim ffnDim dtype device))
-> TransformerMLP embedDim ffnDim dtype device
forall f.
(f -> HList (Parameters f))
-> (f -> HList (Parameters f) -> f) -> Parameterized f
replaceParameters :: TransformerMLP embedDim ffnDim dtype device
-> HList (Parameters (TransformerMLP embedDim ffnDim dtype device))
-> TransformerMLP embedDim ffnDim dtype device
$creplaceParameters :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerMLP embedDim ffnDim dtype device
-> HList (Parameters (TransformerMLP embedDim ffnDim dtype device))
-> TransformerMLP embedDim ffnDim dtype device
flattenParameters :: TransformerMLP embedDim ffnDim dtype device
-> HList (Parameters (TransformerMLP embedDim ffnDim dtype device))
$cflattenParameters :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerMLP embedDim ffnDim dtype device
-> HList (Parameters (TransformerMLP embedDim ffnDim dtype device))
Parameterized)

transformerMLP ::
  forall embedDim ffnDim seqLen batchSize dtype device.
  ( BasicArithmeticDTypeIsValid device dtype,
    StandardFloatingPointDTypeValidation device dtype,
    KnownNat embedDim,
    IsSuffixOf '[embedDim] '[seqLen, batchSize, embedDim]
  ) =>
  -- | MLP model ADT for transformer
  TransformerMLP embedDim ffnDim dtype device ->
  -- | switch between training mode and evaluation mode (turns random dropout on and off)
  Bool ->
  Tensor device dtype '[seqLen, batchSize, embedDim] -> -- input
  IO (Tensor device dtype '[seqLen, batchSize, embedDim]) -- output
transformerMLP :: forall (embedDim :: Nat) (ffnDim :: Nat) (seqLen :: Nat)
       (batchSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
(BasicArithmeticDTypeIsValid device dtype,
 StandardFloatingPointDTypeValidation device dtype,
 KnownNat embedDim,
 IsSuffixOf '[embedDim] '[seqLen, batchSize, embedDim]) =>
TransformerMLP embedDim ffnDim dtype device
-> Bool
-> Tensor device dtype '[seqLen, batchSize, embedDim]
-> IO (Tensor device dtype '[seqLen, batchSize, embedDim])
transformerMLP TransformerMLP {LayerNorm '[embedDim] dtype device
Linear embedDim ffnDim dtype device
Linear ffnDim embedDim dtype device
Dropout
ln :: LayerNorm '[embedDim] dtype device
dropout1 :: Dropout
dropout0 :: Dropout
linear1 :: Linear ffnDim embedDim dtype device
linear0 :: Linear embedDim ffnDim dtype device
ln :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerMLP embedDim ffnDim dtype device
-> LayerNorm '[embedDim] dtype device
dropout1 :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerMLP embedDim ffnDim dtype device -> Dropout
dropout0 :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerMLP embedDim ffnDim dtype device -> Dropout
linear1 :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerMLP embedDim ffnDim dtype device
-> Linear ffnDim embedDim dtype device
linear0 :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerMLP embedDim ffnDim dtype device
-> Linear embedDim ffnDim dtype device
..} Bool
train Tensor device dtype '[seqLen, batchSize, embedDim]
input =
  forall {device :: (DeviceType, Nat)} {dtype :: DType}
       {dtype' :: DType} {m :: Type -> Type} {shape :: [Nat]}
       {shape' :: [Nat]} {b}.
(BasicArithmeticDTypeIsValid device dtype,
 BasicArithmeticDTypeIsValid device dtype',
 BasicArithmeticDTypeIsValid
   device (DTypePromotionImpl dtype dtype' (CmpDType dtype dtype')),
 Monad m) =>
(Tensor device dtype shape -> m (Tensor device dtype' shape'))
-> (Tensor
      device
      (DTypePromotionImpl dtype dtype' (CmpDType dtype dtype'))
      (CheckBroadcast
         shape
         shape'
         (ComputeBroadcast
            (ReverseImpl shape '[]) (ReverseImpl shape' '[])))
    -> m b)
-> Tensor device dtype shape
-> m b
residual Tensor device dtype '[seqLen, batchSize, embedDim]
-> IO
     (Tensor
        device
        dtype
        (CheckBroadcast
           (CheckMatMul
              '[seqLen, batchSize, ffnDim]
              '[ffnDim, embedDim]
              (ComputeMatMul
                 (ReverseImpl '[seqLen, batchSize, ffnDim] '[])
                 '[embedDim, ffnDim]))
           (CheckMatMul
              '[seqLen, batchSize, ffnDim]
              '[ffnDim, embedDim]
              (ComputeMatMul
                 (ReverseImpl '[seqLen, batchSize, ffnDim] '[])
                 '[embedDim, ffnDim]))
           (ComputeBroadcast
              (ReverseImpl
                 (CheckMatMul
                    '[seqLen, batchSize, ffnDim]
                    '[ffnDim, embedDim]
                    (ComputeMatMul
                       (ReverseImpl '[seqLen, batchSize, ffnDim] '[])
                       '[embedDim, ffnDim]))
                 '[])
              (ReverseImpl
                 (CheckMatMul
                    '[seqLen, batchSize, ffnDim]
                    '[ffnDim, embedDim]
                    (ComputeMatMul
                       (ReverseImpl '[seqLen, batchSize, ffnDim] '[])
                       '[embedDim, ffnDim]))
                 '[]))))
f (forall (f :: Type -> Type) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall f a b. HasForward f a b => f -> a -> b
forward LayerNorm '[embedDim] dtype device
ln) Tensor device dtype '[seqLen, batchSize, embedDim]
input
  where
    f :: Tensor device dtype '[seqLen, batchSize, embedDim]
-> IO
     (Tensor
        device
        dtype
        (CheckBroadcast
           (CheckMatMul
              '[seqLen, batchSize, ffnDim]
              '[ffnDim, embedDim]
              (ComputeMatMul
                 (ReverseImpl '[seqLen, batchSize, ffnDim] '[])
                 '[embedDim, ffnDim]))
           (CheckMatMul
              '[seqLen, batchSize, ffnDim]
              '[ffnDim, embedDim]
              (ComputeMatMul
                 (ReverseImpl '[seqLen, batchSize, ffnDim] '[])
                 '[embedDim, ffnDim]))
           (ComputeBroadcast
              (ReverseImpl
                 (CheckMatMul
                    '[seqLen, batchSize, ffnDim]
                    '[ffnDim, embedDim]
                    (ComputeMatMul
                       (ReverseImpl '[seqLen, batchSize, ffnDim] '[])
                       '[embedDim, ffnDim]))
                 '[])
              (ReverseImpl
                 (CheckMatMul
                    '[seqLen, batchSize, ffnDim]
                    '[ffnDim, embedDim]
                    (ComputeMatMul
                       (ReverseImpl '[seqLen, batchSize, ffnDim] '[])
                       '[embedDim, ffnDim]))
                 '[]))))
f Tensor device dtype '[seqLen, batchSize, embedDim]
x =
      forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Dropout
-> Bool
-> Tensor device dtype shape
-> IO (Tensor device dtype shape)
dropoutForward Dropout
dropout1 Bool
train
        forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall f a b. HasForward f a b => f -> a -> b
forward Linear ffnDim embedDim dtype device
linear1
        forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Dropout
-> Bool
-> Tensor device dtype shape
-> IO (Tensor device dtype shape)
dropoutForward Dropout
dropout0 Bool
train
          forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)) t.
(StandardFloatingPointDTypeValidation device dtype,
 IsUnnamed t device dtype shape) =>
t -> t
relu
          forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall f a b. HasForward f a b => f -> a -> b
forward Linear embedDim ffnDim dtype device
linear0
        forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Tensor device dtype '[seqLen, batchSize, embedDim]
x

instance
  ( All KnownNat '[embedDim, ffnDim],
    KnownDType dtype,
    KnownDevice device,
    RandDTypeIsValid device dtype
  ) =>
  A.Randomizable
    (TransformerMLPSpec embedDim ffnDim dtype device)
    (TransformerMLP embedDim ffnDim dtype device)
  where
  sample :: TransformerMLPSpec embedDim ffnDim dtype device
-> IO (TransformerMLP embedDim ffnDim dtype device)
sample TransformerMLPSpec {Double
DropoutSpec
epsSpec :: Double
dropout1Spec :: DropoutSpec
dropout0Spec :: DropoutSpec
epsSpec :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerMLPSpec embedDim ffnDim dtype device -> Double
dropout1Spec :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerMLPSpec embedDim ffnDim dtype device -> DropoutSpec
dropout0Spec :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerMLPSpec embedDim ffnDim dtype device -> DropoutSpec
..} =
    forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Linear embedDim ffnDim dtype device
-> Linear ffnDim embedDim dtype device
-> Dropout
-> Dropout
-> LayerNorm '[embedDim] dtype device
-> TransformerMLP embedDim ffnDim dtype device
TransformerMLP
      forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> forall spec f. Randomizable spec f => spec -> IO f
A.sample forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
LinearSpec inputFeatures outputFeatures dtype device
LinearSpec
      forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> forall spec f. Randomizable spec f => spec -> IO f
A.sample forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
LinearSpec inputFeatures outputFeatures dtype device
LinearSpec
      forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> forall spec f. Randomizable spec f => spec -> IO f
A.sample DropoutSpec
dropout0Spec
      forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> forall spec f. Randomizable spec f => spec -> IO f
A.sample DropoutSpec
dropout1Spec
      forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> forall spec f. Randomizable spec f => spec -> IO f
A.sample (forall (normalizedShape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Double -> LayerNormSpec normalizedShape dtype device
LayerNormSpec Double
epsSpec)

--------------------------------------------------------------------------------
-- Relation-Aware Transformer Layer
--------------------------------------------------------------------------------

data
  TransformerLayerSpec
    (embedDim :: Nat)
    (kEmbedDim :: Nat)
    (vEmbedDim :: Nat)
    (numHeads :: Nat)
    (ffnDim :: Nat)
    (dtype :: D.DType)
    (device :: (D.DeviceType, Nat))
  where
  TransformerLayerSpec ::
    forall embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device.
    { forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerLayerSpec
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> MultiheadAttentionSpec
     embedDim kEmbedDim vEmbedDim numHeads dtype device
mhaSpec :: MultiheadAttentionSpec embedDim kEmbedDim vEmbedDim numHeads dtype device,
      forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerLayerSpec
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> DropoutSpec
attnDropoutSpec :: DropoutSpec,
      forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerLayerSpec
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> Double
epsSpec' :: Double,
      forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerLayerSpec
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> TransformerMLPSpec embedDim ffnDim dtype device
mlpSpec :: TransformerMLPSpec embedDim ffnDim dtype device
    } ->
    TransformerLayerSpec embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
  deriving (Int
-> TransformerLayerSpec
     embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> ShowS
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Int
-> TransformerLayerSpec
     embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> ShowS
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
[TransformerLayerSpec
   embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device]
-> ShowS
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerLayerSpec
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TransformerLayerSpec
   embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device]
-> ShowS
$cshowList :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
[TransformerLayerSpec
   embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device]
-> ShowS
show :: TransformerLayerSpec
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> String
$cshow :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerLayerSpec
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> String
showsPrec :: Int
-> TransformerLayerSpec
     embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> ShowS
$cshowsPrec :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Int
-> TransformerLayerSpec
     embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> ShowS
Show, TransformerLayerSpec
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> TransformerLayerSpec
     embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> Bool
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerLayerSpec
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> TransformerLayerSpec
     embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: TransformerLayerSpec
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> TransformerLayerSpec
     embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> Bool
$c/= :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerLayerSpec
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> TransformerLayerSpec
     embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> Bool
== :: TransformerLayerSpec
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> TransformerLayerSpec
     embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> Bool
$c== :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerLayerSpec
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> TransformerLayerSpec
     embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> Bool
Eq)

data
  TransformerLayer
    (embedDim :: Nat)
    (kEmbedDim :: Nat)
    (vEmbedDim :: Nat)
    (numHeads :: Nat)
    (ffnDim :: Nat)
    (dtype :: D.DType)
    (device :: (D.DeviceType, Nat))
  where
  TransformerLayer ::
    forall embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device.
    { -- | multi-head attention
      forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerLayer
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> MultiheadAttention
     embedDim kEmbedDim vEmbedDim numHeads dtype device
transformerLayer_mha :: MultiheadAttention embedDim kEmbedDim vEmbedDim numHeads dtype device,
      -- | dropout
      forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerLayer
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> Dropout
transformerLayer_attnDropout :: Dropout,
      -- | layer norm
      forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerLayer
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> LayerNorm '[embedDim] dtype device
transformerLayer_ln :: LayerNorm '[embedDim] dtype device,
      -- | MLP
      forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerLayer
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> TransformerMLP embedDim ffnDim dtype device
transformerLayer_mlp :: TransformerMLP embedDim ffnDim dtype device
    } ->
    TransformerLayer embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
  deriving (Int
-> TransformerLayer
     embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> ShowS
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Int
-> TransformerLayer
     embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> ShowS
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
[TransformerLayer
   embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device]
-> ShowS
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerLayer
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TransformerLayer
   embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device]
-> ShowS
$cshowList :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
[TransformerLayer
   embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device]
-> ShowS
show :: TransformerLayer
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> String
$cshow :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerLayer
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> String
showsPrec :: Int
-> TransformerLayer
     embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> ShowS
$cshowsPrec :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Int
-> TransformerLayer
     embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> ShowS
Show, forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)) x.
Rep
  (TransformerLayer
     embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
  x
-> TransformerLayer
     embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)) x.
TransformerLayer
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> Rep
     (TransformerLayer
        embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
     x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)) x.
Rep
  (TransformerLayer
     embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
  x
-> TransformerLayer
     embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
$cfrom :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)) x.
TransformerLayer
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> Rep
     (TransformerLayer
        embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
     x
Generic, forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerLayer
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> HList
     (Parameters
        (TransformerLayer
           embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device))
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerLayer
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> HList
     (Parameters
        (TransformerLayer
           embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device))
-> TransformerLayer
     embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
forall f.
(f -> HList (Parameters f))
-> (f -> HList (Parameters f) -> f) -> Parameterized f
replaceParameters :: TransformerLayer
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> HList
     (Parameters
        (TransformerLayer
           embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device))
-> TransformerLayer
     embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
$creplaceParameters :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerLayer
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> HList
     (Parameters
        (TransformerLayer
           embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device))
-> TransformerLayer
     embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
flattenParameters :: TransformerLayer
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> HList
     (Parameters
        (TransformerLayer
           embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device))
$cflattenParameters :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerLayer
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> HList
     (Parameters
        (TransformerLayer
           embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device))
Parameterized)

transformerLayer ::
  forall (numHeads :: Nat) (ffnDim :: Nat) (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat) (headDim :: Nat) (seqLen :: Nat) (seqLen' :: Nat) (batchSize :: Nat) dtype device.
  ( 1 <= numHeads,
    embedDim ~ (headDim * numHeads),
    All KnownNat '[embedDim, kEmbedDim, vEmbedDim, numHeads, seqLen, seqLen', batchSize, headDim],
    IsSuffixOf '[embedDim] '[batchSize, seqLen', embedDim],
    KnownDType dtype,
    dtype ~ SumDType dtype,
    StandardFloatingPointDTypeValidation device dtype,
    MatMulDTypeIsValid device dtype,
    BasicArithmeticDTypeIsValid device dtype,
    SumDTypeIsValid device dtype,
    KnownDevice device
  ) =>
  -- | transformer layer model ADT
  TransformerLayer embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device ->
  -- | switch between training mode and evaluation mode (turns random dropout on and off)
  Bool ->
  -- | optional attention mask
  Maybe (Tensor device dtype '[batchSize, seqLen', seqLen]) ->
  -- | optional key padding mask
  Maybe (Tensor device 'D.Bool '[batchSize, seqLen]) ->
  -- | optional key relations
  Maybe (Tensor device dtype '[batchSize, seqLen', seqLen, headDim]) ->
  -- | optional value relations
  Maybe (Tensor device dtype '[batchSize, seqLen', seqLen, headDim]) ->
  -- | query representation
  Tensor device dtype '[batchSize, seqLen', embedDim] ->
  -- | key representation
  Tensor device dtype '[batchSize, seqLen, kEmbedDim] ->
  -- | value representation
  Tensor device dtype '[batchSize, seqLen, vEmbedDim] ->
  -- | transformer layer output representation
  IO (Tensor device dtype '[batchSize, seqLen', embedDim])
transformerLayer :: forall (numHeads :: Nat) (ffnDim :: Nat) (embedDim :: Nat)
       (kEmbedDim :: Nat) (vEmbedDim :: Nat) (headDim :: Nat)
       (seqLen :: Nat) (seqLen' :: Nat) (batchSize :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
(1 <= numHeads, embedDim ~ (headDim * numHeads),
 All
   KnownNat
   '[embedDim, kEmbedDim, vEmbedDim, numHeads, seqLen, seqLen',
     batchSize, headDim],
 IsSuffixOf '[embedDim] '[batchSize, seqLen', embedDim],
 KnownDType dtype, dtype ~ SumDType dtype,
 StandardFloatingPointDTypeValidation device dtype,
 MatMulDTypeIsValid device dtype,
 BasicArithmeticDTypeIsValid device dtype,
 SumDTypeIsValid device dtype, KnownDevice device) =>
TransformerLayer
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> Bool
-> Maybe (Tensor device dtype '[batchSize, seqLen', seqLen])
-> Maybe (Tensor device 'Bool '[batchSize, seqLen])
-> Maybe
     (Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
-> Maybe
     (Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
-> Tensor device dtype '[batchSize, seqLen', embedDim]
-> Tensor device dtype '[batchSize, seqLen, kEmbedDim]
-> Tensor device dtype '[batchSize, seqLen, vEmbedDim]
-> IO (Tensor device dtype '[batchSize, seqLen', embedDim])
transformerLayer TransformerLayer {LayerNorm '[embedDim] dtype device
Dropout
TransformerMLP embedDim ffnDim dtype device
MultiheadAttention
  embedDim kEmbedDim vEmbedDim numHeads dtype device
transformerLayer_mlp :: TransformerMLP embedDim ffnDim dtype device
transformerLayer_ln :: LayerNorm '[embedDim] dtype device
transformerLayer_attnDropout :: Dropout
transformerLayer_mha :: MultiheadAttention
  embedDim kEmbedDim vEmbedDim numHeads dtype device
transformerLayer_mlp :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerLayer
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> TransformerMLP embedDim ffnDim dtype device
transformerLayer_ln :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerLayer
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> LayerNorm '[embedDim] dtype device
transformerLayer_attnDropout :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerLayer
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> Dropout
transformerLayer_mha :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerLayer
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> MultiheadAttention
     embedDim kEmbedDim vEmbedDim numHeads dtype device
..} Bool
train Maybe (Tensor device dtype '[batchSize, seqLen', seqLen])
attentionMask Maybe (Tensor device 'Bool '[batchSize, seqLen])
keyPaddingMask Maybe (Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
keyRelations Maybe (Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
valueRelations Tensor device dtype '[batchSize, seqLen', embedDim]
query Tensor device dtype '[batchSize, seqLen, kEmbedDim]
key Tensor device dtype '[batchSize, seqLen, vEmbedDim]
value =
  let f :: Tensor device dtype '[batchSize, seqLen', embedDim]
-> IO (Tensor device dtype '[batchSize, seqLen', embedDim])
f Tensor device dtype '[batchSize, seqLen', embedDim]
query' =
        forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (seqLen :: Nat) (seqLen' :: Nat)
       (batchSize :: Nat) (headDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(1 <= numHeads, embedDim ~ (headDim * numHeads),
 All
   KnownNat
   '[embedDim, kEmbedDim, vEmbedDim, numHeads, seqLen, seqLen',
     batchSize, headDim],
 KnownDType dtype,
 StandardFloatingPointDTypeValidation device dtype,
 MatMulDTypeIsValid device dtype,
 BasicArithmeticDTypeIsValid device dtype, dtype ~ SumDType dtype,
 SumDTypeIsValid device dtype, KnownDevice device) =>
MultiheadAttention
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Bool
-> Maybe (Tensor device dtype '[batchSize, seqLen', seqLen])
-> Maybe (Tensor device 'Bool '[batchSize, seqLen])
-> Maybe
     (Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
-> Maybe
     (Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
-> Tensor device dtype '[batchSize, seqLen', embedDim]
-> Tensor device dtype '[batchSize, seqLen, kEmbedDim]
-> Tensor device dtype '[batchSize, seqLen, vEmbedDim]
-> IO
     (Tensor device dtype '[batchSize, seqLen', embedDim],
      Tensor device dtype '[batchSize, seqLen', seqLen])
multiheadAttention MultiheadAttention
  embedDim kEmbedDim vEmbedDim numHeads dtype device
transformerLayer_mha Bool
train Maybe (Tensor device dtype '[batchSize, seqLen', seqLen])
attentionMask Maybe (Tensor device 'Bool '[batchSize, seqLen])
keyPaddingMask Maybe (Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
keyRelations Maybe (Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
valueRelations Tensor device dtype '[batchSize, seqLen', embedDim]
query' Tensor device dtype '[batchSize, seqLen, kEmbedDim]
key Tensor device dtype '[batchSize, seqLen, vEmbedDim]
value
          forall (m :: Type -> Type) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Dropout
-> Bool
-> Tensor device dtype shape
-> IO (Tensor device dtype shape)
dropoutForward Dropout
transformerLayer_attnDropout Bool
train forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst
   in forall {device :: (DeviceType, Nat)} {dtype :: DType}
       {dtype' :: DType} {m :: Type -> Type} {shape :: [Nat]}
       {shape' :: [Nat]} {b}.
(BasicArithmeticDTypeIsValid device dtype,
 BasicArithmeticDTypeIsValid device dtype',
 BasicArithmeticDTypeIsValid
   device (DTypePromotionImpl dtype dtype' (CmpDType dtype dtype')),
 Monad m) =>
(Tensor device dtype shape -> m (Tensor device dtype' shape'))
-> (Tensor
      device
      (DTypePromotionImpl dtype dtype' (CmpDType dtype dtype'))
      (CheckBroadcast
         shape
         shape'
         (ComputeBroadcast
            (ReverseImpl shape '[]) (ReverseImpl shape' '[])))
    -> m b)
-> Tensor device dtype shape
-> m b
residual Tensor device dtype '[batchSize, seqLen', embedDim]
-> IO (Tensor device dtype '[batchSize, seqLen', embedDim])
f (forall (f :: Type -> Type) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall f a b. HasForward f a b => f -> a -> b
forward LayerNorm '[embedDim] dtype device
transformerLayer_ln) Tensor device dtype '[batchSize, seqLen', embedDim]
query forall (m :: Type -> Type) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (embedDim :: Nat) (ffnDim :: Nat) (seqLen :: Nat)
       (batchSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
(BasicArithmeticDTypeIsValid device dtype,
 StandardFloatingPointDTypeValidation device dtype,
 KnownNat embedDim,
 IsSuffixOf '[embedDim] '[seqLen, batchSize, embedDim]) =>
TransformerMLP embedDim ffnDim dtype device
-> Bool
-> Tensor device dtype '[seqLen, batchSize, embedDim]
-> IO (Tensor device dtype '[seqLen, batchSize, embedDim])
transformerMLP TransformerMLP embedDim ffnDim dtype device
transformerLayer_mlp Bool
train

instance
  ( All KnownNat '[embedDim, kEmbedDim, vEmbedDim, numHeads, ffnDim],
    KnownDType dtype,
    KnownDevice device,
    RandDTypeIsValid device dtype
  ) =>
  A.Randomizable
    (TransformerLayerSpec embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
    (TransformerLayer embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
  where
  sample :: TransformerLayerSpec
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> IO
     (TransformerLayer
        embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
sample TransformerLayerSpec {Double
DropoutSpec
TransformerMLPSpec embedDim ffnDim dtype device
MultiheadAttentionSpec
  embedDim kEmbedDim vEmbedDim numHeads dtype device
mlpSpec :: TransformerMLPSpec embedDim ffnDim dtype device
epsSpec' :: Double
attnDropoutSpec :: DropoutSpec
mhaSpec :: MultiheadAttentionSpec
  embedDim kEmbedDim vEmbedDim numHeads dtype device
mlpSpec :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerLayerSpec
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> TransformerMLPSpec embedDim ffnDim dtype device
epsSpec' :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerLayerSpec
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> Double
attnDropoutSpec :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerLayerSpec
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> DropoutSpec
mhaSpec :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerLayerSpec
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> MultiheadAttentionSpec
     embedDim kEmbedDim vEmbedDim numHeads dtype device
..} =
    forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
MultiheadAttention
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Dropout
-> LayerNorm '[embedDim] dtype device
-> TransformerMLP embedDim ffnDim dtype device
-> TransformerLayer
     embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
TransformerLayer
      forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> forall spec f. Randomizable spec f => spec -> IO f
A.sample MultiheadAttentionSpec
  embedDim kEmbedDim vEmbedDim numHeads dtype device
mhaSpec
      forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> forall spec f. Randomizable spec f => spec -> IO f
A.sample DropoutSpec
attnDropoutSpec
      forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> forall spec f. Randomizable spec f => spec -> IO f
A.sample (forall (normalizedShape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Double -> LayerNormSpec normalizedShape dtype device
LayerNormSpec Double
epsSpec')
      forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> forall spec f. Randomizable spec f => spec -> IO f
A.sample TransformerMLPSpec embedDim ffnDim dtype device
mlpSpec

--------------------------------------------------------------------------------
-- Transformer Language Model (GPT-2)
--------------------------------------------------------------------------------

data
  TransformerLMSpec
    (numAttnLayers :: Nat)
    (numHeads :: Nat)
    (ffnDim :: Nat)
    (paddingIdx :: Nat)
    (numEmbeds :: Nat)
    (embedDim :: Nat)
    (dtype :: D.DType)
    (device :: (D.DeviceType, Nat))
  where
  TransformerLMSpec ::
    forall numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device.
    { -- | dropout spec
      forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
       (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLMSpec
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> DropoutSpec
lmDropoutSpec :: DropoutSpec,
      -- | spec for each and every transformer layer
      forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
       (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLMSpec
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> TransformerLayerSpec
     embedDim embedDim embedDim numHeads ffnDim dtype device
lmLayerSpec :: TransformerLayerSpec embedDim embedDim embedDim numHeads ffnDim dtype device
    } ->
    TransformerLMSpec numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device
  deriving (Int
-> TransformerLMSpec
     numAttnLayers
     numHeads
     ffnDim
     paddingIdx
     numEmbeds
     embedDim
     dtype
     device
-> ShowS
forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
       (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Int
-> TransformerLMSpec
     numAttnLayers
     numHeads
     ffnDim
     paddingIdx
     numEmbeds
     embedDim
     dtype
     device
-> ShowS
forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
       (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
[TransformerLMSpec
   numAttnLayers
   numHeads
   ffnDim
   paddingIdx
   numEmbeds
   embedDim
   dtype
   device]
-> ShowS
forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
       (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLMSpec
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TransformerLMSpec
   numAttnLayers
   numHeads
   ffnDim
   paddingIdx
   numEmbeds
   embedDim
   dtype
   device]
-> ShowS
$cshowList :: forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
       (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
[TransformerLMSpec
   numAttnLayers
   numHeads
   ffnDim
   paddingIdx
   numEmbeds
   embedDim
   dtype
   device]
-> ShowS
show :: TransformerLMSpec
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> String
$cshow :: forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
       (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLMSpec
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> String
showsPrec :: Int
-> TransformerLMSpec
     numAttnLayers
     numHeads
     ffnDim
     paddingIdx
     numEmbeds
     embedDim
     dtype
     device
-> ShowS
$cshowsPrec :: forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
       (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Int
-> TransformerLMSpec
     numAttnLayers
     numHeads
     ffnDim
     paddingIdx
     numEmbeds
     embedDim
     dtype
     device
-> ShowS
Show, TransformerLMSpec
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> TransformerLMSpec
     numAttnLayers
     numHeads
     ffnDim
     paddingIdx
     numEmbeds
     embedDim
     dtype
     device
-> Bool
forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
       (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLMSpec
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> TransformerLMSpec
     numAttnLayers
     numHeads
     ffnDim
     paddingIdx
     numEmbeds
     embedDim
     dtype
     device
-> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: TransformerLMSpec
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> TransformerLMSpec
     numAttnLayers
     numHeads
     ffnDim
     paddingIdx
     numEmbeds
     embedDim
     dtype
     device
-> Bool
$c/= :: forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
       (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLMSpec
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> TransformerLMSpec
     numAttnLayers
     numHeads
     ffnDim
     paddingIdx
     numEmbeds
     embedDim
     dtype
     device
-> Bool
== :: TransformerLMSpec
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> TransformerLMSpec
     numAttnLayers
     numHeads
     ffnDim
     paddingIdx
     numEmbeds
     embedDim
     dtype
     device
-> Bool
$c== :: forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
       (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLMSpec
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> TransformerLMSpec
     numAttnLayers
     numHeads
     ffnDim
     paddingIdx
     numEmbeds
     embedDim
     dtype
     device
-> Bool
Eq)

data
  TransformerLM
    (numAttnLayers :: Nat)
    (numHeads :: Nat)
    (ffnDim :: Nat)
    (paddingIdx :: Nat)
    (numEmbeds :: Nat)
    (embedDim :: Nat)
    (dtype :: D.DType)
    (device :: (D.DeviceType, Nat))
  where
  TransformerLM ::
    forall numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device.
    { -- | token embedding
      forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
       (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLM
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> Embedding
     ('Just paddingIdx) numEmbeds embedDim 'Learned dtype device
tEmbedding :: Embedding ('Just paddingIdx) numEmbeds embedDim 'Learned dtype device,
      -- | positional embedding
      forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
       (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLM
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> Embedding 'Nothing 2048 embedDim 'Constant dtype device
tPosEmbedding :: Embedding 'Nothing 2048 embedDim 'Constant dtype device,
      -- | transformer dropout
      forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
       (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLM
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> Dropout
tDropout :: Dropout,
      -- | transformer layers
      forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
       (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLM
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> HList
     (HReplicateR
        numAttnLayers
        (TransformerLayer
           embedDim embedDim embedDim numHeads ffnDim dtype device))
tLayers :: HList (HReplicateR numAttnLayers (TransformerLayer embedDim embedDim embedDim numHeads ffnDim dtype device)),
      -- | final output projection
      forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
       (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLM
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> Linear embedDim numEmbeds dtype device
tProj :: Linear embedDim numEmbeds dtype device
    } ->
    TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device
  deriving (forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
       (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)) x.
Rep
  (TransformerLM
     numAttnLayers
     numHeads
     ffnDim
     paddingIdx
     numEmbeds
     embedDim
     dtype
     device)
  x
-> TransformerLM
     numAttnLayers
     numHeads
     ffnDim
     paddingIdx
     numEmbeds
     embedDim
     dtype
     device
forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
       (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)) x.
TransformerLM
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> Rep
     (TransformerLM
        numAttnLayers
        numHeads
        ffnDim
        paddingIdx
        numEmbeds
        embedDim
        dtype
        device)
     x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
       (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)) x.
Rep
  (TransformerLM
     numAttnLayers
     numHeads
     ffnDim
     paddingIdx
     numEmbeds
     embedDim
     dtype
     device)
  x
-> TransformerLM
     numAttnLayers
     numHeads
     ffnDim
     paddingIdx
     numEmbeds
     embedDim
     dtype
     device
$cfrom :: forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
       (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)) x.
TransformerLM
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> Rep
     (TransformerLM
        numAttnLayers
        numHeads
        ffnDim
        paddingIdx
        numEmbeds
        embedDim
        dtype
        device)
     x
Generic)

deriving instance
  ( Show
      ( HList
          ( HReplicateR
              numAttnLayers
              ( TransformerLayer
                  embedDim
                  embedDim
                  embedDim
                  numHeads
                  ffnDim
                  dtype
                  device
              )
          )
      )
  ) =>
  Show (TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device)

instance
  ( layers
      ~ ( HReplicateR
            numAttnLayers
            ( TransformerLayer
                embedDim
                embedDim
                embedDim
                numHeads
                ffnDim
                dtype
                device
            )
        ),
    Parameterized
      ( HList
          layers
      ),
    HAppendFD
      (Parameters (HList layers))
      '[ Parameter device dtype '[numEmbeds, embedDim],
         Parameter device dtype '[numEmbeds]
       ]
      ( Parameters (HList layers)
          ++ '[ Parameter device dtype '[numEmbeds, embedDim],
                Parameter device dtype '[numEmbeds]
              ]
      )
  ) =>
  Parameterized (TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device)

data
  FoldLayers
    (batchSize :: Nat)
    (seqLen :: Nat)
    (dtype :: D.DType)
    (device :: (D.DeviceType, Nat)) = FoldLayers
  { -- | switch between training mode and evaluation mode (turns random dropout on and off)
    forall (batchSize :: Nat) (seqLen :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
FoldLayers batchSize seqLen dtype device -> Bool
flTrain :: Bool,
    -- | optional attention mask
    forall (batchSize :: Nat) (seqLen :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
FoldLayers batchSize seqLen dtype device
-> Maybe (Tensor device dtype '[batchSize, seqLen, seqLen])
flAttentionMask :: Maybe (Tensor device dtype '[batchSize, seqLen, seqLen]),
    -- | optional key padding mask
    forall (batchSize :: Nat) (seqLen :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
FoldLayers batchSize seqLen dtype device
-> Maybe (Tensor device 'Bool '[batchSize, seqLen])
flKeyPaddingMask :: Maybe (Tensor device 'D.Bool '[batchSize, seqLen])
  }

instance
  ( 1 <= numHeads,
    embedDim ~ (headDim * numHeads),
    All KnownNat '[embedDim, numHeads, seqLen, batchSize, headDim],
    IsSuffixOf '[embedDim] '[batchSize, seqLen, embedDim],
    KnownDType dtype,
    StandardFloatingPointDTypeValidation device dtype,
    MatMulDTypeIsValid device dtype,
    BasicArithmeticDTypeIsValid device dtype,
    dtype ~ SumDType dtype,
    SumDTypeIsValid device dtype,
    KnownDevice device
  ) =>
  Apply'
    (FoldLayers batchSize seqLen dtype device)
    ( TransformerLayer embedDim embedDim embedDim numHeads ffnDim dtype device,
      IO (Tensor device dtype '[batchSize, seqLen, embedDim])
    )
    (IO (Tensor device dtype '[batchSize, seqLen, embedDim]))
  where
  apply' :: FoldLayers batchSize seqLen dtype device
-> (TransformerLayer
      embedDim embedDim embedDim numHeads ffnDim dtype device,
    IO (Tensor device dtype '[batchSize, seqLen, embedDim]))
-> IO (Tensor device dtype '[batchSize, seqLen, embedDim])
apply' FoldLayers {Bool
Maybe (Tensor device dtype '[batchSize, seqLen, seqLen])
Maybe (Tensor device 'Bool '[batchSize, seqLen])
flKeyPaddingMask :: Maybe (Tensor device 'Bool '[batchSize, seqLen])
flAttentionMask :: Maybe (Tensor device dtype '[batchSize, seqLen, seqLen])
flTrain :: Bool
flKeyPaddingMask :: forall (batchSize :: Nat) (seqLen :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
FoldLayers batchSize seqLen dtype device
-> Maybe (Tensor device 'Bool '[batchSize, seqLen])
flAttentionMask :: forall (batchSize :: Nat) (seqLen :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
FoldLayers batchSize seqLen dtype device
-> Maybe (Tensor device dtype '[batchSize, seqLen, seqLen])
flTrain :: forall (batchSize :: Nat) (seqLen :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
FoldLayers batchSize seqLen dtype device -> Bool
..} (TransformerLayer
  embedDim embedDim embedDim numHeads ffnDim dtype device
layer, IO (Tensor device dtype '[batchSize, seqLen, embedDim])
mx) = IO (Tensor device dtype '[batchSize, seqLen, embedDim])
mx forall (m :: Type -> Type) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Tensor device dtype '[batchSize, seqLen, embedDim]
x -> forall (numHeads :: Nat) (ffnDim :: Nat) (embedDim :: Nat)
       (kEmbedDim :: Nat) (vEmbedDim :: Nat) (headDim :: Nat)
       (seqLen :: Nat) (seqLen' :: Nat) (batchSize :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
(1 <= numHeads, embedDim ~ (headDim * numHeads),
 All
   KnownNat
   '[embedDim, kEmbedDim, vEmbedDim, numHeads, seqLen, seqLen',
     batchSize, headDim],
 IsSuffixOf '[embedDim] '[batchSize, seqLen', embedDim],
 KnownDType dtype, dtype ~ SumDType dtype,
 StandardFloatingPointDTypeValidation device dtype,
 MatMulDTypeIsValid device dtype,
 BasicArithmeticDTypeIsValid device dtype,
 SumDTypeIsValid device dtype, KnownDevice device) =>
TransformerLayer
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> Bool
-> Maybe (Tensor device dtype '[batchSize, seqLen', seqLen])
-> Maybe (Tensor device 'Bool '[batchSize, seqLen])
-> Maybe
     (Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
-> Maybe
     (Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
-> Tensor device dtype '[batchSize, seqLen', embedDim]
-> Tensor device dtype '[batchSize, seqLen, kEmbedDim]
-> Tensor device dtype '[batchSize, seqLen, vEmbedDim]
-> IO (Tensor device dtype '[batchSize, seqLen', embedDim])
transformerLayer TransformerLayer
  embedDim embedDim embedDim numHeads ffnDim dtype device
layer Bool
flTrain Maybe (Tensor device dtype '[batchSize, seqLen, seqLen])
flAttentionMask Maybe (Tensor device 'Bool '[batchSize, seqLen])
flKeyPaddingMask forall a. Maybe a
Nothing forall a. Maybe a
Nothing Tensor device dtype '[batchSize, seqLen, embedDim]
x Tensor device dtype '[batchSize, seqLen, embedDim]
x Tensor device dtype '[batchSize, seqLen, embedDim]
x

transformerLM ::
  forall
    numAttnLayers
    numHeads
    ffnDim
    paddingIdx
    numEmbeds
    embedDim
    seqLen
    batchSize
    dtype
    device.
  ( All KnownNat '[paddingIdx, embedDim, seqLen, batchSize],
    paddingIdx + 1 <= numEmbeds,
    1 <= seqLen,
    HFoldrM
      IO
      (FoldLayers batchSize seqLen dtype device)
      (Tensor device dtype '[batchSize, seqLen, embedDim])
      (HReplicateR numAttnLayers (TransformerLayer embedDim embedDim embedDim numHeads ffnDim dtype device))
      (Tensor device dtype '[batchSize, seqLen, embedDim]),
    BasicArithmeticDTypeIsValid device dtype,
    ComparisonDTypeIsValid device dtype,
    ComparisonDTypeIsValid device 'D.Int64,
    KnownDType dtype,
    KnownDevice device
  ) =>
  TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device ->
  Bool ->
  Tensor device 'D.Int64 '[batchSize, seqLen] ->
  IO (Tensor device dtype '[batchSize, seqLen, numEmbeds])
transformerLM :: forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
       (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
       (seqLen :: Nat) (batchSize :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(All KnownNat '[paddingIdx, embedDim, seqLen, batchSize],
 (paddingIdx + 1) <= numEmbeds, 1 <= seqLen,
 HFoldrM
   IO
   (FoldLayers batchSize seqLen dtype device)
   (Tensor device dtype '[batchSize, seqLen, embedDim])
   (HReplicateR
      numAttnLayers
      (TransformerLayer
         embedDim embedDim embedDim numHeads ffnDim dtype device))
   (Tensor device dtype '[batchSize, seqLen, embedDim]),
 BasicArithmeticDTypeIsValid device dtype,
 ComparisonDTypeIsValid device dtype,
 ComparisonDTypeIsValid device 'Int64, KnownDType dtype,
 KnownDevice device) =>
TransformerLM
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> Bool
-> Tensor device 'Int64 '[batchSize, seqLen]
-> IO (Tensor device dtype '[batchSize, seqLen, numEmbeds])
transformerLM TransformerLM {HList
  (HReplicateR
     numAttnLayers
     (TransformerLayer
        embedDim embedDim embedDim numHeads ffnDim dtype device))
Embedding 'Nothing 2048 embedDim 'Constant dtype device
Embedding
  ('Just paddingIdx) numEmbeds embedDim 'Learned dtype device
Linear embedDim numEmbeds dtype device
Dropout
tProj :: Linear embedDim numEmbeds dtype device
tLayers :: HList
  (HReplicateR
     numAttnLayers
     (TransformerLayer
        embedDim embedDim embedDim numHeads ffnDim dtype device))
tDropout :: Dropout
tPosEmbedding :: Embedding 'Nothing 2048 embedDim 'Constant dtype device
tEmbedding :: Embedding
  ('Just paddingIdx) numEmbeds embedDim 'Learned dtype device
tProj :: forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
       (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLM
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> Linear embedDim numEmbeds dtype device
tLayers :: forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
       (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLM
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> HList
     (HReplicateR
        numAttnLayers
        (TransformerLayer
           embedDim embedDim embedDim numHeads ffnDim dtype device))
tDropout :: forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
       (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLM
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> Dropout
tPosEmbedding :: forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
       (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLM
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> Embedding 'Nothing 2048 embedDim 'Constant dtype device
tEmbedding :: forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
       (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLM
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> Embedding
     ('Just paddingIdx) numEmbeds embedDim 'Learned dtype device
..} Bool
train Tensor device 'Int64 '[batchSize, seqLen]
xTokens = do
  let x :: Tensor device dtype '[batchSize, seqLen, embedDim]
x = forall (paddingIdx :: Maybe Nat) (shape :: [Nat])
       (numEmbeds :: Nat) (embedSize :: Nat)
       (embeddingType :: EmbeddingType) (dtype :: DType)
       (device :: (DeviceType, Nat)) (shape' :: [Nat]).
(KnownMaybeNat paddingIdx, PaddingIdxCheck paddingIdx numEmbeds,
 shape' ~ Reverse (embedSize : Reverse shape)) =>
Embedding paddingIdx numEmbeds embedSize embeddingType dtype device
-> Tensor device 'Int64 shape -> Tensor device dtype shape'
embed Embedding
  ('Just paddingIdx) numEmbeds embedDim 'Learned dtype device
tEmbedding Tensor device 'Int64 '[batchSize, seqLen]
xTokens
      positions :: Tensor device dtype '[batchSize, seqLen, embedDim]
positions =
        forall (shape' :: [Nat]) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(KnownShape shape', shape' ~ Broadcast shape shape') =>
Bool -> Tensor device dtype shape -> Tensor device dtype shape'
expand @'[batchSize, seqLen, embedDim] Bool
True
          forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (paddingIdx :: Maybe Nat) (shape :: [Nat])
       (numEmbeds :: Nat) (embedSize :: Nat)
       (embeddingType :: EmbeddingType) (dtype :: DType)
       (device :: (DeviceType, Nat)) (shape' :: [Nat]).
(KnownMaybeNat paddingIdx, PaddingIdxCheck paddingIdx numEmbeds,
 shape' ~ Reverse (embedSize : Reverse shape)) =>
Embedding paddingIdx numEmbeds embedSize embeddingType dtype device
-> Tensor device 'Int64 shape -> Tensor device dtype shape'
embed Embedding 'Nothing 2048 embedDim 'Constant dtype device
tPosEmbedding
          forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (dtype' :: DType) (dtype :: DType)
       (device :: (DeviceType, Nat)) (shape :: [Nat]) t t'.
(KnownDType dtype', IsUnnamed t device dtype shape, Unnamed t',
 t' ~ ReplaceDType'' t dtype') =>
t -> t'
Torch.Typed.Tensor.toDType @D.Int64
          forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (steps :: Nat) (device :: (DeviceType, Nat)) start end.
(Scalar start, Scalar end, KnownNat steps,
 TensorOptions '[steps] 'Float device) =>
start -> end -> Tensor device 'Float '[steps]
linspace @seqLen (Int
0 :: Int)
          forall a b. (a -> b) -> a -> b
$ forall (n :: Nat). KnownNat n => Int
natValI @(seqLen - 1)
  Tensor device dtype '[batchSize, seqLen, embedDim]
x' <- forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Dropout
-> Bool
-> Tensor device dtype shape
-> IO (Tensor device dtype shape)
dropoutForward Dropout
tDropout Bool
train (Tensor device dtype '[batchSize, seqLen, embedDim]
x forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
       (device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
 shape'' ~ Broadcast shape shape',
 BasicArithmeticDTypeIsValid device dtype,
 BasicArithmeticDTypeIsValid device dtype',
 BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
`add` Tensor device dtype '[batchSize, seqLen, embedDim]
positions)
  let attentionMask :: Tensor device 'Bool '[1, seqLen, seqLen]
attentionMask =
        forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ Unsqueeze shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
unsqueeze @0
          forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (dtype' :: DType) (dtype :: DType)
       (device :: (DeviceType, Nat)) (shape :: [Nat]) t t'.
(KnownDType dtype', IsUnnamed t device dtype shape, Unnamed t',
 t' ~ ReplaceDType'' t dtype') =>
t -> t'
Torch.Typed.Tensor.toDType @D.Bool
          forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(shape ~ MatrixOrMatrixBatch shape) =>
Int -> Tensor device dtype shape -> Tensor device dtype shape
triu Int
1
          forall a b. (a -> b) -> a -> b
$ forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
Tensor device dtype shape
ones @'[seqLen, seqLen] @D.Int8 @device
      attentionMask' :: Maybe (Tensor device dtype '[batchSize, seqLen, seqLen])
attentionMask' =
        forall (f :: Type -> Type) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a (shape :: [Nat]) (shape' :: [Nat]) (shape'' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(Scalar a, shape'' ~ Broadcast shape shape') =>
Tensor device 'Bool shape'
-> a -> Tensor device dtype shape -> Tensor device dtype shape''
maskedFill Tensor device 'Bool '[1, seqLen, seqLen]
attentionMask (-Double
1 forall a. Fractional a => a -> a -> a
/ Double
0 :: Double) forall a b. (a -> b) -> a -> b
$
          forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
Tensor device dtype shape
zeros @'[batchSize, seqLen, seqLen] @dtype @device
  let keyPaddingMask :: Maybe (Tensor device 'Bool '[batchSize, seqLen])
keyPaddingMask = forall (f :: Type -> Type) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Tensor device 'Int64 '[batchSize, seqLen]
xTokens forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
 ComparisonDTypeIsValid device dtype,
 ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
==. (forall a. Num a => Integer -> a
fromInteger forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @paddingIdx :: Tensor device 'D.Int64 '[])
  Tensor device dtype '[batchSize, seqLen, embedDim]
y <- forall {k} {k1} (m :: k -> Type) f acc (xs :: [k1]) (res :: k).
HFoldrM m f acc xs res =>
f -> acc -> HList xs -> m res
hfoldrM (forall (batchSize :: Nat) (seqLen :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Bool
-> Maybe (Tensor device dtype '[batchSize, seqLen, seqLen])
-> Maybe (Tensor device 'Bool '[batchSize, seqLen])
-> FoldLayers batchSize seqLen dtype device
FoldLayers Bool
train Maybe (Tensor device dtype '[batchSize, seqLen, seqLen])
attentionMask' Maybe (Tensor device 'Bool '[batchSize, seqLen])
keyPaddingMask) Tensor device dtype '[batchSize, seqLen, embedDim]
x' HList
  (HReplicateR
     numAttnLayers
     (TransformerLayer
        embedDim embedDim embedDim numHeads ffnDim dtype device))
tLayers
  forall (m :: Type -> Type) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall f a b. HasForward f a b => f -> a -> b
forward Linear embedDim numEmbeds dtype device
tProj Tensor device dtype '[batchSize, seqLen, embedDim]
y

instance
  ( All KnownNat '[paddingIdx, embedDim, seqLen, batchSize],
    paddingIdx + 1 <= numEmbeds,
    1 <= seqLen,
    HFoldrM
      IO
      (FoldLayers batchSize seqLen dtype device)
      (Tensor device dtype '[batchSize, seqLen, embedDim])
      (HReplicateR numAttnLayers (TransformerLayer embedDim embedDim embedDim numHeads ffnDim dtype device))
      (Tensor device dtype '[batchSize, seqLen, embedDim]),
    BasicArithmeticDTypeIsValid device dtype,
    ComparisonDTypeIsValid device dtype,
    ComparisonDTypeIsValid device 'D.Int64,
    KnownDType dtype,
    KnownDevice device
  ) =>
  HasForward (TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device) (Tensor device 'D.Int64 '[batchSize, seqLen]) (Tensor device dtype '[batchSize, seqLen, numEmbeds])
  where
  forward :: TransformerLM
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> Tensor device 'Int64 '[batchSize, seqLen]
-> Tensor device dtype '[batchSize, seqLen, numEmbeds]
forward TransformerLM
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
model Tensor device 'Int64 '[batchSize, seqLen]
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
       (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
       (seqLen :: Nat) (batchSize :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(All KnownNat '[paddingIdx, embedDim, seqLen, batchSize],
 (paddingIdx + 1) <= numEmbeds, 1 <= seqLen,
 HFoldrM
   IO
   (FoldLayers batchSize seqLen dtype device)
   (Tensor device dtype '[batchSize, seqLen, embedDim])
   (HReplicateR
      numAttnLayers
      (TransformerLayer
         embedDim embedDim embedDim numHeads ffnDim dtype device))
   (Tensor device dtype '[batchSize, seqLen, embedDim]),
 BasicArithmeticDTypeIsValid device dtype,
 ComparisonDTypeIsValid device dtype,
 ComparisonDTypeIsValid device 'Int64, KnownDType dtype,
 KnownDevice device) =>
TransformerLM
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> Bool
-> Tensor device 'Int64 '[batchSize, seqLen]
-> IO (Tensor device dtype '[batchSize, seqLen, numEmbeds])
transformerLM TransformerLM
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
model Bool
False Tensor device 'Int64 '[batchSize, seqLen]
input
  forwardStoch :: TransformerLM
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> Tensor device 'Int64 '[batchSize, seqLen]
-> IO (Tensor device dtype '[batchSize, seqLen, numEmbeds])
forwardStoch TransformerLM
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
model Tensor device 'Int64 '[batchSize, seqLen]
input = forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
       (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
       (seqLen :: Nat) (batchSize :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(All KnownNat '[paddingIdx, embedDim, seqLen, batchSize],
 (paddingIdx + 1) <= numEmbeds, 1 <= seqLen,
 HFoldrM
   IO
   (FoldLayers batchSize seqLen dtype device)
   (Tensor device dtype '[batchSize, seqLen, embedDim])
   (HReplicateR
      numAttnLayers
      (TransformerLayer
         embedDim embedDim embedDim numHeads ffnDim dtype device))
   (Tensor device dtype '[batchSize, seqLen, embedDim]),
 BasicArithmeticDTypeIsValid device dtype,
 ComparisonDTypeIsValid device dtype,
 ComparisonDTypeIsValid device 'Int64, KnownDType dtype,
 KnownDevice device) =>
TransformerLM
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> Bool
-> Tensor device 'Int64 '[batchSize, seqLen]
-> IO (Tensor device dtype '[batchSize, seqLen, numEmbeds])
transformerLM TransformerLM
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
model Bool
True Tensor device 'Int64 '[batchSize, seqLen]
input

sinusoidal ::
  forall numEmbeds embedDim device.
  ( All KnownNat '[numEmbeds, embedDim],
    1 <= numEmbeds,
    1 <= Div embedDim 2,
    (Div embedDim 2 * 2) ~ embedDim,
    StandardFloatingPointDTypeValidation device 'D.Float,
    BasicArithmeticDTypeIsValid device 'D.Float,
    KnownDevice device
  ) =>
  Tensor device 'D.Float '[numEmbeds, embedDim]
sinusoidal :: forall (numEmbeds :: Nat) (embedDim :: Nat)
       (device :: (DeviceType, Nat)).
(All KnownNat '[numEmbeds, embedDim], 1 <= numEmbeds,
 1 <= Div embedDim 2, (Div embedDim 2 * 2) ~ embedDim,
 StandardFloatingPointDTypeValidation device 'Float,
 BasicArithmeticDTypeIsValid device 'Float, KnownDevice device) =>
Tensor device 'Float '[numEmbeds, embedDim]
sinusoidal =
  let positions :: Tensor device 'Float '[numEmbeds, 1]
positions =
        forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ Unsqueeze shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
unsqueeze @1
          forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (steps :: Nat) (device :: (DeviceType, Nat)) start end.
(Scalar start, Scalar end, KnownNat steps,
 TensorOptions '[steps] 'Float device) =>
start -> end -> Tensor device 'Float '[steps]
linspace @numEmbeds (Int
0 :: Int)
          forall a b. (a -> b) -> a -> b
$ forall (n :: Nat). KnownNat n => Int
natValI @(numEmbeds - 1)
      scalingFactors :: Tensor device 'Float '[Div embedDim 2]
scalingFactors =
        forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
exp
          forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
mulScalar (- forall a. Floating a => a -> a
log (Double
10000 :: Double) forall a. Fractional a => a -> a -> a
/ (forall a. Num a => Integer -> a
fromInteger forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @(Div embedDim 2)))
          forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (steps :: Nat) (device :: (DeviceType, Nat)) start end.
(Scalar start, Scalar end, KnownNat steps,
 TensorOptions '[steps] 'Float device) =>
start -> end -> Tensor device 'Float '[steps]
linspace @(Div embedDim 2) (Int
0 :: Int)
          forall a b. (a -> b) -> a -> b
$ forall (n :: Nat). KnownNat n => Int
natValI @((Div embedDim 2) - 1)
      radians :: Tensor device 'Float '[numEmbeds, Div embedDim 2]
radians = forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
       (device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
 shape'' ~ Broadcast shape shape',
 BasicArithmeticDTypeIsValid device dtype,
 BasicArithmeticDTypeIsValid device dtype',
 BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
mul Tensor device 'Float '[numEmbeds, 1]
positions Tensor device 'Float '[Div embedDim 2]
scalingFactors
      weights :: Tensor device 'Float '[numEmbeds, Div embedDim 2, 2]
weights = forall {k} (dim :: Nat) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)) (tensors :: [k]).
(KnownNat dim, '(shape, dtype, device) ~ Stack dim tensors,
 Castable (HList tensors) [ATenTensor]) =>
HList tensors -> Tensor device dtype shape
stack @2 (forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)) t.
(StandardFloatingPointDTypeValidation device dtype,
 IsUnnamed t device dtype shape) =>
t -> t
sin Tensor device 'Float '[numEmbeds, Div embedDim 2]
radians forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
:. forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)) t.
(StandardFloatingPointDTypeValidation device dtype,
 IsUnnamed t device dtype shape) =>
t -> t
cos Tensor device 'Float '[numEmbeds, Div embedDim 2]
radians forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
:. forall k. HList '[]
HNil)
   in forall (shape' :: [Nat]) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(KnownShape shape', Numel shape ~ Numel shape') =>
Tensor device dtype shape -> Tensor device dtype shape'
reshape Tensor device 'Float '[numEmbeds, Div embedDim 2, 2]
weights

instance
  ( paddingIdx <= numEmbeds,
    1 <= numEmbeds - paddingIdx,
    1 <= Div embedDim 2,
    (((numEmbeds - paddingIdx) - 1) + (1 + paddingIdx)) ~ numEmbeds,
    (Div embedDim 2 * 2) ~ embedDim,
    All KnownNat '[ffnDim, paddingIdx, numEmbeds, embedDim],
    HReplicate numAttnLayers (TransformerLayerSpec embedDim embedDim embedDim numHeads ffnDim dtype device),
    A.Randomizable
      (HList (HReplicateR numAttnLayers (TransformerLayerSpec embedDim embedDim embedDim numHeads ffnDim dtype device)))
      (HList (HReplicateR numAttnLayers (TransformerLayer embedDim embedDim embedDim numHeads ffnDim dtype device))),
    KnownDType dtype,
    RandDTypeIsValid device dtype,
    StandardFloatingPointDTypeValidation device 'D.Float,
    BasicArithmeticDTypeIsValid device 'D.Float,
    KnownDevice device
  ) =>
  A.Randomizable
    (TransformerLMSpec numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device)
    (TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device)
  where
  sample :: TransformerLMSpec
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> IO
     (TransformerLM
        numAttnLayers
        numHeads
        ffnDim
        paddingIdx
        numEmbeds
        embedDim
        dtype
        device)
sample TransformerLMSpec {DropoutSpec
TransformerLayerSpec
  embedDim embedDim embedDim numHeads ffnDim dtype device
lmLayerSpec :: TransformerLayerSpec
  embedDim embedDim embedDim numHeads ffnDim dtype device
lmDropoutSpec :: DropoutSpec
lmLayerSpec :: forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
       (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLMSpec
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> TransformerLayerSpec
     embedDim embedDim embedDim numHeads ffnDim dtype device
lmDropoutSpec :: forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
       (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLMSpec
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> DropoutSpec
..} =
    forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
       (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Embedding
  ('Just paddingIdx) numEmbeds embedDim 'Learned dtype device
-> Embedding 'Nothing 2048 embedDim 'Constant dtype device
-> Dropout
-> HList
     (HReplicateR
        numAttnLayers
        (TransformerLayer
           embedDim embedDim embedDim numHeads ffnDim dtype device))
-> Linear embedDim numEmbeds dtype device
-> TransformerLM
     numAttnLayers
     numHeads
     ffnDim
     paddingIdx
     numEmbeds
     embedDim
     dtype
     device
TransformerLM
      forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> forall spec f. Randomizable spec f => spec -> IO f
A.sample (forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
       (embedSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
EmbeddingSpec paddingIdx numEmbeds embedSize 'Learned dtype device
LearnedEmbeddingWithRandomInitSpec @('Just paddingIdx))
      forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> forall spec f. Randomizable spec f => spec -> IO f
A.sample (forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
       (embedSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Tensor device dtype '[numEmbeds, embedSize]
-> EmbeddingSpec
     paddingIdx numEmbeds embedSize 'Constant dtype device
ConstEmbeddingSpec @'Nothing (forall (dtype' :: DType) (dtype :: DType)
       (device :: (DeviceType, Nat)) (shape :: [Nat]) t t'.
(KnownDType dtype', IsUnnamed t device dtype shape, Unnamed t',
 t' ~ ReplaceDType'' t dtype') =>
t -> t'
Torch.Typed.Tensor.toDType forall (numEmbeds :: Nat) (embedDim :: Nat)
       (device :: (DeviceType, Nat)).
(All KnownNat '[numEmbeds, embedDim], 1 <= numEmbeds,
 1 <= Div embedDim 2, (Div embedDim 2 * 2) ~ embedDim,
 StandardFloatingPointDTypeValidation device 'Float,
 BasicArithmeticDTypeIsValid device 'Float, KnownDevice device) =>
Tensor device 'Float '[numEmbeds, embedDim]
sinusoidal))
      forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> forall spec f. Randomizable spec f => spec -> IO f
A.sample DropoutSpec
lmDropoutSpec
      forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> forall spec f. Randomizable spec f => spec -> IO f
A.sample (forall (n :: Nat) e. HReplicate n e => e -> HList (HReplicateR n e)
hreplicate @numAttnLayers TransformerLayerSpec
  embedDim embedDim embedDim numHeads ffnDim dtype device
lmLayerSpec)
      forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> forall spec f. Randomizable spec f => spec -> IO f
A.sample forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
LinearSpec inputFeatures outputFeatures dtype device
LinearSpec