{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}

module Torch.GraduallyTyped.NN.Transformer.Generation where

import Control.Lens (Lens)
import Control.Monad.Catch (MonadThrow)
import Control.Monad.State (MonadState (..))
import Data.Function (fix)
import Foreign.ForeignPtr (ForeignPtr)
import Torch.GraduallyTyped.DType (DType (..), DataType (..))
import Torch.GraduallyTyped.Index.Type (Index (NegativeIndex), SIndex (..))
import Torch.GraduallyTyped.NN.Transformer.GEncoderDecoder (SimplifiedEncoderDecoderTransformerGenerationInput (..), SimplifiedEncoderDecoderTransformerOutput (..))
import Torch.GraduallyTyped.Prelude (Catch, pattern (:|:))
import Torch.GraduallyTyped.Prelude.List (SList (SNil))
import Torch.GraduallyTyped.RequiresGradient (Gradient (..), RequiresGradient (..), SGradient (..), SRequiresGradient (..))
import Torch.GraduallyTyped.Shape.Class (BroadcastShapesF)
import Torch.GraduallyTyped.Shape.Type (By (..), SBy (..), SSelectDim (..), SelectDim (..), Shape (..))
import Torch.GraduallyTyped.Tensor.Indexing (IndexDims, IndexType (..), Indices (..), SIndexType (..), SIndices (..), (!))
import Torch.GraduallyTyped.Tensor.IndexingSlicingJoining (CatHListF, HasCat (..), SqueezeDimF, UnsqueezeF, sSqueezeDim, sUnsqueeze)
import Torch.GraduallyTyped.Tensor.MathOperations.Comparison ((/=.), (==.))
import Torch.GraduallyTyped.Tensor.MathOperations.Pointwise (mul, mulScalar, sub, subScalar)
import Torch.GraduallyTyped.Tensor.MathOperations.Reduction (ArgmaxF, MaxAllCheckF, argmax, maxAll)
import Torch.GraduallyTyped.Tensor.Type (SGetDataType (..), SGetDevice (..), SGetDim, SGetLayout (..), SGetShape (..), Tensor, TensorLike (..), sCheckedShape, sSetDataType)
import Torch.GraduallyTyped.Unify (type (<+>), type (<|>))
import Torch.HList (HList (HNil), pattern (:.))
import qualified Torch.Internal.Class as ATen
import qualified Torch.Internal.Type as ATen
import Prelude hiding (all)

decode ::
  Monad m =>
  (x -> s -> m (Maybe (x, s))) ->
  x ->
  s ->
  m (x, s)
decode :: forall (m :: * -> *) x s.
Monad m =>
(x -> s -> m (Maybe (x, s))) -> x -> s -> m (x, s)
decode x -> s -> m (Maybe (x, s))
f x
x s
s = do
  forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a. (a -> a) -> a
fix (x
x, s
s) forall a b. (a -> b) -> a -> b
$ \(x, s) -> m (x, s)
loop (x
x', s
s') -> do
    Maybe (x, s)
r <- x -> s -> m (Maybe (x, s))
f x
x' s
s'
    case Maybe (x, s)
r of
      Maybe (x, s)
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (x
x', s
s')
      Just (x
x'', s
s'') -> (x, s) -> m (x, s)
loop (x
x'', s
s'')

sedtOutputToInput ::
  Monad m =>
  Lens
    (SimplifiedEncoderDecoderTransformerOutput logits encoderOutput decoderInput inputPaddingMask)
    (m (SimplifiedEncoderDecoderTransformerGenerationInput decoderInput' encoderOutput inputPaddingMask))
    (logits, decoderInput)
    (m decoderInput')
sedtOutputToInput :: forall (m :: * -> *) logits encoderOutput decoderInput
       inputPaddingMask decoderInput'.
Monad m =>
Lens
  (SimplifiedEncoderDecoderTransformerOutput
     logits encoderOutput decoderInput inputPaddingMask)
  (m (SimplifiedEncoderDecoderTransformerGenerationInput
        decoderInput' encoderOutput inputPaddingMask))
  (logits, decoderInput)
  (m decoderInput')
sedtOutputToInput (logits, decoderInput) -> f (m decoderInput')
f SimplifiedEncoderDecoderTransformerOutput {logits
encoderOutput
decoderInput
inputPaddingMask
sedtInputPaddingMask :: forall decoderOutput encoderOutput decoderInput inputPaddingMask.
SimplifiedEncoderDecoderTransformerOutput
  decoderOutput encoderOutput decoderInput inputPaddingMask
-> inputPaddingMask
sedtOriginalDecoderInput :: forall decoderOutput encoderOutput decoderInput inputPaddingMask.
SimplifiedEncoderDecoderTransformerOutput
  decoderOutput encoderOutput decoderInput inputPaddingMask
-> decoderInput
sedtEncoderOutput :: forall decoderOutput encoderOutput decoderInput inputPaddingMask.
SimplifiedEncoderDecoderTransformerOutput
  decoderOutput encoderOutput decoderInput inputPaddingMask
-> encoderOutput
sedtDecoderOutput :: forall decoderOutput encoderOutput decoderInput inputPaddingMask.
SimplifiedEncoderDecoderTransformerOutput
  decoderOutput encoderOutput decoderInput inputPaddingMask
-> decoderOutput
sedtInputPaddingMask :: inputPaddingMask
sedtOriginalDecoderInput :: decoderInput
sedtEncoderOutput :: encoderOutput
sedtDecoderOutput :: logits
..} =
  ( \m decoderInput'
decoderInput' ->
      forall decoderInput encoderOutput inputPaddingMask.
decoderInput
-> encoderOutput
-> inputPaddingMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
     decoderInput encoderOutput inputPaddingMask
SimplifiedEncoderDecoderTransformerGenerationInput
        forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m decoderInput'
decoderInput' forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure encoderOutput
sedtEncoderOutput forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure inputPaddingMask
sedtInputPaddingMask
  )
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (logits, decoderInput) -> f (m decoderInput')
f (logits
sedtDecoderOutput, decoderInput
sedtOriginalDecoderInput)

prepNext ::
  ( logits ~ Tensor logitsGradient logitsLayout logitsDevice logitsDataType logitsShape,
    ntShape' ~ UnsqueezeF ('SelectDim ('ByIndex 1)) ntShape,
    Catch ntShape',
    tensors ~ '[decoderInput, Tensor ntGradient ntLayout ntDevice ntDataType ntShape'],
    decoderInput' ~ CatHListF ('SelectDim ('ByIndex 1)) tensors,
    ATen.Castable decoderInput' (ForeignPtr ATen.Tensor),
    ATen.Castable (HList tensors) (ForeignPtr ATen.TensorList),
    MonadThrow m
  ) =>
  Lens
    (logits, decoderInput)
    (m decoderInput')
    logits
    (m (Tensor ntGradient ntLayout ntDevice ntDataType ntShape))
prepNext :: forall logits (logitsGradient :: Gradient RequiresGradient)
       (logitsLayout :: Layout LayoutType)
       (logitsDevice :: Device (DeviceType Nat))
       (logitsDataType :: DataType DType)
       (logitsShape :: Shape [Dim (Name Symbol) (Size Nat)])
       (ntShape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (ntShape :: Shape [Dim (Name Symbol) (Size Nat)]) (tensors :: [*])
       decoderInput (ntGradient :: Gradient RequiresGradient)
       (ntLayout :: Layout LayoutType)
       (ntDevice :: Device (DeviceType Nat))
       (ntDataType :: DataType DType) decoderInput' (m :: * -> *).
(logits
 ~ Tensor
     logitsGradient
     logitsLayout
     logitsDevice
     logitsDataType
     logitsShape,
 ntShape' ~ UnsqueezeF ('SelectDim ('ByIndex 1)) ntShape,
 Catch ntShape',
 tensors
 ~ '[decoderInput,
     Tensor ntGradient ntLayout ntDevice ntDataType ntShape'],
 decoderInput' ~ CatHListF ('SelectDim ('ByIndex 1)) tensors,
 Castable decoderInput' (ForeignPtr Tensor),
 Castable (HList tensors) (ForeignPtr TensorList), MonadThrow m) =>
Lens
  (logits, decoderInput)
  (m decoderInput')
  logits
  (m (Tensor ntGradient ntLayout ntDevice ntDataType ntShape))
prepNext logits
-> f (m (Tensor ntGradient ntLayout ntDevice ntDataType ntShape))
f (logits
logits, decoderInput
decoderInput) =
  ( \m (Tensor ntGradient ntLayout ntDevice ntDataType ntShape)
nextTokens -> do
      Tensor ntGradient ntLayout ntDevice ntDataType ntShape
nextTokens' <- m (Tensor ntGradient ntLayout ntDevice ntDataType ntShape)
nextTokens
      Tensor ntGradient ntLayout ntDevice ntDataType ntShape'
nextTokens'' <- forall (selectDim :: SelectDim (By Symbol Nat))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(shape' ~ UnsqueezeF selectDim shape, Catch shape',
 MonadThrow m) =>
SSelectDim selectDim
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
sUnsqueeze (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim forall a b. (a -> b) -> a -> b
$ forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @1) Tensor ntGradient ntLayout ntDevice ntDataType ntShape
nextTokens'
      forall (selectDim :: SelectDim (By Symbol Nat)) k (c :: k -> *)
       (a :: k) (m :: * -> *).
(HasCat selectDim k c a, MonadThrow m) =>
SSelectDim selectDim -> c a -> m (CatF selectDim a c)
sCat (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim forall a b. (a -> b) -> a -> b
$ forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @1) (decoderInput
decoderInput forall x (xs :: [*]). x -> HList xs -> HList (x : xs)
:. Tensor ntGradient ntLayout ntDevice ntDataType ntShape'
nextTokens'' forall x (xs :: [*]). x -> HList xs -> HList (x : xs)
:. forall k. HList '[]
HNil)
  )
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> logits
-> f (m (Tensor ntGradient ntLayout ntDevice ntDataType ntShape))
f logits
logits

greedyNextTokens ::
  ( nextTokenLogitsShape ~ IndexDims ('Indices '[ 'SliceAll, 'SliceAt ('NegativeIndex 1), 'SliceAll]) logitsShape,
    nextTokensShape ~ ArgmaxF ('SelectDim ('ByIndex 1)) nextTokenLogitsShape,
    Catch nextTokensShape,
    nextTokensShape' ~ SqueezeDimF ('SelectDim ('ByIndex 1)) nextTokensShape,
    ntShape ~ 'Shape '[ntDim],
    Catch (nextTokensShape' <+> ntShape),
    SGetShape nextTokensShape',
    SGetDim ntDim,
    Catch ntDim,
    Catch nextTokensShape',
    MonadThrow m,
    MonadState (Tensor ('Gradient 'WithoutGradient) logitsLayout logitsDevice ('DataType 'Int64) ntShape) m,
    SGetDevice logitsDevice,
    SGetLayout logitsLayout
  ) =>
  Int ->
  Int ->
  Tensor logitsGradient logitsLayout logitsDevice logitsDataType logitsShape ->
  m (Tensor ('Gradient 'WithoutGradient) logitsLayout logitsDevice ('DataType 'Int64) ntShape)
greedyNextTokens :: forall (nextTokenLogitsShape :: Shape
                                  [Dim (Name Symbol) (Size Nat)])
       (logitsShape :: Shape [Dim (Name Symbol) (Size Nat)])
       (nextTokensShape :: Shape [Dim (Name Symbol) (Size Nat)])
       (nextTokensShape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (ntShape :: Shape [Dim (Name Symbol) (Size Nat)])
       (ntDim :: Dim (Name Symbol) (Size Nat)) (m :: * -> *)
       (logitsLayout :: Layout LayoutType)
       (logitsDevice :: Device (DeviceType Nat))
       (logitsGradient :: Gradient RequiresGradient)
       (logitsDataType :: DataType DType).
(nextTokenLogitsShape
 ~ IndexDims
     ('Indices '[ 'SliceAll, 'SliceAt ('NegativeIndex 1), 'SliceAll])
     logitsShape,
 nextTokensShape
 ~ ArgmaxF ('SelectDim ('ByIndex 1)) nextTokenLogitsShape,
 Catch nextTokensShape,
 nextTokensShape'
 ~ SqueezeDimF ('SelectDim ('ByIndex 1)) nextTokensShape,
 ntShape ~ 'Shape '[ntDim], Catch (nextTokensShape' <+> ntShape),
 SGetShape nextTokensShape', SGetDim ntDim, Catch ntDim,
 Catch nextTokensShape', MonadThrow m,
 MonadState
   (Tensor
      ('Gradient 'WithoutGradient)
      logitsLayout
      logitsDevice
      ('DataType 'Int64)
      ntShape)
   m,
 SGetDevice logitsDevice, SGetLayout logitsLayout) =>
Int
-> Int
-> Tensor
     logitsGradient logitsLayout logitsDevice logitsDataType logitsShape
-> m (Tensor
        ('Gradient 'WithoutGradient)
        logitsLayout
        logitsDevice
        ('DataType 'Int64)
        ntShape)
greedyNextTokens Int
padTokenId Int
eosTokenId Tensor
  logitsGradient logitsLayout logitsDevice logitsDataType logitsShape
logits = do
  Tensor
  logitsGradient
  logitsLayout
  logitsDevice
  logitsDataType
  (IndexDims
     ('Indices '[ 'SliceAll, 'SliceAt ('NegativeIndex 1), 'SliceAll])
     logitsShape)
nextTokenLogits <- Tensor
  logitsGradient logitsLayout logitsDevice logitsDataType logitsShape
logits forall (indices :: Indices [IndexType (Index Nat)])
       (requiresGradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
MonadThrow m =>
Tensor requiresGradient layout device dataType shape
-> SIndices indices
-> m (Tensor
        requiresGradient layout device dataType (IndexDims indices shape))
! forall (indexTypes :: [IndexType (Index Nat)]).
SList indexTypes -> SIndices ('Indices indexTypes)
SIndices (forall a. SIndexType 'SliceAll
SSliceAll forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall a (n :: a). Sing n -> SIndexType ('SliceAt n)
SSliceAt (forall (index1 :: Nat).
KnownNat index1 =>
SIndex ('NegativeIndex index1)
SNegativeIndex @1) forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall a. SIndexType 'SliceAll
SSliceAll forall {k} (a :: k) (as :: [k]).
Sing a -> SList as -> SList (a : as)
:|: forall a. SList '[]
SNil)
  Tensor
  ('Gradient 'WithoutGradient)
  logitsLayout
  logitsDevice
  ('DataType 'Int64)
  nextTokensShape
nextTokens <- forall (selectDims :: SelectDim (By Symbol Nat))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, shape' ~ ArgmaxF selectDims shape, Catch shape') =>
SSelectDim selectDims
-> Tensor gradient layout device dataType shape
-> m (Tensor
        ('Gradient 'WithoutGradient)
        layout
        device
        ('DataType 'Int64)
        shape')
argmax (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim forall a b. (a -> b) -> a -> b
$ forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @1) Tensor
  logitsGradient
  logitsLayout
  logitsDevice
  logitsDataType
  (IndexDims
     ('Indices '[ 'SliceAll, 'SliceAt ('NegativeIndex 1), 'SliceAll])
     logitsShape)
nextTokenLogits
  Tensor
  ('Gradient 'WithoutGradient)
  logitsLayout
  logitsDevice
  ('DataType 'Int64)
  nextTokensShape'
nextTokens' <- forall (selectDim :: SelectDim (By Symbol Nat))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, shape' ~ SqueezeDimF selectDim shape,
 Catch shape') =>
SSelectDim selectDim
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
sSqueezeDim (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim forall a b. (a -> b) -> a -> b
$ forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @1) Tensor
  ('Gradient 'WithoutGradient)
  logitsLayout
  logitsDevice
  ('DataType 'Int64)
  nextTokensShape
nextTokens
  Tensor
  ('Gradient 'WithoutGradient)
  logitsLayout
  logitsDevice
  ('DataType 'Int64)
  ('Shape '[ntDim])
unfinishedSequences <- forall s (m :: * -> *). MonadState s m => m s
get
  let usShape :: SShape ('Shape '[ntDim])
usShape = forall (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType).
SGetShape shape =>
Tensor gradient layout device dataType shape -> SShape shape
sGetShape Tensor
  ('Gradient 'WithoutGradient)
  logitsLayout
  logitsDevice
  ('DataType 'Int64)
  ('Shape '[ntDim])
unfinishedSequences
  Tensor
  ('Gradient 'WithoutGradient)
  logitsLayout
  logitsDevice
  ('DataType 'Int64)
  ('Shape '[ntDim])
nextTokens'' <- forall (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (m :: * -> *) (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(SGetShape shape, MonadThrow m, Catch (shape <+> shape')) =>
SShape shape'
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
sCheckedShape SShape ('Shape '[ntDim])
usShape Tensor
  ('Gradient 'WithoutGradient)
  logitsLayout
  logitsDevice
  ('DataType 'Int64)
  nextTokensShape'
nextTokens'
  Tensor
  ('Gradient 'WithoutGradient)
  logitsLayout
  logitsDevice
  ('DataType 'Int64)
  ntShape
nextTokens''' <- forall (m :: * -> *) (kntDataType :: DataType DType)
       (usDataType :: DataType DType) (ntDataType :: DataType DType)
       (kntShape :: Shape [Dim (Name Symbol) (Size Nat)])
       (usShape :: Shape [Dim (Name Symbol) (Size Nat)])
       (ntShape :: Shape [Dim (Name Symbol) (Size Nat)])
       (ntGradient' :: Gradient RequiresGradient)
       (usGradient :: Gradient RequiresGradient)
       (ntGradient :: Gradient RequiresGradient)
       (ntLayout' :: Layout LayoutType) (usLayout :: Layout LayoutType)
       (ntLayout :: Layout LayoutType)
       (ntDevice' :: Device (DeviceType Nat))
       (usDevice :: Device (DeviceType Nat))
       (ntDevice :: Device (DeviceType Nat))
       (ntDataType' :: DataType DType)
       (ntShape' :: Shape [Dim (Name Symbol) (Size Nat)]).
(MonadThrow m, kntDataType ~ (usDataType <+> ntDataType),
 kntShape ~ BroadcastShapesF usShape ntShape, Catch kntShape,
 ntGradient' ~ (usGradient <|> ntGradient),
 ntLayout' ~ ((usLayout <+> ntLayout) <+> usLayout),
 ntDevice' ~ ((usDevice <+> ntDevice) <+> usDevice),
 ntDataType' ~ ((usDataType <+> ntDataType) <+> usDataType),
 ntShape' ~ BroadcastShapesF kntShape usShape, Catch ntShape') =>
Int
-> Tensor usGradient usLayout usDevice usDataType usShape
-> Tensor ntGradient ntLayout ntDevice ntDataType ntShape
-> m (Tensor ntGradient' ntLayout' ntDevice' ntDataType' ntShape')
applyUnfinishedSequences Int
padTokenId Tensor
  ('Gradient 'WithoutGradient)
  logitsLayout
  logitsDevice
  ('DataType 'Int64)
  ('Shape '[ntDim])
unfinishedSequences Tensor
  ('Gradient 'WithoutGradient)
  logitsLayout
  logitsDevice
  ('DataType 'Int64)
  ('Shape '[ntDim])
nextTokens''
  Tensor
  (Or
     (Gradient RequiresGradient)
     ('Gradient 'WithoutGradient)
     ('Gradient 'WithoutGradient))
  (Unify (Layout LayoutType) logitsLayout logitsLayout)
  (Unify (Device (DeviceType Nat)) logitsDevice logitsDevice)
  ('DataType 'Int64)
  ('Shape '[ntDim])
unfinishedSequences' <- forall (ntDataType :: DataType DType)
       (ntShape :: Shape [Dim (Name Symbol) (Size Nat)])
       (usShape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (usDataType :: DataType DType)
       (ntDevice :: Device (DeviceType Nat))
       (ntLayout :: Layout LayoutType) (m :: * -> *)
       (usShape :: Shape [Dim (Name Symbol) (Size Nat)])
       (usGradient :: Gradient RequiresGradient)
       (usGradient' :: Gradient RequiresGradient)
       (usLayout :: Layout LayoutType) (usLayout' :: Layout LayoutType)
       (usDevice :: Device (DeviceType Nat))
       (usDevice' :: Device (DeviceType Nat))
       (ntGradient :: Gradient RequiresGradient).
(Catch (ntDataType <+> 'DataType 'Int64),
 Catch (BroadcastShapesF ntShape ('Shape '[])), Catch usShape',
 SGetDataType usDataType, SGetDevice ntDevice, SGetLayout ntLayout,
 MonadThrow m,
 BroadcastShapesF usShape (BroadcastShapesF ntShape ('Shape '[]))
 ~ usShape',
 (usGradient <|> 'Gradient 'WithoutGradient) ~ usGradient',
 (usLayout <+> ntLayout) ~ usLayout',
 (usDevice <+> ntDevice) ~ usDevice') =>
Int
-> Tensor ntGradient ntLayout ntDevice ntDataType ntShape
-> Tensor usGradient usLayout usDevice usDataType usShape
-> m (Tensor usGradient' usLayout' usDevice' usDataType usShape')
updateUnfinishedSequences Int
eosTokenId Tensor
  ('Gradient 'WithoutGradient)
  logitsLayout
  logitsDevice
  ('DataType 'Int64)
  ntShape
nextTokens''' Tensor
  ('Gradient 'WithoutGradient)
  logitsLayout
  logitsDevice
  ('DataType 'Int64)
  ('Shape '[ntDim])
unfinishedSequences
  forall s (m :: * -> *). MonadState s m => s -> m ()
put Tensor
  (Or
     (Gradient RequiresGradient)
     ('Gradient 'WithoutGradient)
     ('Gradient 'WithoutGradient))
  (Unify (Layout LayoutType) logitsLayout logitsLayout)
  (Unify (Device (DeviceType Nat)) logitsDevice logitsDevice)
  ('DataType 'Int64)
  ('Shape '[ntDim])
unfinishedSequences'
  forall (f :: * -> *) a. Applicative f => a -> f a
pure Tensor
  ('Gradient 'WithoutGradient)
  logitsLayout
  logitsDevice
  ('DataType 'Int64)
  ntShape
nextTokens'''

allSequencesFinished ::
  ( SGetLayout usLayout,
    SGetDevice usDevice,
    MonadThrow m,
    Catch (usDataType <+> 'DataType 'Int64),
    Catch (BroadcastShapesF usShape ('Shape '[])),
    MaxAllCheckF usShape
  ) =>
  Tensor usGradient usLayout usDevice usDataType usShape ->
  m Bool
allSequencesFinished :: forall (usLayout :: Layout LayoutType)
       (usDevice :: Device (DeviceType Nat)) (m :: * -> *)
       (usDataType :: DataType DType)
       (usShape :: Shape [Dim (Name Symbol) (Size Nat)])
       (usGradient :: Gradient RequiresGradient).
(SGetLayout usLayout, SGetDevice usDevice, MonadThrow m,
 Catch (usDataType <+> 'DataType 'Int64),
 Catch (BroadcastShapesF usShape ('Shape '[])),
 MaxAllCheckF usShape) =>
Tensor usGradient usLayout usDevice usDataType usShape -> m Bool
allSequencesFinished Tensor usGradient usLayout usDevice usDataType usShape
unfinishedSequences = do
  let gradient :: SGradient ('Gradient 'WithoutGradient)
gradient = forall (requiresGradient :: RequiresGradient).
SRequiresGradient requiresGradient
-> SGradient ('Gradient requiresGradient)
SGradient SRequiresGradient 'WithoutGradient
SWithoutGradient
      layout :: SLayout usLayout
layout = forall (layout :: Layout LayoutType)
       (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGetLayout layout =>
Tensor gradient layout device dataType shape -> SLayout layout
sGetLayout Tensor usGradient usLayout usDevice usDataType usShape
unfinishedSequences
      device :: SDevice usDevice
device = forall (device :: Device (DeviceType Nat))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGetDevice device =>
Tensor gradient layout device dataType shape -> SDevice device
sGetDevice Tensor usGradient usLayout usDevice usDataType usShape
unfinishedSequences
  Tensor
  ('Gradient 'WithoutGradient)
  usLayout
  usDevice
  ('DataType 'Int64)
  ('Shape '[])
zero <- forall a (dType :: DType) (dims :: [Dim (Name Symbol) (Size Nat)])
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (m :: * -> *).
(TensorLike a dType dims, MonadThrow m) =>
SGradient gradient
-> SLayout layout
-> SDevice device
-> a
-> m (Tensor
        gradient layout device ('DataType dType) ('Shape dims))
sToTensor SGradient ('Gradient 'WithoutGradient)
gradient SLayout usLayout
layout SDevice usDevice
device (Int
0 :: Int)
  Tensor
  ('Gradient 'WithoutGradient)
  usLayout
  usDevice
  ('DataType 'Bool)
  ('Shape '[])
isZero <- forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
MaxAllCheckF shape =>
Tensor gradient layout device dataType shape
-> Tensor gradient layout device dataType ('Shape '[])
maxAll Tensor usGradient usLayout usDevice usDataType usShape
unfinishedSequences forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient' :: Gradient RequiresGradient)
       (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
       (dataType' :: DataType DType)
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, Catch (dataType <+> dataType'),
 shape'' ~ BroadcastShapesF shape shape', Catch shape'') =>
Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> m (Tensor
        ('Gradient 'WithoutGradient)
        (layout <+> layout')
        (device <+> device')
        ('DataType 'Bool)
        shape'')
==. Tensor
  ('Gradient 'WithoutGradient)
  usLayout
  usDevice
  ('DataType 'Int64)
  ('Shape '[])
zero
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a (dType :: DType) (dims :: [Dim (Name Symbol) (Size Nat)])
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)).
TensorLike a dType dims =>
Tensor gradient layout device ('DataType dType) ('Shape dims) -> a
fromTensor Tensor
  ('Gradient 'WithoutGradient)
  usLayout
  usDevice
  ('DataType 'Bool)
  ('Shape '[])
isZero

applyUnfinishedSequences ::
  ( MonadThrow m,
    kntDataType ~ (usDataType <+> ntDataType),
    kntShape ~ BroadcastShapesF usShape ntShape,
    Catch kntShape,
    ntGradient' ~ (usGradient <|> ntGradient),
    ntLayout' ~ ((usLayout <+> ntLayout) <+> usLayout),
    ntDevice' ~ ((usDevice <+> ntDevice) <+> usDevice),
    ntDataType' ~ ((usDataType <+> ntDataType) <+> usDataType),
    ntShape' ~ BroadcastShapesF kntShape usShape,
    Catch ntShape'
  ) =>
  Int ->
  Tensor usGradient usLayout usDevice usDataType usShape ->
  Tensor ntGradient ntLayout ntDevice ntDataType ntShape ->
  m (Tensor ntGradient' ntLayout' ntDevice' ntDataType' ntShape')
applyUnfinishedSequences :: forall (m :: * -> *) (kntDataType :: DataType DType)
       (usDataType :: DataType DType) (ntDataType :: DataType DType)
       (kntShape :: Shape [Dim (Name Symbol) (Size Nat)])
       (usShape :: Shape [Dim (Name Symbol) (Size Nat)])
       (ntShape :: Shape [Dim (Name Symbol) (Size Nat)])
       (ntGradient' :: Gradient RequiresGradient)
       (usGradient :: Gradient RequiresGradient)
       (ntGradient :: Gradient RequiresGradient)
       (ntLayout' :: Layout LayoutType) (usLayout :: Layout LayoutType)
       (ntLayout :: Layout LayoutType)
       (ntDevice' :: Device (DeviceType Nat))
       (usDevice :: Device (DeviceType Nat))
       (ntDevice :: Device (DeviceType Nat))
       (ntDataType' :: DataType DType)
       (ntShape' :: Shape [Dim (Name Symbol) (Size Nat)]).
(MonadThrow m, kntDataType ~ (usDataType <+> ntDataType),
 kntShape ~ BroadcastShapesF usShape ntShape, Catch kntShape,
 ntGradient' ~ (usGradient <|> ntGradient),
 ntLayout' ~ ((usLayout <+> ntLayout) <+> usLayout),
 ntDevice' ~ ((usDevice <+> ntDevice) <+> usDevice),
 ntDataType' ~ ((usDataType <+> ntDataType) <+> usDataType),
 ntShape' ~ BroadcastShapesF kntShape usShape, Catch ntShape') =>
Int
-> Tensor usGradient usLayout usDevice usDataType usShape
-> Tensor ntGradient ntLayout ntDevice ntDataType ntShape
-> m (Tensor ntGradient' ntLayout' ntDevice' ntDataType' ntShape')
applyUnfinishedSequences Int
padTokenId Tensor usGradient usLayout usDevice usDataType usShape
unfinishedSequences Tensor ntGradient ntLayout ntDevice ntDataType ntShape
nextTokens = do
  Tensor
  ntGradient'
  (Unify (Layout LayoutType) usLayout ntLayout)
  (Unify (Device (DeviceType Nat)) usDevice ntDevice)
  kntDataType
  kntShape
keptNextTokens <- Tensor usGradient usLayout usDevice usDataType usShape
unfinishedSequences forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient' :: Gradient RequiresGradient)
       (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
       (dataType' :: DataType DType)
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, shape'' ~ BroadcastShapesF shape shape',
 Catch shape'') =>
Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> m (Tensor
        (gradient <|> gradient')
        (layout <+> layout')
        (device <+> device')
        (dataType <+> dataType')
        shape'')
`mul` Tensor ntGradient ntLayout ntDevice ntDataType ntShape
nextTokens
  Tensor usGradient usLayout usDevice usDataType usShape
replacedNextTokens <- do
    Tensor usGradient usLayout usDevice usDataType usShape
finishedSequences <- Tensor usGradient usLayout usDevice usDataType usShape
unfinishedSequences forall other (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(Scalar other, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> other -> m (Tensor gradient layout device dataType shape)
`subScalar` (Int
1 :: Int)
    Tensor usGradient usLayout usDevice usDataType usShape
finishedSequences forall other (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(Scalar other, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> other -> m (Tensor gradient layout device dataType shape)
`mulScalar` Int
padTokenId
  Tensor
  ntGradient'
  (Unify (Layout LayoutType) usLayout ntLayout)
  (Unify (Device (DeviceType Nat)) usDevice ntDevice)
  kntDataType
  kntShape
keptNextTokens forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient' :: Gradient RequiresGradient)
       (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
       (dataType' :: DataType DType)
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, shape'' ~ BroadcastShapesF shape shape',
 Catch shape'') =>
Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> m (Tensor
        (gradient <|> gradient)
        (layout <+> layout')
        (device <+> device')
        (dataType <+> dataType')
        shape'')
`sub` Tensor usGradient usLayout usDevice usDataType usShape
replacedNextTokens

updateUnfinishedSequences ::
  ( Catch (ntDataType <+> 'DataType 'Int64),
    Catch (BroadcastShapesF ntShape ('Shape '[])),
    Catch usShape',
    SGetDataType usDataType,
    SGetDevice ntDevice,
    SGetLayout ntLayout,
    MonadThrow m,
    BroadcastShapesF usShape (BroadcastShapesF ntShape ('Shape '[])) ~ usShape',
    (usGradient <|> 'Gradient 'WithoutGradient) ~ usGradient',
    (usLayout <+> ntLayout) ~ usLayout',
    (usDevice <+> ntDevice) ~ usDevice'
  ) =>
  Int ->
  Tensor ntGradient ntLayout ntDevice ntDataType ntShape ->
  Tensor usGradient usLayout usDevice usDataType usShape ->
  m (Tensor usGradient' usLayout' usDevice' usDataType usShape')
updateUnfinishedSequences :: forall (ntDataType :: DataType DType)
       (ntShape :: Shape [Dim (Name Symbol) (Size Nat)])
       (usShape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (usDataType :: DataType DType)
       (ntDevice :: Device (DeviceType Nat))
       (ntLayout :: Layout LayoutType) (m :: * -> *)
       (usShape :: Shape [Dim (Name Symbol) (Size Nat)])
       (usGradient :: Gradient RequiresGradient)
       (usGradient' :: Gradient RequiresGradient)
       (usLayout :: Layout LayoutType) (usLayout' :: Layout LayoutType)
       (usDevice :: Device (DeviceType Nat))
       (usDevice' :: Device (DeviceType Nat))
       (ntGradient :: Gradient RequiresGradient).
(Catch (ntDataType <+> 'DataType 'Int64),
 Catch (BroadcastShapesF ntShape ('Shape '[])), Catch usShape',
 SGetDataType usDataType, SGetDevice ntDevice, SGetLayout ntLayout,
 MonadThrow m,
 BroadcastShapesF usShape (BroadcastShapesF ntShape ('Shape '[]))
 ~ usShape',
 (usGradient <|> 'Gradient 'WithoutGradient) ~ usGradient',
 (usLayout <+> ntLayout) ~ usLayout',
 (usDevice <+> ntDevice) ~ usDevice') =>
Int
-> Tensor ntGradient ntLayout ntDevice ntDataType ntShape
-> Tensor usGradient usLayout usDevice usDataType usShape
-> m (Tensor usGradient' usLayout' usDevice' usDataType usShape')
updateUnfinishedSequences Int
eosTokenId Tensor ntGradient ntLayout ntDevice ntDataType ntShape
nextTokens Tensor usGradient usLayout usDevice usDataType usShape
unfinishedSequences = do
  let gradient :: SGradient ('Gradient 'WithoutGradient)
gradient = forall (requiresGradient :: RequiresGradient).
SRequiresGradient requiresGradient
-> SGradient ('Gradient requiresGradient)
SGradient SRequiresGradient 'WithoutGradient
SWithoutGradient
      ntLayout :: SLayout ntLayout
ntLayout = forall (layout :: Layout LayoutType)
       (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGetLayout layout =>
Tensor gradient layout device dataType shape -> SLayout layout
sGetLayout Tensor ntGradient ntLayout ntDevice ntDataType ntShape
nextTokens
      ntDevice :: SDevice ntDevice
ntDevice = forall (device :: Device (DeviceType Nat))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGetDevice device =>
Tensor gradient layout device dataType shape -> SDevice device
sGetDevice Tensor ntGradient ntLayout ntDevice ntDataType ntShape
nextTokens
      usDataType :: SDataType usDataType
usDataType = forall (dataType :: DataType DType)
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SGetDataType dataType =>
Tensor gradient layout device dataType shape -> SDataType dataType
sGetDataType Tensor usGradient usLayout usDevice usDataType usShape
unfinishedSequences
  Tensor
  ('Gradient 'WithoutGradient)
  ntLayout
  ntDevice
  ('DataType 'Int64)
  ('Shape '[])
eosTokenId' <- forall a (dType :: DType) (dims :: [Dim (Name Symbol) (Size Nat)])
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (m :: * -> *).
(TensorLike a dType dims, MonadThrow m) =>
SGradient gradient
-> SLayout layout
-> SDevice device
-> a
-> m (Tensor
        gradient layout device ('DataType dType) ('Shape dims))
sToTensor SGradient ('Gradient 'WithoutGradient)
gradient SLayout ntLayout
ntLayout SDevice ntDevice
ntDevice Int
eosTokenId
  Tensor
  ('Gradient 'WithoutGradient)
  ntLayout
  ntDevice
  ('DataType 'Bool)
  (BroadcastShapesF ntShape ('Shape '[]))
isNotEos <- Tensor ntGradient ntLayout ntDevice ntDataType ntShape
nextTokens forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient' :: Gradient RequiresGradient)
       (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
       (dataType' :: DataType DType)
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, Catch (dataType <+> dataType'),
 shape'' ~ BroadcastShapesF shape shape', Catch shape'') =>
Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> m (Tensor
        ('Gradient 'WithoutGradient)
        (layout <+> layout')
        (device <+> device')
        ('DataType 'Bool)
        shape'')
/=. Tensor
  ('Gradient 'WithoutGradient)
  ntLayout
  ntDevice
  ('DataType 'Int64)
  ('Shape '[])
eosTokenId'
  Tensor
  ('Gradient 'WithoutGradient)
  ntLayout
  ntDevice
  usDataType
  (BroadcastShapesF ntShape ('Shape '[]))
isNotEos' <- forall (m :: * -> *) (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType) (dataType' :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
MonadThrow m =>
SDataType dataType
-> Tensor gradient layout device dataType' shape
-> m (Tensor gradient layout device dataType shape)
sSetDataType SDataType usDataType
usDataType Tensor
  ('Gradient 'WithoutGradient)
  ntLayout
  ntDevice
  ('DataType 'Bool)
  (BroadcastShapesF ntShape ('Shape '[]))
isNotEos
  Tensor usGradient usLayout usDevice usDataType usShape
unfinishedSequences forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient' :: Gradient RequiresGradient)
       (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
       (dataType' :: DataType DType)
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, shape'' ~ BroadcastShapesF shape shape',
 Catch shape'') =>
Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> m (Tensor
        (gradient <|> gradient')
        (layout <+> layout')
        (device <+> device')
        (dataType <+> dataType')
        shape'')
`mul` Tensor
  ('Gradient 'WithoutGradient)
  ntLayout
  ntDevice
  usDataType
  (BroadcastShapesF ntShape ('Shape '[]))
isNotEos'