{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-}
{-# OPTIONS_GHC -fno-warn-partial-type-signatures #-}
module Torch.GraduallyTyped.NN.Transformer.T5.Generation where
import Control.Applicative (Alternative (..))
import Control.Monad (MonadPlus (..), guard)
import Control.Monad.Catch (MonadThrow)
import Control.Monad.State (MonadState (..), MonadTrans (..), StateT (..), evalStateT, gets, lift, modify)
import Control.Monad.Trans.Free (FreeF (..), FreeT (..), runFreeT)
import Data.Foldable (asum)
import Data.List (isInfixOf, nub, sortOn, uncons)
import qualified Data.Map as Map (Map, lookup)
import System.IO.Unsafe (unsafePerformIO)
import Text.Parser.Char (CharParsing (..), spaces)
import Text.Parser.Combinators (Parsing (..), between, manyTill)
import Text.Parser.Token (TokenParsing (..))
import Torch.Data.Parser (Parser, combine, isNotToken, isString, isToken, recurse, satisfy, scan, token)
import Torch.GraduallyTyped.DType (DType (..), DataType (..))
import Torch.GraduallyTyped.Device (Device (..), DeviceType (..), SDevice (..), SDeviceType (..))
import Torch.GraduallyTyped.Layout (Layout (..), LayoutType (..))
import Torch.GraduallyTyped.NN.Class (HasForward (..), HasStateDict (fromStateDict), stateDictFromFile)
import Torch.GraduallyTyped.NN.Functional.NonLinearActivation (logSoftmax)
import Torch.GraduallyTyped.NN.Transformer.GEncoderDecoder (SimplifiedEncoderDecoderTransformerGenerationInput (..), SimplifiedEncoderDecoderTransformerInput (..), SimplifiedEncoderDecoderTransformerOutput (..))
import Torch.GraduallyTyped.NN.Transformer.T5.Common (T5DataType, mkT5Input, t5EOSTokenId)
import Torch.GraduallyTyped.NN.Transformer.T5.Small (t5SmallSpec)
import Torch.GraduallyTyped.NN.Transformer.Type (STransformerHead (SWithLMHead))
import Torch.GraduallyTyped.NN.Type (SHasDropout (SWithDropout))
import Torch.GraduallyTyped.Random (Generator, sMkGenerator)
import Torch.GraduallyTyped.RequiresGradient (Gradient (..), RequiresGradient (..), SGradient (..), SRequiresGradient (..))
import Torch.GraduallyTyped.Shape.Class (BroadcastShapesF)
import Torch.GraduallyTyped.Shape.Type (By (..), Dim (..), Name (..), SBy (..), SName (..), SSelectDim (..), SShape (..), SSize (..), SelectDim (..), Shape (..), Size (..), pattern (:&:))
import Torch.GraduallyTyped.Tensor.IndexingSlicingJoining (sExpand)
import Torch.GraduallyTyped.Tensor.MathOperations.Comparison (Order (..), Sorted (..), sort)
import Torch.GraduallyTyped.Tensor.Type (SGetShape, Tensor (..))
import Torch.Language.SpiderSQL (SpiderSQL, spiderSQL)
import qualified Torch.Tensor
import Prelude hiding (Word, words)
data IsFinished = Finished | Unfinished
data Beams a b where
Beams ::
forall a b.
{ forall a b. Beams a b -> [Hypothesis 'Finished a b]
finished :: [Hypothesis 'Finished a b],
forall a b. Beams a b -> [Hypothesis 'Unfinished a b]
unfinished :: [Hypothesis 'Unfinished a b]
} ->
Beams a b
deriving (Int -> Beams a b -> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall a b. (Show a, Show b) => Int -> Beams a b -> ShowS
forall a b. (Show a, Show b) => [Beams a b] -> ShowS
forall a b. (Show a, Show b) => Beams a b -> String
showList :: [Beams a b] -> ShowS
$cshowList :: forall a b. (Show a, Show b) => [Beams a b] -> ShowS
show :: Beams a b -> String
$cshow :: forall a b. (Show a, Show b) => Beams a b -> String
showsPrec :: Int -> Beams a b -> ShowS
$cshowsPrec :: forall a b. (Show a, Show b) => Int -> Beams a b -> ShowS
Show)
data Hypothesis (isFinished :: IsFinished) a b where
InitialHypothesis ::
forall a b.
Hypothesis 'Unfinished a b
UnfinishedHypothesis ::
forall a b.
{ forall a b. Hypothesis 'Unfinished a b -> a
currentToken :: a,
forall a b. Hypothesis 'Unfinished a b -> Float
currentScore :: Float,
forall a b.
Hypothesis 'Unfinished a b -> Hypothesis 'Unfinished a b
previousHypothesis :: Hypothesis 'Unfinished a b
} ->
Hypothesis 'Unfinished a b
FinishedHypothesis ::
forall a b.
{ forall a b. Hypothesis 'Finished a b -> a
finalToken :: a,
forall a b. Hypothesis 'Finished a b -> Float
finalScore :: Float,
forall a b. Hypothesis 'Finished a b -> Hypothesis 'Unfinished a b
penultimateHypothesis :: Hypothesis 'Unfinished a b,
forall a b. Hypothesis 'Finished a b -> b
finalValue :: b
} ->
Hypothesis 'Finished a b
deriving instance (Eq a, Eq b) => Eq (Hypothesis 'Unfinished a b)
deriving instance (Eq a, Eq b) => Eq (Hypothesis 'Finished a b)
deriving instance (Ord a, Ord b) => Ord (Hypothesis 'Unfinished a b)
deriving instance (Ord a, Ord b) => Ord (Hypothesis 'Finished a b)
deriving instance (Show a, Show b) => Show (Hypothesis 'Unfinished a b)
deriving instance (Show a, Show b) => Show (Hypothesis 'Finished a b)
getTokens :: forall a b. Hypothesis 'Unfinished a b -> [a]
getTokens :: forall a b. Hypothesis 'Unfinished a b -> [a]
getTokens Hypothesis 'Unfinished a b
InitialHypothesis = []
getTokens (UnfinishedHypothesis a
token Float
_ Hypothesis 'Unfinished a b
previousHypothesis) = a
token forall a. a -> [a] -> [a]
: forall a b. Hypothesis 'Unfinished a b -> [a]
getTokens Hypothesis 'Unfinished a b
previousHypothesis
getScore :: forall a b. Hypothesis 'Unfinished a b -> Float
getScore :: forall a b. Hypothesis 'Unfinished a b -> Float
getScore Hypothesis 'Unfinished a b
InitialHypothesis = Float
0
getScore (UnfinishedHypothesis a
_ Float
previousScore Hypothesis 'Unfinished a b
_) = Float
previousScore
data SomeHypothesis a b = forall isFinished. SomeHypothesis {()
unSomeHypothesis :: Hypothesis isFinished a b}
beamSearch ::
forall a b m.
Monad m =>
Int ->
Int ->
([Hypothesis 'Unfinished a b] -> m [SomeHypothesis a b]) ->
m [Beams a b]
beamSearch :: forall a b (m :: * -> *).
Monad m =>
Int
-> Int
-> ([Hypothesis 'Unfinished a b] -> m [SomeHypothesis a b])
-> m [Beams a b]
beamSearch Int
maxSteps Int
beamSize [Hypothesis 'Unfinished a b] -> m [SomeHypothesis a b]
cont = Int -> Beams a b -> m [Beams a b]
go Int
maxSteps (forall a b.
[Hypothesis 'Finished a b]
-> [Hypothesis 'Unfinished a b] -> Beams a b
Beams [] (forall a. Int -> a -> [a]
replicate Int
beamSize forall a b. Hypothesis 'Unfinished a b
InitialHypothesis))
where
go :: Int -> Beams a b -> m [Beams a b]
go !Int
_ (Beams [Hypothesis 'Finished a b]
_ []) = forall (f :: * -> *) a. Applicative f => a -> f a
pure []
go Int
n Beams a b
beams
| Int
n forall a. Ord a => a -> a -> Bool
<= Int
0 = forall (f :: * -> *) a. Applicative f => a -> f a
pure []
| Bool
otherwise = do
Beams a b
beams' <- forall a b (m :: * -> *).
Functor m =>
([Hypothesis 'Unfinished a b] -> m [SomeHypothesis a b])
-> Beams a b -> m (Beams a b)
beamSearchStep [Hypothesis 'Unfinished a b] -> m [SomeHypothesis a b]
cont Beams a b
beams
(Beams a b
beams' forall a. a -> [a] -> [a]
:) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> Beams a b -> m [Beams a b]
go (Int
n forall a. Num a => a -> a -> a
- Int
1) Beams a b
beams'
beamSearchStep ::
forall a b m.
Functor m =>
([Hypothesis 'Unfinished a b] -> m [SomeHypothesis a b]) ->
Beams a b ->
m (Beams a b)
beamSearchStep :: forall a b (m :: * -> *).
Functor m =>
([Hypothesis 'Unfinished a b] -> m [SomeHypothesis a b])
-> Beams a b -> m (Beams a b)
beamSearchStep [Hypothesis 'Unfinished a b] -> m [SomeHypothesis a b]
cont Beams a b
beam =
let someHypotheses :: m [SomeHypothesis a b]
someHypotheses =
forall a. Int -> [a] -> [a]
take (forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall a b. Beams a b -> [Hypothesis 'Unfinished a b]
unfinished Beams a b
beam))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [a] -> [a]
reverse
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn @Float
( \case
SomeHypothesis Hypothesis isFinished a b
InitialHypothesis -> Float
0
SomeHypothesis u :: Hypothesis isFinished a b
u@UnfinishedHypothesis {} -> forall a b. Hypothesis 'Unfinished a b -> Float
currentScore Hypothesis isFinished a b
u
SomeHypothesis f :: Hypothesis isFinished a b
f@FinishedHypothesis {} -> forall a b. Hypothesis 'Finished a b -> Float
finalScore Hypothesis isFinished a b
f
)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([Hypothesis 'Unfinished a b] -> m [SomeHypothesis a b]
cont forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. Beams a b -> [Hypothesis 'Unfinished a b]
unfinished forall a b. (a -> b) -> a -> b
$ Beams a b
beam)
in forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
( \(Beams [Hypothesis 'Finished a b]
fs [Hypothesis 'Unfinished a b]
us) SomeHypothesis a b
someHypothesis ->
case SomeHypothesis a b
someHypothesis of
SomeHypothesis u :: Hypothesis isFinished a b
u@Hypothesis isFinished a b
InitialHypothesis -> forall a b.
[Hypothesis 'Finished a b]
-> [Hypothesis 'Unfinished a b] -> Beams a b
Beams [Hypothesis 'Finished a b]
fs (Hypothesis isFinished a b
u forall a. a -> [a] -> [a]
: [Hypothesis 'Unfinished a b]
us)
SomeHypothesis u :: Hypothesis isFinished a b
u@UnfinishedHypothesis {} -> forall a b.
[Hypothesis 'Finished a b]
-> [Hypothesis 'Unfinished a b] -> Beams a b
Beams [Hypothesis 'Finished a b]
fs (Hypothesis isFinished a b
u forall a. a -> [a] -> [a]
: [Hypothesis 'Unfinished a b]
us)
SomeHypothesis f :: Hypothesis isFinished a b
f@FinishedHypothesis {} -> forall a b.
[Hypothesis 'Finished a b]
-> [Hypothesis 'Unfinished a b] -> Beams a b
Beams (Hypothesis isFinished a b
f forall a. a -> [a] -> [a]
: [Hypothesis 'Finished a b]
fs) [Hypothesis 'Unfinished a b]
us
)
(forall a b.
[Hypothesis 'Finished a b]
-> [Hypothesis 'Unfinished a b] -> Beams a b
Beams (forall a b. Beams a b -> [Hypothesis 'Finished a b]
finished Beams a b
beam) [])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m [SomeHypothesis a b]
someHypotheses
runBeamSearch ::
forall model input decoderInput encoderOutput encoderOutputShape encoderOutput' inputPaddingMask decoderOutput generatorDevice.
( HasForward
model
(SimplifiedEncoderDecoderTransformerInput input decoderInput)
generatorDevice
(SimplifiedEncoderDecoderTransformerOutput decoderOutput encoderOutput decoderInput inputPaddingMask)
generatorDevice,
encoderOutput
~ Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
encoderOutputShape,
'UncheckedShape ~ BroadcastShapesF encoderOutputShape 'UncheckedShape,
SGetShape encoderOutputShape,
HasForward
model
(SimplifiedEncoderDecoderTransformerGenerationInput decoderInput encoderOutput' inputPaddingMask)
generatorDevice
(SimplifiedEncoderDecoderTransformerOutput decoderOutput encoderOutput' decoderInput inputPaddingMask)
generatorDevice,
encoderOutput'
~ Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
'UncheckedShape,
decoderInput
~ Tensor
('Gradient 'WithoutGradient)
('Layout 'Dense)
('Device 'CPU)
('DataType 'Int64)
('Shape '[ 'Dim ('Name "*") 'UncheckedSize, 'Dim ('Name "*") 'UncheckedSize]),
decoderOutput
~ Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
'UncheckedShape,
generatorDevice ~ 'Device 'CPU
) =>
Int ->
Int ->
model ->
input ->
Generator generatorDevice ->
IO [Beams Int [Int]]
runBeamSearch :: forall model input decoderInput encoderOutput
(encoderOutputShape :: Shape [Dim (Name Symbol) (Size Nat)])
encoderOutput' inputPaddingMask decoderOutput
(generatorDevice :: Device (DeviceType Nat)).
(HasForward
model
(SimplifiedEncoderDecoderTransformerInput input decoderInput)
generatorDevice
(SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask)
generatorDevice,
encoderOutput
~ Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
encoderOutputShape,
'UncheckedShape
~ BroadcastShapesF encoderOutputShape 'UncheckedShape,
SGetShape encoderOutputShape,
HasForward
model
(SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput' inputPaddingMask)
generatorDevice
(SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput' decoderInput inputPaddingMask)
generatorDevice,
encoderOutput'
~ Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
'UncheckedShape,
decoderInput
~ Tensor
('Gradient 'WithoutGradient)
('Layout 'Dense)
('Device 'CPU)
('DataType 'Int64)
('Shape
'[ 'Dim ('Name "*") 'UncheckedSize,
'Dim ('Name "*") 'UncheckedSize]),
decoderOutput
~ Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
'UncheckedShape,
generatorDevice ~ 'Device 'CPU) =>
Int
-> Int
-> model
-> input
-> Generator generatorDevice
-> IO [Beams Int [Int]]
runBeamSearch Int
maxSteps Int
beamSize model
model input
input Generator generatorDevice
g =
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT (forall a b (m :: * -> *).
Monad m =>
Int
-> Int
-> ([Hypothesis 'Unfinished a b] -> m [SomeHypothesis a b])
-> m [Beams a b]
beamSearch Int
maxSteps Int
beamSize [Hypothesis 'Unfinished Int [Int]]
-> StateT
(Maybe (encoderOutput, inputPaddingMask),
Generator generatorDevice)
IO
[SomeHypothesis Int [Int]]
cont) (forall a. Maybe a
Nothing, Generator generatorDevice
g)
where
cont :: [Hypothesis 'Unfinished Int [Int]] -> StateT (Maybe (encoderOutput, inputPaddingMask), Generator generatorDevice) IO [SomeHypothesis Int [Int]]
cont :: [Hypothesis 'Unfinished Int [Int]]
-> StateT
(Maybe (encoderOutput, inputPaddingMask),
Generator generatorDevice)
IO
[SomeHypothesis Int [Int]]
cont [Hypothesis 'Unfinished Int [Int]]
previousHypotheses = do
let previousHypotheses' :: [Hypothesis 'Unfinished Int [Int]]
previousHypotheses' = forall a. Eq a => [a] -> [a]
nub [Hypothesis 'Unfinished Int [Int]]
previousHypotheses
decoderInput
decoderInput :: decoderInput <- do
let tokens :: [[Int]]
tokens = forall a. [a] -> [a]
reverse forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. Hypothesis 'Unfinished a b -> [a]
getTokens forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Hypothesis 'Unfinished Int [Int]]
previousHypotheses'
batchSize :: Integer
batchSize = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Int
length [Hypothesis 'Unfinished Int [Int]]
previousHypotheses'
seqSize :: Integer
seqSize =
let go :: Hypothesis 'Unfinished Int [Int] -> Int
go :: Hypothesis 'Unfinished Int [Int] -> Int
go Hypothesis 'Unfinished Int [Int]
InitialHypothesis = Int
0
go (UnfinishedHypothesis Int
_ Float
_ Hypothesis 'Unfinished Int [Int]
previousHypothesis') = Int
1 forall a. Num a => a -> a -> a
+ Hypothesis 'Unfinished Int [Int] -> Int
go Hypothesis 'Unfinished Int [Int]
previousHypothesis'
in forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum forall a b. (a -> b) -> a -> b
$ Hypothesis 'Unfinished Int [Int] -> Int
go forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Hypothesis 'Unfinished Int [Int]]
previousHypotheses'
forall (batchDim :: Dim (Name Symbol) (Size Nat))
(seqDim :: Dim (Name Symbol) (Size Nat))
(device :: Device (DeviceType Nat)) (m :: * -> *) output.
(MonadThrow m, SGetDim batchDim, SGetDim seqDim,
Catch
('Shape
'[ 'Dim ('Name "*") 'UncheckedSize,
'Dim ('Name "*") 'UncheckedSize]
<+> 'Shape '[batchDim, seqDim]),
output
~ Tensor
('Gradient 'WithoutGradient)
('Layout 'Dense)
device
('DataType 'Int64)
('Shape '[batchDim, seqDim])) =>
SDim batchDim
-> SDim seqDim -> SDevice device -> [[Int]] -> m output
mkT5Input (forall (name1 :: Symbol). KnownSymbol name1 => SName ('Name name1)
SName @"*" forall (name :: Name Symbol) (size :: Size Nat).
SName name -> SSize size -> SDim ('Dim name size)
:&: Integer -> SSize 'UncheckedSize
SUncheckedSize Integer
batchSize) (forall (name1 :: Symbol). KnownSymbol name1 => SName ('Name name1)
SName @"*" forall (name :: Name Symbol) (size :: Size Nat).
SName name -> SSize size -> SDim ('Dim name size)
:&: Integer -> SSize 'UncheckedSize
SUncheckedSize Integer
seqSize) (forall (deviceType1 :: DeviceType Nat).
SDeviceType deviceType1 -> SDevice ('Device deviceType1)
SDevice SDeviceType 'CPU
SCPU) [[Int]]
tokens
[[[Float]]]
logProbs <- decoderInput
-> StateT
(Maybe (encoderOutput, inputPaddingMask),
Generator generatorDevice)
IO
[[[Float]]]
getLogProbs decoderInput
decoderInput
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [Hypothesis 'Unfinished Int [Int]]
previousHypotheses' [[[Float]]]
logProbs forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry (\Hypothesis 'Unfinished Int [Int]
previousHypothesis -> forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (Hypothesis 'Unfinished Int [Int]
-> Int -> Float -> SomeHypothesis Int [Int]
mkHypothesis Hypothesis 'Unfinished Int [Int]
previousHypothesis) [Int
0, Int
1 ..] forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [a] -> a
last)
getLogProbs :: decoderInput -> StateT (Maybe (encoderOutput, inputPaddingMask), Generator generatorDevice) IO [[[Float]]]
getLogProbs :: decoderInput
-> StateT
(Maybe (encoderOutput, inputPaddingMask),
Generator generatorDevice)
IO
[[[Float]]]
getLogProbs decoderInput
decoderInput = do
(Maybe
(Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
encoderOutputShape,
inputPaddingMask)
maybeStuff, Generator ('Device 'CPU)
g) <- forall s (m :: * -> *). MonadState s m => m s
get
(SimplifiedEncoderDecoderTransformerOutput Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
'UncheckedShape
decoderOutput Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
encoderOutputShape
encoderOutput decoderInput
_ inputPaddingMask
inputPaddingMask, Generator ('Device 'CPU)
g') <- case Maybe
(Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
encoderOutputShape,
inputPaddingMask)
maybeStuff of
Maybe
(Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
encoderOutputShape,
inputPaddingMask)
Nothing -> forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward model
model (forall input decoderInput.
input
-> decoderInput
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
SimplifiedEncoderDecoderTransformerInput input
input decoderInput
decoderInput) Generator ('Device 'CPU)
g
Just (Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
encoderOutputShape
encoderOutput, inputPaddingMask
inputPaddingMask) -> do
Dim String Integer
decoderInputBatchDim <- forall a. HasCallStack => a
undefined
[Dim String Integer]
encoderOutputDims <- forall a. HasCallStack => a
undefined
let encoderOutput' :: Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
(BroadcastShapesF encoderOutputShape 'UncheckedShape)
encoderOutput' = forall (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
(shape'' :: Shape [Dim (Name Symbol) (Size Nat)])
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(shape'' ~ BroadcastShapesF shape shape', Catch shape'') =>
SShape shape'
-> Tensor gradient layout device dataType shape
-> Tensor gradient layout device dataType shape''
sExpand ([Dim String Integer] -> SShape 'UncheckedShape
SUncheckedShape (Dim String Integer
decoderInputBatchDim forall a. a -> [a] -> [a]
: [Dim String Integer]
encoderOutputDims)) Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
encoderOutputShape
encoderOutput
(SimplifiedEncoderDecoderTransformerOutput Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
'UncheckedShape
decoderOutput Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
'UncheckedShape
_ Tensor
('Gradient 'WithoutGradient)
('Layout 'Dense)
('Device 'CPU)
('DataType 'Int64)
('Shape
'[ 'Dim ('Name "*") 'UncheckedSize,
'Dim ('Name "*") 'UncheckedSize])
_ inputPaddingMask
_, Generator ('Device 'CPU)
g') <- forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward model
model (forall decoderInput encoderOutput inputPaddingMask.
decoderInput
-> encoderOutput
-> inputPaddingMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
SimplifiedEncoderDecoderTransformerGenerationInput decoderInput
decoderInput Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
(BroadcastShapesF encoderOutputShape 'UncheckedShape)
encoderOutput' inputPaddingMask
inputPaddingMask) Generator ('Device 'CPU)
g
forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall decoderOutput encoderOutput decoderInput inputPaddingMask.
decoderOutput
-> encoderOutput
-> decoderInput
-> inputPaddingMask
-> SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask
SimplifiedEncoderDecoderTransformerOutput Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
'UncheckedShape
decoderOutput Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
encoderOutputShape
encoderOutput decoderInput
decoderInput inputPaddingMask
inputPaddingMask, Generator ('Device 'CPU)
g')
forall s (m :: * -> *). MonadState s m => s -> m ()
put (forall a. a -> Maybe a
Just (Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
encoderOutputShape
encoderOutput, inputPaddingMask
inputPaddingMask), Generator ('Device 'CPU)
g')
Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
(SoftmaxF ('SelectDim ('ByIndex 2)) 'UncheckedShape)
probs <- forall (selectDim :: SelectDim (By Symbol Nat))
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, shape' ~ SoftmaxF selectDim shape, Catch shape') =>
SSelectDim selectDim
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
logSoftmax (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim forall a b. (a -> b) -> a -> b
$ forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @2) Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
'UncheckedShape
decoderOutput
case Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
(SoftmaxF ('SelectDim ('ByIndex 2)) 'UncheckedShape)
probs of
UnsafeTensor ForeignPtr Tensor
t -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. TensorLike a => Tensor -> a
Torch.Tensor.asValue forall b c a. (b -> c) -> (a -> b) -> a -> c
. ForeignPtr Tensor -> Tensor
Torch.Tensor.Unsafe forall a b. (a -> b) -> a -> b
$ ForeignPtr Tensor
t
mkHypothesis :: Hypothesis 'Unfinished Int [Int] -> Int -> Float -> SomeHypothesis Int [Int]
mkHypothesis :: Hypothesis 'Unfinished Int [Int]
-> Int -> Float -> SomeHypothesis Int [Int]
mkHypothesis Hypothesis 'Unfinished Int [Int]
previousHypothesis Int
token Float
logProb
| Int
token forall a. Eq a => a -> a -> Bool
== Int
t5EOSTokenId =
let finalValue :: [Int]
finalValue = forall a. [a] -> [a]
reverse forall a b. (a -> b) -> a -> b
$ Int
token forall a. a -> [a] -> [a]
: forall a b. Hypothesis 'Unfinished a b -> [a]
getTokens Hypothesis 'Unfinished Int [Int]
previousHypothesis
finalScore :: Float
finalScore = Float
logProb forall a. Num a => a -> a -> a
+ forall a b. Hypothesis 'Unfinished a b -> Float
getScore Hypothesis 'Unfinished Int [Int]
previousHypothesis
in forall a b (isFinished :: IsFinished).
Hypothesis isFinished a b -> SomeHypothesis a b
SomeHypothesis (forall a b.
a
-> Float
-> Hypothesis 'Unfinished a b
-> b
-> Hypothesis 'Finished a b
FinishedHypothesis Int
token Float
finalScore Hypothesis 'Unfinished Int [Int]
previousHypothesis [Int]
finalValue)
| Bool
otherwise =
let score :: Float
score = Float
logProb forall a. Num a => a -> a -> a
+ forall a b. Hypothesis 'Unfinished a b -> Float
getScore Hypothesis 'Unfinished Int [Int]
previousHypothesis
in forall a b (isFinished :: IsFinished).
Hypothesis isFinished a b -> SomeHypothesis a b
SomeHypothesis (forall a b.
a
-> Float
-> Hypothesis 'Unfinished a b
-> Hypothesis 'Unfinished a b
UnfinishedHypothesis Int
token Float
score Hypothesis 'Unfinished Int [Int]
previousHypothesis)
testBeamSearch :: IO ()
testBeamSearch = do
Tensor
('Gradient 'WithoutGradient)
('Layout 'Dense)
('Device 'CPU)
('DataType 'Int64)
('Shape
'[ 'Dim ('Name "*") ('Size 1), 'Dim ('Name "*") ('Size 19)])
input <- do
let tokens :: [[Int]]
tokens = [[Int
13959, Int
1566, Int
12, Int
2968, Int
10, Int
6536, Int
43, Int
2008, Int
24, Int
293, Int
53, Int
3, Int
9, Int
1782, Int
19, Int
207, Int
21, Int
25, Int
1]]
forall (batchDim :: Dim (Name Symbol) (Size Nat))
(seqDim :: Dim (Name Symbol) (Size Nat))
(device :: Device (DeviceType Nat)) (m :: * -> *) output.
(MonadThrow m, SGetDim batchDim, SGetDim seqDim,
Catch
('Shape
'[ 'Dim ('Name "*") 'UncheckedSize,
'Dim ('Name "*") 'UncheckedSize]
<+> 'Shape '[batchDim, seqDim]),
output
~ Tensor
('Gradient 'WithoutGradient)
('Layout 'Dense)
device
('DataType 'Int64)
('Shape '[batchDim, seqDim])) =>
SDim batchDim
-> SDim seqDim -> SDevice device -> [[Int]] -> m output
mkT5Input
(forall (name1 :: Symbol). KnownSymbol name1 => SName ('Name name1)
SName @"*" forall (name :: Name Symbol) (size :: Size Nat).
SName name -> SSize size -> SDim ('Dim name size)
:&: forall (size1 :: Nat). KnownNat size1 => SSize ('Size size1)
SSize @1)
(forall (name1 :: Symbol). KnownSymbol name1 => SName ('Name name1)
SName @"*" forall (name :: Name Symbol) (size :: Size Nat).
SName name -> SSize size -> SDim ('Dim name size)
:&: forall (size1 :: Nat). KnownNat size1 => SSize ('Size size1)
SSize @19)
(forall (deviceType1 :: DeviceType Nat).
SDeviceType deviceType1 -> SDevice ('Device deviceType1)
SDevice SDeviceType 'CPU
SCPU)
[[Int]]
tokens
StateDict
stateDict <- String -> IO StateDict
stateDictFromFile String
"/tmp/t5-small-state-dict.pt"
let spec :: ModelSpec
(T5Small
'WithLMHead ('Gradient 'WithGradient) ('Device 'CPU) 'WithDropout)
spec = forall (transformerHead :: TransformerHead)
(gradient :: Gradient RequiresGradient)
(device :: Device (DeviceType Nat)) (hasDropout :: HasDropout).
STransformerHead transformerHead
-> SGradient gradient
-> SDevice device
-> SHasDropout hasDropout
-> ModelSpec (T5Small transformerHead gradient device hasDropout)
t5SmallSpec STransformerHead 'WithLMHead
SWithLMHead (forall (requiresGradient :: RequiresGradient).
SRequiresGradient requiresGradient
-> SGradient ('Gradient requiresGradient)
SGradient SRequiresGradient 'WithGradient
SWithGradient) (forall (deviceType1 :: DeviceType Nat).
SDeviceType deviceType1 -> SDevice ('Device deviceType1)
SDevice SDeviceType 'CPU
SCPU) SHasDropout 'WithDropout
SWithDropout
GSimplifiedEncoderDecoderTransformer
(GEncoderDecoderTransformer
('Dim ('Name "*") ('Size 512))
(NamedModel
(GTransformer
()
(NamedModel
(Embedding
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
('Dim ('Name "*") ('Size 32))
('Dim ('Name "*") ('Size 8))
'Nothing))
()
Dropout
(NamedModel
(GTransformerStack
(Vector
6
(GTransformerBlock
(NamedModel
(GSelfAttention
(NamedModel
(LayerNorm
'WithoutBias
('Gradient 'WithGradient)
('Device 'CPU)
T5DataType
('Shape '[ 'Dim ('Name "*") ('Size 512)])))
(NamedModel
(GMultiHeadAttention
('Dim ('Name "*") ('Size 8))
('Dim ('Name "*") ('Size 64))
('Dim ('Name "*") ('Size 512))
(NamedModel
(GLinear
(NamedModel
(Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
('Shape
'[ 'Dim ('Name "*") ('Size 512),
'Dim ('Name "*") ('Size 512)])))
(NamedModel ())))
(NamedModel
(GLinear
(NamedModel
(Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
('Shape
'[ 'Dim ('Name "*") ('Size 512),
'Dim ('Name "*") ('Size 512)])))
(NamedModel ())))
(NamedModel
(GLinear
(NamedModel
(Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
('Shape
'[ 'Dim ('Name "*") ('Size 512),
'Dim ('Name "*") ('Size 512)])))
(NamedModel ())))
(NamedModel
(GLinear
(NamedModel
(Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
('Shape
'[ 'Dim ('Name "*") ('Size 512),
'Dim ('Name "*") ('Size 512)])))
(NamedModel ())))
Dropout))
Dropout
()))
()
(NamedModel
(GTransformerFeedForwardNetwork
(NamedModel
(LayerNorm
'WithoutBias
('Gradient 'WithGradient)
('Device 'CPU)
T5DataType
('Shape '[ 'Dim ('Name "*") ('Size 512)])))
(NamedModel
(GLinear
(NamedModel
(Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
('Shape
'[ 'Dim ('Name "*") ('Size 2048),
'Dim ('Name "*") ('Size 512)])))
(NamedModel ())))
Relu
Dropout
(NamedModel
(GLinear
(NamedModel
(Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
('Shape
'[ 'Dim ('Name "*") ('Size 512),
'Dim ('Name "*") ('Size 2048)])))
(NamedModel ())))
Dropout
()))))))
(NamedModel
(LayerNorm
'WithoutBias
('Gradient 'WithGradient)
('Device 'CPU)
T5DataType
('Shape '[ 'Dim ('Name "*") ('Size 512)])))
Dropout))
(NamedModel
(GTransformer
()
(NamedModel
(Embedding
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
('Dim ('Name "*") ('Size 32))
('Dim ('Name "*") ('Size 8))
'Nothing))
()
Dropout
(NamedModel
(GTransformerStack
(Vector
6
(GTransformerBlock
(NamedModel
(GSelfAttention
(NamedModel
(LayerNorm
'WithoutBias
('Gradient 'WithGradient)
('Device 'CPU)
T5DataType
('Shape '[ 'Dim ('Name "*") ('Size 512)])))
(NamedModel
(GMultiHeadAttention
('Dim ('Name "*") ('Size 8))
('Dim ('Name "*") ('Size 64))
('Dim ('Name "*") ('Size 512))
(NamedModel
(GLinear
(NamedModel
(Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
('Shape
'[ 'Dim ('Name "*") ('Size 512),
'Dim ('Name "*") ('Size 512)])))
(NamedModel ())))
(NamedModel
(GLinear
(NamedModel
(Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
('Shape
'[ 'Dim ('Name "*") ('Size 512),
'Dim ('Name "*") ('Size 512)])))
(NamedModel ())))
(NamedModel
(GLinear
(NamedModel
(Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
('Shape
'[ 'Dim ('Name "*") ('Size 512),
'Dim ('Name "*") ('Size 512)])))
(NamedModel ())))
(NamedModel
(GLinear
(NamedModel
(Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
('Shape
'[ 'Dim ('Name "*") ('Size 512),
'Dim ('Name "*") ('Size 512)])))
(NamedModel ())))
Dropout))
Dropout
()))
(NamedModel
(GCrossAttention
(NamedModel
(LayerNorm
'WithoutBias
('Gradient 'WithGradient)
('Device 'CPU)
T5DataType
('Shape '[ 'Dim ('Name "*") ('Size 512)])))
(NamedModel
(GMultiHeadAttention
('Dim ('Name "*") ('Size 8))
('Dim ('Name "*") ('Size 64))
('Dim ('Name "*") ('Size 512))
(NamedModel
(GLinear
(NamedModel
(Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
('Shape
'[ 'Dim ('Name "*") ('Size 512),
'Dim ('Name "*") ('Size 512)])))
(NamedModel ())))
(NamedModel
(GLinear
(NamedModel
(Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
('Shape
'[ 'Dim ('Name "*") ('Size 512),
'Dim ('Name "*") ('Size 512)])))
(NamedModel ())))
(NamedModel
(GLinear
(NamedModel
(Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
('Shape
'[ 'Dim ('Name "*") ('Size 512),
'Dim ('Name "*") ('Size 512)])))
(NamedModel ())))
(NamedModel
(GLinear
(NamedModel
(Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
('Shape
'[ 'Dim ('Name "*") ('Size 512),
'Dim ('Name "*") ('Size 512)])))
(NamedModel ())))
Dropout))
Dropout
()))
(NamedModel
(GTransformerFeedForwardNetwork
(NamedModel
(LayerNorm
'WithoutBias
('Gradient 'WithGradient)
('Device 'CPU)
T5DataType
('Shape '[ 'Dim ('Name "*") ('Size 512)])))
(NamedModel
(GLinear
(NamedModel
(Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
('Shape
'[ 'Dim ('Name "*") ('Size 2048),
'Dim ('Name "*") ('Size 512)])))
(NamedModel ())))
Relu
Dropout
(NamedModel
(GLinear
(NamedModel
(Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
('Shape
'[ 'Dim ('Name "*") ('Size 512),
'Dim ('Name "*") ('Size 2048)])))
(NamedModel ())))
Dropout
()))))))
(NamedModel
(LayerNorm
'WithoutBias
('Gradient 'WithGradient)
('Device 'CPU)
T5DataType
('Shape '[ 'Dim ('Name "*") ('Size 512)])))
Dropout))
(NamedModel
(Embedding
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
('Dim ('Name "*") ('Size 32128))
('Dim ('Name "*") ('Size 512))
'Nothing))
(NamedModel
(GLMHead
('Dim ('Name "*") ('Size 512))
()
()
()
(NamedModel
(GLinear
(NamedModel
(Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
('Shape
'[ 'Dim ('Name "*") ('Size 32128), 'Dim ('Name "*") ('Size 512)])))
(NamedModel ())))
(GBias ()))))
(MkRelPos ('Dim ('Name "*") ('Size 32)))
(MkRelPos ('Dim ('Name "*") ('Size 32)))
MkTransformerPaddingMask
(MkTransformerAttentionMask T5DataType)
(MkTransformerCrossAttentionMask T5DataType)
(MkTransformerDecoderAttentionMask T5DataType)
model <- forall a b c. (a -> b -> c) -> b -> a -> c
flip forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT StateDict
stateDict forall a b. (a -> b) -> a -> b
$ forall model (m :: * -> *).
(HasStateDict model, MonadIO m, MonadThrow m,
MonadState StateDict m) =>
ModelSpec model -> StateDictKey -> m model
fromStateDict GSimplifiedEncoderDecoderTransformer
(GEncoderDecoderTransformer
('Dim ('Name "*") ('Size 512))
(NamedModel
(GTransformer
()
(NamedModel
(EmbeddingSpec
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
('Dim ('Name "*") ('Size 32))
('Dim ('Name "*") ('Size 8))
'Nothing))
()
Dropout
(NamedModel
(GTransformerStack
(VectorSpec
6
(GTransformerBlock
(NamedModel
(GSelfAttention
(NamedModel
(LayerNorm
'WithoutBias
('Gradient 'WithGradient)
('Device 'CPU)
T5DataType
('Shape '[ 'Dim ('Name "*") ('Size 512)])))
(NamedModel
(GMultiHeadAttention
('Dim ('Name "*") ('Size 8))
('Dim ('Name "*") ('Size 64))
('Dim ('Name "*") ('Size 512))
(NamedModel
(GLinear
(NamedModel
(Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
('Shape
'[ 'Dim ('Name "*") ('Size 512),
'Dim ('Name "*") ('Size 512)])))
(NamedModel ())))
(NamedModel
(GLinear
(NamedModel
(Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
('Shape
'[ 'Dim ('Name "*") ('Size 512),
'Dim ('Name "*") ('Size 512)])))
(NamedModel ())))
(NamedModel
(GLinear
(NamedModel
(Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
('Shape
'[ 'Dim ('Name "*") ('Size 512),
'Dim ('Name "*") ('Size 512)])))
(NamedModel ())))
(NamedModel
(GLinear
(NamedModel
(Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
('Shape
'[ 'Dim ('Name "*") ('Size 512),
'Dim ('Name "*") ('Size 512)])))
(NamedModel ())))
Dropout))
Dropout
()))
()
(NamedModel
(GTransformerFeedForwardNetwork
(NamedModel
(LayerNorm
'WithoutBias
('Gradient 'WithGradient)
('Device 'CPU)
T5DataType
('Shape '[ 'Dim ('Name "*") ('Size 512)])))
(NamedModel
(GLinear
(NamedModel
(Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
('Shape
'[ 'Dim ('Name "*") ('Size 2048),
'Dim ('Name "*") ('Size 512)])))
(NamedModel ())))
Relu
Dropout
(NamedModel
(GLinear
(NamedModel
(Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
('Shape
'[ 'Dim ('Name "*") ('Size 512),
'Dim ('Name "*") ('Size 2048)])))
(NamedModel ())))
Dropout
()))))))
(NamedModel
(LayerNormSpec
'WithoutBias
('Gradient 'WithGradient)
('Device 'CPU)
T5DataType
('Shape '[ 'Dim ('Name "*") ('Size 512)])))
Dropout))
(NamedModel
(GTransformer
()
(NamedModel
(EmbeddingSpec
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
('Dim ('Name "*") ('Size 32))
('Dim ('Name "*") ('Size 8))
'Nothing))
()
Dropout
(NamedModel
(GTransformerStack
(VectorSpec
6
(GTransformerBlock
(NamedModel
(GSelfAttention
(NamedModel
(LayerNorm
'WithoutBias
('Gradient 'WithGradient)
('Device 'CPU)
T5DataType
('Shape '[ 'Dim ('Name "*") ('Size 512)])))
(NamedModel
(GMultiHeadAttention
('Dim ('Name "*") ('Size 8))
('Dim ('Name "*") ('Size 64))
('Dim ('Name "*") ('Size 512))
(NamedModel
(GLinear
(NamedModel
(Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
('Shape
'[ 'Dim ('Name "*") ('Size 512),
'Dim ('Name "*") ('Size 512)])))
(NamedModel ())))
(NamedModel
(GLinear
(NamedModel
(Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
('Shape
'[ 'Dim ('Name "*") ('Size 512),
'Dim ('Name "*") ('Size 512)])))
(NamedModel ())))
(NamedModel
(GLinear
(NamedModel
(Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
('Shape
'[ 'Dim ('Name "*") ('Size 512),
'Dim ('Name "*") ('Size 512)])))
(NamedModel ())))
(NamedModel
(GLinear
(NamedModel
(Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
('Shape
'[ 'Dim ('Name "*") ('Size 512),
'Dim ('Name "*") ('Size 512)])))
(NamedModel ())))
Dropout))
Dropout
()))
(NamedModel
(GCrossAttention
(NamedModel
(LayerNorm
'WithoutBias
('Gradient 'WithGradient)
('Device 'CPU)
T5DataType
('Shape '[ 'Dim ('Name "*") ('Size 512)])))
(NamedModel
(GMultiHeadAttention
('Dim ('Name "*") ('Size 8))
('Dim ('Name "*") ('Size 64))
('Dim ('Name "*") ('Size 512))
(NamedModel
(GLinear
(NamedModel
(Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
('Shape
'[ 'Dim ('Name "*") ('Size 512),
'Dim ('Name "*") ('Size 512)])))
(NamedModel ())))
(NamedModel
(GLinear
(NamedModel
(Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
('Shape
'[ 'Dim ('Name "*") ('Size 512),
'Dim ('Name "*") ('Size 512)])))
(NamedModel ())))
(NamedModel
(GLinear
(NamedModel
(Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
('Shape
'[ 'Dim ('Name "*") ('Size 512),
'Dim ('Name "*") ('Size 512)])))
(NamedModel ())))
(NamedModel
(GLinear
(NamedModel
(Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
('Shape
'[ 'Dim ('Name "*") ('Size 512),
'Dim ('Name "*") ('Size 512)])))
(NamedModel ())))
Dropout))
Dropout
()))
(NamedModel
(GTransformerFeedForwardNetwork
(NamedModel
(LayerNorm
'WithoutBias
('Gradient 'WithGradient)
('Device 'CPU)
T5DataType
('Shape '[ 'Dim ('Name "*") ('Size 512)])))
(NamedModel
(GLinear
(NamedModel
(Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
('Shape
'[ 'Dim ('Name "*") ('Size 2048),
'Dim ('Name "*") ('Size 512)])))
(NamedModel ())))
Relu
Dropout
(NamedModel
(GLinear
(NamedModel
(Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
('Shape
'[ 'Dim ('Name "*") ('Size 512),
'Dim ('Name "*") ('Size 2048)])))
(NamedModel ())))
Dropout
()))))))
(NamedModel
(LayerNormSpec
'WithoutBias
('Gradient 'WithGradient)
('Device 'CPU)
T5DataType
('Shape '[ 'Dim ('Name "*") ('Size 512)])))
Dropout))
(NamedModel
(EmbeddingSpec
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
('Dim ('Name "*") ('Size 32128))
('Dim ('Name "*") ('Size 512))
'Nothing))
(NamedModel
(GLMHead
('Dim ('Name "*") ('Size 512))
()
()
()
(NamedModel
(GLinear
(NamedModel
(TensorSpec
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
('Shape
'[ 'Dim ('Name "*") ('Size 32128), 'Dim ('Name "*") ('Size 512)])))
(NamedModel ())))
(GBias ()))))
(MkRelPos ('Dim ('Name "*") ('Size 32)))
(MkRelPos ('Dim ('Name "*") ('Size 32)))
MkTransformerPaddingMask
(MkTransformerAttentionMask T5DataType)
(MkTransformerCrossAttentionMask T5DataType)
(MkTransformerDecoderAttentionMask T5DataType)
spec forall a. Monoid a => a
mempty
Generator ('Device 'CPU)
g <- forall (m :: * -> *) (device :: Device (DeviceType Nat)).
MonadThrow m =>
SDevice device -> Word64 -> m (Generator device)
sMkGenerator (forall (deviceType1 :: DeviceType Nat).
SDeviceType deviceType1 -> SDevice ('Device deviceType1)
SDevice SDeviceType 'CPU
SCPU) Word64
0
Beams [Hypothesis 'Finished Int [Int]]
finished [Hypothesis 'Unfinished Int [Int]]
_ <- forall a. [a] -> a
last forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall model input decoderInput encoderOutput
(encoderOutputShape :: Shape [Dim (Name Symbol) (Size Nat)])
encoderOutput' inputPaddingMask decoderOutput
(generatorDevice :: Device (DeviceType Nat)).
(HasForward
model
(SimplifiedEncoderDecoderTransformerInput input decoderInput)
generatorDevice
(SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask)
generatorDevice,
encoderOutput
~ Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
encoderOutputShape,
'UncheckedShape
~ BroadcastShapesF encoderOutputShape 'UncheckedShape,
SGetShape encoderOutputShape,
HasForward
model
(SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput' inputPaddingMask)
generatorDevice
(SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput' decoderInput inputPaddingMask)
generatorDevice,
encoderOutput'
~ Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
'UncheckedShape,
decoderInput
~ Tensor
('Gradient 'WithoutGradient)
('Layout 'Dense)
('Device 'CPU)
('DataType 'Int64)
('Shape
'[ 'Dim ('Name "*") 'UncheckedSize,
'Dim ('Name "*") 'UncheckedSize]),
decoderOutput
~ Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
'UncheckedShape,
generatorDevice ~ 'Device 'CPU) =>
Int
-> Int
-> model
-> input
-> Generator generatorDevice
-> IO [Beams Int [Int]]
runBeamSearch Int
50 Int
1 GSimplifiedEncoderDecoderTransformer
(GEncoderDecoderTransformer
('Dim ('Name "*") ('Size 512))
(NamedModel
(GTransformer
()
(NamedModel
(Embedding
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
('Dim ('Name "*") ('Size 32))
('Dim ('Name "*") ('Size 8))
'Nothing))
()
Dropout
(NamedModel
(GTransformerStack
(Vector
6
(GTransformerBlock
(NamedModel
(GSelfAttention
(NamedModel
(LayerNorm
'WithoutBias
('Gradient 'WithGradient)
('Device 'CPU)
T5DataType
('Shape '[ 'Dim ('Name "*") ('Size 512)])))
(NamedModel
(GMultiHeadAttention
('Dim ('Name "*") ('Size 8))
('Dim ('Name "*") ('Size 64))
('Dim ('Name "*") ('Size 512))
(NamedModel
(GLinear
(NamedModel
(Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
('Shape
'[ 'Dim ('Name "*") ('Size 512),
'Dim ('Name "*") ('Size 512)])))
(NamedModel ())))
(NamedModel
(GLinear
(NamedModel
(Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
('Shape
'[ 'Dim ('Name "*") ('Size 512),
'Dim ('Name "*") ('Size 512)])))
(NamedModel ())))
(NamedModel
(GLinear
(NamedModel
(Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
('Shape
'[ 'Dim ('Name "*") ('Size 512),
'Dim ('Name "*") ('Size 512)])))
(NamedModel ())))
(NamedModel
(GLinear
(NamedModel
(Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
('Shape
'[ 'Dim ('Name "*") ('Size 512),
'Dim ('Name "*") ('Size 512)])))
(NamedModel ())))
Dropout))
Dropout
()))
()
(NamedModel
(GTransformerFeedForwardNetwork
(NamedModel
(LayerNorm
'WithoutBias
('Gradient 'WithGradient)
('Device 'CPU)
T5DataType
('Shape '[ 'Dim ('Name "*") ('Size 512)])))
(NamedModel
(GLinear
(NamedModel
(Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
('Shape
'[ 'Dim ('Name "*") ('Size 2048),
'Dim ('Name "*") ('Size 512)])))
(NamedModel ())))
Relu
Dropout
(NamedModel
(GLinear
(NamedModel
(Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
('Shape
'[ 'Dim ('Name "*") ('Size 512),
'Dim ('Name "*") ('Size 2048)])))
(NamedModel ())))
Dropout
()))))))
(NamedModel
(LayerNorm
'WithoutBias
('Gradient 'WithGradient)
('Device 'CPU)
T5DataType
('Shape '[ 'Dim ('Name "*") ('Size 512)])))
Dropout))
(NamedModel
(GTransformer
()
(NamedModel
(Embedding
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
('Dim ('Name "*") ('Size 32))
('Dim ('Name "*") ('Size 8))
'Nothing))
()
Dropout
(NamedModel
(GTransformerStack
(Vector
6
(GTransformerBlock
(NamedModel
(GSelfAttention
(NamedModel
(LayerNorm
'WithoutBias
('Gradient 'WithGradient)
('Device 'CPU)
T5DataType
('Shape '[ 'Dim ('Name "*") ('Size 512)])))
(NamedModel
(GMultiHeadAttention
('Dim ('Name "*") ('Size 8))
('Dim ('Name "*") ('Size 64))
('Dim ('Name "*") ('Size 512))
(NamedModel
(GLinear
(NamedModel
(Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
('Shape
'[ 'Dim ('Name "*") ('Size 512),
'Dim ('Name "*") ('Size 512)])))
(NamedModel ())))
(NamedModel
(GLinear
(NamedModel
(Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
('Shape
'[ 'Dim ('Name "*") ('Size 512),
'Dim ('Name "*") ('Size 512)])))
(NamedModel ())))
(NamedModel
(GLinear
(NamedModel
(Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
('Shape
'[ 'Dim ('Name "*") ('Size 512),
'Dim ('Name "*") ('Size 512)])))
(NamedModel ())))
(NamedModel
(GLinear
(NamedModel
(Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
('Shape
'[ 'Dim ('Name "*") ('Size 512),
'Dim ('Name "*") ('Size 512)])))
(NamedModel ())))
Dropout))
Dropout
()))
(NamedModel
(GCrossAttention
(NamedModel
(LayerNorm
'WithoutBias
('Gradient 'WithGradient)
('Device 'CPU)
T5DataType
('Shape '[ 'Dim ('Name "*") ('Size 512)])))
(NamedModel
(GMultiHeadAttention
('Dim ('Name "*") ('Size 8))
('Dim ('Name "*") ('Size 64))
('Dim ('Name "*") ('Size 512))
(NamedModel
(GLinear
(NamedModel
(Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
('Shape
'[ 'Dim ('Name "*") ('Size 512),
'Dim ('Name "*") ('Size 512)])))
(NamedModel ())))
(NamedModel
(GLinear
(NamedModel
(Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
('Shape
'[ 'Dim ('Name "*") ('Size 512),
'Dim ('Name "*") ('Size 512)])))
(NamedModel ())))
(NamedModel
(GLinear
(NamedModel
(Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
('Shape
'[ 'Dim ('Name "*") ('Size 512),
'Dim ('Name "*") ('Size 512)])))
(NamedModel ())))
(NamedModel
(GLinear
(NamedModel
(Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
('Shape
'[ 'Dim ('Name "*") ('Size 512),
'Dim ('Name "*") ('Size 512)])))
(NamedModel ())))
Dropout))
Dropout
()))
(NamedModel
(GTransformerFeedForwardNetwork
(NamedModel
(LayerNorm
'WithoutBias
('Gradient 'WithGradient)
('Device 'CPU)
T5DataType
('Shape '[ 'Dim ('Name "*") ('Size 512)])))
(NamedModel
(GLinear
(NamedModel
(Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
('Shape
'[ 'Dim ('Name "*") ('Size 2048),
'Dim ('Name "*") ('Size 512)])))
(NamedModel ())))
Relu
Dropout
(NamedModel
(GLinear
(NamedModel
(Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
('Shape
'[ 'Dim ('Name "*") ('Size 512),
'Dim ('Name "*") ('Size 2048)])))
(NamedModel ())))
Dropout
()))))))
(NamedModel
(LayerNorm
'WithoutBias
('Gradient 'WithGradient)
('Device 'CPU)
T5DataType
('Shape '[ 'Dim ('Name "*") ('Size 512)])))
Dropout))
(NamedModel
(Embedding
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
('Dim ('Name "*") ('Size 32128))
('Dim ('Name "*") ('Size 512))
'Nothing))
(NamedModel
(GLMHead
('Dim ('Name "*") ('Size 512))
()
()
()
(NamedModel
(GLinear
(NamedModel
(Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
('Shape
'[ 'Dim ('Name "*") ('Size 32128), 'Dim ('Name "*") ('Size 512)])))
(NamedModel ())))
(GBias ()))))
(MkRelPos ('Dim ('Name "*") ('Size 32)))
(MkRelPos ('Dim ('Name "*") ('Size 32)))
MkTransformerPaddingMask
(MkTransformerAttentionMask T5DataType)
(MkTransformerCrossAttentionMask T5DataType)
(MkTransformerDecoderAttentionMask T5DataType)
model Tensor
('Gradient 'WithoutGradient)
('Layout 'Dense)
('Device 'CPU)
('DataType 'Int64)
('Shape
'[ 'Dim ('Name "*") ('Size 1), 'Dim ('Name "*") ('Size 19)])
input Generator ('Device 'CPU)
g
forall a. Show a => a -> IO ()
print forall a b. (a -> b) -> a -> b
$ forall a b. Hypothesis 'Finished a b -> b
finalValue forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Hypothesis 'Finished Int [Int]]
finished
next ::
forall t b i a.
( i ~ Int,
Show i,
MonadTrans t,
Monad (t (StateT [i] b)),
Alternative (t (StateT [i] b)),
Monad b,
Foldable b
) =>
t (StateT [i] b) i ->
Parser (StateT [i] b) i a ->
(Parser (StateT [i] b) i a -> t (StateT [i] b) a) ->
t (StateT [i] b) a
next :: forall (t :: (* -> *) -> * -> *) (b :: * -> *) i a.
(i ~ Int, Show i, MonadTrans t, Monad (t (StateT [i] b)),
Alternative (t (StateT [i] b)), Monad b, Foldable b) =>
t (StateT [i] b) i
-> Parser (StateT [i] b) i a
-> (Parser (StateT [i] b) i a -> t (StateT [i] b) a)
-> t (StateT [i] b) a
next t (StateT [i] b) i
is Parser (StateT [i] b) i a
parser Parser (StateT [i] b) i a -> t (StateT [i] b) a
cont = do
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (forall (b :: * -> *) i a.
(Monad b, Foldable b) =>
Parser (StateT [i] b) i a -> StateT [i] b Bool
notNull Parser (StateT [i] b) i a
parser) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (f :: * -> *). Alternative f => Bool -> f ()
guard
i
i <- t (StateT [i] b) i
is
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (i
i forall a. a -> [a] -> [a]
:)
FreeF ((->) i) a (Parser (StateT [i] b) i a)
val <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) (m :: * -> *) a.
FreeT f m a -> m (FreeF f a (FreeT f m a))
runFreeT Parser (StateT [i] b) i a
parser
case FreeF ((->) i) a (Parser (StateT [i] b) i a)
val of
Pure a
a -> forall (f :: * -> *) a. Applicative f => a -> f a
pure a
a
Free i -> Parser (StateT [i] b) i a
feed -> forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Parser (StateT [i] b) i a -> t (StateT [i] b) a
cont forall b c a. (b -> c) -> (a -> b) -> a -> c
. i -> Parser (StateT [i] b) i a
feed forall a b. (a -> b) -> a -> b
$ i
i
notNull ::
(Monad b, Foldable b) =>
Parser (StateT [i] b) i a ->
StateT [i] b Bool
notNull :: forall (b :: * -> *) i a.
(Monad b, Foldable b) =>
Parser (StateT [i] b) i a -> StateT [i] b Bool
notNull Parser (StateT [i] b) i a
parser = forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall a b. (a -> b) -> a -> b
$ (Bool -> Bool
not forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a. Foldable t => t a -> Bool
null) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT (forall (f :: * -> *) (m :: * -> *) a.
FreeT f m a -> m (FreeF f a (FreeT f m a))
runFreeT Parser (StateT [i] b) i a
parser)
hasFree ::
(Monad b, Foldable b) =>
Parser (StateT [i] b) i a ->
StateT [i] b Bool
hasFree :: forall (b :: * -> *) i a.
(Monad b, Foldable b) =>
Parser (StateT [i] b) i a -> StateT [i] b Bool
hasFree Parser (StateT [i] b) i a
parser = forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (\case Free i -> Parser (StateT [i] b) i a
_ -> Bool
True; FreeF ((->) i) a (Parser (StateT [i] b) i a)
_ -> Bool
False) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT (forall (f :: * -> *) (m :: * -> *) a.
FreeT f m a -> m (FreeF f a (FreeT f m a))
runFreeT Parser (StateT [i] b) i a
parser)
transParser :: MonadPlus b => Map.Map Int String -> Parser b Char a -> Parser b Int a
transParser :: forall (b :: * -> *) a.
MonadPlus b =>
Map Int String -> Parser b Char a -> Parser b Int a
transParser Map Int String
vocab = forall (f :: * -> *) (m :: * -> *) a.
m (FreeF f a (FreeT f m a)) -> FreeT f m a
FreeT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall (b :: * -> *) a.
MonadPlus b =>
Map Int String -> Parser b Char a -> Parser b Int a
transParser Map Int String
vocab) forall b c a. (b -> c) -> (a -> b) -> a -> c
. FreeF ((->) Char) a (Parser b Char a)
-> FreeF ((->) Int) a (Parser b Char a)
transFreeF) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) (m :: * -> *) a.
FreeT f m a -> m (FreeF f a (FreeT f m a))
runFreeT
where
transFreeF :: FreeF ((->) Char) a (Parser b Char a)
-> FreeF ((->) Int) a (Parser b Char a)
transFreeF (Pure a
a) = forall (f :: * -> *) a b. a -> FreeF f a b
Pure a
a
transFreeF (Free Char -> Parser b Char a
feed) =
let feed' :: Int -> Parser b Char a
feed' Int
i = do
String
s <-
let clean :: ShowS
clean (Char
'▁' : String
s) = Char
' ' forall a. a -> [a] -> [a]
: String
s
clean String
s = String
s
in forall b a. b -> (a -> b) -> Maybe a -> b
maybe forall (f :: * -> *) a. Alternative f => f a
empty (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. ShowS
clean) (forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup Int
i Map Int String
vocab)
(Char
c, String
cs) <- forall b a. b -> (a -> b) -> Maybe a -> b
maybe forall (f :: * -> *) a. Alternative f => f a
empty forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a. [a] -> Maybe (a, [a])
uncons String
s)
forall {m :: * -> *} {t} {a}.
Monad m =>
[t] -> FreeT ((->) t) m a -> FreeT ((->) t) m a
go String
cs (Char -> Parser b Char a
feed Char
c)
go :: [t] -> FreeT ((->) t) m a -> FreeT ((->) t) m a
go [] FreeT ((->) t) m a
p = FreeT ((->) t) m a
p
go (t
c : [t]
cs) FreeT ((->) t) m a
p = do
FreeF ((->) t) a (FreeT ((->) t) m a)
val <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) (m :: * -> *) a.
FreeT f m a -> m (FreeF f a (FreeT f m a))
runFreeT FreeT ((->) t) m a
p
case FreeF ((->) t) a (FreeT ((->) t) m a)
val of
Pure a
a -> forall (f :: * -> *) a. Applicative f => a -> f a
pure a
a
Free t -> FreeT ((->) t) m a
feed -> [t] -> FreeT ((->) t) m a -> FreeT ((->) t) m a
go [t]
cs (t -> FreeT ((->) t) m a
feed t
c)
in forall (f :: * -> *) a b. f b -> FreeF f a b
Free Int -> Parser b Char a
feed'
instance
(Alternative b, Foldable b, MonadPlus b) =>
Parsing (FreeT ((->) Char) (StateT [Int] b))
where
try :: forall a.
FreeT ((->) Char) (StateT [Int] b) a
-> FreeT ((->) Char) (StateT [Int] b) a
try = forall a. a -> a
id
<?> :: forall a.
FreeT ((->) Char) (StateT [Int] b) a
-> String -> FreeT ((->) Char) (StateT [Int] b) a
(<?>) = forall a b. a -> b -> a
const
skipMany :: forall a.
FreeT ((->) Char) (StateT [Int] b) a
-> FreeT ((->) Char) (StateT [Int] b) ()
skipMany FreeT ((->) Char) (StateT [Int] b) a
p = FreeT ((->) Char) (StateT [Int] b) ()
scan where scan :: FreeT ((->) Char) (StateT [Int] b) ()
scan = (FreeT ((->) Char) (StateT [Int] b) a
p forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> FreeT ((->) Char) (StateT [Int] b) ()
scan) forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
skipSome :: forall a.
FreeT ((->) Char) (StateT [Int] b) a
-> FreeT ((->) Char) (StateT [Int] b) ()
skipSome FreeT ((->) Char) (StateT [Int] b) a
p = FreeT ((->) Char) (StateT [Int] b) a
p forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> forall (m :: * -> *) a. Parsing m => m a -> m ()
skipMany FreeT ((->) Char) (StateT [Int] b) a
p
unexpected :: forall a. String -> FreeT ((->) Char) (StateT [Int] b) a
unexpected = forall a b. a -> b -> a
const forall (f :: * -> *) a. Alternative f => f a
empty
eof :: FreeT ((->) Char) (StateT [Int] b) ()
eof = forall a. HasCallStack => a
undefined
notFollowedBy :: forall a.
Show a =>
FreeT ((->) Char) (StateT [Int] b) a
-> FreeT ((->) Char) (StateT [Int] b) ()
notFollowedBy = forall a. HasCallStack => a
undefined
instance
(Alternative b, Foldable b, MonadPlus b) =>
CharParsing (FreeT ((->) Char) (StateT [Int] b))
where
satisfy :: (Char -> Bool) -> FreeT ((->) Char) (StateT [Int] b) Char
satisfy = forall (b :: * -> *) i. MonadPlus b => (i -> Bool) -> Parser b i i
Torch.Data.Parser.satisfy
char :: Char -> FreeT ((->) Char) (StateT [Int] b) Char
char = forall (b :: * -> *) i. (MonadPlus b, Eq i) => i -> Parser b i i
isToken
notChar :: Char -> FreeT ((->) Char) (StateT [Int] b) Char
notChar = forall (b :: * -> *) i. (MonadPlus b, Eq i) => i -> Parser b i i
isNotToken
anyChar :: FreeT ((->) Char) (StateT [Int] b) Char
anyChar = forall (b :: * -> *) i. Monad b => Parser b i i
Torch.Data.Parser.token
string :: String -> FreeT ((->) Char) (StateT [Int] b) String
string = forall (t :: * -> *) (b :: * -> *) i.
(Traversable t, MonadPlus b, Eq i) =>
t i -> Parser b i (t i)
isString
instance
(Alternative b, Foldable b, MonadPlus b) =>
TokenParsing (FreeT ((->) Char) (StateT [Int] b))
getIs ::
forall model input generatorDevice b decoderInput encoderOutput decoderOutput inputPaddingMask s.
( Alternative b,
MonadThrow b,
s ~ (Maybe (encoderOutput, inputPaddingMask), Generator generatorDevice),
decoderInput
~ Tensor
('Gradient 'WithoutGradient)
('Layout 'Dense)
('Device 'CPU)
('DataType 'Int64)
( 'Shape
'[ 'Dim ('Name "*") ('Size 1), 'Dim ('Name "*") 'UncheckedSize]
),
decoderOutput
~ Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
'UncheckedShape,
HasForward
model
(SimplifiedEncoderDecoderTransformerInput input decoderInput)
generatorDevice
(SimplifiedEncoderDecoderTransformerOutput decoderOutput encoderOutput decoderInput inputPaddingMask)
generatorDevice,
HasForward
model
(SimplifiedEncoderDecoderTransformerGenerationInput decoderInput encoderOutput inputPaddingMask)
generatorDevice
(SimplifiedEncoderDecoderTransformerOutput decoderOutput encoderOutput decoderInput inputPaddingMask)
generatorDevice
) =>
Int ->
model ->
input ->
StateT s (StateT [Int] b) Int
getIs :: forall model input (generatorDevice :: Device (DeviceType Nat))
(b :: * -> *) decoderInput encoderOutput decoderOutput
inputPaddingMask s.
(Alternative b, MonadThrow b,
s
~ (Maybe (encoderOutput, inputPaddingMask),
Generator generatorDevice),
decoderInput
~ Tensor
('Gradient 'WithoutGradient)
('Layout 'Dense)
('Device 'CPU)
('DataType 'Int64)
('Shape
'[ 'Dim ('Name "*") ('Size 1), 'Dim ('Name "*") 'UncheckedSize]),
decoderOutput
~ Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
'UncheckedShape,
HasForward
model
(SimplifiedEncoderDecoderTransformerInput input decoderInput)
generatorDevice
(SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask)
generatorDevice,
HasForward
model
(SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask)
generatorDevice
(SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask)
generatorDevice) =>
Int -> model -> input -> StateT s (StateT [Int] b) Int
getIs Int
n model
model input
input = do
[Int]
tokens <- do
[Int]
ts <- forall a. [a] -> [a]
reverse forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall s (m :: * -> *). MonadState s m => m s
get
let ts' :: [Int]
ts' = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Int]
ts
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Int]
ts'
decoderInput
decoderInput :: decoderInput <-
forall (batchDim :: Dim (Name Symbol) (Size Nat))
(seqDim :: Dim (Name Symbol) (Size Nat))
(device :: Device (DeviceType Nat)) (m :: * -> *) output.
(MonadThrow m, SGetDim batchDim, SGetDim seqDim,
Catch
('Shape
'[ 'Dim ('Name "*") 'UncheckedSize,
'Dim ('Name "*") 'UncheckedSize]
<+> 'Shape '[batchDim, seqDim]),
output
~ Tensor
('Gradient 'WithoutGradient)
('Layout 'Dense)
device
('DataType 'Int64)
('Shape '[batchDim, seqDim])) =>
SDim batchDim
-> SDim seqDim -> SDevice device -> [[Int]] -> m output
mkT5Input
(forall (name1 :: Symbol). KnownSymbol name1 => SName ('Name name1)
SName @"*" forall (name :: Name Symbol) (size :: Size Nat).
SName name -> SSize size -> SDim ('Dim name size)
:&: forall (size1 :: Nat). KnownNat size1 => SSize ('Size size1)
SSize @1)
(forall (name1 :: Symbol). KnownSymbol name1 => SName ('Name name1)
SName @"*" forall (name :: Name Symbol) (size :: Size Nat).
SName name -> SSize size -> SDim ('Dim name size)
:&: Integer -> SSize 'UncheckedSize
SUncheckedSize (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
tokens))
(forall (deviceType1 :: DeviceType Nat).
SDeviceType deviceType1 -> SDevice ('Device deviceType1)
SDevice SDeviceType 'CPU
SCPU)
[[Int]
tokens]
Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
'UncheckedShape
decoderOutput <- do
(Maybe (encoderOutput, inputPaddingMask)
mTensors, Generator generatorDevice
g) <- forall s (m :: * -> *). MonadState s m => m s
get
(SimplifiedEncoderDecoderTransformerOutput Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
'UncheckedShape
decoderOutput encoderOutput
encoderOutput Tensor
('Gradient 'WithoutGradient)
('Layout 'Dense)
('Device 'CPU)
('DataType 'Int64)
('Shape
'[ 'Dim ('Name "*") ('Size 1), 'Dim ('Name "*") 'UncheckedSize])
decoderInput inputPaddingMask
inputPaddingMask, Generator generatorDevice
g') <-
case Maybe (encoderOutput, inputPaddingMask)
mTensors of
Maybe (encoderOutput, inputPaddingMask)
Nothing -> forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward model
model (forall input decoderInput.
input
-> decoderInput
-> SimplifiedEncoderDecoderTransformerInput input decoderInput
SimplifiedEncoderDecoderTransformerInput input
input decoderInput
decoderInput) Generator generatorDevice
g
Just (encoderOutput
encoderOutput, inputPaddingMask
inputPaddingMask) ->
forall model input (generatorDevice :: Device (DeviceType Nat))
output (generatorOutputDevice :: Device (DeviceType Nat))
(m :: * -> *).
(HasForward
model input generatorDevice output generatorOutputDevice,
MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward model
model (forall decoderInput encoderOutput inputPaddingMask.
decoderInput
-> encoderOutput
-> inputPaddingMask
-> SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask
SimplifiedEncoderDecoderTransformerGenerationInput decoderInput
decoderInput encoderOutput
encoderOutput inputPaddingMask
inputPaddingMask) Generator generatorDevice
g
forall s (m :: * -> *). MonadState s m => s -> m ()
put (forall a. a -> Maybe a
Just (encoderOutput
encoderOutput, inputPaddingMask
inputPaddingMask), Generator generatorDevice
g')
forall (f :: * -> *) a. Applicative f => a -> f a
pure Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
'UncheckedShape
decoderOutput
Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
(SoftmaxF ('SelectDim ('ByIndex 2)) 'UncheckedShape)
probs <- forall (selectDim :: SelectDim (By Symbol Nat))
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, shape' ~ SoftmaxF selectDim shape, Catch shape') =>
SSelectDim selectDim
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
logSoftmax (forall (by :: By Symbol Nat). SBy by -> SSelectDim ('SelectDim by)
SSelectDim forall a b. (a -> b) -> a -> b
$ forall (index :: Nat). KnownNat index => SBy ('ByIndex index)
SByIndex @2) Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
'UncheckedShape
decoderOutput
case forall (selectDim :: SelectDim (By Symbol Nat))
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
SingI selectDim =>
Order
-> Tensor gradient layout device dataType shape
-> Sorted gradient layout device dataType (SortF selectDim shape)
sort @('SelectDim ('ByIndex 2)) Order
Descending Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
(SoftmaxF ('SelectDim ('ByIndex 2)) 'UncheckedShape)
probs of
Sorted Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
(SortF
('SelectDim ('ByIndex 2))
(SoftmaxF ('SelectDim ('ByIndex 2)) 'UncheckedShape))
_ (UnsafeTensor ForeignPtr Tensor
indices) ->
let indices' :: [Int]
indices' = forall a. Int -> [a] -> [a]
take Int
n forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [a] -> a
last forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [a] -> a
head forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. TensorLike a => Tensor -> a
Torch.Tensor.asValue @[[[Int]]] forall b c a. (b -> c) -> (a -> b) -> a -> c
. ForeignPtr Tensor -> Tensor
Torch.Tensor.Unsafe forall a b. (a -> b) -> a -> b
$ ForeignPtr Tensor
indices
in forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) (f :: * -> *) a.
(Foldable t, Alternative f) =>
t (f a) -> f a
asum forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Int]
indices'
runParser ::
forall model input generatorDevice b a.
_ =>
Int ->
model ->
input ->
Generator generatorDevice ->
Parser (StateT [Int] b) Int a ->
b (a, [Int])
runParser :: Int
-> model
-> input
-> Generator generatorDevice
-> Parser (StateT [Int] b) Int a
-> b (a, [Int])
runParser Int
n model
model input
input Generator generatorDevice
g =
forall a b c. (a -> b -> c) -> b -> a -> c
flip forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT []
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b c. (a -> b -> c) -> b -> a -> c
flip forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT (forall a. Maybe a
Nothing, Generator generatorDevice
g)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: (* -> *) -> * -> *) (b :: * -> *) i a.
(Parser b i a -> (Parser b i a -> t b a) -> t b a)
-> Parser b i a -> t b a
recurse (forall (t :: (* -> *) -> * -> *) (b :: * -> *) i a.
(i ~ Int, Show i, MonadTrans t, Monad (t (StateT [i] b)),
Alternative (t (StateT [i] b)), Monad b, Foldable b) =>
t (StateT [i] b) i
-> Parser (StateT [i] b) i a
-> (Parser (StateT [i] b) i a -> t (StateT [i] b) a)
-> t (StateT [i] b) a
next (forall model input (generatorDevice :: Device (DeviceType Nat))
(b :: * -> *) decoderInput encoderOutput decoderOutput
inputPaddingMask s.
(Alternative b, MonadThrow b,
s
~ (Maybe (encoderOutput, inputPaddingMask),
Generator generatorDevice),
decoderInput
~ Tensor
('Gradient 'WithoutGradient)
('Layout 'Dense)
('Device 'CPU)
('DataType 'Int64)
('Shape
'[ 'Dim ('Name "*") ('Size 1), 'Dim ('Name "*") 'UncheckedSize]),
decoderOutput
~ Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
T5DataType
'UncheckedShape,
HasForward
model
(SimplifiedEncoderDecoderTransformerInput input decoderInput)
generatorDevice
(SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask)
generatorDevice,
HasForward
model
(SimplifiedEncoderDecoderTransformerGenerationInput
decoderInput encoderOutput inputPaddingMask)
generatorDevice
(SimplifiedEncoderDecoderTransformerOutput
decoderOutput encoderOutput decoderInput inputPaddingMask)
generatorDevice) =>
Int -> model -> input -> StateT s (StateT [Int] b) Int
getIs Int
n model
model input
input))
t5Text :: MonadPlus b => Parser b Char String
t5Text :: forall (b :: * -> *). MonadPlus b => Parser b Char String
t5Text = forall (m :: * -> *) a end. Alternative m => m a -> m end -> m [a]
manyTill forall (b :: * -> *) i. Monad b => Parser b i i
Torch.Data.Parser.token (forall (t :: * -> *) (b :: * -> *) i.
(Traversable t, MonadPlus b, Eq i) =>
t i -> Parser b i (t i)
isString String
"</s>")
t5Test :: MonadPlus b => Parser b Char String
t5Test :: forall (b :: * -> *). MonadPlus b => Parser b Char String
t5Test =
forall {b :: * -> *}.
MonadPlus b =>
Int -> FreeT ((->) Char) b String
notEnd Int
25
forall (f :: * -> *) a.
(Applicative f, Semigroup a) =>
f a -> f a -> f a
`combine` forall (t :: * -> *) (b :: * -> *) i.
(Traversable t, MonadPlus b, Eq i) =>
t i -> Parser b i (t i)
isString String
"belegt"
forall (f :: * -> *) a.
(Applicative f, Semigroup a) =>
f a -> f a -> f a
`combine` forall {b :: * -> *}.
MonadPlus b =>
Int -> FreeT ((->) Char) b String
notEnd Int
5
forall (f :: * -> *) a.
(Applicative f, Semigroup a) =>
f a -> f a -> f a
`combine` forall (t :: * -> *) (b :: * -> *) i.
(Traversable t, MonadPlus b, Eq i) =>
t i -> Parser b i (t i)
isString String
"dass es"
forall (f :: * -> *) a.
(Applicative f, Semigroup a) =>
f a -> f a -> f a
`combine` forall {b :: * -> *}.
MonadPlus b =>
Int -> FreeT ((->) Char) b String
notEnd Int
25
forall (f :: * -> *) a.
(Applicative f, Semigroup a) =>
f a -> f a -> f a
`combine` forall (t :: * -> *) (b :: * -> *) i.
(Traversable t, MonadPlus b, Eq i) =>
t i -> Parser b i (t i)
isString String
"haben"
forall (f :: * -> *) a.
(Applicative f, Semigroup a) =>
f a -> f a -> f a
`combine` forall (t :: * -> *) (b :: * -> *) i.
(Traversable t, MonadPlus b, Eq i) =>
t i -> Parser b i (t i)
isString String
"</s>"
where
notEnd :: Int -> FreeT ((->) Char) b String
notEnd Int
n = forall (m :: * -> *) s a.
(Alternative m, Monad m) =>
(s -> a -> Maybe s) -> s -> m a -> m [a]
scan String -> Char -> Maybe String
f String
"" forall (b :: * -> *) i. Monad b => Parser b i i
Torch.Data.Parser.token
where
f :: String -> Char -> Maybe String
f String
s Char
a = case String
s forall a. [a] -> [a] -> [a]
++ [Char
a] of
String
s'
| String
"</s>" forall a. Eq a => [a] -> [a] -> Bool
`isInfixOf` String
s' -> forall a. Maybe a
Nothing
| forall (t :: * -> *) a. Foldable t => t a -> Int
length String
s' forall a. Ord a => a -> a -> Bool
> Int
n -> forall a. Maybe a
Nothing
| Bool
otherwise -> forall a. a -> Maybe a
Just String
s'
t5Sql ::
(TokenParsing (FreeT ((->) Char) b), MonadPlus b) =>
Parser b Char SpiderSQL
t5Sql :: forall (b :: * -> *).
(TokenParsing (FreeT ((->) Char) b), MonadPlus b) =>
Parser b Char SpiderSQL
t5Sql =
let q :: FreeT ((->) Char) b Char
q = forall (m :: * -> *). CharParsing m => m ()
spaces forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> forall (m :: * -> *). CharParsing m => Char -> m Char
char Char
'\"' forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* forall (m :: * -> *). CharParsing m => m ()
spaces
in forall (m :: * -> *) bra ket a.
Applicative m =>
m bra -> m ket -> m a -> m a
between FreeT ((->) Char) b Char
q FreeT ((->) Char) b Char
q forall (m :: * -> *). (TokenParsing m, Monad m) => m SpiderSQL
spiderSQL forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* forall (t :: * -> *) (b :: * -> *) i.
(Traversable t, MonadPlus b, Eq i) =>
t i -> Parser b i (t i)
isString String
"</s>"