Safe Haskell | Safe-Inferred |
---|---|
Language | Haskell2010 |
Documentation
sedtOutputToInput :: Monad m => Lens (SimplifiedEncoderDecoderTransformerOutput logits encoderOutput decoderInput inputPaddingMask) (m (SimplifiedEncoderDecoderTransformerGenerationInput decoderInput' encoderOutput inputPaddingMask)) (logits, decoderInput) (m decoderInput') Source #
prepNext :: (logits ~ Tensor logitsGradient logitsLayout logitsDevice logitsDataType logitsShape, ntShape' ~ UnsqueezeF ('SelectDim ('ByIndex 1)) ntShape, Catch ntShape', tensors ~ '[decoderInput, Tensor ntGradient ntLayout ntDevice ntDataType ntShape'], decoderInput' ~ CatHListF ('SelectDim ('ByIndex 1)) tensors, Castable decoderInput' (ForeignPtr Tensor), Castable (HList tensors) (ForeignPtr TensorList), MonadThrow m) => Lens (logits, decoderInput) (m decoderInput') logits (m (Tensor ntGradient ntLayout ntDevice ntDataType ntShape)) Source #
greedyNextTokens :: (nextTokenLogitsShape ~ IndexDims ('Indices '['SliceAll, 'SliceAt ('NegativeIndex 1), 'SliceAll]) logitsShape, nextTokensShape ~ ArgmaxF ('SelectDim ('ByIndex 1)) nextTokenLogitsShape, Catch nextTokensShape, nextTokensShape' ~ SqueezeDimF ('SelectDim ('ByIndex 1)) nextTokensShape, ntShape ~ 'Shape '[ntDim], Catch (nextTokensShape' <+> ntShape), SGetShape nextTokensShape', SGetDim ntDim, Catch ntDim, Catch nextTokensShape', MonadThrow m, MonadState (Tensor ('Gradient 'WithoutGradient) logitsLayout logitsDevice ('DataType 'Int64) ntShape) m, SGetDevice logitsDevice, SGetLayout logitsLayout) => Int -> Int -> Tensor logitsGradient logitsLayout logitsDevice logitsDataType logitsShape -> m (Tensor ('Gradient 'WithoutGradient) logitsLayout logitsDevice ('DataType 'Int64) ntShape) Source #
allSequencesFinished :: (SGetLayout usLayout, SGetDevice usDevice, MonadThrow m, Catch (usDataType <+> 'DataType 'Int64), Catch (BroadcastShapesF usShape ('Shape '[])), MaxAllCheckF usShape) => Tensor usGradient usLayout usDevice usDataType usShape -> m Bool Source #
applyUnfinishedSequences :: (MonadThrow m, kntDataType ~ (usDataType <+> ntDataType), kntShape ~ BroadcastShapesF usShape ntShape, Catch kntShape, ntGradient' ~ (usGradient <|> ntGradient), ntLayout' ~ ((usLayout <+> ntLayout) <+> usLayout), ntDevice' ~ ((usDevice <+> ntDevice) <+> usDevice), ntDataType' ~ ((usDataType <+> ntDataType) <+> usDataType), ntShape' ~ BroadcastShapesF kntShape usShape, Catch ntShape') => Int -> Tensor usGradient usLayout usDevice usDataType usShape -> Tensor ntGradient ntLayout ntDevice ntDataType ntShape -> m (Tensor ntGradient' ntLayout' ntDevice' ntDataType' ntShape') Source #
updateUnfinishedSequences :: (Catch (ntDataType <+> 'DataType 'Int64), Catch (BroadcastShapesF ntShape ('Shape '[])), Catch usShape', SGetDataType usDataType, SGetDevice ntDevice, SGetLayout ntLayout, MonadThrow m, BroadcastShapesF usShape (BroadcastShapesF ntShape ('Shape '[])) ~ usShape', (usGradient <|> 'Gradient 'WithoutGradient) ~ usGradient', (usLayout <+> ntLayout) ~ usLayout', (usDevice <+> ntDevice) ~ usDevice') => Int -> Tensor ntGradient ntLayout ntDevice ntDataType ntShape -> Tensor usGradient usLayout usDevice usDataType usShape -> m (Tensor usGradient' usLayout' usDevice' usDataType usShape') Source #