{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# OPTIONS_GHC -Wno-partial-type-signatures #-}
module Torch.Typed.NN.Linear where
import GHC.Generics
import GHC.TypeLits
import qualified Torch.DType as D
import qualified Torch.Device as D
import Torch.NN (HasForward (..), Randomizable (..))
import Torch.Typed.Factories
import Torch.Typed.Functional
import Torch.Typed.Parameter
import Torch.Typed.Tensor
data
LinearSpec
(inputFeatures :: Nat)
(outputFeatures :: Nat)
(dtype :: D.DType)
(device :: (D.DeviceType, Nat))
= LinearSpec
deriving (Int
-> LinearSpec inputFeatures outputFeatures dtype device -> ShowS
forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
Int
-> LinearSpec inputFeatures outputFeatures dtype device -> ShowS
forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
[LinearSpec inputFeatures outputFeatures dtype device] -> ShowS
forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
LinearSpec inputFeatures outputFeatures dtype device -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [LinearSpec inputFeatures outputFeatures dtype device] -> ShowS
$cshowList :: forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
[LinearSpec inputFeatures outputFeatures dtype device] -> ShowS
show :: LinearSpec inputFeatures outputFeatures dtype device -> String
$cshow :: forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
LinearSpec inputFeatures outputFeatures dtype device -> String
showsPrec :: Int
-> LinearSpec inputFeatures outputFeatures dtype device -> ShowS
$cshowsPrec :: forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
Int
-> LinearSpec inputFeatures outputFeatures dtype device -> ShowS
Show, LinearSpec inputFeatures outputFeatures dtype device
-> LinearSpec inputFeatures outputFeatures dtype device -> Bool
forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
LinearSpec inputFeatures outputFeatures dtype device
-> LinearSpec inputFeatures outputFeatures dtype device -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: LinearSpec inputFeatures outputFeatures dtype device
-> LinearSpec inputFeatures outputFeatures dtype device -> Bool
$c/= :: forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
LinearSpec inputFeatures outputFeatures dtype device
-> LinearSpec inputFeatures outputFeatures dtype device -> Bool
== :: LinearSpec inputFeatures outputFeatures dtype device
-> LinearSpec inputFeatures outputFeatures dtype device -> Bool
$c== :: forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
LinearSpec inputFeatures outputFeatures dtype device
-> LinearSpec inputFeatures outputFeatures dtype device -> Bool
Eq)
data
Linear
(inputFeatures :: Nat)
(outputFeatures :: Nat)
(dtype :: D.DType)
(device :: (D.DeviceType, Nat))
where
Linear ::
forall inputFeatures outputFeatures dtype device.
{ forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
Linear inputFeatures outputFeatures dtype device
-> Parameter device dtype '[outputFeatures, inputFeatures]
weight :: Parameter device dtype '[outputFeatures, inputFeatures],
forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
Linear inputFeatures outputFeatures dtype device
-> Parameter device dtype '[outputFeatures]
bias :: Parameter device dtype '[outputFeatures]
} ->
Linear inputFeatures outputFeatures dtype device
deriving (Int -> Linear inputFeatures outputFeatures dtype device -> ShowS
forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
Int -> Linear inputFeatures outputFeatures dtype device -> ShowS
forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
[Linear inputFeatures outputFeatures dtype device] -> ShowS
forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
Linear inputFeatures outputFeatures dtype device -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Linear inputFeatures outputFeatures dtype device] -> ShowS
$cshowList :: forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
[Linear inputFeatures outputFeatures dtype device] -> ShowS
show :: Linear inputFeatures outputFeatures dtype device -> String
$cshow :: forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
Linear inputFeatures outputFeatures dtype device -> String
showsPrec :: Int -> Linear inputFeatures outputFeatures dtype device -> ShowS
$cshowsPrec :: forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
Int -> Linear inputFeatures outputFeatures dtype device -> ShowS
Show, forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)) x.
Rep (Linear inputFeatures outputFeatures dtype device) x
-> Linear inputFeatures outputFeatures dtype device
forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)) x.
Linear inputFeatures outputFeatures dtype device
-> Rep (Linear inputFeatures outputFeatures dtype device) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)) x.
Rep (Linear inputFeatures outputFeatures dtype device) x
-> Linear inputFeatures outputFeatures dtype device
$cfrom :: forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)) x.
Linear inputFeatures outputFeatures dtype device
-> Rep (Linear inputFeatures outputFeatures dtype device) x
Generic, forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
Linear inputFeatures outputFeatures dtype device
-> HList
(Parameters (Linear inputFeatures outputFeatures dtype device))
forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
Linear inputFeatures outputFeatures dtype device
-> HList
(Parameters (Linear inputFeatures outputFeatures dtype device))
-> Linear inputFeatures outputFeatures dtype device
forall f.
(f -> HList (Parameters f))
-> (f -> HList (Parameters f) -> f) -> Parameterized f
replaceParameters :: Linear inputFeatures outputFeatures dtype device
-> HList
(Parameters (Linear inputFeatures outputFeatures dtype device))
-> Linear inputFeatures outputFeatures dtype device
$creplaceParameters :: forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
Linear inputFeatures outputFeatures dtype device
-> HList
(Parameters (Linear inputFeatures outputFeatures dtype device))
-> Linear inputFeatures outputFeatures dtype device
flattenParameters :: Linear inputFeatures outputFeatures dtype device
-> HList
(Parameters (Linear inputFeatures outputFeatures dtype device))
$cflattenParameters :: forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
Linear inputFeatures outputFeatures dtype device
-> HList
(Parameters (Linear inputFeatures outputFeatures dtype device))
Parameterized)
linearForward ::
_ =>
Linear _ _ _ _ ->
Tensor _ _ _ ->
Tensor _ _ _
linearForward :: Linear inputFeatures outputFeatures w w
-> Tensor w w shape
-> Tensor
w
w
(CheckBroadcast
(CheckMatMul
shape
'[inputFeatures, outputFeatures]
(ComputeMatMul
(ReverseImpl shape '[]) '[outputFeatures, inputFeatures]))
(CheckMatMul
shape
'[inputFeatures, outputFeatures]
(ComputeMatMul
(ReverseImpl shape '[]) '[outputFeatures, inputFeatures]))
(ComputeBroadcast
(ReverseImpl
(CheckMatMul
shape
'[inputFeatures, outputFeatures]
(ComputeMatMul
(ReverseImpl shape '[]) '[outputFeatures, inputFeatures]))
'[])
(ReverseImpl
(CheckMatMul
shape
'[inputFeatures, outputFeatures]
(ComputeMatMul
(ReverseImpl shape '[]) '[outputFeatures, inputFeatures]))
'[])))
linearForward Linear {Parameter w w '[outputFeatures, inputFeatures]
Parameter w w '[outputFeatures]
bias :: Parameter w w '[outputFeatures]
weight :: Parameter w w '[outputFeatures, inputFeatures]
$sel:bias:Linear :: forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
Linear inputFeatures outputFeatures dtype device
-> Parameter device dtype '[outputFeatures]
$sel:weight:Linear :: forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
Linear inputFeatures outputFeatures dtype device
-> Parameter device dtype '[outputFeatures, inputFeatures]
..} Tensor w w shape
input = forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(shape :: [Nat]) (shape' :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)) (shape'' :: [Nat]).
(shape'' ~ MatMul shape '[inputFeatures, outputFeatures],
shape' ~ Broadcast shape'' shape'') =>
Tensor device dtype '[outputFeatures, inputFeatures]
-> Tensor device dtype '[outputFeatures]
-> Tensor device dtype shape
-> Tensor device dtype shape'
linear' (forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Parameter device dtype shape -> Tensor device dtype shape
toDependent Parameter w w '[outputFeatures, inputFeatures]
weight) (forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Parameter device dtype shape -> Tensor device dtype shape
toDependent Parameter w w '[outputFeatures]
bias) Tensor w w shape
input
instance
( shape'' ~ MatMul shape '[inputFeatures, outputFeatures],
shape' ~ Broadcast shape'' shape''
) =>
HasForward (Linear inputFeatures outputFeatures dtype device) (Tensor device dtype shape) (Tensor device dtype shape')
where
forward :: Linear inputFeatures outputFeatures dtype device
-> Tensor device dtype shape -> Tensor device dtype shape'
forward = forall {inputFeatures :: Nat} {outputFeatures :: Nat} {w :: DType}
{w :: (DeviceType, Nat)} {shape :: [Nat]}.
Linear inputFeatures outputFeatures w w
-> Tensor w w shape
-> Tensor
w
w
(CheckBroadcast
(CheckMatMul
shape
'[inputFeatures, outputFeatures]
(ComputeMatMul
(ReverseImpl shape '[]) '[outputFeatures, inputFeatures]))
(CheckMatMul
shape
'[inputFeatures, outputFeatures]
(ComputeMatMul
(ReverseImpl shape '[]) '[outputFeatures, inputFeatures]))
(ComputeBroadcast
(ReverseImpl
(CheckMatMul
shape
'[inputFeatures, outputFeatures]
(ComputeMatMul
(ReverseImpl shape '[]) '[outputFeatures, inputFeatures]))
'[])
(ReverseImpl
(CheckMatMul
shape
'[inputFeatures, outputFeatures]
(ComputeMatMul
(ReverseImpl shape '[]) '[outputFeatures, inputFeatures]))
'[])))
linearForward
forwardStoch :: Linear inputFeatures outputFeatures dtype device
-> Tensor device dtype shape -> IO (Tensor device dtype shape')
forwardStoch = (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
.) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall f a b. HasForward f a b => f -> a -> b
forward
instance
( KnownNat inputFeatures,
KnownNat outputFeatures,
KnownDType dtype,
KnownDevice device,
RandDTypeIsValid device dtype
) =>
Randomizable
(LinearSpec inputFeatures outputFeatures dtype device)
(Linear inputFeatures outputFeatures dtype device)
where
sample :: LinearSpec inputFeatures outputFeatures dtype device
-> IO (Linear inputFeatures outputFeatures dtype device)
sample LinearSpec inputFeatures outputFeatures dtype device
LinearSpec =
forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
Parameter device dtype '[outputFeatures, inputFeatures]
-> Parameter device dtype '[outputFeatures]
-> Linear inputFeatures outputFeatures dtype device
Linear forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(TensorOptions shape dtype device,
RandDTypeIsValid device dtype) =>
IO (Tensor device dtype shape)
randn) forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(TensorOptions shape dtype device,
RandDTypeIsValid device dtype) =>
IO (Tensor device dtype shape)
randn)