{-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE PartialTypeSignatures #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} module Torch.GraduallyTyped.NN.Transformer.Generation where import Control.Lens (Lens) import Control.Monad.Catch (MonadThrow) import Control.Monad.State (MonadState (..)) import Data.Function (fix) import Foreign.ForeignPtr (ForeignPtr) import Torch.GraduallyTyped.DType (DType (..), DataType (..)) import Torch.GraduallyTyped.Index.Type (Index (NegativeIndex), SIndex (..)) import Torch.GraduallyTyped.NN.Transformer.GEncoderDecoder (SimplifiedEncoderDecoderTransformerGenerationInput (..), SimplifiedEncoderDecoderTransformerOutput (..)) import Torch.GraduallyTyped.Prelude (Catch, pattern (:|:)) import Torch.GraduallyTyped.Prelude.List (SList (SNil)) import Torch.GraduallyTyped.RequiresGradient (Gradient (..), RequiresGradient (..), SGradient (..), SRequiresGradient (..)) import Torch.GraduallyTyped.Shape.Class (BroadcastShapesF) import Torch.GraduallyTyped.Shape.Type (By (..), SBy (..), SSelectDim (..), SelectDim (..), Shape (..)) import Torch.GraduallyTyped.Tensor.Indexing (IndexDims, IndexType (..), Indices (..), SIndexType (..), SIndices (..), (!)) import Torch.GraduallyTyped.Tensor.IndexingSlicingJoining (CatHListF, HasCat (..), SqueezeDimF, UnsqueezeF, sSqueezeDim, sUnsqueeze) import Torch.GraduallyTyped.Tensor.MathOperations.Comparison ((/=.), (==.)) import Torch.GraduallyTyped.Tensor.MathOperations.Pointwise (mul, mulScalar, sub, subScalar) import Torch.GraduallyTyped.Tensor.MathOperations.Reduction (ArgmaxF, MaxAllCheckF, argmax, maxAll) import Torch.GraduallyTyped.Tensor.Type (SGetDataType (..), SGetDevice (..), SGetDim, SGetLayout (..), SGetShape (..), Tensor, TensorLike (..), sCheckedShape, sSetDataType) import Torch.GraduallyTyped.Unify (type (<+>), type (<|>)) import Torch.HList (HList (HNil), pattern (:.)) import qualified Torch.Internal.Class as ATen import qualified Torch.Internal.Type as ATen import Prelude hiding (all) decode :: Monad m => (x -> s -> m (Maybe (x, s))) -> x -> s -> m (x, s) decode :: forall (m :: * -> *) x s. Monad m => (x -> s -> m (Maybe (x, s))) -> x -> s -> m (x, s) decode x -> s -> m (Maybe (x, s)) f x x s s = do forall a b c. (a -> b -> c) -> b -> a -> c flip forall a. (a -> a) -> a fix (x x, s s) forall a b. (a -> b) -> a -> b $ \(x, s) -> m (x, s) loop (x x', s s') -> do Maybe (x, s) r <- x -> s -> m (Maybe (x, s)) f x x' s s' case Maybe (x, s) r of Maybe (x, s) Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a pure (x x', s s') Just (x x'', s s'') -> (x, s) -> m (x, s) loop (x x'', s s'') sedtOutputToInput :: Monad m => Lens (SimplifiedEncoderDecoderTransformerOutput logits encoderOutput decoderInput inputPaddingMask) (m (SimplifiedEncoderDecoderTransformerGenerationInput decoderInput' encoderOutput inputPaddingMask)) (logits, decoderInput) (m decoderInput') sedtOutputToInput :: forall (m :: * -> *) logits encoderOutput decoderInput inputPaddingMask decoderInput'. Monad m => Lens (SimplifiedEncoderDecoderTransformerOutput logits encoderOutput decoderInput inputPaddingMask) (m (SimplifiedEncoderDecoderTransformerGenerationInput decoderInput' encoderOutput inputPaddingMask)) (logits, decoderInput) (m decoderInput') sedtOutputToInput (logits, decoderInput) -> f (m decoderInput') f SimplifiedEncoderDecoderTransformerOutput {logits encoderOutput decoderInput inputPaddingMask sedtInputPaddingMask :: forall decoderOutput encoderOutput decoderInput inputPaddingMask. SimplifiedEncoderDecoderTransformerOutput decoderOutput encoderOutput decoderInput inputPaddingMask -> inputPaddingMask sedtOriginalDecoderInput :: forall decoderOutput encoderOutput decoderInput inputPaddingMask. SimplifiedEncoderDecoderTransformerOutput decoderOutput encoderOutput decoderInput inputPaddingMask -> decoderInput sedtEncoderOutput :: forall decoderOutput encoderOutput decoderInput inputPaddingMask. SimplifiedEncoderDecoderTransformerOutput decoderOutput encoderOutput decoderInput inputPaddingMask -> encoderOutput sedtDecoderOutput :: forall decoderOutput encoderOutput decoderInput inputPaddingMask. SimplifiedEncoderDecoderTransformerOutput decoderOutput encoderOutput decoderInput inputPaddingMask -> decoderOutput sedtInputPaddingMask :: inputPaddingMask sedtOriginalDecoderInput :: decoderInput sedtEncoderOutput :: encoderOutput sedtDecoderOutput :: logits ..} = ( \m decoderInput' decoderInput' -> forall decoderInput encoderOutput inputPaddingMask. decoderInput -> encoderOutput -> inputPaddingMask -> SimplifiedEncoderDecoderTransformerGenerationInput decoderInput encoderOutput inputPaddingMask SimplifiedEncoderDecoderTransformerGenerationInput forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b <$> m decoderInput' decoderInput' forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b <*> forall (f :: * -> *) a. Applicative f => a -> f a pure encoderOutput sedtEncoderOutput forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b <*> forall (f :: * -> *) a. Applicative f => a -> f a pure inputPaddingMask sedtInputPaddingMask ) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b <$> (logits, decoderInput) -> f (m decoderInput') f (logits sedtDecoderOutput, decoderInput sedtOriginalDecoderInput) prepNext :: ( logits ~ Tensor logitsGradient logitsLayout logitsDevice logitsDataType logitsShape, ntShape' ~ UnsqueezeF ('SelectDim ('ByIndex 1)) ntShape, Catch ntShape', tensors ~ '[decoderInput, Tensor ntGradient ntLayout ntDevice ntDataType ntShape'], decoderInput' ~ CatHListF ('SelectDim ('ByIndex 1)) tensors, ATen.Castable decoderInput' (ForeignPtr ATen.Tensor), ATen.Castable (HList tensors) (ForeignPtr ATen.TensorList), MonadThrow m ) => Lens (logits, decoderInput) (m decoderInput') logits (m (Tensor ntGradient ntLayout ntDevice ntDataType ntShape)) prepNext :: forall logits (logitsGradient :: Gradient RequiresGradient) (logitsLayout :: Layout LayoutType) (logitsDevice :: Device (DeviceType Nat)) (logitsDataType :: DataType DType) (logitsShape :: Shape [Dim (Name Symbol) (Size Nat)]) (ntShape' :: Shape [Dim (Name Symbol) (Size Nat)]) (ntShape :: Shape [Dim (Name Symbol) (Size Nat)]) (tensors :: [*]) decoderInput (ntGradient :: Gradient RequiresGradient) (ntLayout :: Layout LayoutType) (ntDevice :: Device (DeviceType Nat)) (ntDataType :: DataType DType) decoderInput' (m :: * -> *). (logits ~ Tensor logitsGradient logitsLayout logitsDevice logitsDataType logitsShape, ntShape' ~ UnsqueezeF ('SelectDim ('ByIndex 1)) ntShape, Catch ntShape', tensors ~ '[decoderInput, Tensor ntGradient ntLayout ntDevice ntDataType ntShape'], decoderInput' ~ CatHListF ('SelectDim ('ByIndex 1)) tensors, Castable decoderInput' (ForeignPtr Tensor), Castable (HList tensors) (ForeignPtr TensorList), MonadThrow m) => Lens (logits, decoderInput) (m decoderInput') logits (m (Tensor ntGradient ntLayout ntDevice ntDataType ntShape)) prepNext logits -> f (m (Tensor ntGradient ntLayout ntDevice ntDataType ntShape)) f (logits logits, decoderInput decoderInput) = ( \m (Tensor ntGradient ntLayout ntDevice ntDataType ntShape) nextTokens -> do Tensor ntGradient ntLayout ntDevice ntDataType ntShape nextTokens' <- m (Tensor ntGradient ntLayout ntDevice ntDataType ntShape) nextTokens Tensor ntGradient ntLayout ntDevice ntDataType ntShape' nextTokens'' <- forall (selectDim :: SelectDim (By Symbol Nat)) (gradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *). (shape' ~ UnsqueezeF selectDim shape, Catch shape', MonadThrow m) => SSelectDim selectDim -> Tensor gradient layout device dataType shape -> m (Tensor gradient layout device dataType shape') sUnsqueeze (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by) SSelectDim forall a b. (a -> b) -> a -> b $ forall (index :: Nat). KnownNat index => SBy ('ByIndex index) SByIndex @1) Tensor ntGradient ntLayout ntDevice ntDataType ntShape nextTokens' forall (selectDim :: SelectDim (By Symbol Nat)) k (c :: k -> *) (a :: k) (m :: * -> *). (HasCat selectDim k c a, MonadThrow m) => SSelectDim selectDim -> c a -> m (CatF selectDim a c) sCat (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by) SSelectDim forall a b. (a -> b) -> a -> b $ forall (index :: Nat). KnownNat index => SBy ('ByIndex index) SByIndex @1) (decoderInput decoderInput forall x (xs :: [*]). x -> HList xs -> HList (x : xs) :. Tensor ntGradient ntLayout ntDevice ntDataType ntShape' nextTokens'' forall x (xs :: [*]). x -> HList xs -> HList (x : xs) :. forall k. HList '[] HNil) ) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b <$> logits -> f (m (Tensor ntGradient ntLayout ntDevice ntDataType ntShape)) f logits logits greedyNextTokens :: ( nextTokenLogitsShape ~ IndexDims ('Indices '[ 'SliceAll, 'SliceAt ('NegativeIndex 1), 'SliceAll]) logitsShape, nextTokensShape ~ ArgmaxF ('SelectDim ('ByIndex 1)) nextTokenLogitsShape, Catch nextTokensShape, nextTokensShape' ~ SqueezeDimF ('SelectDim ('ByIndex 1)) nextTokensShape, ntShape ~ 'Shape '[ntDim], Catch (nextTokensShape' <+> ntShape), SGetShape nextTokensShape', SGetDim ntDim, Catch ntDim, Catch nextTokensShape', MonadThrow m, MonadState (Tensor ('Gradient 'WithoutGradient) logitsLayout logitsDevice ('DataType 'Int64) ntShape) m, SGetDevice logitsDevice, SGetLayout logitsLayout ) => Int -> Int -> Tensor logitsGradient logitsLayout logitsDevice logitsDataType logitsShape -> m (Tensor ('Gradient 'WithoutGradient) logitsLayout logitsDevice ('DataType 'Int64) ntShape) greedyNextTokens :: forall (nextTokenLogitsShape :: Shape [Dim (Name Symbol) (Size Nat)]) (logitsShape :: Shape [Dim (Name Symbol) (Size Nat)]) (nextTokensShape :: Shape [Dim (Name Symbol) (Size Nat)]) (nextTokensShape' :: Shape [Dim (Name Symbol) (Size Nat)]) (ntShape :: Shape [Dim (Name Symbol) (Size Nat)]) (ntDim :: Dim (Name Symbol) (Size Nat)) (m :: * -> *) (logitsLayout :: Layout LayoutType) (logitsDevice :: Device (DeviceType Nat)) (logitsGradient :: Gradient RequiresGradient) (logitsDataType :: DataType DType). (nextTokenLogitsShape ~ IndexDims ('Indices '[ 'SliceAll, 'SliceAt ('NegativeIndex 1), 'SliceAll]) logitsShape, nextTokensShape ~ ArgmaxF ('SelectDim ('ByIndex 1)) nextTokenLogitsShape, Catch nextTokensShape, nextTokensShape' ~ SqueezeDimF ('SelectDim ('ByIndex 1)) nextTokensShape, ntShape ~ 'Shape '[ntDim], Catch (nextTokensShape' <+> ntShape), SGetShape nextTokensShape', SGetDim ntDim, Catch ntDim, Catch nextTokensShape', MonadThrow m, MonadState (Tensor ('Gradient 'WithoutGradient) logitsLayout logitsDevice ('DataType 'Int64) ntShape) m, SGetDevice logitsDevice, SGetLayout logitsLayout) => Int -> Int -> Tensor logitsGradient logitsLayout logitsDevice logitsDataType logitsShape -> m (Tensor ('Gradient 'WithoutGradient) logitsLayout logitsDevice ('DataType 'Int64) ntShape) greedyNextTokens Int padTokenId Int eosTokenId Tensor logitsGradient logitsLayout logitsDevice logitsDataType logitsShape logits = do Tensor logitsGradient logitsLayout logitsDevice logitsDataType (IndexDims ('Indices '[ 'SliceAll, 'SliceAt ('NegativeIndex 1), 'SliceAll]) logitsShape) nextTokenLogits <- Tensor logitsGradient logitsLayout logitsDevice logitsDataType logitsShape logits forall (indices :: Indices [IndexType (Index Nat)]) (requiresGradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *). MonadThrow m => Tensor requiresGradient layout device dataType shape -> SIndices indices -> m (Tensor requiresGradient layout device dataType (IndexDims indices shape)) ! forall (indexTypes :: [IndexType (Index Nat)]). SList indexTypes -> SIndices ('Indices indexTypes) SIndices (forall a. SIndexType 'SliceAll SSliceAll forall {k} (a :: k) (as :: [k]). Sing a -> SList as -> SList (a : as) :|: forall a (n :: a). Sing n -> SIndexType ('SliceAt n) SSliceAt (forall (index1 :: Nat). KnownNat index1 => SIndex ('NegativeIndex index1) SNegativeIndex @1) forall {k} (a :: k) (as :: [k]). Sing a -> SList as -> SList (a : as) :|: forall a. SIndexType 'SliceAll SSliceAll forall {k} (a :: k) (as :: [k]). Sing a -> SList as -> SList (a : as) :|: forall a. SList '[] SNil) Tensor ('Gradient 'WithoutGradient) logitsLayout logitsDevice ('DataType 'Int64) nextTokensShape nextTokens <- forall (selectDims :: SelectDim (By Symbol Nat)) (gradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *). (MonadThrow m, shape' ~ ArgmaxF selectDims shape, Catch shape') => SSelectDim selectDims -> Tensor gradient layout device dataType shape -> m (Tensor ('Gradient 'WithoutGradient) layout device ('DataType 'Int64) shape') argmax (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by) SSelectDim forall a b. (a -> b) -> a -> b $ forall (index :: Nat). KnownNat index => SBy ('ByIndex index) SByIndex @1) Tensor logitsGradient logitsLayout logitsDevice logitsDataType (IndexDims ('Indices '[ 'SliceAll, 'SliceAt ('NegativeIndex 1), 'SliceAll]) logitsShape) nextTokenLogits Tensor ('Gradient 'WithoutGradient) logitsLayout logitsDevice ('DataType 'Int64) nextTokensShape' nextTokens' <- forall (selectDim :: SelectDim (By Symbol Nat)) (gradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *). (MonadThrow m, shape' ~ SqueezeDimF selectDim shape, Catch shape') => SSelectDim selectDim -> Tensor gradient layout device dataType shape -> m (Tensor gradient layout device dataType shape') sSqueezeDim (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by) SSelectDim forall a b. (a -> b) -> a -> b $ forall (index :: Nat). KnownNat index => SBy ('ByIndex index) SByIndex @1) Tensor ('Gradient 'WithoutGradient) logitsLayout logitsDevice ('DataType 'Int64) nextTokensShape nextTokens Tensor ('Gradient 'WithoutGradient) logitsLayout logitsDevice ('DataType 'Int64) ('Shape '[ntDim]) unfinishedSequences <- forall s (m :: * -> *). MonadState s m => m s get let usShape :: SShape ('Shape '[ntDim]) usShape = forall (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (gradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) (dataType :: DataType DType). SGetShape shape => Tensor gradient layout device dataType shape -> SShape shape sGetShape Tensor ('Gradient 'WithoutGradient) logitsLayout logitsDevice ('DataType 'Int64) ('Shape '[ntDim]) unfinishedSequences Tensor ('Gradient 'WithoutGradient) logitsLayout logitsDevice ('DataType 'Int64) ('Shape '[ntDim]) nextTokens'' <- forall (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *) (gradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (shape :: Shape [Dim (Name Symbol) (Size Nat)]). (SGetShape shape, MonadThrow m, Catch (shape <+> shape')) => SShape shape' -> Tensor gradient layout device dataType shape -> m (Tensor gradient layout device dataType shape') sCheckedShape SShape ('Shape '[ntDim]) usShape Tensor ('Gradient 'WithoutGradient) logitsLayout logitsDevice ('DataType 'Int64) nextTokensShape' nextTokens' Tensor ('Gradient 'WithoutGradient) logitsLayout logitsDevice ('DataType 'Int64) ntShape nextTokens''' <- forall (m :: * -> *) (kntDataType :: DataType DType) (usDataType :: DataType DType) (ntDataType :: DataType DType) (kntShape :: Shape [Dim (Name Symbol) (Size Nat)]) (usShape :: Shape [Dim (Name Symbol) (Size Nat)]) (ntShape :: Shape [Dim (Name Symbol) (Size Nat)]) (ntGradient' :: Gradient RequiresGradient) (usGradient :: Gradient RequiresGradient) (ntGradient :: Gradient RequiresGradient) (ntLayout' :: Layout LayoutType) (usLayout :: Layout LayoutType) (ntLayout :: Layout LayoutType) (ntDevice' :: Device (DeviceType Nat)) (usDevice :: Device (DeviceType Nat)) (ntDevice :: Device (DeviceType Nat)) (ntDataType' :: DataType DType) (ntShape' :: Shape [Dim (Name Symbol) (Size Nat)]). (MonadThrow m, kntDataType ~ (usDataType <+> ntDataType), kntShape ~ BroadcastShapesF usShape ntShape, Catch kntShape, ntGradient' ~ (usGradient <|> ntGradient), ntLayout' ~ ((usLayout <+> ntLayout) <+> usLayout), ntDevice' ~ ((usDevice <+> ntDevice) <+> usDevice), ntDataType' ~ ((usDataType <+> ntDataType) <+> usDataType), ntShape' ~ BroadcastShapesF kntShape usShape, Catch ntShape') => Int -> Tensor usGradient usLayout usDevice usDataType usShape -> Tensor ntGradient ntLayout ntDevice ntDataType ntShape -> m (Tensor ntGradient' ntLayout' ntDevice' ntDataType' ntShape') applyUnfinishedSequences Int padTokenId Tensor ('Gradient 'WithoutGradient) logitsLayout logitsDevice ('DataType 'Int64) ('Shape '[ntDim]) unfinishedSequences Tensor ('Gradient 'WithoutGradient) logitsLayout logitsDevice ('DataType 'Int64) ('Shape '[ntDim]) nextTokens'' Tensor (Or (Gradient RequiresGradient) ('Gradient 'WithoutGradient) ('Gradient 'WithoutGradient)) (Unify (Layout LayoutType) logitsLayout logitsLayout) (Unify (Device (DeviceType Nat)) logitsDevice logitsDevice) ('DataType 'Int64) ('Shape '[ntDim]) unfinishedSequences' <- forall (ntDataType :: DataType DType) (ntShape :: Shape [Dim (Name Symbol) (Size Nat)]) (usShape' :: Shape [Dim (Name Symbol) (Size Nat)]) (usDataType :: DataType DType) (ntDevice :: Device (DeviceType Nat)) (ntLayout :: Layout LayoutType) (m :: * -> *) (usShape :: Shape [Dim (Name Symbol) (Size Nat)]) (usGradient :: Gradient RequiresGradient) (usGradient' :: Gradient RequiresGradient) (usLayout :: Layout LayoutType) (usLayout' :: Layout LayoutType) (usDevice :: Device (DeviceType Nat)) (usDevice' :: Device (DeviceType Nat)) (ntGradient :: Gradient RequiresGradient). (Catch (ntDataType <+> 'DataType 'Int64), Catch (BroadcastShapesF ntShape ('Shape '[])), Catch usShape', SGetDataType usDataType, SGetDevice ntDevice, SGetLayout ntLayout, MonadThrow m, BroadcastShapesF usShape (BroadcastShapesF ntShape ('Shape '[])) ~ usShape', (usGradient <|> 'Gradient 'WithoutGradient) ~ usGradient', (usLayout <+> ntLayout) ~ usLayout', (usDevice <+> ntDevice) ~ usDevice') => Int -> Tensor ntGradient ntLayout ntDevice ntDataType ntShape -> Tensor usGradient usLayout usDevice usDataType usShape -> m (Tensor usGradient' usLayout' usDevice' usDataType usShape') updateUnfinishedSequences Int eosTokenId Tensor ('Gradient 'WithoutGradient) logitsLayout logitsDevice ('DataType 'Int64) ntShape nextTokens''' Tensor ('Gradient 'WithoutGradient) logitsLayout logitsDevice ('DataType 'Int64) ('Shape '[ntDim]) unfinishedSequences forall s (m :: * -> *). MonadState s m => s -> m () put Tensor (Or (Gradient RequiresGradient) ('Gradient 'WithoutGradient) ('Gradient 'WithoutGradient)) (Unify (Layout LayoutType) logitsLayout logitsLayout) (Unify (Device (DeviceType Nat)) logitsDevice logitsDevice) ('DataType 'Int64) ('Shape '[ntDim]) unfinishedSequences' forall (f :: * -> *) a. Applicative f => a -> f a pure Tensor ('Gradient 'WithoutGradient) logitsLayout logitsDevice ('DataType 'Int64) ntShape nextTokens''' allSequencesFinished :: ( SGetLayout usLayout, SGetDevice usDevice, MonadThrow m, Catch (usDataType <+> 'DataType 'Int64), Catch (BroadcastShapesF usShape ('Shape '[])), MaxAllCheckF usShape ) => Tensor usGradient usLayout usDevice usDataType usShape -> m Bool allSequencesFinished :: forall (usLayout :: Layout LayoutType) (usDevice :: Device (DeviceType Nat)) (m :: * -> *) (usDataType :: DataType DType) (usShape :: Shape [Dim (Name Symbol) (Size Nat)]) (usGradient :: Gradient RequiresGradient). (SGetLayout usLayout, SGetDevice usDevice, MonadThrow m, Catch (usDataType <+> 'DataType 'Int64), Catch (BroadcastShapesF usShape ('Shape '[])), MaxAllCheckF usShape) => Tensor usGradient usLayout usDevice usDataType usShape -> m Bool allSequencesFinished Tensor usGradient usLayout usDevice usDataType usShape unfinishedSequences = do let gradient :: SGradient ('Gradient 'WithoutGradient) gradient = forall (requiresGradient :: RequiresGradient). SRequiresGradient requiresGradient -> SGradient ('Gradient requiresGradient) SGradient SRequiresGradient 'WithoutGradient SWithoutGradient layout :: SLayout usLayout layout = forall (layout :: Layout LayoutType) (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (shape :: Shape [Dim (Name Symbol) (Size Nat)]). SGetLayout layout => Tensor gradient layout device dataType shape -> SLayout layout sGetLayout Tensor usGradient usLayout usDevice usDataType usShape unfinishedSequences device :: SDevice usDevice device = forall (device :: Device (DeviceType Nat)) (gradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (dataType :: DataType DType) (shape :: Shape [Dim (Name Symbol) (Size Nat)]). SGetDevice device => Tensor gradient layout device dataType shape -> SDevice device sGetDevice Tensor usGradient usLayout usDevice usDataType usShape unfinishedSequences Tensor ('Gradient 'WithoutGradient) usLayout usDevice ('DataType 'Int64) ('Shape '[]) zero <- forall a (dType :: DType) (dims :: [Dim (Name Symbol) (Size Nat)]) (gradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) (m :: * -> *). (TensorLike a dType dims, MonadThrow m) => SGradient gradient -> SLayout layout -> SDevice device -> a -> m (Tensor gradient layout device ('DataType dType) ('Shape dims)) sToTensor SGradient ('Gradient 'WithoutGradient) gradient SLayout usLayout layout SDevice usDevice device (Int 0 :: Int) Tensor ('Gradient 'WithoutGradient) usLayout usDevice ('DataType 'Bool) ('Shape '[]) isZero <- forall (gradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (shape :: Shape [Dim (Name Symbol) (Size Nat)]). MaxAllCheckF shape => Tensor gradient layout device dataType shape -> Tensor gradient layout device dataType ('Shape '[]) maxAll Tensor usGradient usLayout usDevice usDataType usShape unfinishedSequences forall (gradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (gradient' :: Gradient RequiresGradient) (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat)) (dataType' :: DataType DType) (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *). (MonadThrow m, Catch (dataType <+> dataType'), shape'' ~ BroadcastShapesF shape shape', Catch shape'') => Tensor gradient layout device dataType shape -> Tensor gradient' layout' device' dataType' shape' -> m (Tensor ('Gradient 'WithoutGradient) (layout <+> layout') (device <+> device') ('DataType 'Bool) shape'') ==. Tensor ('Gradient 'WithoutGradient) usLayout usDevice ('DataType 'Int64) ('Shape '[]) zero forall (f :: * -> *) a. Applicative f => a -> f a pure forall a b. (a -> b) -> a -> b $ forall a (dType :: DType) (dims :: [Dim (Name Symbol) (Size Nat)]) (gradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)). TensorLike a dType dims => Tensor gradient layout device ('DataType dType) ('Shape dims) -> a fromTensor Tensor ('Gradient 'WithoutGradient) usLayout usDevice ('DataType 'Bool) ('Shape '[]) isZero applyUnfinishedSequences :: ( MonadThrow m, kntDataType ~ (usDataType <+> ntDataType), kntShape ~ BroadcastShapesF usShape ntShape, Catch kntShape, ntGradient' ~ (usGradient <|> ntGradient), ntLayout' ~ ((usLayout <+> ntLayout) <+> usLayout), ntDevice' ~ ((usDevice <+> ntDevice) <+> usDevice), ntDataType' ~ ((usDataType <+> ntDataType) <+> usDataType), ntShape' ~ BroadcastShapesF kntShape usShape, Catch ntShape' ) => Int -> Tensor usGradient usLayout usDevice usDataType usShape -> Tensor ntGradient ntLayout ntDevice ntDataType ntShape -> m (Tensor ntGradient' ntLayout' ntDevice' ntDataType' ntShape') applyUnfinishedSequences :: forall (m :: * -> *) (kntDataType :: DataType DType) (usDataType :: DataType DType) (ntDataType :: DataType DType) (kntShape :: Shape [Dim (Name Symbol) (Size Nat)]) (usShape :: Shape [Dim (Name Symbol) (Size Nat)]) (ntShape :: Shape [Dim (Name Symbol) (Size Nat)]) (ntGradient' :: Gradient RequiresGradient) (usGradient :: Gradient RequiresGradient) (ntGradient :: Gradient RequiresGradient) (ntLayout' :: Layout LayoutType) (usLayout :: Layout LayoutType) (ntLayout :: Layout LayoutType) (ntDevice' :: Device (DeviceType Nat)) (usDevice :: Device (DeviceType Nat)) (ntDevice :: Device (DeviceType Nat)) (ntDataType' :: DataType DType) (ntShape' :: Shape [Dim (Name Symbol) (Size Nat)]). (MonadThrow m, kntDataType ~ (usDataType <+> ntDataType), kntShape ~ BroadcastShapesF usShape ntShape, Catch kntShape, ntGradient' ~ (usGradient <|> ntGradient), ntLayout' ~ ((usLayout <+> ntLayout) <+> usLayout), ntDevice' ~ ((usDevice <+> ntDevice) <+> usDevice), ntDataType' ~ ((usDataType <+> ntDataType) <+> usDataType), ntShape' ~ BroadcastShapesF kntShape usShape, Catch ntShape') => Int -> Tensor usGradient usLayout usDevice usDataType usShape -> Tensor ntGradient ntLayout ntDevice ntDataType ntShape -> m (Tensor ntGradient' ntLayout' ntDevice' ntDataType' ntShape') applyUnfinishedSequences Int padTokenId Tensor usGradient usLayout usDevice usDataType usShape unfinishedSequences Tensor ntGradient ntLayout ntDevice ntDataType ntShape nextTokens = do Tensor ntGradient' (Unify (Layout LayoutType) usLayout ntLayout) (Unify (Device (DeviceType Nat)) usDevice ntDevice) kntDataType kntShape keptNextTokens <- Tensor usGradient usLayout usDevice usDataType usShape unfinishedSequences forall (gradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (gradient' :: Gradient RequiresGradient) (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat)) (dataType' :: DataType DType) (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *). (MonadThrow m, shape'' ~ BroadcastShapesF shape shape', Catch shape'') => Tensor gradient layout device dataType shape -> Tensor gradient' layout' device' dataType' shape' -> m (Tensor (gradient <|> gradient') (layout <+> layout') (device <+> device') (dataType <+> dataType') shape'') `mul` Tensor ntGradient ntLayout ntDevice ntDataType ntShape nextTokens Tensor usGradient usLayout usDevice usDataType usShape replacedNextTokens <- do Tensor usGradient usLayout usDevice usDataType usShape finishedSequences <- Tensor usGradient usLayout usDevice usDataType usShape unfinishedSequences forall other (gradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *). (Scalar other, MonadThrow m) => Tensor gradient layout device dataType shape -> other -> m (Tensor gradient layout device dataType shape) `subScalar` (Int 1 :: Int) Tensor usGradient usLayout usDevice usDataType usShape finishedSequences forall other (gradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *). (Scalar other, MonadThrow m) => Tensor gradient layout device dataType shape -> other -> m (Tensor gradient layout device dataType shape) `mulScalar` Int padTokenId Tensor ntGradient' (Unify (Layout LayoutType) usLayout ntLayout) (Unify (Device (DeviceType Nat)) usDevice ntDevice) kntDataType kntShape keptNextTokens forall (gradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (gradient' :: Gradient RequiresGradient) (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat)) (dataType' :: DataType DType) (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *). (MonadThrow m, shape'' ~ BroadcastShapesF shape shape', Catch shape'') => Tensor gradient layout device dataType shape -> Tensor gradient' layout' device' dataType' shape' -> m (Tensor (gradient <|> gradient) (layout <+> layout') (device <+> device') (dataType <+> dataType') shape'') `sub` Tensor usGradient usLayout usDevice usDataType usShape replacedNextTokens updateUnfinishedSequences :: ( Catch (ntDataType <+> 'DataType 'Int64), Catch (BroadcastShapesF ntShape ('Shape '[])), Catch usShape', SGetDataType usDataType, SGetDevice ntDevice, SGetLayout ntLayout, MonadThrow m, BroadcastShapesF usShape (BroadcastShapesF ntShape ('Shape '[])) ~ usShape', (usGradient <|> 'Gradient 'WithoutGradient) ~ usGradient', (usLayout <+> ntLayout) ~ usLayout', (usDevice <+> ntDevice) ~ usDevice' ) => Int -> Tensor ntGradient ntLayout ntDevice ntDataType ntShape -> Tensor usGradient usLayout usDevice usDataType usShape -> m (Tensor usGradient' usLayout' usDevice' usDataType usShape') updateUnfinishedSequences :: forall (ntDataType :: DataType DType) (ntShape :: Shape [Dim (Name Symbol) (Size Nat)]) (usShape' :: Shape [Dim (Name Symbol) (Size Nat)]) (usDataType :: DataType DType) (ntDevice :: Device (DeviceType Nat)) (ntLayout :: Layout LayoutType) (m :: * -> *) (usShape :: Shape [Dim (Name Symbol) (Size Nat)]) (usGradient :: Gradient RequiresGradient) (usGradient' :: Gradient RequiresGradient) (usLayout :: Layout LayoutType) (usLayout' :: Layout LayoutType) (usDevice :: Device (DeviceType Nat)) (usDevice' :: Device (DeviceType Nat)) (ntGradient :: Gradient RequiresGradient). (Catch (ntDataType <+> 'DataType 'Int64), Catch (BroadcastShapesF ntShape ('Shape '[])), Catch usShape', SGetDataType usDataType, SGetDevice ntDevice, SGetLayout ntLayout, MonadThrow m, BroadcastShapesF usShape (BroadcastShapesF ntShape ('Shape '[])) ~ usShape', (usGradient <|> 'Gradient 'WithoutGradient) ~ usGradient', (usLayout <+> ntLayout) ~ usLayout', (usDevice <+> ntDevice) ~ usDevice') => Int -> Tensor ntGradient ntLayout ntDevice ntDataType ntShape -> Tensor usGradient usLayout usDevice usDataType usShape -> m (Tensor usGradient' usLayout' usDevice' usDataType usShape') updateUnfinishedSequences Int eosTokenId Tensor ntGradient ntLayout ntDevice ntDataType ntShape nextTokens Tensor usGradient usLayout usDevice usDataType usShape unfinishedSequences = do let gradient :: SGradient ('Gradient 'WithoutGradient) gradient = forall (requiresGradient :: RequiresGradient). SRequiresGradient requiresGradient -> SGradient ('Gradient requiresGradient) SGradient SRequiresGradient 'WithoutGradient SWithoutGradient ntLayout :: SLayout ntLayout ntLayout = forall (layout :: Layout LayoutType) (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (shape :: Shape [Dim (Name Symbol) (Size Nat)]). SGetLayout layout => Tensor gradient layout device dataType shape -> SLayout layout sGetLayout Tensor ntGradient ntLayout ntDevice ntDataType ntShape nextTokens ntDevice :: SDevice ntDevice ntDevice = forall (device :: Device (DeviceType Nat)) (gradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (dataType :: DataType DType) (shape :: Shape [Dim (Name Symbol) (Size Nat)]). SGetDevice device => Tensor gradient layout device dataType shape -> SDevice device sGetDevice Tensor ntGradient ntLayout ntDevice ntDataType ntShape nextTokens usDataType :: SDataType usDataType usDataType = forall (dataType :: DataType DType) (gradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) (shape :: Shape [Dim (Name Symbol) (Size Nat)]). SGetDataType dataType => Tensor gradient layout device dataType shape -> SDataType dataType sGetDataType Tensor usGradient usLayout usDevice usDataType usShape unfinishedSequences Tensor ('Gradient 'WithoutGradient) ntLayout ntDevice ('DataType 'Int64) ('Shape '[]) eosTokenId' <- forall a (dType :: DType) (dims :: [Dim (Name Symbol) (Size Nat)]) (gradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) (m :: * -> *). (TensorLike a dType dims, MonadThrow m) => SGradient gradient -> SLayout layout -> SDevice device -> a -> m (Tensor gradient layout device ('DataType dType) ('Shape dims)) sToTensor SGradient ('Gradient 'WithoutGradient) gradient SLayout ntLayout ntLayout SDevice ntDevice ntDevice Int eosTokenId Tensor ('Gradient 'WithoutGradient) ntLayout ntDevice ('DataType 'Bool) (BroadcastShapesF ntShape ('Shape '[])) isNotEos <- Tensor ntGradient ntLayout ntDevice ntDataType ntShape nextTokens forall (gradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (gradient' :: Gradient RequiresGradient) (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat)) (dataType' :: DataType DType) (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *). (MonadThrow m, Catch (dataType <+> dataType'), shape'' ~ BroadcastShapesF shape shape', Catch shape'') => Tensor gradient layout device dataType shape -> Tensor gradient' layout' device' dataType' shape' -> m (Tensor ('Gradient 'WithoutGradient) (layout <+> layout') (device <+> device') ('DataType 'Bool) shape'') /=. Tensor ('Gradient 'WithoutGradient) ntLayout ntDevice ('DataType 'Int64) ('Shape '[]) eosTokenId' Tensor ('Gradient 'WithoutGradient) ntLayout ntDevice usDataType (BroadcastShapesF ntShape ('Shape '[])) isNotEos' <- forall (m :: * -> *) (gradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (dataType' :: DataType DType) (shape :: Shape [Dim (Name Symbol) (Size Nat)]). MonadThrow m => SDataType dataType -> Tensor gradient layout device dataType' shape -> m (Tensor gradient layout device dataType shape) sSetDataType SDataType usDataType usDataType Tensor ('Gradient 'WithoutGradient) ntLayout ntDevice ('DataType 'Bool) (BroadcastShapesF ntShape ('Shape '[])) isNotEos Tensor usGradient usLayout usDevice usDataType usShape unfinishedSequences forall (gradient :: Gradient RequiresGradient) (layout :: Layout LayoutType) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (shape :: Shape [Dim (Name Symbol) (Size Nat)]) (gradient' :: Gradient RequiresGradient) (layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat)) (dataType' :: DataType DType) (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *). (MonadThrow m, shape'' ~ BroadcastShapesF shape shape', Catch shape'') => Tensor gradient layout device dataType shape -> Tensor gradient' layout' device' dataType' shape' -> m (Tensor (gradient <|> gradient') (layout <+> layout') (device <+> device') (dataType <+> dataType') shape'') `mul` Tensor ('Gradient 'WithoutGradient) ntLayout ntDevice usDataType (BroadcastShapesF ntShape ('Shape '[])) isNotEos'