{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

module Torch.GraduallyTyped.NN.Transformer.GPooler where

import Control.Monad.Indexed (IxPointed (..), (>>>=))
import Control.Monad.Indexed.State (IxStateT (..))
import Data.Kind (Type)
import GHC.TypeLits (Nat, Symbol)
import Torch.GraduallyTyped.DType (DType, DataType, SDataType)
import Torch.GraduallyTyped.Device (Device, DeviceType, SDevice)
import Torch.GraduallyTyped.NN.Activation (Tanh)
import Torch.GraduallyTyped.NN.Class (HasForward (..), ModelSpec, NamedModel)
import Torch.GraduallyTyped.NN.Linear (GLinearF)
import Torch.GraduallyTyped.NN.Transformer.Type (STransformerStyle, TransformerStyle (..))
import Torch.GraduallyTyped.NN.Type (HasBias (..))
import Torch.GraduallyTyped.RequiresGradient (Gradient, RequiresGradient, SGradient)
import Torch.GraduallyTyped.Shape.Type (Dim, Name, SDim, Size)
import Torch.GraduallyTyped.Tensor.Type (Tensor)

data
  GPooler
    (dense :: Type)
    (activation :: Type)
  where
  GPooler ::
    forall dense activation.
    { forall dense activation. GPooler dense activation -> dense
poolerDense :: dense,
      forall dense activation. GPooler dense activation -> activation
poolerActivation :: activation
    } ->
    GPooler dense activation

type instance
  ModelSpec (GPooler dense activation) =
    GPooler (ModelSpec dense) (ModelSpec activation)

poolerSpec ::
  forall style gradient device dataType inputEmbedDim.
  STransformerStyle style ->
  SGradient gradient ->
  SDevice device ->
  SDataType dataType ->
  SDim inputEmbedDim ->
  GPooler
    (PoolerDenseF style gradient device dataType inputEmbedDim)
    (PoolerActivationF style)
poolerSpec :: forall (style :: TransformerStyle)
       (gradient :: Gradient RequiresGradient)
       (device :: Device (DeviceType Nat)) (dataType :: DataType DType)
       (inputEmbedDim :: Dim (Name Symbol) (Size Nat)).
STransformerStyle style
-> SGradient gradient
-> SDevice device
-> SDataType dataType
-> SDim inputEmbedDim
-> GPooler
     (PoolerDenseF style gradient device dataType inputEmbedDim)
     (PoolerActivationF style)
poolerSpec STransformerStyle style
style SGradient gradient
gradient SDevice device
device SDataType dataType
dataType SDim inputEmbedDim
inputEmbedDim = forall a. HasCallStack => a
undefined

type family
  PoolerDenseF
    (style :: TransformerStyle)
    (gradient :: Gradient RequiresGradient)
    (device :: Device (DeviceType Nat))
    (dataType :: DataType DType)
    (inputEmbedDim :: Dim (Name Symbol) (Size Nat)) ::
    Type
  where
  PoolerDenseF 'RoBERTa gradient device dataType inputEmbedDim =
    NamedModel (GLinearF 'WithBias gradient device dataType inputEmbedDim inputEmbedDim)

type family
  PoolerActivationF
    (style :: TransformerStyle) ::
    Type
  where
  PoolerActivationF 'RoBERTa = Tanh

instance
  ( HasForward
      dense
      (Tensor gradient layout device dataType shape)
      generatorDevice
      tensor0
      generatorDevice0,
    HasForward
      activation
      tensor0
      generatorDevice0
      output
      generatorOutputDevice
  ) =>
  HasForward
    (GPooler dense activation)
    (Tensor gradient layout device dataType shape)
    generatorDevice
    output
    generatorOutputDevice
  where
  forward :: forall (m :: * -> *).
MonadThrow m =>
GPooler dense activation
-> Tensor gradient layout device dataType shape
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward GPooler {dense
activation
poolerActivation :: activation
poolerDense :: dense
poolerActivation :: forall dense activation. GPooler dense activation -> activation
poolerDense :: forall dense activation. GPooler dense activation -> dense
..} Tensor gradient layout device dataType shape
input =
    forall (m :: * -> *) i j a. IxStateT m i j a -> i -> m (a, j)
runIxStateT forall a b. (a -> b) -> a -> b
$
      forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn Tensor gradient layout device dataType shape
input
        forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward dense
poolerDense
        forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
>>>= forall (m :: * -> *) i j a. (i -> m (a, j)) -> IxStateT m i j a
IxStateT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall model input (generatorDevice :: Device (DeviceType Nat))
       output (generatorOutputDevice :: Device (DeviceType Nat))
       (m :: * -> *).
(HasForward
   model input generatorDevice output generatorOutputDevice,
 MonadThrow m) =>
model
-> input
-> Generator generatorDevice
-> m (output, Generator generatorOutputDevice)
forward activation
poolerActivation