{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# OPTIONS_GHC -Wno-partial-type-signatures #-}

module Torch.Typed.NN.Linear where

import GHC.Generics
import GHC.TypeLits
import qualified Torch.DType as D
import qualified Torch.Device as D
import Torch.NN (HasForward (..), Randomizable (..))
import Torch.Typed.Factories
import Torch.Typed.Functional
import Torch.Typed.Parameter
import Torch.Typed.Tensor

data
  LinearSpec
    (inputFeatures :: Nat)
    (outputFeatures :: Nat)
    (dtype :: D.DType)
    (device :: (D.DeviceType, Nat))
  = LinearSpec
  deriving (Int
-> LinearSpec inputFeatures outputFeatures dtype device -> ShowS
forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Int
-> LinearSpec inputFeatures outputFeatures dtype device -> ShowS
forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
[LinearSpec inputFeatures outputFeatures dtype device] -> ShowS
forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
LinearSpec inputFeatures outputFeatures dtype device -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [LinearSpec inputFeatures outputFeatures dtype device] -> ShowS
$cshowList :: forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
[LinearSpec inputFeatures outputFeatures dtype device] -> ShowS
show :: LinearSpec inputFeatures outputFeatures dtype device -> String
$cshow :: forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
LinearSpec inputFeatures outputFeatures dtype device -> String
showsPrec :: Int
-> LinearSpec inputFeatures outputFeatures dtype device -> ShowS
$cshowsPrec :: forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Int
-> LinearSpec inputFeatures outputFeatures dtype device -> ShowS
Show, LinearSpec inputFeatures outputFeatures dtype device
-> LinearSpec inputFeatures outputFeatures dtype device -> Bool
forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
LinearSpec inputFeatures outputFeatures dtype device
-> LinearSpec inputFeatures outputFeatures dtype device -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: LinearSpec inputFeatures outputFeatures dtype device
-> LinearSpec inputFeatures outputFeatures dtype device -> Bool
$c/= :: forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
LinearSpec inputFeatures outputFeatures dtype device
-> LinearSpec inputFeatures outputFeatures dtype device -> Bool
== :: LinearSpec inputFeatures outputFeatures dtype device
-> LinearSpec inputFeatures outputFeatures dtype device -> Bool
$c== :: forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
LinearSpec inputFeatures outputFeatures dtype device
-> LinearSpec inputFeatures outputFeatures dtype device -> Bool
Eq)

data
  Linear
    (inputFeatures :: Nat)
    (outputFeatures :: Nat)
    (dtype :: D.DType)
    (device :: (D.DeviceType, Nat))
  where
  Linear ::
    forall inputFeatures outputFeatures dtype device.
    { forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Linear inputFeatures outputFeatures dtype device
-> Parameter device dtype '[outputFeatures, inputFeatures]
weight :: Parameter device dtype '[outputFeatures, inputFeatures],
      forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Linear inputFeatures outputFeatures dtype device
-> Parameter device dtype '[outputFeatures]
bias :: Parameter device dtype '[outputFeatures]
    } ->
    Linear inputFeatures outputFeatures dtype device
  deriving (Int -> Linear inputFeatures outputFeatures dtype device -> ShowS
forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Int -> Linear inputFeatures outputFeatures dtype device -> ShowS
forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
[Linear inputFeatures outputFeatures dtype device] -> ShowS
forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Linear inputFeatures outputFeatures dtype device -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Linear inputFeatures outputFeatures dtype device] -> ShowS
$cshowList :: forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
[Linear inputFeatures outputFeatures dtype device] -> ShowS
show :: Linear inputFeatures outputFeatures dtype device -> String
$cshow :: forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Linear inputFeatures outputFeatures dtype device -> String
showsPrec :: Int -> Linear inputFeatures outputFeatures dtype device -> ShowS
$cshowsPrec :: forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Int -> Linear inputFeatures outputFeatures dtype device -> ShowS
Show, forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)) x.
Rep (Linear inputFeatures outputFeatures dtype device) x
-> Linear inputFeatures outputFeatures dtype device
forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)) x.
Linear inputFeatures outputFeatures dtype device
-> Rep (Linear inputFeatures outputFeatures dtype device) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)) x.
Rep (Linear inputFeatures outputFeatures dtype device) x
-> Linear inputFeatures outputFeatures dtype device
$cfrom :: forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)) x.
Linear inputFeatures outputFeatures dtype device
-> Rep (Linear inputFeatures outputFeatures dtype device) x
Generic, forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Linear inputFeatures outputFeatures dtype device
-> HList
     (Parameters (Linear inputFeatures outputFeatures dtype device))
forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Linear inputFeatures outputFeatures dtype device
-> HList
     (Parameters (Linear inputFeatures outputFeatures dtype device))
-> Linear inputFeatures outputFeatures dtype device
forall f.
(f -> HList (Parameters f))
-> (f -> HList (Parameters f) -> f) -> Parameterized f
replaceParameters :: Linear inputFeatures outputFeatures dtype device
-> HList
     (Parameters (Linear inputFeatures outputFeatures dtype device))
-> Linear inputFeatures outputFeatures dtype device
$creplaceParameters :: forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Linear inputFeatures outputFeatures dtype device
-> HList
     (Parameters (Linear inputFeatures outputFeatures dtype device))
-> Linear inputFeatures outputFeatures dtype device
flattenParameters :: Linear inputFeatures outputFeatures dtype device
-> HList
     (Parameters (Linear inputFeatures outputFeatures dtype device))
$cflattenParameters :: forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Linear inputFeatures outputFeatures dtype device
-> HList
     (Parameters (Linear inputFeatures outputFeatures dtype device))
Parameterized)

-- | linear
-- The constraints on this one are _very_ involved, so the partial signatures
-- make the code significantly cleaner.
linearForward ::
  _ =>
  Linear _ _ _ _ ->
  Tensor _ _ _ ->
  Tensor _ _ _
linearForward :: Linear inputFeatures outputFeatures w w
-> Tensor w w shape
-> Tensor
     w
     w
     (CheckBroadcast
        (CheckMatMul
           shape
           '[inputFeatures, outputFeatures]
           (ComputeMatMul
              (ReverseImpl shape '[]) '[outputFeatures, inputFeatures]))
        (CheckMatMul
           shape
           '[inputFeatures, outputFeatures]
           (ComputeMatMul
              (ReverseImpl shape '[]) '[outputFeatures, inputFeatures]))
        (ComputeBroadcast
           (ReverseImpl
              (CheckMatMul
                 shape
                 '[inputFeatures, outputFeatures]
                 (ComputeMatMul
                    (ReverseImpl shape '[]) '[outputFeatures, inputFeatures]))
              '[])
           (ReverseImpl
              (CheckMatMul
                 shape
                 '[inputFeatures, outputFeatures]
                 (ComputeMatMul
                    (ReverseImpl shape '[]) '[outputFeatures, inputFeatures]))
              '[])))
linearForward Linear {Parameter w w '[outputFeatures, inputFeatures]
Parameter w w '[outputFeatures]
bias :: Parameter w w '[outputFeatures]
weight :: Parameter w w '[outputFeatures, inputFeatures]
$sel:bias:Linear :: forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Linear inputFeatures outputFeatures dtype device
-> Parameter device dtype '[outputFeatures]
$sel:weight:Linear :: forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Linear inputFeatures outputFeatures dtype device
-> Parameter device dtype '[outputFeatures, inputFeatures]
..} Tensor w w shape
input = forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (shape :: [Nat]) (shape' :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)) (shape'' :: [Nat]).
(shape'' ~ MatMul shape '[inputFeatures, outputFeatures],
 shape' ~ Broadcast shape'' shape'') =>
Tensor device dtype '[outputFeatures, inputFeatures]
-> Tensor device dtype '[outputFeatures]
-> Tensor device dtype shape
-> Tensor device dtype shape'
linear' (forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Parameter device dtype shape -> Tensor device dtype shape
toDependent Parameter w w '[outputFeatures, inputFeatures]
weight) (forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Parameter device dtype shape -> Tensor device dtype shape
toDependent Parameter w w '[outputFeatures]
bias) Tensor w w shape
input

instance
  ( shape'' ~ MatMul shape '[inputFeatures, outputFeatures],
    shape' ~ Broadcast shape'' shape''
  ) =>
  HasForward (Linear inputFeatures outputFeatures dtype device) (Tensor device dtype shape) (Tensor device dtype shape')
  where
  forward :: Linear inputFeatures outputFeatures dtype device
-> Tensor device dtype shape -> Tensor device dtype shape'
forward = forall {inputFeatures :: Nat} {outputFeatures :: Nat} {w :: DType}
       {w :: (DeviceType, Nat)} {shape :: [Nat]}.
Linear inputFeatures outputFeatures w w
-> Tensor w w shape
-> Tensor
     w
     w
     (CheckBroadcast
        (CheckMatMul
           shape
           '[inputFeatures, outputFeatures]
           (ComputeMatMul
              (ReverseImpl shape '[]) '[outputFeatures, inputFeatures]))
        (CheckMatMul
           shape
           '[inputFeatures, outputFeatures]
           (ComputeMatMul
              (ReverseImpl shape '[]) '[outputFeatures, inputFeatures]))
        (ComputeBroadcast
           (ReverseImpl
              (CheckMatMul
                 shape
                 '[inputFeatures, outputFeatures]
                 (ComputeMatMul
                    (ReverseImpl shape '[]) '[outputFeatures, inputFeatures]))
              '[])
           (ReverseImpl
              (CheckMatMul
                 shape
                 '[inputFeatures, outputFeatures]
                 (ComputeMatMul
                    (ReverseImpl shape '[]) '[outputFeatures, inputFeatures]))
              '[])))
linearForward
  forwardStoch :: Linear inputFeatures outputFeatures dtype device
-> Tensor device dtype shape -> IO (Tensor device dtype shape')
forwardStoch = (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
.) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall f a b. HasForward f a b => f -> a -> b
forward

instance
  ( KnownNat inputFeatures,
    KnownNat outputFeatures,
    KnownDType dtype,
    KnownDevice device,
    RandDTypeIsValid device dtype
  ) =>
  Randomizable
    (LinearSpec inputFeatures outputFeatures dtype device)
    (Linear inputFeatures outputFeatures dtype device)
  where
  sample :: LinearSpec inputFeatures outputFeatures dtype device
-> IO (Linear inputFeatures outputFeatures dtype device)
sample LinearSpec inputFeatures outputFeatures dtype device
LinearSpec =
    forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Parameter device dtype '[outputFeatures, inputFeatures]
-> Parameter device dtype '[outputFeatures]
-> Linear inputFeatures outputFeatures dtype device
Linear forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(TensorOptions shape dtype device,
 RandDTypeIsValid device dtype) =>
IO (Tensor device dtype shape)
randn) forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(TensorOptions shape dtype device,
 RandDTypeIsValid device dtype) =>
IO (Tensor device dtype shape)
randn)