{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-}
{-# OPTIONS_GHC -fno-warn-partial-type-signatures #-}

module Torch.GraduallyTyped.NN.Transformer.T5.Generation where

import Control.Applicative (Alternative (..))
import Control.Monad (MonadPlus (..), guard)
import Control.Monad.Catch (MonadThrow)
import Control.Monad.State (MonadState (..), MonadTrans (..), StateT (..), evalStateT, gets, lift, modify)
import Control.Monad.Trans.Free (FreeF (..), FreeT (..), runFreeT)
import Data.Foldable (asum)
import Data.List (isInfixOf, nub, sortOn, uncons)
import qualified Data.Map as Map (Map, lookup)
import System.IO.Unsafe (unsafePerformIO)
import Text.Parser.Char (CharParsing (..), spaces)
import Text.Parser.Combinators (Parsing (..), between, manyTill)
import Text.Parser.Token (TokenParsing (..))
import Torch.Data.Parser (Parser, combine, isNotToken, isString, isToken, recurse, satisfy, scan, token)
import Torch.GraduallyTyped.DType (DType (..), DataType (..))
import Torch.GraduallyTyped.Device (Device (..), DeviceType (..), SDevice (..), SDeviceType (..))
import Torch.GraduallyTyped.Layout (Layout (..), LayoutType (..))
import Torch.GraduallyTyped.NN.Class (HasForward (..), HasStateDict (fromStateDict), stateDictFromFile)
import Torch.GraduallyTyped.NN.Functional.NonLinearActivation (logSoftmax)
import Torch.GraduallyTyped.NN.Transformer.GEncoderDecoder (SimplifiedEncoderDecoderTransformerGenerationInput (..), SimplifiedEncoderDecoderTransformerInput (..), SimplifiedEncoderDecoderTransformerOutput (..))
import Torch.GraduallyTyped.NN.Transformer.T5.Common (T5DataType, mkT5Input, t5EOSTokenId)
import Torch.GraduallyTyped.NN.Transformer.T5.Small (t5SmallSpec)
import Torch.GraduallyTyped.NN.Transformer.Type (STransformerHead (SWithLMHead))
import Torch.GraduallyTyped.NN.Type (SHasDropout (SWithDropout))
import Torch.GraduallyTyped.Random (Generator, sMkGenerator)
import Torch.GraduallyTyped.RequiresGradient (Gradient (..), RequiresGradient (..), SGradient (..), SRequiresGradient (..))
import Torch.GraduallyTyped.Shape.Class (BroadcastShapesF)
import Torch.GraduallyTyped.Shape.Type (By (..), Dim (..), Name (..), SBy (..), SName (..), SSelectDim (..), SShape (..), SSize (..), SelectDim (..), Shape (..), Size (..), pattern (:&:))
import Torch.GraduallyTyped.Tensor.IndexingSlicingJoining (sExpand)
import Torch.GraduallyTyped.Tensor.MathOperations.Comparison (Order (..), Sorted (..), sort)
import Torch.GraduallyTyped.Tensor.Type (SGetShape, Tensor (..))
import Torch.Language.SpiderSQL (SpiderSQL, spiderSQL)
import qualified Torch.Tensor
import Prelude hiding (Word, words)

data IsFinished = Finished | Unfinished

data Beams a b where
  Beams ::
    forall a b.
    { forall a b. Beams a b -> [Hypothesis 'Finished a b]
finished :: [Hypothesis 'Finished a b],
      forall a b. Beams a b -> [Hypothesis 'Unfinished a b]
unfinished :: [Hypothesis 'Unfinished a b]
    } ->
    Beams a b
  deriving (Int -> Beams a b -> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall a b. (Show a, Show b) => Int -> Beams a b -> ShowS
forall a b. (Show a, Show b) => [Beams a b] -> ShowS
forall a b. (Show a, Show b) => Beams a b -> String
showList :: [Beams a b] -> ShowS
$cshowList :: forall a b. (Show a, Show b) => [Beams a b] -> ShowS
show :: Beams a b -> String
$cshow :: forall a b. (Show a, Show b) => Beams a b -> String
showsPrec :: Int -> Beams a b -> ShowS
$cshowsPrec :: forall a b. (Show a, Show b) => Int -> Beams a b -> ShowS
Show)

data Hypothesis (isFinished :: IsFinished) a b where
  InitialHypothesis ::
    forall a b.
    Hypothesis 'Unfinished a b
  UnfinishedHypothesis ::
    forall a b.
    { forall a b. Hypothesis 'Unfinished a b -> a
currentToken :: a,
      forall a b. Hypothesis 'Unfinished a b -> Float
currentScore :: Float,
      forall a b.
Hypothesis 'Unfinished a b -> Hypothesis 'Unfinished a b
previousHypothesis :: Hypothesis 'Unfinished a b
    } ->
    Hypothesis 'Unfinished a b
  FinishedHypothesis ::
    forall a b.
    { forall a b. Hypothesis 'Finished a b -> a
finalToken :: a,
      forall a b. Hypothesis 'Finished a b -> Float
finalScore :: Float,
      forall a b. Hypothesis 'Finished a b -> Hypothesis 'Unfinished a b
penultimateHypothesis :: Hypothesis 'Unfinished a b,
      forall a b. Hypothesis 'Finished a b -> b
finalValue :: b
    } ->
    Hypothesis 'Finished a b

deriving instance (Eq a, Eq b) => Eq (Hypothesis 'Unfinished a b)

deriving instance (Eq a, Eq b) => Eq (Hypothesis 'Finished a b)

deriving instance (Ord a, Ord b) => Ord (Hypothesis 'Unfinished a b)

deriving instance (Ord a, Ord b) => Ord (Hypothesis 'Finished a b)

deriving instance (Show a, Show b) => Show (Hypothesis 'Unfinished a b)

deriving instance (Show a, Show b) => Show (Hypothesis 'Finished a b)

getTokens :: forall a b. Hypothesis 'Unfinished a b -> [a]
getTokens :: forall a b. Hypothesis 'Unfinished a b -> [a]
getTokens Hypothesis 'Unfinished a b
InitialHypothesis = []
getTokens (UnfinishedHypothesis a
token Float
_ Hypothesis 'Unfinished a b
previousHypothesis) = a
token forall a. a -> [a] -> [a]
: forall a b. Hypothesis 'Unfinished a b -> [a]
getTokens Hypothesis 'Unfinished a b
previousHypothesis

getScore :: forall a b. Hypothesis 'Unfinished a b -> Float
getScore :: forall a b. Hypothesis 'Unfinished a b -> Float
getScore Hypothesis 'Unfinished a b
InitialHypothesis = Float
0
getScore (UnfinishedHypothesis a
_ Float
previousScore Hypothesis 'Unfinished a b
_) = Float
previousScore

data SomeHypothesis a b = forall isFinished. SomeHypothesis {()
unSomeHypothesis :: Hypothesis isFinished a b}

beamSearch ::
  forall a b m.
  Monad m =>
  Int ->
  Int ->
  ([Hypothesis 'Unfinished a b] -> m [SomeHypothesis a b]) ->
  m [Beams a b]
beamSearch :: forall a b (m :: * -> *).
Monad m =>
Int
-> Int
-> ([Hypothesis 'Unfinished a b] -> m [SomeHypothesis a b])
-> m [Beams a b]
beamSearch Int
maxSteps Int
beamSize [Hypothesis 'Unfinished a b] -> m [SomeHypothesis a b]
cont = Int -> Beams a b -> m [Beams a b]
go Int
maxSteps (forall a b.
[Hypothesis 'Finished a b]
-> [Hypothesis 'Unfinished a b] -> Beams a b
Beams [] (forall a. Int -> a -> [a]
replicate Int
beamSize forall a b. Hypothesis 'Unfinished a b
InitialHypothesis))
  where
    go :: Int -> Beams a b -> m [Beams a b]
go !Int
_ (Beams [Hypothesis 'Finished a b]
_ []) = forall (f :: * -> *) a. Applicative f => a -> f a
pure []
    go Int
n Beams a b
beams
      | Int
n forall a. Ord a => a -> a -> Bool
<= Int
0 = forall (f :: * -> *) a. Applicative f => a -> f a
pure []
      | Bool
otherwise = do
        Beams a b
beams' <- forall a b (m :: * -> *).
Functor m =>
([Hypothesis 'Unfinished a b] -> m [SomeHypothesis a b])
-> Beams a b -> m (Beams a b)
beamSearchStep [Hypothesis 'Unfinished a b] -> m [SomeHypothesis a b]
cont Beams a b
beams
        (Beams a b
beams' forall a. a -> [a] -> [a]
:) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> Beams a b -> m [Beams a b]
go (Int
n forall a. Num a => a -> a -> a
- Int
1) Beams a b
beams'

beamSearchStep ::
  forall a b m.
  Functor m =>
  ([Hypothesis 'Unfinished a b] -> m [SomeHypothesis a b]) ->
  Beams a b ->
  m (Beams a b)
beamSearchStep :: forall a b (m :: * -> *).
Functor m =>
([Hypothesis 'Unfinished a b] -> m [SomeHypothesis a b])
-> Beams a b -> m (Beams a b)
beamSearchStep [Hypothesis 'Unfinished a b] -> m [SomeHypothesis a b]
cont Beams a b
beam =
  let someHypotheses :: m [SomeHypothesis a b]
someHypotheses =
        forall a. Int -> [a] -> [a]
take (forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall a b. Beams a b -> [Hypothesis 'Unfinished a b]
unfinished Beams a b
beam))
          forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [a] -> [a]
reverse
          forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn @Float
            ( \case
                SomeHypothesis Hypothesis isFinished a b
InitialHypothesis -> Float
0
                SomeHypothesis u :: Hypothesis isFinished a b
u@UnfinishedHypothesis {} -> forall a b. Hypothesis 'Unfinished a b -> Float
currentScore Hypothesis isFinished a b
u
                SomeHypothesis f :: Hypothesis isFinished a b
f@FinishedHypothesis {} -> forall a b. Hypothesis 'Finished a b -> Float
finalScore Hypothesis isFinished a b
f
            )
          forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([Hypothesis 'Unfinished a b] -> m [SomeHypothesis a b]
cont forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. Beams a b -> [Hypothesis 'Unfinished a b]
unfinished forall a b. (a -> b) -> a -> b
$ Beams a b
beam)
   in forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
        ( \(Beams [Hypothesis 'Finished a b]
fs [Hypothesis 'Unfinished a b]
us) SomeHypothesis a b
someHypothesis ->
            case SomeHypothesis a b
someHypothesis of
              SomeHypothesis u :: Hypothesis isFinished a b
u@Hypothesis isFinished a b
InitialHypothesis -> forall a b.
[Hypothesis 'Finished a b]
-> [Hypothesis 'Unfinished a b] -> Beams a b
Beams [Hypothesis 'Finished a b]
fs (Hypothesis isFinished a b
u forall a. a -> [a] -> [a]
: [Hypothesis 'Unfinished a b]
us)
              SomeHypothesis u :: Hypothesis isFinished a b
u@UnfinishedHypothesis {} -> forall a b.
[Hypothesis 'Finished a b]
-> [Hypothesis 'Unfinished a b] -> Beams a b
Beams [Hypothesis 'Finished a b]
fs (Hypothesis isFinished a b
u forall a. a -> [a] -> [a]
: [Hypothesis 'Unfinished a b]
us)
              SomeHypothesis f :: Hypothesis isFinished a b
f@FinishedHypothesis {} -> forall a b.
[Hypothesis 'Finished a b]
-> [Hypothesis 'Unfinished a b] -> Beams a b
Beams (Hypothesis isFinished a b
f forall a. a -> [a] -> [a]
: [Hypothesis 'Finished a b]
fs) [Hypothesis 'Unfinished a b]
us
        )
        (forall a b.
[Hypothesis 'Finished a b]
-> [Hypothesis 'Unfinished a b] -> Beams a b
Beams (forall a b. Beams a b -> [Hypothesis 'Finished a b]
finished Beams a b
beam) [])
        forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m [SomeHypothesis a b]
someHypotheses

runBeamSearch ::
  forall model input decoderInput encoderOutput encoderOutputShape encoderOutput' inputPaddingMask decoderOutput generatorDevice.
  ( HasForward
      model
      (SimplifiedEncoderDecoderTransformerInput input decoderInput)
      generatorDevice
      (SimplifiedEncoderDecoderTransformerOutput decoderOutput encoderOutput decoderInput inputPaddingMask)
      generatorDevice,
    encoderOutput
      ~ Tensor
          ('Gradient 'WithGradient)
          ('Layout 'Dense)
          ('Device 'CPU)
          T5DataType
          encoderOutputShape,
    'UncheckedShape ~ BroadcastShapesF encoderOutputShape 'UncheckedShape,
    SGetShape encoderOutputShape,
    HasForward
      model
      (SimplifiedEncoderDecoderTransformerGenerationInput decoderInput encoderOutput' inputPaddingMask)
      generatorDevice
      (SimplifiedEncoderDecoderTransformerOutput decoderOutput encoderOutput' decoderInput inputPaddingMask)
      generatorDevice,
    encoderOutput'
      ~ Tensor
          ('Gradient 'WithGradient)
          ('Layout 'Dense)
          ('Device 'CPU)
          T5DataType
          'UncheckedShape,
    decoderInput
      ~ Tensor
          ('Gradient 'WithoutGradient)
          ('Layout 'Dense)
          ('Device 'CPU)
          ('DataType 'Int64)
          ('Shape '[ 'Dim ('Name "*") 'UncheckedSize, 'Dim ('Name "*") 'UncheckedSize]),
    decoderOutput
      ~ Tensor
          ('Gradient 'WithGradient)
          ('Layout 'Dense)
          ('Device 'CPU)
          T5DataType
          'UncheckedShape,
    generatorDevice ~ 'Device 'CPU
  ) =>
  Int ->
  Int ->
  model ->
  input ->
  Generator generatorDevice ->
  IO [Beams Int [Int]]
runBeamSearch :: forall model input decoderInput encoderOutput
       (encoderOutputShape :: Shape [Dim (Name Symbol) (Size Nat)])
       encoderOutput' inputPaddingMask decoderOutput
       (generatorDevice :: Device (DeviceType Nat)).
(HasForward
   model
   (SimplifiedEncoderDecoderTransformerInput input decoderInput)
   generatorDevice
   (SimplifiedEncoderDecoderTransformerOutput
      decoderOutput encoderOutput decoderInput inputPaddingMask)
   generatorDevice,
 encoderOutput
 ~ Tensor
     ('Gradient 'WithGradient)
     ('Layout 'Dense)
     ('Device 'CPU)
     T5DataType
     encoderOutputShape,
 'UncheckedShape
 ~ BroadcastShapesF encoderOutputShape 'UncheckedShape,
 SGetShape encoderOutputShape,
 HasForward
   model
   (SimplifiedEncoderDecoderTransformerGenerationInput
      decoderInput encoderOutput' inputPaddingMask)
   generatorDevice
   (SimplifiedEncoderDecoderTransformerOutput
      decoderOutput encoderOutput' decoderInput inputPaddingMask)
   generatorDevice,
 encoderOutput'
 ~ Tensor
     ('Gradient 'WithGradient)
     ('Layout 'Dense)
     ('Device 'CPU)
     T5DataType
     'UncheckedShape,
 decoderInput
 ~ Tensor
     ('Gradient 'WithoutGradient)
     ('Layout 'Dense)
     ('Device 'CPU)
     ('DataType 'Int64)
     ('Shape
        '[ 'Dim ('Name "*") 'UncheckedSize,
           'Dim ('Name "*") 'UncheckedSize]),
 decoderOutput
 ~ Tensor
     ('Gradient 'WithGradient)
     ('Layout 'Dense)
     ('Device 'CPU)
     T5DataType
     'UncheckedShape,
 generatorDevice ~ 'Device 'CPU) =>
Int
-> Int
-> model
-> input
-> Generator generatorDevice
-> IO [Beams Int [Int]]
runBeamSearch Int
maxSteps Int
beamSize model
model input
input Generator generatorDevice
g =
  forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT (forall a b (m :: * -> *).
Monad m =>
Int
-> Int
-> ([Hypothesis 'Unfinished a b] -> m [SomeHypothesis a b])
-> m [Beams a b]
beamSearch Int
maxSteps Int
beamSize [Hypothesis 'Unfinished Int [Int]]
-> StateT
     (Maybe (encoderOutput, inputPaddingMask),
      Generator generatorDevice)
     IO
     [SomeHypothesis Int [Int]]
cont) (forall a. Maybe a
Nothing, Generator generatorDevice
g)
  where
    cont :: [Hypothesis 'Unfinished Int [Int]] -> StateT (Maybe (encoderOutput, inputPaddingMask), Generator generatorDevice) IO [SomeHypothesis Int [Int]]
    cont :: [Hypothesis 'Unfinished Int [Int]]
-> StateT
     (Maybe (encoderOutput, inputPaddingMask),
      Generator generatorDevice)
     IO
     [SomeHypothesis Int [Int]]
cont [Hypothesis 'Unfinished Int [Int]]
previousHypotheses = do
      let previousHypotheses' :: [Hypothesis 'Unfinished Int [Int]]
previousHypotheses' = forall a. Eq a => [a] -> [a]
nub [Hypothesis 'Unfinished Int [Int]]
previousHypotheses
      decoderInput
decoderInput :: decoderInput <- do
        let tokens :: [[Int]]
tokens = forall a. [a] -> [a]
reverse forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. Hypothesis 'Unfinished a b -> [a]
getTokens forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Hypothesis 'Unfinished Int [Int]]
previousHypotheses'
            batchSize :: Integer
batchSize = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Int
length [Hypothesis 'Unfinished Int [Int]]
previousHypotheses'
            seqSize :: Integer
seqSize =
              let go :: Hypothesis 'Unfinished Int [Int] -> Int
                  go :: Hypothesis 'Unfinished Int [Int] -> Int
go Hypothesis 'Unfinished Int [Int]
InitialHypothesis = Int
0
                  go (UnfinishedHypothesis Int
_ Float
_ Hypothesis 'Unfinished Int [Int]
previousHypothesis') = Int
1 forall a. Num a => a -> a -> a
+ Hypothesis 'Unfinished Int [Int] -> Int
go Hypothesis 'Unfinished Int [Int]
previousHypothesis'
               in forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum forall a b. (a -> b) -> a -> b
$ Hypothesis 'Unfinished Int [Int] -> Int
go forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Hypothesis 'Unfinished Int [Int]]
previousHypotheses'
        -- liftIO . print $ ((t5Vocab Map.!) <$>) <$> tokens
        forall (batchDim :: Dim (Name Symbol) (Size Nat))
       (seqDim :: Dim (Name Symbol) (Size Nat))
       (device :: Device (DeviceType Nat)) (m :: * -> *) output.
(MonadThrow m, SGetDim batchDim, SGetDim seqDim,
 Catch
   ('Shape
      '[ 'Dim ('Name "*") 'UncheckedSize,
         'Dim ('Name "*") 'UncheckedSize]
    <+> 'Shape '[batchDim, seqDim]),
 output
 ~ Tensor
     ('Gradient 'WithoutGradient)
     ('Layout 'Dense)
     device
     ('DataType 'Int64)
     ('Shape '[batchDim, seqDim])) =>
SDim batchDim
-> SDim seqDim -> SDevice device -> [[Int]] -> m output
mkT5Input (forall (name1 :: Symbol). KnownSymbol name1 => SName ('Name name1)
SName @"*" forall (name :: Name Symbol) (size :: Size Nat).
SName name -> SSize size -> SDim ('Dim name size)
:&: Integer -> SSize 'UncheckedSize
SUncheckedSize Integer
batchSize) (forall (name1 :: Symbol). KnownSymbol name1 => SName ('Name name1)
SName @"*" forall (name :: Name Symbol) (size :: Size Nat).
SName name -> SSize size -> SDim ('Dim name size)
:&: Integer -> SSize 'UncheckedSize
SUncheckedSize Integer
seqSize) (forall (deviceType1 :: DeviceType Nat).
SDeviceType deviceType1 -> SDevice ('Device deviceType1)
SDevice SDeviceType 'CPU
SCPU) [[Int]]
tokens
      [[[Float]]]
logProbs <- decoderInput
-> StateT
     (Maybe (encoderOutput, inputPaddingMask),
      Generator generatorDevice)
     IO
     [[[Float]]]
getLogProbs decoderInput
decoderInput
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [Hypothesis 'Unfinished Int [Int]]
previousHypotheses' [[[Float]]]
logProbs forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry (\Hypothesis 'Unfinished Int [Int]
previousHypothesis -> forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (Hypothesis 'Unfinished Int [Int]
-> Int -> Float -> SomeHypothesis Int [Int]
mkHypothesis Hypothesis 'Unfinished Int [Int]
previousHypothesis) [Int
0, Int
1 ..] forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [a] -> a
last)
    getLogProbs :: decoderInput -> StateT (Maybe (encoderOutput, inputPaddingMask), Generator generatorDevice) IO [[[Float]]]
    getLogProbs :: decoderInput
-> StateT
     (Maybe (encoderOutput, inputPaddingMask),
      Generator generatorDevice)
     IO
     [[[Float]]]
getLogProbs decoderInput
decoderInput = do
      (Maybe
  (Tensor
     ('Gradient 'WithGradient)
     ('Layout 'Dense)
     ('Device 'CPU)
     T5DataType
     encoderOutputShape,
   inputPaddingMask)
maybeStuff, Generator ('Device 'CPU)
g) <- forall s (m :: * -> *). MonadState s m => m s
get
      (SimplifiedEncoderDecoderTransformerOutput Tensor
  ('Gradient 'WithGradient)
  ('Layout 'Dense)
  ('Device 'CPU)
  T5DataType
  'UncheckedShape
decoderOutput Tensor
  ('Gradient 'WithGradient)
  ('Layout 'Dense)
  ('Device 'CPU)
  T5DataType
  encoderOutputShape
encoderOutput decoderInput
_ inputPaddingMask
inputPaddingMask, Generator ('Device 'CPU)
g') <- case Maybe
  (Tensor
     ('Gradient 'WithGradient)
     ('Layout 'Dense)
     ('Device 'CPU)
     T5DataType
     encoderOutputShape,
   inputPaddingMask)
maybeStuff of
        Maybe
  (Tensor
     ('Gradient 'WithGradient)
     ('Layout 'Dense)
     ('Device 'CPU)
     T5DataType
     encoderOutputShape,
   inputPaddingMask)
Nothing -> forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward model
model (forall input decoderInput.
input
-> decoderInput
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
SimplifiedEncoderDecoderTransformerInput input
input decoderInput
decoderInput) Generator ('Device 'CPU)
g
        Just (Tensor
  ('Gradient 'WithGradient)
  ('Layout 'Dense)
  ('Device 'CPU)
  T5DataType
  encoderOutputShape
encoderOutput, inputPaddingMask
inputPaddingMask) -> do
          -- decoderInputBatchDim : _ <- dims decoderInput
          Dim String Integer
decoderInputBatchDim <- forall a. HasCallStack => a
undefined
          -- _encoderOutputBatchDim : encoderOutputDims <- dims encoderOutput
          [Dim String Integer]
encoderOutputDims <- forall a. HasCallStack => a
undefined
          let encoderOutput' :: Tensor
  ('Gradient 'WithGradient)
  ('Layout 'Dense)
  ('Device 'CPU)
  T5DataType
  (BroadcastShapesF encoderOutputShape 'UncheckedShape)
encoderOutput' = forall (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape'' :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(shape'' ~ BroadcastShapesF shape shape', Catch shape'') =>
SShape shape'
-> Tensor gradient layout device dataType shape
-> Tensor gradient layout device dataType shape''
sExpand ([Dim String Integer] -> SShape 'UncheckedShape
SUncheckedShape (Dim String Integer
decoderInputBatchDim forall a. a -> [a] -> [a]
: [Dim String Integer]
encoderOutputDims)) Tensor
  ('Gradient 'WithGradient)
  ('Layout 'Dense)
  ('Device 'CPU)
  T5DataType
  encoderOutputShape
encoderOutput
          (SimplifiedEncoderDecoderTransformerOutput Tensor
  ('Gradient 'WithGradient)
  ('Layout 'Dense)
  ('Device 'CPU)
  T5DataType
  'UncheckedShape
decoderOutput Tensor
  ('Gradient 'WithGradient)
  ('Layout 'Dense)
  ('Device 'CPU)
  T5DataType
  'UncheckedShape
_ Tensor
  ('Gradient 'WithoutGradient)
  ('Layout 'Dense)
  ('Device 'CPU)
  ('DataType 'Int64)
  ('Shape
     '[ 'Dim ('Name "*") 'UncheckedSize,
        'Dim ('Name "*") 'UncheckedSize])
_ inputPaddingMask
_, Generator ('Device 'CPU)
g') <- forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward model
model (forall decoderInput encoderOutput inputPaddingMask.
decoderInput
-> encoderOutput
-> inputPaddingMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
     decoderInput encoderOutput inputPaddingMask
SimplifiedEncoderDecoderTransformerGenerationInput decoderInput
decoderInput Tensor
  ('Gradient 'WithGradient)
  ('Layout 'Dense)
  ('Device 'CPU)
  T5DataType
  (BroadcastShapesF encoderOutputShape 'UncheckedShape)
encoderOutput' inputPaddingMask
inputPaddingMask) Generator ('Device 'CPU)
g
          forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall decoderOutput encoderOutput decoderInput inputPaddingMask.
decoderOutput
-> encoderOutput
-> decoderInput
-> inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput
     decoderOutput encoderOutput decoderInput inputPaddingMask
SimplifiedEncoderDecoderTransformerOutput Tensor
  ('Gradient 'WithGradient)
  ('Layout 'Dense)
  ('Device 'CPU)
  T5DataType
  'UncheckedShape
decoderOutput Tensor
  ('Gradient 'WithGradient)
  ('Layout 'Dense)
  ('Device 'CPU)
  T5DataType
  encoderOutputShape
encoderOutput decoderInput
decoderInput inputPaddingMask
inputPaddingMask, Generator ('Device 'CPU)
g')
      forall s (m :: * -> *). MonadState s m => s -> m ()
put (forall a. a -> Maybe a
Just (Tensor
  ('Gradient 'WithGradient)
  ('Layout 'Dense)
  ('Device 'CPU)
  T5DataType
  encoderOutputShape
encoderOutput, inputPaddingMask
inputPaddingMask), Generator ('Device 'CPU)
g')
      Tensor
  ('Gradient 'WithGradient)
  ('Layout 'Dense)
  ('Device 'CPU)
  T5DataType
  (SoftmaxF ('SelectDim ('ByIndex 2)) 'UncheckedShape)
probs <- forall (selectDim :: SelectDim (By Symbol Nat))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, shape' ~ SoftmaxF selectDim shape, Catch shape') =>
SSelectDim selectDim
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
logSoftmax (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim forall a b. (a -> b) -> a -> b
$ forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @2) Tensor
  ('Gradient 'WithGradient)
  ('Layout 'Dense)
  ('Device 'CPU)
  T5DataType
  'UncheckedShape
decoderOutput
      case Tensor
  ('Gradient 'WithGradient)
  ('Layout 'Dense)
  ('Device 'CPU)
  T5DataType
  (SoftmaxF ('SelectDim ('ByIndex 2)) 'UncheckedShape)
probs of
        UnsafeTensor ForeignPtr Tensor
t -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. TensorLike a => Tensor -> a
Torch.Tensor.asValue forall b c a. (b -> c) -> (a -> b) -> a -> c
. ForeignPtr Tensor -> Tensor
Torch.Tensor.Unsafe forall a b. (a -> b) -> a -> b
$ ForeignPtr Tensor
t
    mkHypothesis :: Hypothesis 'Unfinished Int [Int] -> Int -> Float -> SomeHypothesis Int [Int]
    mkHypothesis :: Hypothesis 'Unfinished Int [Int]
-> Int -> Float -> SomeHypothesis Int [Int]
mkHypothesis Hypothesis 'Unfinished Int [Int]
previousHypothesis Int
token Float
logProb
      | Int
token forall a. Eq a => a -> a -> Bool
== Int
t5EOSTokenId =
        let finalValue :: [Int]
finalValue = forall a. [a] -> [a]
reverse forall a b. (a -> b) -> a -> b
$ Int
token forall a. a -> [a] -> [a]
: forall a b. Hypothesis 'Unfinished a b -> [a]
getTokens Hypothesis 'Unfinished Int [Int]
previousHypothesis
            finalScore :: Float
finalScore = Float
logProb forall a. Num a => a -> a -> a
+ forall a b. Hypothesis 'Unfinished a b -> Float
getScore Hypothesis 'Unfinished Int [Int]
previousHypothesis
         in forall a b (isFinished :: IsFinished).
Hypothesis isFinished a b -> SomeHypothesis a b
SomeHypothesis (forall a b.
a
-> Float
-> Hypothesis 'Unfinished a b
-> b
-> Hypothesis 'Finished a b
FinishedHypothesis Int
token Float
finalScore Hypothesis 'Unfinished Int [Int]
previousHypothesis [Int]
finalValue)
      | Bool
otherwise =
        let score :: Float
score = Float
logProb forall a. Num a => a -> a -> a
+ forall a b. Hypothesis 'Unfinished a b -> Float
getScore Hypothesis 'Unfinished Int [Int]
previousHypothesis
         in forall a b (isFinished :: IsFinished).
Hypothesis isFinished a b -> SomeHypothesis a b
SomeHypothesis (forall a b.
a
-> Float
-> Hypothesis 'Unfinished a b
-> Hypothesis 'Unfinished a b
UnfinishedHypothesis Int
token Float
score Hypothesis 'Unfinished Int [Int]
previousHypothesis)

testBeamSearch :: IO ()
testBeamSearch = do
  Tensor
  ('Gradient 'WithoutGradient)
  ('Layout 'Dense)
  ('Device 'CPU)
  ('DataType 'Int64)
  ('Shape
     '[ 'Dim ('Name "*") ('Size 1), 'Dim ('Name "*") ('Size 19)])
input <- do
    let tokens :: [[Int]]
tokens = [[Int
13959, Int
1566, Int
12, Int
2968, Int
10, Int
6536, Int
43, Int
2008, Int
24, Int
293, Int
53, Int
3, Int
9, Int
1782, Int
19, Int
207, Int
21, Int
25, Int
1]]
    -- let tokens = [[13959, 1566, 12, 2968, 10, 148, 31, 60, 423, 13, 3, 7, 10536, 55, 1]]
    -- print $ ((t5Vocab Map.!) <$>) <$> tokens
    forall (batchDim :: Dim (Name Symbol) (Size Nat))
       (seqDim :: Dim (Name Symbol) (Size Nat))
       (device :: Device (DeviceType Nat)) (m :: * -> *) output.
(MonadThrow m, SGetDim batchDim, SGetDim seqDim,
 Catch
   ('Shape
      '[ 'Dim ('Name "*") 'UncheckedSize,
         'Dim ('Name "*") 'UncheckedSize]
    <+> 'Shape '[batchDim, seqDim]),
 output
 ~ Tensor
     ('Gradient 'WithoutGradient)
     ('Layout 'Dense)
     device
     ('DataType 'Int64)
     ('Shape '[batchDim, seqDim])) =>
SDim batchDim
-> SDim seqDim -> SDevice device -> [[Int]] -> m output
mkT5Input
      (forall (name1 :: Symbol). KnownSymbol name1 => SName ('Name name1)
SName @"*" forall (name :: Name Symbol) (size :: Size Nat).
SName name -> SSize size -> SDim ('Dim name size)
:&: forall (size1 :: Nat). KnownNat size1 => SSize ('Size size1)
SSize @1)
      (forall (name1 :: Symbol). KnownSymbol name1 => SName ('Name name1)
SName @"*" forall (name :: Name Symbol) (size :: Size Nat).
SName name -> SSize size -> SDim ('Dim name size)
:&: forall (size1 :: Nat). KnownNat size1 => SSize ('Size size1)
SSize @19)
      (forall (deviceType1 :: DeviceType Nat).
SDeviceType deviceType1 -> SDevice ('Device deviceType1)
SDevice SDeviceType 'CPU
SCPU)
      [[Int]]
tokens
  StateDict
stateDict <- String -> IO StateDict
stateDictFromFile String
"/tmp/t5-small-state-dict.pt"
  let spec :: ModelSpec
  (T5Small
     'WithLMHead ('Gradient 'WithGradient) ('Device 'CPU) 'WithDropout)
spec = forall (transformerHead :: TransformerHead)
       (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (hasDropout :: HasDropout).
STransformerHead transformerHead
-> SGradient gradient
-> SDevice device
-> SHasDropout hasDropout
-> ModelSpec (T5Small transformerHead gradient device hasDropout)
t5SmallSpec STransformerHead 'WithLMHead
SWithLMHead (forall (requiresGradient :: RequiresGradient).
SRequiresGradient requiresGradient
-> SGradient ('Gradient requiresGradient)
SGradient SRequiresGradient 'WithGradient
SWithGradient) (forall (deviceType1 :: DeviceType Nat).
SDeviceType deviceType1 -> SDevice ('Device deviceType1)
SDevice SDeviceType 'CPU
SCPU) SHasDropout 'WithDropout
SWithDropout
  GSimplifiedEncoderDecoderTransformer
  (GEncoderDecoderTransformer
     ('Dim ('Name "*") ('Size 512))
     (NamedModel
        (GTransformer
           ()
           (NamedModel
              (Embedding
                 ('Gradient 'WithGradient)
                 ('Layout 'Dense)
                 ('Device 'CPU)
                 T5DataType
                 ('Dim ('Name "*") ('Size 32))
                 ('Dim ('Name "*") ('Size 8))
                 'Nothing))
           ()
           Dropout
           (NamedModel
              (GTransformerStack
                 (Vector
                    6
                    (GTransformerBlock
                       (NamedModel
                          (GSelfAttention
                             (NamedModel
                                (LayerNorm
                                   'WithoutBias
                                   ('Gradient 'WithGradient)
                                   ('Device 'CPU)
                                   T5DataType
                                   ('Shape '[ 'Dim ('Name "*") ('Size 512)])))
                             (NamedModel
                                (GMultiHeadAttention
                                   ('Dim ('Name "*") ('Size 8))
                                   ('Dim ('Name "*") ('Size 64))
                                   ('Dim ('Name "*") ('Size 512))
                                   (NamedModel
                                      (GLinear
                                         (NamedModel
                                            (Tensor
                                               ('Gradient 'WithGradient)
                                               ('Layout 'Dense)
                                               ('Device 'CPU)
                                               T5DataType
                                               ('Shape
                                                  '[ 'Dim ('Name "*") ('Size 512),
                                                     'Dim ('Name "*") ('Size 512)])))
                                         (NamedModel ())))
                                   (NamedModel
                                      (GLinear
                                         (NamedModel
                                            (Tensor
                                               ('Gradient 'WithGradient)
                                               ('Layout 'Dense)
                                               ('Device 'CPU)
                                               T5DataType
                                               ('Shape
                                                  '[ 'Dim ('Name "*") ('Size 512),
                                                     'Dim ('Name "*") ('Size 512)])))
                                         (NamedModel ())))
                                   (NamedModel
                                      (GLinear
                                         (NamedModel
                                            (Tensor
                                               ('Gradient 'WithGradient)
                                               ('Layout 'Dense)
                                               ('Device 'CPU)
                                               T5DataType
                                               ('Shape
                                                  '[ 'Dim ('Name "*") ('Size 512),
                                                     'Dim ('Name "*") ('Size 512)])))
                                         (NamedModel ())))
                                   (NamedModel
                                      (GLinear
                                         (NamedModel
                                            (Tensor
                                               ('Gradient 'WithGradient)
                                               ('Layout 'Dense)
                                               ('Device 'CPU)
                                               T5DataType
                                               ('Shape
                                                  '[ 'Dim ('Name "*") ('Size 512),
                                                     'Dim ('Name "*") ('Size 512)])))
                                         (NamedModel ())))
                                   Dropout))
                             Dropout
                             ()))
                       ()
                       (NamedModel
                          (GTransformerFeedForwardNetwork
                             (NamedModel
                                (LayerNorm
                                   'WithoutBias
                                   ('Gradient 'WithGradient)
                                   ('Device 'CPU)
                                   T5DataType
                                   ('Shape '[ 'Dim ('Name "*") ('Size 512)])))
                             (NamedModel
                                (GLinear
                                   (NamedModel
                                      (Tensor
                                         ('Gradient 'WithGradient)
                                         ('Layout 'Dense)
                                         ('Device 'CPU)
                                         T5DataType
                                         ('Shape
                                            '[ 'Dim ('Name "*") ('Size 2048),
                                               'Dim ('Name "*") ('Size 512)])))
                                   (NamedModel ())))
                             Relu
                             Dropout
                             (NamedModel
                                (GLinear
                                   (NamedModel
                                      (Tensor
                                         ('Gradient 'WithGradient)
                                         ('Layout 'Dense)
                                         ('Device 'CPU)
                                         T5DataType
                                         ('Shape
                                            '[ 'Dim ('Name "*") ('Size 512),
                                               'Dim ('Name "*") ('Size 2048)])))
                                   (NamedModel ())))
                             Dropout
                             ()))))))
           (NamedModel
              (LayerNorm
                 'WithoutBias
                 ('Gradient 'WithGradient)
                 ('Device 'CPU)
                 T5DataType
                 ('Shape '[ 'Dim ('Name "*") ('Size 512)])))
           Dropout))
     (NamedModel
        (GTransformer
           ()
           (NamedModel
              (Embedding
                 ('Gradient 'WithGradient)
                 ('Layout 'Dense)
                 ('Device 'CPU)
                 T5DataType
                 ('Dim ('Name "*") ('Size 32))
                 ('Dim ('Name "*") ('Size 8))
                 'Nothing))
           ()
           Dropout
           (NamedModel
              (GTransformerStack
                 (Vector
                    6
                    (GTransformerBlock
                       (NamedModel
                          (GSelfAttention
                             (NamedModel
                                (LayerNorm
                                   'WithoutBias
                                   ('Gradient 'WithGradient)
                                   ('Device 'CPU)
                                   T5DataType
                                   ('Shape '[ 'Dim ('Name "*") ('Size 512)])))
                             (NamedModel
                                (GMultiHeadAttention
                                   ('Dim ('Name "*") ('Size 8))
                                   ('Dim ('Name "*") ('Size 64))
                                   ('Dim ('Name "*") ('Size 512))
                                   (NamedModel
                                      (GLinear
                                         (NamedModel
                                            (Tensor
                                               ('Gradient 'WithGradient)
                                               ('Layout 'Dense)
                                               ('Device 'CPU)
                                               T5DataType
                                               ('Shape
                                                  '[ 'Dim ('Name "*") ('Size 512),
                                                     'Dim ('Name "*") ('Size 512)])))
                                         (NamedModel ())))
                                   (NamedModel
                                      (GLinear
                                         (NamedModel
                                            (Tensor
                                               ('Gradient 'WithGradient)
                                               ('Layout 'Dense)
                                               ('Device 'CPU)
                                               T5DataType
                                               ('Shape
                                                  '[ 'Dim ('Name "*") ('Size 512),
                                                     'Dim ('Name "*") ('Size 512)])))
                                         (NamedModel ())))
                                   (NamedModel
                                      (GLinear
                                         (NamedModel
                                            (Tensor
                                               ('Gradient 'WithGradient)
                                               ('Layout 'Dense)
                                               ('Device 'CPU)
                                               T5DataType
                                               ('Shape
                                                  '[ 'Dim ('Name "*") ('Size 512),
                                                     'Dim ('Name "*") ('Size 512)])))
                                         (NamedModel ())))
                                   (NamedModel
                                      (GLinear
                                         (NamedModel
                                            (Tensor
                                               ('Gradient 'WithGradient)
                                               ('Layout 'Dense)
                                               ('Device 'CPU)
                                               T5DataType
                                               ('Shape
                                                  '[ 'Dim ('Name "*") ('Size 512),
                                                     'Dim ('Name "*") ('Size 512)])))
                                         (NamedModel ())))
                                   Dropout))
                             Dropout
                             ()))
                       (NamedModel
                          (GCrossAttention
                             (NamedModel
                                (LayerNorm
                                   'WithoutBias
                                   ('Gradient 'WithGradient)
                                   ('Device 'CPU)
                                   T5DataType
                                   ('Shape '[ 'Dim ('Name "*") ('Size 512)])))
                             (NamedModel
                                (GMultiHeadAttention
                                   ('Dim ('Name "*") ('Size 8))
                                   ('Dim ('Name "*") ('Size 64))
                                   ('Dim ('Name "*") ('Size 512))
                                   (NamedModel
                                      (GLinear
                                         (NamedModel
                                            (Tensor
                                               ('Gradient 'WithGradient)
                                               ('Layout 'Dense)
                                               ('Device 'CPU)
                                               T5DataType
                                               ('Shape
                                                  '[ 'Dim ('Name "*") ('Size 512),
                                                     'Dim ('Name "*") ('Size 512)])))
                                         (NamedModel ())))
                                   (NamedModel
                                      (GLinear
                                         (NamedModel
                                            (Tensor
                                               ('Gradient 'WithGradient)
                                               ('Layout 'Dense)
                                               ('Device 'CPU)
                                               T5DataType
                                               ('Shape
                                                  '[ 'Dim ('Name "*") ('Size 512),
                                                     'Dim ('Name "*") ('Size 512)])))
                                         (NamedModel ())))
                                   (NamedModel
                                      (GLinear
                                         (NamedModel
                                            (Tensor
                                               ('Gradient 'WithGradient)
                                               ('Layout 'Dense)
                                               ('Device 'CPU)
                                               T5DataType
                                               ('Shape
                                                  '[ 'Dim ('Name "*") ('Size 512),
                                                     'Dim ('Name "*") ('Size 512)])))
                                         (NamedModel ())))
                                   (NamedModel
                                      (GLinear
                                         (NamedModel
                                            (Tensor
                                               ('Gradient 'WithGradient)
                                               ('Layout 'Dense)
                                               ('Device 'CPU)
                                               T5DataType
                                               ('Shape
                                                  '[ 'Dim ('Name "*") ('Size 512),
                                                     'Dim ('Name "*") ('Size 512)])))
                                         (NamedModel ())))
                                   Dropout))
                             Dropout
                             ()))
                       (NamedModel
                          (GTransformerFeedForwardNetwork
                             (NamedModel
                                (LayerNorm
                                   'WithoutBias
                                   ('Gradient 'WithGradient)
                                   ('Device 'CPU)
                                   T5DataType
                                   ('Shape '[ 'Dim ('Name "*") ('Size 512)])))
                             (NamedModel
                                (GLinear
                                   (NamedModel
                                      (Tensor
                                         ('Gradient 'WithGradient)
                                         ('Layout 'Dense)
                                         ('Device 'CPU)
                                         T5DataType
                                         ('Shape
                                            '[ 'Dim ('Name "*") ('Size 2048),
                                               'Dim ('Name "*") ('Size 512)])))
                                   (NamedModel ())))
                             Relu
                             Dropout
                             (NamedModel
                                (GLinear
                                   (NamedModel
                                      (Tensor
                                         ('Gradient 'WithGradient)
                                         ('Layout 'Dense)
                                         ('Device 'CPU)
                                         T5DataType
                                         ('Shape
                                            '[ 'Dim ('Name "*") ('Size 512),
                                               'Dim ('Name "*") ('Size 2048)])))
                                   (NamedModel ())))
                             Dropout
                             ()))))))
           (NamedModel
              (LayerNorm
                 'WithoutBias
                 ('Gradient 'WithGradient)
                 ('Device 'CPU)
                 T5DataType
                 ('Shape '[ 'Dim ('Name "*") ('Size 512)])))
           Dropout))
     (NamedModel
        (Embedding
           ('Gradient 'WithGradient)
           ('Layout 'Dense)
           ('Device 'CPU)
           T5DataType
           ('Dim ('Name "*") ('Size 32128))
           ('Dim ('Name "*") ('Size 512))
           'Nothing))
     (NamedModel
        (GLMHead
           ('Dim ('Name "*") ('Size 512))
           ()
           ()
           ()
           (NamedModel
              (GLinear
                 (NamedModel
                    (Tensor
                       ('Gradient 'WithGradient)
                       ('Layout 'Dense)
                       ('Device 'CPU)
                       T5DataType
                       ('Shape
                          '[ 'Dim ('Name "*") ('Size 32128), 'Dim ('Name "*") ('Size 512)])))
                 (NamedModel ())))
           (GBias ()))))
  (MkRelPos ('Dim ('Name "*") ('Size 32)))
  (MkRelPos ('Dim ('Name "*") ('Size 32)))
  MkTransformerPaddingMask
  (MkTransformerAttentionMask T5DataType)
  (MkTransformerCrossAttentionMask T5DataType)
  (MkTransformerDecoderAttentionMask T5DataType)
model <- forall a b c. (a -> b -> c) -> b -> a -> c
flip forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT StateDict
stateDict forall a b. (a -> b) -> a -> b
$ forall model (m :: * -> *).
(HasStateDict model, MonadIO m, MonadThrow m,
 MonadState StateDict m) =>
ModelSpec model -> StateDictKey -> m model
fromStateDict GSimplifiedEncoderDecoderTransformer
  (GEncoderDecoderTransformer
     ('Dim ('Name "*") ('Size 512))
     (NamedModel
        (GTransformer
           ()
           (NamedModel
              (EmbeddingSpec
                 ('Gradient 'WithGradient)
                 ('Layout 'Dense)
                 ('Device 'CPU)
                 T5DataType
                 ('Dim ('Name "*") ('Size 32))
                 ('Dim ('Name "*") ('Size 8))
                 'Nothing))
           ()
           Dropout
           (NamedModel
              (GTransformerStack
                 (VectorSpec
                    6
                    (GTransformerBlock
                       (NamedModel
                          (GSelfAttention
                             (NamedModel
                                (LayerNorm
                                   'WithoutBias
                                   ('Gradient 'WithGradient)
                                   ('Device 'CPU)
                                   T5DataType
                                   ('Shape '[ 'Dim ('Name "*") ('Size 512)])))
                             (NamedModel
                                (GMultiHeadAttention
                                   ('Dim ('Name "*") ('Size 8))
                                   ('Dim ('Name "*") ('Size 64))
                                   ('Dim ('Name "*") ('Size 512))
                                   (NamedModel
                                      (GLinear
                                         (NamedModel
                                            (Tensor
                                               ('Gradient 'WithGradient)
                                               ('Layout 'Dense)
                                               ('Device 'CPU)
                                               T5DataType
                                               ('Shape
                                                  '[ 'Dim ('Name "*") ('Size 512),
                                                     'Dim ('Name "*") ('Size 512)])))
                                         (NamedModel ())))
                                   (NamedModel
                                      (GLinear
                                         (NamedModel
                                            (Tensor
                                               ('Gradient 'WithGradient)
                                               ('Layout 'Dense)
                                               ('Device 'CPU)
                                               T5DataType
                                               ('Shape
                                                  '[ 'Dim ('Name "*") ('Size 512),
                                                     'Dim ('Name "*") ('Size 512)])))
                                         (NamedModel ())))
                                   (NamedModel
                                      (GLinear
                                         (NamedModel
                                            (Tensor
                                               ('Gradient 'WithGradient)
                                               ('Layout 'Dense)
                                               ('Device 'CPU)
                                               T5DataType
                                               ('Shape
                                                  '[ 'Dim ('Name "*") ('Size 512),
                                                     'Dim ('Name "*") ('Size 512)])))
                                         (NamedModel ())))
                                   (NamedModel
                                      (GLinear
                                         (NamedModel
                                            (Tensor
                                               ('Gradient 'WithGradient)
                                               ('Layout 'Dense)
                                               ('Device 'CPU)
                                               T5DataType
                                               ('Shape
                                                  '[ 'Dim ('Name "*") ('Size 512),
                                                     'Dim ('Name "*") ('Size 512)])))
                                         (NamedModel ())))
                                   Dropout))
                             Dropout
                             ()))
                       ()
                       (NamedModel
                          (GTransformerFeedForwardNetwork
                             (NamedModel
                                (LayerNorm
                                   'WithoutBias
                                   ('Gradient 'WithGradient)
                                   ('Device 'CPU)
                                   T5DataType
                                   ('Shape '[ 'Dim ('Name "*") ('Size 512)])))
                             (NamedModel
                                (GLinear
                                   (NamedModel
                                      (Tensor
                                         ('Gradient 'WithGradient)
                                         ('Layout 'Dense)
                                         ('Device 'CPU)
                                         T5DataType
                                         ('Shape
                                            '[ 'Dim ('Name "*") ('Size 2048),
                                               'Dim ('Name "*") ('Size 512)])))
                                   (NamedModel ())))
                             Relu
                             Dropout
                             (NamedModel
                                (GLinear
                                   (NamedModel
                                      (Tensor
                                         ('Gradient 'WithGradient)
                                         ('Layout 'Dense)
                                         ('Device 'CPU)
                                         T5DataType
                                         ('Shape
                                            '[ 'Dim ('Name "*") ('Size 512),
                                               'Dim ('Name "*") ('Size 2048)])))
                                   (NamedModel ())))
                             Dropout
                             ()))))))
           (NamedModel
              (LayerNormSpec
                 'WithoutBias
                 ('Gradient 'WithGradient)
                 ('Device 'CPU)
                 T5DataType
                 ('Shape '[ 'Dim ('Name "*") ('Size 512)])))
           Dropout))
     (NamedModel
        (GTransformer
           ()
           (NamedModel
              (EmbeddingSpec
                 ('Gradient 'WithGradient)
                 ('Layout 'Dense)
                 ('Device 'CPU)
                 T5DataType
                 ('Dim ('Name "*") ('Size 32))
                 ('Dim ('Name "*") ('Size 8))
                 'Nothing))
           ()
           Dropout
           (NamedModel
              (GTransformerStack
                 (VectorSpec
                    6
                    (GTransformerBlock
                       (NamedModel
                          (GSelfAttention
                             (NamedModel
                                (LayerNorm
                                   'WithoutBias
                                   ('Gradient 'WithGradient)
                                   ('Device 'CPU)
                                   T5DataType
                                   ('Shape '[ 'Dim ('Name "*") ('Size 512)])))
                             (NamedModel
                                (GMultiHeadAttention
                                   ('Dim ('Name "*") ('Size 8))
                                   ('Dim ('Name "*") ('Size 64))
                                   ('Dim ('Name "*") ('Size 512))
                                   (NamedModel
                                      (GLinear
                                         (NamedModel
                                            (Tensor
                                               ('Gradient 'WithGradient)
                                               ('Layout 'Dense)
                                               ('Device 'CPU)
                                               T5DataType
                                               ('Shape
                                                  '[ 'Dim ('Name "*") ('Size 512),
                                                     'Dim ('Name "*") ('Size 512)])))
                                         (NamedModel ())))
                                   (NamedModel
                                      (GLinear
                                         (NamedModel
                                            (Tensor
                                               ('Gradient 'WithGradient)
                                               ('Layout 'Dense)
                                               ('Device 'CPU)
                                               T5DataType
                                               ('Shape
                                                  '[ 'Dim ('Name "*") ('Size 512),
                                                     'Dim ('Name "*") ('Size 512)])))
                                         (NamedModel ())))
                                   (NamedModel
                                      (GLinear
                                         (NamedModel
                                            (Tensor
                                               ('Gradient 'WithGradient)
                                               ('Layout 'Dense)
                                               ('Device 'CPU)
                                               T5DataType
                                               ('Shape
                                                  '[ 'Dim ('Name "*") ('Size 512),
                                                     'Dim ('Name "*") ('Size 512)])))
                                         (NamedModel ())))
                                   (NamedModel
                                      (GLinear
                                         (NamedModel
                                            (Tensor
                                               ('Gradient 'WithGradient)
                                               ('Layout 'Dense)
                                               ('Device 'CPU)
                                               T5DataType
                                               ('Shape
                                                  '[ 'Dim ('Name "*") ('Size 512),
                                                     'Dim ('Name "*") ('Size 512)])))
                                         (NamedModel ())))
                                   Dropout))
                             Dropout
                             ()))
                       (NamedModel
                          (GCrossAttention
                             (NamedModel
                                (LayerNorm
                                   'WithoutBias
                                   ('Gradient 'WithGradient)
                                   ('Device 'CPU)
                                   T5DataType
                                   ('Shape '[ 'Dim ('Name "*") ('Size 512)])))
                             (NamedModel
                                (GMultiHeadAttention
                                   ('Dim ('Name "*") ('Size 8))
                                   ('Dim ('Name "*") ('Size 64))
                                   ('Dim ('Name "*") ('Size 512))
                                   (NamedModel
                                      (GLinear
                                         (NamedModel
                                            (Tensor
                                               ('Gradient 'WithGradient)
                                               ('Layout 'Dense)
                                               ('Device 'CPU)
                                               T5DataType
                                               ('Shape
                                                  '[ 'Dim ('Name "*") ('Size 512),
                                                     'Dim ('Name "*") ('Size 512)])))
                                         (NamedModel ())))
                                   (NamedModel
                                      (GLinear
                                         (NamedModel
                                            (Tensor
                                               ('Gradient 'WithGradient)
                                               ('Layout 'Dense)
                                               ('Device 'CPU)
                                               T5DataType
                                               ('Shape
                                                  '[ 'Dim ('Name "*") ('Size 512),
                                                     'Dim ('Name "*") ('Size 512)])))
                                         (NamedModel ())))
                                   (NamedModel
                                      (GLinear
                                         (NamedModel
                                            (Tensor
                                               ('Gradient 'WithGradient)
                                               ('Layout 'Dense)
                                               ('Device 'CPU)
                                               T5DataType
                                               ('Shape
                                                  '[ 'Dim ('Name "*") ('Size 512),
                                                     'Dim ('Name "*") ('Size 512)])))
                                         (NamedModel ())))
                                   (NamedModel
                                      (GLinear
                                         (NamedModel
                                            (Tensor
                                               ('Gradient 'WithGradient)
                                               ('Layout 'Dense)
                                               ('Device 'CPU)
                                               T5DataType
                                               ('Shape
                                                  '[ 'Dim ('Name "*") ('Size 512),
                                                     'Dim ('Name "*") ('Size 512)])))
                                         (NamedModel ())))
                                   Dropout))
                             Dropout
                             ()))
                       (NamedModel
                          (GTransformerFeedForwardNetwork
                             (NamedModel
                                (LayerNorm
                                   'WithoutBias
                                   ('Gradient 'WithGradient)
                                   ('Device 'CPU)
                                   T5DataType
                                   ('Shape '[ 'Dim ('Name "*") ('Size 512)])))
                             (NamedModel
                                (GLinear
                                   (NamedModel
                                      (Tensor
                                         ('Gradient 'WithGradient)
                                         ('Layout 'Dense)
                                         ('Device 'CPU)
                                         T5DataType
                                         ('Shape
                                            '[ 'Dim ('Name "*") ('Size 2048),
                                               'Dim ('Name "*") ('Size 512)])))
                                   (NamedModel ())))
                             Relu
                             Dropout
                             (NamedModel
                                (GLinear
                                   (NamedModel
                                      (Tensor
                                         ('Gradient 'WithGradient)
                                         ('Layout 'Dense)
                                         ('Device 'CPU)
                                         T5DataType
                                         ('Shape
                                            '[ 'Dim ('Name "*") ('Size 512),
                                               'Dim ('Name "*") ('Size 2048)])))
                                   (NamedModel ())))
                             Dropout
                             ()))))))
           (NamedModel
              (LayerNormSpec
                 'WithoutBias
                 ('Gradient 'WithGradient)
                 ('Device 'CPU)
                 T5DataType
                 ('Shape '[ 'Dim ('Name "*") ('Size 512)])))
           Dropout))
     (NamedModel
        (EmbeddingSpec
           ('Gradient 'WithGradient)
           ('Layout 'Dense)
           ('Device 'CPU)
           T5DataType
           ('Dim ('Name "*") ('Size 32128))
           ('Dim ('Name "*") ('Size 512))
           'Nothing))
     (NamedModel
        (GLMHead
           ('Dim ('Name "*") ('Size 512))
           ()
           ()
           ()
           (NamedModel
              (GLinear
                 (NamedModel
                    (TensorSpec
                       ('Gradient 'WithGradient)
                       ('Layout 'Dense)
                       ('Device 'CPU)
                       T5DataType
                       ('Shape
                          '[ 'Dim ('Name "*") ('Size 32128), 'Dim ('Name "*") ('Size 512)])))
                 (NamedModel ())))
           (GBias ()))))
  (MkRelPos ('Dim ('Name "*") ('Size 32)))
  (MkRelPos ('Dim ('Name "*") ('Size 32)))
  MkTransformerPaddingMask
  (MkTransformerAttentionMask T5DataType)
  (MkTransformerCrossAttentionMask T5DataType)
  (MkTransformerDecoderAttentionMask T5DataType)
spec forall a. Monoid a => a
mempty
  Generator ('Device 'CPU)
g <- forall (m :: * -> *) (device :: Device (DeviceType Nat)).
MonadThrow m =>
SDevice device -> Word64 -> m (Generator device)
sMkGenerator (forall (deviceType1 :: DeviceType Nat).
SDeviceType deviceType1 -> SDevice ('Device deviceType1)
SDevice SDeviceType 'CPU
SCPU) Word64
0
  Beams [Hypothesis 'Finished Int [Int]]
finished [Hypothesis 'Unfinished Int [Int]]
_ <- forall a. [a] -> a
last forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall model input decoderInput encoderOutput
       (encoderOutputShape :: Shape [Dim (Name Symbol) (Size Nat)])
       encoderOutput' inputPaddingMask decoderOutput
       (generatorDevice :: Device (DeviceType Nat)).
(HasForward
   model
   (SimplifiedEncoderDecoderTransformerInput input decoderInput)
   generatorDevice
   (SimplifiedEncoderDecoderTransformerOutput
      decoderOutput encoderOutput decoderInput inputPaddingMask)
   generatorDevice,
 encoderOutput
 ~ Tensor
     ('Gradient 'WithGradient)
     ('Layout 'Dense)
     ('Device 'CPU)
     T5DataType
     encoderOutputShape,
 'UncheckedShape
 ~ BroadcastShapesF encoderOutputShape 'UncheckedShape,
 SGetShape encoderOutputShape,
 HasForward
   model
   (SimplifiedEncoderDecoderTransformerGenerationInput
      decoderInput encoderOutput' inputPaddingMask)
   generatorDevice
   (SimplifiedEncoderDecoderTransformerOutput
      decoderOutput encoderOutput' decoderInput inputPaddingMask)
   generatorDevice,
 encoderOutput'
 ~ Tensor
     ('Gradient 'WithGradient)
     ('Layout 'Dense)
     ('Device 'CPU)
     T5DataType
     'UncheckedShape,
 decoderInput
 ~ Tensor
     ('Gradient 'WithoutGradient)
     ('Layout 'Dense)
     ('Device 'CPU)
     ('DataType 'Int64)
     ('Shape
        '[ 'Dim ('Name "*") 'UncheckedSize,
           'Dim ('Name "*") 'UncheckedSize]),
 decoderOutput
 ~ Tensor
     ('Gradient 'WithGradient)
     ('Layout 'Dense)
     ('Device 'CPU)
     T5DataType
     'UncheckedShape,
 generatorDevice ~ 'Device 'CPU) =>
Int
-> Int
-> model
-> input
-> Generator generatorDevice
-> IO [Beams Int [Int]]
runBeamSearch Int
50 Int
1 GSimplifiedEncoderDecoderTransformer
  (GEncoderDecoderTransformer
     ('Dim ('Name "*") ('Size 512))
     (NamedModel
        (GTransformer
           ()
           (NamedModel
              (Embedding
                 ('Gradient 'WithGradient)
                 ('Layout 'Dense)
                 ('Device 'CPU)
                 T5DataType
                 ('Dim ('Name "*") ('Size 32))
                 ('Dim ('Name "*") ('Size 8))
                 'Nothing))
           ()
           Dropout
           (NamedModel
              (GTransformerStack
                 (Vector
                    6
                    (GTransformerBlock
                       (NamedModel
                          (GSelfAttention
                             (NamedModel
                                (LayerNorm
                                   'WithoutBias
                                   ('Gradient 'WithGradient)
                                   ('Device 'CPU)
                                   T5DataType
                                   ('Shape '[ 'Dim ('Name "*") ('Size 512)])))
                             (NamedModel
                                (GMultiHeadAttention
                                   ('Dim ('Name "*") ('Size 8))
                                   ('Dim ('Name "*") ('Size 64))
                                   ('Dim ('Name "*") ('Size 512))
                                   (NamedModel
                                      (GLinear
                                         (NamedModel
                                            (Tensor
                                               ('Gradient 'WithGradient)
                                               ('Layout 'Dense)
                                               ('Device 'CPU)
                                               T5DataType
                                               ('Shape
                                                  '[ 'Dim ('Name "*") ('Size 512),
                                                     'Dim ('Name "*") ('Size 512)])))
                                         (NamedModel ())))
                                   (NamedModel
                                      (GLinear
                                         (NamedModel
                                            (Tensor
                                               ('Gradient 'WithGradient)
                                               ('Layout 'Dense)
                                               ('Device 'CPU)
                                               T5DataType
                                               ('Shape
                                                  '[ 'Dim ('Name "*") ('Size 512),
                                                     'Dim ('Name "*") ('Size 512)])))
                                         (NamedModel ())))
                                   (NamedModel
                                      (GLinear
                                         (NamedModel
                                            (Tensor
                                               ('Gradient 'WithGradient)
                                               ('Layout 'Dense)
                                               ('Device 'CPU)
                                               T5DataType
                                               ('Shape
                                                  '[ 'Dim ('Name "*") ('Size 512),
                                                     'Dim ('Name "*") ('Size 512)])))
                                         (NamedModel ())))
                                   (NamedModel
                                      (GLinear
                                         (NamedModel
                                            (Tensor
                                               ('Gradient 'WithGradient)
                                               ('Layout 'Dense)
                                               ('Device 'CPU)
                                               T5DataType
                                               ('Shape
                                                  '[ 'Dim ('Name "*") ('Size 512),
                                                     'Dim ('Name "*") ('Size 512)])))
                                         (NamedModel ())))
                                   Dropout))
                             Dropout
                             ()))
                       ()
                       (NamedModel
                          (GTransformerFeedForwardNetwork
                             (NamedModel
                                (LayerNorm
                                   'WithoutBias
                                   ('Gradient 'WithGradient)
                                   ('Device 'CPU)
                                   T5DataType
                                   ('Shape '[ 'Dim ('Name "*") ('Size 512)])))
                             (NamedModel
                                (GLinear
                                   (NamedModel
                                      (Tensor
                                         ('Gradient 'WithGradient)
                                         ('Layout 'Dense)
                                         ('Device 'CPU)
                                         T5DataType
                                         ('Shape
                                            '[ 'Dim ('Name "*") ('Size 2048),
                                               'Dim ('Name "*") ('Size 512)])))
                                   (NamedModel ())))
                             Relu
                             Dropout
                             (NamedModel
                                (GLinear
                                   (NamedModel
                                      (Tensor
                                         ('Gradient 'WithGradient)
                                         ('Layout 'Dense)
                                         ('Device 'CPU)
                                         T5DataType
                                         ('Shape
                                            '[ 'Dim ('Name "*") ('Size 512),
                                               'Dim ('Name "*") ('Size 2048)])))
                                   (NamedModel ())))
                             Dropout
                             ()))))))
           (NamedModel
              (LayerNorm
                 'WithoutBias
                 ('Gradient 'WithGradient)
                 ('Device 'CPU)
                 T5DataType
                 ('Shape '[ 'Dim ('Name "*") ('Size 512)])))
           Dropout))
     (NamedModel
        (GTransformer
           ()
           (NamedModel
              (Embedding
                 ('Gradient 'WithGradient)
                 ('Layout 'Dense)
                 ('Device 'CPU)
                 T5DataType
                 ('Dim ('Name "*") ('Size 32))
                 ('Dim ('Name "*") ('Size 8))
                 'Nothing))
           ()
           Dropout
           (NamedModel
              (GTransformerStack
                 (Vector
                    6
                    (GTransformerBlock
                       (NamedModel
                          (GSelfAttention
                             (NamedModel
                                (LayerNorm
                                   'WithoutBias
                                   ('Gradient 'WithGradient)
                                   ('Device 'CPU)
                                   T5DataType
                                   ('Shape '[ 'Dim ('Name "*") ('Size 512)])))
                             (NamedModel
                                (GMultiHeadAttention
                                   ('Dim ('Name "*") ('Size 8))
                                   ('Dim ('Name "*") ('Size 64))
                                   ('Dim ('Name "*") ('Size 512))
                                   (NamedModel
                                      (GLinear
                                         (NamedModel
                                            (Tensor
                                               ('Gradient 'WithGradient)
                                               ('Layout 'Dense)
                                               ('Device 'CPU)
                                               T5DataType
                                               ('Shape
                                                  '[ 'Dim ('Name "*") ('Size 512),
                                                     'Dim ('Name "*") ('Size 512)])))
                                         (NamedModel ())))
                                   (NamedModel
                                      (GLinear
                                         (NamedModel
                                            (Tensor
                                               ('Gradient 'WithGradient)
                                               ('Layout 'Dense)
                                               ('Device 'CPU)
                                               T5DataType
                                               ('Shape
                                                  '[ 'Dim ('Name "*") ('Size 512),
                                                     'Dim ('Name "*") ('Size 512)])))
                                         (NamedModel ())))
                                   (NamedModel
                                      (GLinear
                                         (NamedModel
                                            (Tensor
                                               ('Gradient 'WithGradient)
                                               ('Layout 'Dense)
                                               ('Device 'CPU)
                                               T5DataType
                                               ('Shape
                                                  '[ 'Dim ('Name "*") ('Size 512),
                                                     'Dim ('Name "*") ('Size 512)])))
                                         (NamedModel ())))
                                   (NamedModel
                                      (GLinear
                                         (NamedModel
                                            (Tensor
                                               ('Gradient 'WithGradient)
                                               ('Layout 'Dense)
                                               ('Device 'CPU)
                                               T5DataType
                                               ('Shape
                                                  '[ 'Dim ('Name "*") ('Size 512),
                                                     'Dim ('Name "*") ('Size 512)])))
                                         (NamedModel ())))
                                   Dropout))
                             Dropout
                             ()))
                       (NamedModel
                          (GCrossAttention
                             (NamedModel
                                (LayerNorm
                                   'WithoutBias
                                   ('Gradient 'WithGradient)
                                   ('Device 'CPU)
                                   T5DataType
                                   ('Shape '[ 'Dim ('Name "*") ('Size 512)])))
                             (NamedModel
                                (GMultiHeadAttention
                                   ('Dim ('Name "*") ('Size 8))
                                   ('Dim ('Name "*") ('Size 64))
                                   ('Dim ('Name "*") ('Size 512))
                                   (NamedModel
                                      (GLinear
                                         (NamedModel
                                            (Tensor
                                               ('Gradient 'WithGradient)
                                               ('Layout 'Dense)
                                               ('Device 'CPU)
                                               T5DataType
                                               ('Shape
                                                  '[ 'Dim ('Name "*") ('Size 512),
                                                     'Dim ('Name "*") ('Size 512)])))
                                         (NamedModel ())))
                                   (NamedModel
                                      (GLinear
                                         (NamedModel
                                            (Tensor
                                               ('Gradient 'WithGradient)
                                               ('Layout 'Dense)
                                               ('Device 'CPU)
                                               T5DataType
                                               ('Shape
                                                  '[ 'Dim ('Name "*") ('Size 512),
                                                     'Dim ('Name "*") ('Size 512)])))
                                         (NamedModel ())))
                                   (NamedModel
                                      (GLinear
                                         (NamedModel
                                            (Tensor
                                               ('Gradient 'WithGradient)
                                               ('Layout 'Dense)
                                               ('Device 'CPU)
                                               T5DataType
                                               ('Shape
                                                  '[ 'Dim ('Name "*") ('Size 512),
                                                     'Dim ('Name "*") ('Size 512)])))
                                         (NamedModel ())))
                                   (NamedModel
                                      (GLinear
                                         (NamedModel
                                            (Tensor
                                               ('Gradient 'WithGradient)
                                               ('Layout 'Dense)
                                               ('Device 'CPU)
                                               T5DataType
                                               ('Shape
                                                  '[ 'Dim ('Name "*") ('Size 512),
                                                     'Dim ('Name "*") ('Size 512)])))
                                         (NamedModel ())))
                                   Dropout))
                             Dropout
                             ()))
                       (NamedModel
                          (GTransformerFeedForwardNetwork
                             (NamedModel
                                (LayerNorm
                                   'WithoutBias
                                   ('Gradient 'WithGradient)
                                   ('Device 'CPU)
                                   T5DataType
                                   ('Shape '[ 'Dim ('Name "*") ('Size 512)])))
                             (NamedModel
                                (GLinear
                                   (NamedModel
                                      (Tensor
                                         ('Gradient 'WithGradient)
                                         ('Layout 'Dense)
                                         ('Device 'CPU)
                                         T5DataType
                                         ('Shape
                                            '[ 'Dim ('Name "*") ('Size 2048),
                                               'Dim ('Name "*") ('Size 512)])))
                                   (NamedModel ())))
                             Relu
                             Dropout
                             (NamedModel
                                (GLinear
                                   (NamedModel
                                      (Tensor
                                         ('Gradient 'WithGradient)
                                         ('Layout 'Dense)
                                         ('Device 'CPU)
                                         T5DataType
                                         ('Shape
                                            '[ 'Dim ('Name "*") ('Size 512),
                                               'Dim ('Name "*") ('Size 2048)])))
                                   (NamedModel ())))
                             Dropout
                             ()))))))
           (NamedModel
              (LayerNorm
                 'WithoutBias
                 ('Gradient 'WithGradient)
                 ('Device 'CPU)
                 T5DataType
                 ('Shape '[ 'Dim ('Name "*") ('Size 512)])))
           Dropout))
     (NamedModel
        (Embedding
           ('Gradient 'WithGradient)
           ('Layout 'Dense)
           ('Device 'CPU)
           T5DataType
           ('Dim ('Name "*") ('Size 32128))
           ('Dim ('Name "*") ('Size 512))
           'Nothing))
     (NamedModel
        (GLMHead
           ('Dim ('Name "*") ('Size 512))
           ()
           ()
           ()
           (NamedModel
              (GLinear
                 (NamedModel
                    (Tensor
                       ('Gradient 'WithGradient)
                       ('Layout 'Dense)
                       ('Device 'CPU)
                       T5DataType
                       ('Shape
                          '[ 'Dim ('Name "*") ('Size 32128), 'Dim ('Name "*") ('Size 512)])))
                 (NamedModel ())))
           (GBias ()))))
  (MkRelPos ('Dim ('Name "*") ('Size 32)))
  (MkRelPos ('Dim ('Name "*") ('Size 32)))
  MkTransformerPaddingMask
  (MkTransformerAttentionMask T5DataType)
  (MkTransformerCrossAttentionMask T5DataType)
  (MkTransformerDecoderAttentionMask T5DataType)
model Tensor
  ('Gradient 'WithoutGradient)
  ('Layout 'Dense)
  ('Device 'CPU)
  ('DataType 'Int64)
  ('Shape
     '[ 'Dim ('Name "*") ('Size 1), 'Dim ('Name "*") ('Size 19)])
input Generator ('Device 'CPU)
g
  forall a. Show a => a -> IO ()
print forall a b. (a -> b) -> a -> b
$ forall a b. Hypothesis 'Finished a b -> b
finalValue forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Hypothesis 'Finished Int [Int]]
finished

-- let tmp = parseString @[] (transParser t5Vocab t5Text) . finalValue <$> finished
-- print tmp

next ::
  forall t b i a.
  ( i ~ Int,
    Show i,
    MonadTrans t,
    Monad (t (StateT [i] b)),
    Alternative (t (StateT [i] b)),
    Monad b,
    Foldable b
  ) =>
  t (StateT [i] b) i ->
  Parser (StateT [i] b) i a ->
  (Parser (StateT [i] b) i a -> t (StateT [i] b) a) ->
  t (StateT [i] b) a
next :: forall (t :: (* -> *) -> * -> *) (b :: * -> *) i a.
(i ~ Int, Show i, MonadTrans t, Monad (t (StateT [i] b)),
 Alternative (t (StateT [i] b)), Monad b, Foldable b) =>
t (StateT [i] b) i
-> Parser (StateT [i] b) i a
-> (Parser (StateT [i] b) i a -> t (StateT [i] b) a)
-> t (StateT [i] b) a
next t (StateT [i] b) i
is Parser (StateT [i] b) i a
parser Parser (StateT [i] b) i a -> t (StateT [i] b) a
cont = do
  forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (forall (b :: * -> *) i a.
(Monad b, Foldable b) =>
Parser (StateT [i] b) i a -> StateT [i] b Bool
notNull Parser (StateT [i] b) i a
parser) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (f :: * -> *). Alternative f => Bool -> f ()
guard
  i
i <- t (StateT [i] b) i
is
  forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (i
i forall a. a -> [a] -> [a]
:)
  FreeF ((->) i) a (Parser (StateT [i] b) i a)
val <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) (m :: * -> *) a.
FreeT f m a -> m (FreeF f a (FreeT f m a))
runFreeT Parser (StateT [i] b) i a
parser
  case FreeF ((->) i) a (Parser (StateT [i] b) i a)
val of
    Pure a
a -> forall (f :: * -> *) a. Applicative f => a -> f a
pure a
a
    Free i -> Parser (StateT [i] b) i a
feed -> forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
      -- putStrLn $ "feed: " <> show (t5Vocab Map.! i)
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Parser (StateT [i] b) i a -> t (StateT [i] b) a
cont forall b c a. (b -> c) -> (a -> b) -> a -> c
. i -> Parser (StateT [i] b) i a
feed forall a b. (a -> b) -> a -> b
$ i
i

notNull ::
  (Monad b, Foldable b) =>
  Parser (StateT [i] b) i a ->
  StateT [i] b Bool
notNull :: forall (b :: * -> *) i a.
(Monad b, Foldable b) =>
Parser (StateT [i] b) i a -> StateT [i] b Bool
notNull Parser (StateT [i] b) i a
parser = forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall a b. (a -> b) -> a -> b
$ (Bool -> Bool
not forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a. Foldable t => t a -> Bool
null) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT (forall (f :: * -> *) (m :: * -> *) a.
FreeT f m a -> m (FreeF f a (FreeT f m a))
runFreeT Parser (StateT [i] b) i a
parser)

hasFree ::
  (Monad b, Foldable b) =>
  Parser (StateT [i] b) i a ->
  StateT [i] b Bool
hasFree :: forall (b :: * -> *) i a.
(Monad b, Foldable b) =>
Parser (StateT [i] b) i a -> StateT [i] b Bool
hasFree Parser (StateT [i] b) i a
parser = forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (\case Free i -> Parser (StateT [i] b) i a
_ -> Bool
True; FreeF ((->) i) a (Parser (StateT [i] b) i a)
_ -> Bool
False) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT (forall (f :: * -> *) (m :: * -> *) a.
FreeT f m a -> m (FreeF f a (FreeT f m a))
runFreeT Parser (StateT [i] b) i a
parser)

-- | @transParser vocab p@ transforms a parser @p@ over characters 'Char'
-- into a parser over token indices 'Int' using the vocabulary @vocab@.
transParser :: MonadPlus b => Map.Map Int String -> Parser b Char a -> Parser b Int a
transParser :: forall (b :: * -> *) a.
MonadPlus b =>
Map Int String -> Parser b Char a -> Parser b Int a
transParser Map Int String
vocab = forall (f :: * -> *) (m :: * -> *) a.
m (FreeF f a (FreeT f m a)) -> FreeT f m a
FreeT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall (b :: * -> *) a.
MonadPlus b =>
Map Int String -> Parser b Char a -> Parser b Int a
transParser Map Int String
vocab) forall b c a. (b -> c) -> (a -> b) -> a -> c
. FreeF ((->) Char) a (Parser b Char a)
-> FreeF ((->) Int) a (Parser b Char a)
transFreeF) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) (m :: * -> *) a.
FreeT f m a -> m (FreeF f a (FreeT f m a))
runFreeT
  where
    transFreeF :: FreeF ((->) Char) a (Parser b Char a)
-> FreeF ((->) Int) a (Parser b Char a)
transFreeF (Pure a
a) = forall (f :: * -> *) a b. a -> FreeF f a b
Pure a
a
    transFreeF (Free Char -> Parser b Char a
feed) =
      let feed' :: Int -> Parser b Char a
feed' Int
i = do
            String
s <-
              let clean :: ShowS
clean (Char
'▁' : String
s) = Char
' ' forall a. a -> [a] -> [a]
: String
s
                  clean String
s = String
s
               in forall b a. b -> (a -> b) -> Maybe a -> b
maybe forall (f :: * -> *) a. Alternative f => f a
empty (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. ShowS
clean) (forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup Int
i Map Int String
vocab)
            (Char
c, String
cs) <- forall b a. b -> (a -> b) -> Maybe a -> b
maybe forall (f :: * -> *) a. Alternative f => f a
empty forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a. [a] -> Maybe (a, [a])
uncons String
s)
            forall {m :: * -> *} {t} {a}.
Monad m =>
[t] -> FreeT ((->) t) m a -> FreeT ((->) t) m a
go String
cs (Char -> Parser b Char a
feed Char
c)
          go :: [t] -> FreeT ((->) t) m a -> FreeT ((->) t) m a
go [] FreeT ((->) t) m a
p = FreeT ((->) t) m a
p
          go (t
c : [t]
cs) FreeT ((->) t) m a
p = do
            FreeF ((->) t) a (FreeT ((->) t) m a)
val <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) (m :: * -> *) a.
FreeT f m a -> m (FreeF f a (FreeT f m a))
runFreeT FreeT ((->) t) m a
p
            case FreeF ((->) t) a (FreeT ((->) t) m a)
val of
              Pure a
a -> forall (f :: * -> *) a. Applicative f => a -> f a
pure a
a
              Free t -> FreeT ((->) t) m a
feed -> [t] -> FreeT ((->) t) m a -> FreeT ((->) t) m a
go [t]
cs (t -> FreeT ((->) t) m a
feed t
c)
       in forall (f :: * -> *) a b. f b -> FreeF f a b
Free Int -> Parser b Char a
feed'

instance
  (Alternative b, Foldable b, MonadPlus b) =>
  Parsing (FreeT ((->) Char) (StateT [Int] b))
  where
  try :: forall a.
FreeT ((->) Char) (StateT [Int] b) a
-> FreeT ((->) Char) (StateT [Int] b) a
try = forall a. a -> a
id
  <?> :: forall a.
FreeT ((->) Char) (StateT [Int] b) a
-> String -> FreeT ((->) Char) (StateT [Int] b) a
(<?>) = forall a b. a -> b -> a
const
  skipMany :: forall a.
FreeT ((->) Char) (StateT [Int] b) a
-> FreeT ((->) Char) (StateT [Int] b) ()
skipMany FreeT ((->) Char) (StateT [Int] b) a
p = FreeT ((->) Char) (StateT [Int] b) ()
scan where scan :: FreeT ((->) Char) (StateT [Int] b) ()
scan = (FreeT ((->) Char) (StateT [Int] b) a
p forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> FreeT ((->) Char) (StateT [Int] b) ()
scan) forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  skipSome :: forall a.
FreeT ((->) Char) (StateT [Int] b) a
-> FreeT ((->) Char) (StateT [Int] b) ()
skipSome FreeT ((->) Char) (StateT [Int] b) a
p = FreeT ((->) Char) (StateT [Int] b) a
p forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> forall (m :: * -> *) a. Parsing m => m a -> m ()
skipMany FreeT ((->) Char) (StateT [Int] b) a
p
  unexpected :: forall a. String -> FreeT ((->) Char) (StateT [Int] b) a
unexpected = forall a b. a -> b -> a
const forall (f :: * -> *) a. Alternative f => f a
empty
  eof :: FreeT ((->) Char) (StateT [Int] b) ()
eof = forall a. HasCallStack => a
undefined
  notFollowedBy :: forall a.
Show a =>
FreeT ((->) Char) (StateT [Int] b) a
-> FreeT ((->) Char) (StateT [Int] b) ()
notFollowedBy = forall a. HasCallStack => a
undefined

instance
  (Alternative b, Foldable b, MonadPlus b) =>
  CharParsing (FreeT ((->) Char) (StateT [Int] b))
  where
  satisfy :: (Char -> Bool) -> FreeT ((->) Char) (StateT [Int] b) Char
satisfy = forall (b :: * -> *) i. MonadPlus b => (i -> Bool) -> Parser b i i
Torch.Data.Parser.satisfy
  char :: Char -> FreeT ((->) Char) (StateT [Int] b) Char
char = forall (b :: * -> *) i. (MonadPlus b, Eq i) => i -> Parser b i i
isToken
  notChar :: Char -> FreeT ((->) Char) (StateT [Int] b) Char
notChar = forall (b :: * -> *) i. (MonadPlus b, Eq i) => i -> Parser b i i
isNotToken
  anyChar :: FreeT ((->) Char) (StateT [Int] b) Char
anyChar = forall (b :: * -> *) i. Monad b => Parser b i i
Torch.Data.Parser.token
  string :: String -> FreeT ((->) Char) (StateT [Int] b) String
string = forall (t :: * -> *) (b :: * -> *) i.
(Traversable t, MonadPlus b, Eq i) =>
t i -> Parser b i (t i)
isString

instance
  (Alternative b, Foldable b, MonadPlus b) =>
  TokenParsing (FreeT ((->) Char) (StateT [Int] b))

-- | Get continuations from model
getIs ::
  forall model input generatorDevice b decoderInput encoderOutput decoderOutput inputPaddingMask s.
  ( Alternative b,
    MonadThrow b,
    s ~ (Maybe (encoderOutput, inputPaddingMask), Generator generatorDevice),
    decoderInput
      ~ Tensor
          ('Gradient 'WithoutGradient)
          ('Layout 'Dense)
          ('Device 'CPU)
          ('DataType 'Int64)
          ( 'Shape
              '[ 'Dim ('Name "*") ('Size 1), 'Dim ('Name "*") 'UncheckedSize]
          ),
    decoderOutput
      ~ Tensor
          ('Gradient 'WithGradient)
          ('Layout 'Dense)
          ('Device 'CPU)
          T5DataType
          'UncheckedShape,
    HasForward
      model
      (SimplifiedEncoderDecoderTransformerInput input decoderInput)
      generatorDevice
      (SimplifiedEncoderDecoderTransformerOutput decoderOutput encoderOutput decoderInput inputPaddingMask)
      generatorDevice,
    HasForward
      model
      (SimplifiedEncoderDecoderTransformerGenerationInput decoderInput encoderOutput inputPaddingMask)
      generatorDevice
      (SimplifiedEncoderDecoderTransformerOutput decoderOutput encoderOutput decoderInput inputPaddingMask)
      generatorDevice
  ) =>
  Int ->
  model ->
  input ->
  StateT s (StateT [Int] b) Int
getIs :: forall model input (generatorDevice :: Device (DeviceType Nat))
       (b :: * -> *) decoderInput encoderOutput decoderOutput
       inputPaddingMask s.
(Alternative b, MonadThrow b,
 s
 ~ (Maybe (encoderOutput, inputPaddingMask),
    Generator generatorDevice),
 decoderInput
 ~ Tensor
     ('Gradient 'WithoutGradient)
     ('Layout 'Dense)
     ('Device 'CPU)
     ('DataType 'Int64)
     ('Shape
        '[ 'Dim ('Name "*") ('Size 1), 'Dim ('Name "*") 'UncheckedSize]),
 decoderOutput
 ~ Tensor
     ('Gradient 'WithGradient)
     ('Layout 'Dense)
     ('Device 'CPU)
     T5DataType
     'UncheckedShape,
 HasForward
   model
   (SimplifiedEncoderDecoderTransformerInput input decoderInput)
   generatorDevice
   (SimplifiedEncoderDecoderTransformerOutput
      decoderOutput encoderOutput decoderInput inputPaddingMask)
   generatorDevice,
 HasForward
   model
   (SimplifiedEncoderDecoderTransformerGenerationInput
      decoderInput encoderOutput inputPaddingMask)
   generatorDevice
   (SimplifiedEncoderDecoderTransformerOutput
      decoderOutput encoderOutput decoderInput inputPaddingMask)
   generatorDevice) =>
Int -> model -> input -> StateT s (StateT [Int] b) Int
getIs Int
n model
model input
input = do
  -- tokens <- reverse <$> lift get
  [Int]
tokens <- do
    [Int]
ts <- forall a. [a] -> [a]
reverse forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall s (m :: * -> *). MonadState s m => m s
get
    let ts' :: [Int]
ts' = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
          -- putStrLn $ "tokens: " <> show ((t5Vocab Map.!) <$> ts)
          forall (f :: * -> *) a. Applicative f => a -> f a
pure [Int]
ts
    forall (f :: * -> *) a. Applicative f => a -> f a
pure [Int]
ts'
  decoderInput
decoderInput :: decoderInput <-
    forall (batchDim :: Dim (Name Symbol) (Size Nat))
       (seqDim :: Dim (Name Symbol) (Size Nat))
       (device :: Device (DeviceType Nat)) (m :: * -> *) output.
(MonadThrow m, SGetDim batchDim, SGetDim seqDim,
 Catch
   ('Shape
      '[ 'Dim ('Name "*") 'UncheckedSize,
         'Dim ('Name "*") 'UncheckedSize]
    <+> 'Shape '[batchDim, seqDim]),
 output
 ~ Tensor
     ('Gradient 'WithoutGradient)
     ('Layout 'Dense)
     device
     ('DataType 'Int64)
     ('Shape '[batchDim, seqDim])) =>
SDim batchDim
-> SDim seqDim -> SDevice device -> [[Int]] -> m output
mkT5Input
      (forall (name1 :: Symbol). KnownSymbol name1 => SName ('Name name1)
SName @"*" forall (name :: Name Symbol) (size :: Size Nat).
SName name -> SSize size -> SDim ('Dim name size)
:&: forall (size1 :: Nat). KnownNat size1 => SSize ('Size size1)
SSize @1)
      (forall (name1 :: Symbol). KnownSymbol name1 => SName ('Name name1)
SName @"*" forall (name :: Name Symbol) (size :: Size Nat).
SName name -> SSize size -> SDim ('Dim name size)
:&: Integer -> SSize 'UncheckedSize
SUncheckedSize (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
tokens))
      (forall (deviceType1 :: DeviceType Nat).
SDeviceType deviceType1 -> SDevice ('Device deviceType1)
SDevice SDeviceType 'CPU
SCPU)
      [[Int]
tokens]
  Tensor
  ('Gradient 'WithGradient)
  ('Layout 'Dense)
  ('Device 'CPU)
  T5DataType
  'UncheckedShape
decoderOutput <- do
    (Maybe (encoderOutput, inputPaddingMask)
mTensors, Generator generatorDevice
g) <- forall s (m :: * -> *). MonadState s m => m s
get
    (SimplifiedEncoderDecoderTransformerOutput Tensor
  ('Gradient 'WithGradient)
  ('Layout 'Dense)
  ('Device 'CPU)
  T5DataType
  'UncheckedShape
decoderOutput encoderOutput
encoderOutput Tensor
  ('Gradient 'WithoutGradient)
  ('Layout 'Dense)
  ('Device 'CPU)
  ('DataType 'Int64)
  ('Shape
     '[ 'Dim ('Name "*") ('Size 1), 'Dim ('Name "*") 'UncheckedSize])
decoderInput inputPaddingMask
inputPaddingMask, Generator generatorDevice
g') <-
      case Maybe (encoderOutput, inputPaddingMask)
mTensors of
        Maybe (encoderOutput, inputPaddingMask)
Nothing -> forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward model
model (forall input decoderInput.
input
-> decoderInput
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
SimplifiedEncoderDecoderTransformerInput input
input decoderInput
decoderInput) Generator generatorDevice
g
        Just (encoderOutput
encoderOutput, inputPaddingMask
inputPaddingMask) ->
          forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward model
model (forall decoderInput encoderOutput inputPaddingMask.
decoderInput
-> encoderOutput
-> inputPaddingMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
     decoderInput encoderOutput inputPaddingMask
SimplifiedEncoderDecoderTransformerGenerationInput decoderInput
decoderInput encoderOutput
encoderOutput inputPaddingMask
inputPaddingMask) Generator generatorDevice
g
    forall s (m :: * -> *). MonadState s m => s -> m ()
put (forall a. a -> Maybe a
Just (encoderOutput
encoderOutput, inputPaddingMask
inputPaddingMask), Generator generatorDevice
g')
    forall (f :: * -> *) a. Applicative f => a -> f a
pure Tensor
  ('Gradient 'WithGradient)
  ('Layout 'Dense)
  ('Device 'CPU)
  T5DataType
  'UncheckedShape
decoderOutput
  Tensor
  ('Gradient 'WithGradient)
  ('Layout 'Dense)
  ('Device 'CPU)
  T5DataType
  (SoftmaxF ('SelectDim ('ByIndex 2)) 'UncheckedShape)
probs <- forall (selectDim :: SelectDim (By Symbol Nat))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, shape' ~ SoftmaxF selectDim shape, Catch shape') =>
SSelectDim selectDim
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
logSoftmax (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim forall a b. (a -> b) -> a -> b
$ forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @2) Tensor
  ('Gradient 'WithGradient)
  ('Layout 'Dense)
  ('Device 'CPU)
  T5DataType
  'UncheckedShape
decoderOutput
  case forall (selectDim :: SelectDim (By Symbol Nat))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SingI selectDim =>
Order
-> Tensor gradient layout device dataType shape
-> Sorted gradient layout device dataType (SortF selectDim shape)
sort @('SelectDim ('ByIndex 2)) Order
Descending Tensor
  ('Gradient 'WithGradient)
  ('Layout 'Dense)
  ('Device 'CPU)
  T5DataType
  (SoftmaxF ('SelectDim ('ByIndex 2)) 'UncheckedShape)
probs of
    Sorted Tensor
  ('Gradient 'WithGradient)
  ('Layout 'Dense)
  ('Device 'CPU)
  T5DataType
  (SortF
     ('SelectDim ('ByIndex 2))
     (SoftmaxF ('SelectDim ('ByIndex 2)) 'UncheckedShape))
_ (UnsafeTensor ForeignPtr Tensor
indices) ->
      let indices' :: [Int]
indices' = forall a. Int -> [a] -> [a]
take Int
n forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [a] -> a
last forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [a] -> a
head forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. TensorLike a => Tensor -> a
Torch.Tensor.asValue @[[[Int]]] forall b c a. (b -> c) -> (a -> b) -> a -> c
. ForeignPtr Tensor -> Tensor
Torch.Tensor.Unsafe forall a b. (a -> b) -> a -> b
$ ForeignPtr Tensor
indices
       in forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) (f :: * -> *) a.
(Foldable t, Alternative f) =>
t (f a) -> f a
asum forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Int]
indices'

runParser ::
  forall model input generatorDevice b a.
  _ =>
  Int ->
  model ->
  input ->
  Generator generatorDevice ->
  Parser (StateT [Int] b) Int a ->
  b (a, [Int])
runParser :: Int
-> model
-> input
-> Generator generatorDevice
-> Parser (StateT [Int] b) Int a
-> b (a, [Int])
runParser Int
n model
model input
input Generator generatorDevice
g =
  forall a b c. (a -> b -> c) -> b -> a -> c
flip forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT []
    forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b c. (a -> b -> c) -> b -> a -> c
flip forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT (forall a. Maybe a
Nothing, Generator generatorDevice
g)
    forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: (* -> *) -> * -> *) (b :: * -> *) i a.
(Parser b i a -> (Parser b i a -> t b a) -> t b a)
-> Parser b i a -> t b a
recurse (forall (t :: (* -> *) -> * -> *) (b :: * -> *) i a.
(i ~ Int, Show i, MonadTrans t, Monad (t (StateT [i] b)),
 Alternative (t (StateT [i] b)), Monad b, Foldable b) =>
t (StateT [i] b) i
-> Parser (StateT [i] b) i a
-> (Parser (StateT [i] b) i a -> t (StateT [i] b) a)
-> t (StateT [i] b) a
next (forall model input (generatorDevice :: Device (DeviceType Nat))
       (b :: * -> *) decoderInput encoderOutput decoderOutput
       inputPaddingMask s.
(Alternative b, MonadThrow b,
 s
 ~ (Maybe (encoderOutput, inputPaddingMask),
    Generator generatorDevice),
 decoderInput
 ~ Tensor
     ('Gradient 'WithoutGradient)
     ('Layout 'Dense)
     ('Device 'CPU)
     ('DataType 'Int64)
     ('Shape
        '[ 'Dim ('Name "*") ('Size 1), 'Dim ('Name "*") 'UncheckedSize]),
 decoderOutput
 ~ Tensor
     ('Gradient 'WithGradient)
     ('Layout 'Dense)
     ('Device 'CPU)
     T5DataType
     'UncheckedShape,
 HasForward
   model
   (SimplifiedEncoderDecoderTransformerInput input decoderInput)
   generatorDevice
   (SimplifiedEncoderDecoderTransformerOutput
      decoderOutput encoderOutput decoderInput inputPaddingMask)
   generatorDevice,
 HasForward
   model
   (SimplifiedEncoderDecoderTransformerGenerationInput
      decoderInput encoderOutput inputPaddingMask)
   generatorDevice
   (SimplifiedEncoderDecoderTransformerOutput
      decoderOutput encoderOutput decoderInput inputPaddingMask)
   generatorDevice) =>
Int -> model -> input -> StateT s (StateT [Int] b) Int
getIs Int
n model
model input
input))

-- testParser = do
--   input <- do
--     let tokens = [[13959, 1566, 12, 2968, 10, 6536, 43, 2008, 24, 293, 53, 3, 9, 1782, 19, 207, 21, 25, 1]]
--     -- let tokens = [[13959, 1566, 12, 2968, 10, 148, 31, 60, 423, 13, 3, 7, 10536, 55, 1]]
--     -- let tokens = [[13959, 1566, 12, 2968, 10, 3, 31, 7, 15, 3437, 3, 17, 4416, 4350, 6, 3476, 599, 1935, 61, 45, 4219, 38, 3, 17, 536, 1715, 14939, 38, 3, 17, 357, 30, 3, 17, 5411, 2427, 12925, 834, 23, 26, 3274, 3, 17, 4416, 2427, 12925, 834, 23, 26, 563, 57, 3, 17, 5411, 2427, 12925, 834, 23, 26, 31, 1]]
--     -- let tokens = [[13959, 1566, 12, 2968, 10, 96, 3, 23143, 14196, 332, 4416, 4350, 6, 2847, 17161, 599, 1935, 61, 21680, 4219, 6157, 332, 536, 3, 15355, 3162, 14939, 6157, 332, 357, 9191, 332, 5411, 2427, 12925, 834, 23, 26, 3274, 332, 4416, 2427, 12925, 834, 23, 26, 350, 4630, 6880, 272, 476, 3, 17, 5411, 2427, 12925, 834, 23, 26, 96, 1]]
--     print $ length <$> tokens
--     print $ ((t5Vocab Map.!) <$>) <$> tokens
--     mkT5Input
--       @('Dim ('Name "*") ('Size 1))
--       @('Dim ('Name "*") ('Size 61))
--       tokens
--   model <-
--     initialize
--       @(T5Small ('Device 'CPU))
--       "/Users/torsten.scholak/Projects/thirdParty/hasktorch/hasktorch/src/Torch/GraduallyTyped/NN/Transformer/t5-small.pt"
--   g <- mkGenerator @('Device CPU) 0
--   let outputs = runParser 5 model input g (transParser t5Vocab t5Test)
--   pure . fst $ observe outputs

-- | @t5Text@ parses a 'Char' sequence delimited by @</s>@ as a 'String'.
t5Text :: MonadPlus b => Parser b Char String
t5Text :: forall (b :: * -> *). MonadPlus b => Parser b Char String
t5Text = forall (m :: * -> *) a end. Alternative m => m a -> m end -> m [a]
manyTill forall (b :: * -> *) i. Monad b => Parser b i i
Torch.Data.Parser.token (forall (t :: * -> *) (b :: * -> *) i.
(Traversable t, MonadPlus b, Eq i) =>
t i -> Parser b i (t i)
isString String
"</s>")

-- >>> head $ parseString @[] t5Test "Studien haben belegt, dass es gut ist, einen Hund zu haben</s>"
-- []
t5Test :: MonadPlus b => Parser b Char String
t5Test :: forall (b :: * -> *). MonadPlus b => Parser b Char String
t5Test =
  forall {b :: * -> *}.
MonadPlus b =>
Int -> FreeT ((->) Char) b String
notEnd Int
25
    forall (f :: * -> *) a.
(Applicative f, Semigroup a) =>
f a -> f a -> f a
`combine` forall (t :: * -> *) (b :: * -> *) i.
(Traversable t, MonadPlus b, Eq i) =>
t i -> Parser b i (t i)
isString String
"belegt"
    forall (f :: * -> *) a.
(Applicative f, Semigroup a) =>
f a -> f a -> f a
`combine` forall {b :: * -> *}.
MonadPlus b =>
Int -> FreeT ((->) Char) b String
notEnd Int
5
    forall (f :: * -> *) a.
(Applicative f, Semigroup a) =>
f a -> f a -> f a
`combine` forall (t :: * -> *) (b :: * -> *) i.
(Traversable t, MonadPlus b, Eq i) =>
t i -> Parser b i (t i)
isString String
"dass es"
    forall (f :: * -> *) a.
(Applicative f, Semigroup a) =>
f a -> f a -> f a
`combine` forall {b :: * -> *}.
MonadPlus b =>
Int -> FreeT ((->) Char) b String
notEnd Int
25
    forall (f :: * -> *) a.
(Applicative f, Semigroup a) =>
f a -> f a -> f a
`combine` forall (t :: * -> *) (b :: * -> *) i.
(Traversable t, MonadPlus b, Eq i) =>
t i -> Parser b i (t i)
isString String
"haben"
    forall (f :: * -> *) a.
(Applicative f, Semigroup a) =>
f a -> f a -> f a
`combine` forall (t :: * -> *) (b :: * -> *) i.
(Traversable t, MonadPlus b, Eq i) =>
t i -> Parser b i (t i)
isString String
"</s>"
  where
    notEnd :: Int -> FreeT ((->) Char) b String
notEnd Int
n = forall (m :: * -> *) s a.
(Alternative m, Monad m) =>
(s -> a -> Maybe s) -> s -> m a -> m [a]
scan String -> Char -> Maybe String
f String
"" forall (b :: * -> *) i. Monad b => Parser b i i
Torch.Data.Parser.token
      where
        f :: String -> Char -> Maybe String
f String
s Char
a = case String
s forall a. [a] -> [a] -> [a]
++ [Char
a] of
          String
s'
            | String
"</s>" forall a. Eq a => [a] -> [a] -> Bool
`isInfixOf` String
s' -> forall a. Maybe a
Nothing
            | forall (t :: * -> *) a. Foldable t => t a -> Int
length String
s' forall a. Ord a => a -> a -> Bool
> Int
n -> forall a. Maybe a
Nothing
            | Bool
otherwise -> forall a. a -> Maybe a
Just String
s'

-- | @t5Sql@ parses a 'Char' sequence starting with @\"@ and ending with @\" </s>@
-- as 'SpiderSQL'.
t5Sql ::
  (TokenParsing (FreeT ((->) Char) b), MonadPlus b) =>
  Parser b Char SpiderSQL
t5Sql :: forall (b :: * -> *).
(TokenParsing (FreeT ((->) Char) b), MonadPlus b) =>
Parser b Char SpiderSQL
t5Sql =
  let q :: FreeT ((->) Char) b Char
q = forall (m :: * -> *). CharParsing m => m ()
spaces forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> forall (m :: * -> *). CharParsing m => Char -> m Char
char Char
'\"' forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* forall (m :: * -> *). CharParsing m => m ()
spaces
   in forall (m :: * -> *) bra ket a.
Applicative m =>
m bra -> m ket -> m a -> m a
between FreeT ((->) Char) b Char
q FreeT ((->) Char) b Char
q forall (m :: * -> *). (TokenParsing m, Monad m) => m SpiderSQL
spiderSQL forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* forall (t :: * -> *) (b :: * -> *) i.
(Traversable t, MonadPlus b, Eq i) =>
t i -> Parser b i (t i)
isString String
"</s>"