{-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE UndecidableInstances #-} module Torch.GraduallyTyped.NN.Transformer.GPooler where import Control.Monad.Indexed (IxPointed (..), (>>>=)) import Control.Monad.Indexed.State (IxStateT (..)) import Data.Kind (Type) import GHC.TypeLits (Nat, Symbol) import Torch.GraduallyTyped.DType (DType, DataType, SDataType) import Torch.GraduallyTyped.Device (Device, DeviceType, SDevice) import Torch.GraduallyTyped.NN.Activation (Tanh) import Torch.GraduallyTyped.NN.Class (HasForward (..), ModelSpec, NamedModel) import Torch.GraduallyTyped.NN.Linear (GLinearF) import Torch.GraduallyTyped.NN.Transformer.Type (STransformerStyle, TransformerStyle (..)) import Torch.GraduallyTyped.NN.Type (HasBias (..)) import Torch.GraduallyTyped.RequiresGradient (Gradient, RequiresGradient, SGradient) import Torch.GraduallyTyped.Shape.Type (Dim, Name, SDim, Size) import Torch.GraduallyTyped.Tensor.Type (Tensor) data GPooler (dense :: Type) (activation :: Type) where GPooler :: forall dense activation. { forall dense activation. GPooler dense activation -> dense poolerDense :: dense, forall dense activation. GPooler dense activation -> activation poolerActivation :: activation } -> GPooler dense activation type instance ModelSpec (GPooler dense activation) = GPooler (ModelSpec dense) (ModelSpec activation) poolerSpec :: forall style gradient device dataType inputEmbedDim. STransformerStyle style -> SGradient gradient -> SDevice device -> SDataType dataType -> SDim inputEmbedDim -> GPooler (PoolerDenseF style gradient device dataType inputEmbedDim) (PoolerActivationF style) poolerSpec :: forall (style :: TransformerStyle) (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (inputEmbedDim :: Dim (Name Symbol) (Size Nat)). STransformerStyle style -> SGradient gradient -> SDevice device -> SDataType dataType -> SDim inputEmbedDim -> GPooler (PoolerDenseF style gradient device dataType inputEmbedDim) (PoolerActivationF style) poolerSpec STransformerStyle style style SGradient gradient gradient SDevice device device SDataType dataType dataType SDim inputEmbedDim inputEmbedDim = forall a. HasCallStack => a undefined type family PoolerDenseF (style :: TransformerStyle) (gradient :: Gradient RequiresGradient) (device :: Device (DeviceType Nat)) (dataType :: DataType DType) (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) :: Type where PoolerDenseF 'RoBERTa gradient device dataType inputEmbedDim = NamedModel (GLinearF 'WithBias gradient device dataType inputEmbedDim inputEmbedDim) type family PoolerActivationF (style :: TransformerStyle) :: Type where PoolerActivationF 'RoBERTa = Tanh instance ( HasForward dense (Tensor gradient layout device dataType shape) generatorDevice tensor0 generatorDevice0, HasForward activation tensor0 generatorDevice0 output generatorOutputDevice ) => HasForward (GPooler dense activation) (Tensor gradient layout device dataType shape) generatorDevice output generatorOutputDevice where forward :: forall (m :: * -> *). MonadThrow m => GPooler dense activation -> Tensor gradient layout device dataType shape -> Generator generatorDevice -> m (output, Generator generatorOutputDevice) forward GPooler {dense activation poolerActivation :: activation poolerDense :: dense poolerActivation :: forall dense activation. GPooler dense activation -> activation poolerDense :: forall dense activation. GPooler dense activation -> dense ..} Tensor gradient layout device dataType shape input = forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j) runIxStateT forall a b. (a -> b) -> a -> b $ forall {k} (m :: k -> k -> * -> *) a (i :: k). IxPointed m => a -> m i i a ireturn Tensor gradient layout device dataType shape input forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a (k2 :: k1) b. IxMonad m => m i j a -> (a -> m j k2 b) -> m i k2 b >>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c . forall model input (generatorDevice :: Device (DeviceType Nat)) output (generatorOutputDevice :: Device (DeviceType Nat)) (m :: * -> *). (HasForward model input generatorDevice output generatorOutputDevice, MonadThrow m) => model -> input -> Generator generatorDevice -> m (output, Generator generatorOutputDevice) forward dense poolerDense forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a (k2 :: k1) b. IxMonad m => m i j a -> (a -> m j k2 b) -> m i k2 b >>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c . forall model input (generatorDevice :: Device (DeviceType Nat)) output (generatorOutputDevice :: Device (DeviceType Nat)) (m :: * -> *). (HasForward model input generatorDevice output generatorOutputDevice, MonadThrow m) => model -> input -> Generator generatorDevice -> m (output, Generator generatorOutputDevice) forward activation poolerActivation