{-# LANGUAGE DataKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module Torch.GraduallyTyped.NN.Functional.Linear where
import GHC.TypeLits (Nat, Symbol, TypeError)
import System.IO.Unsafe (unsafePerformIO)
import Torch.GraduallyTyped.DType (DType (..), DataType (..))
import Torch.GraduallyTyped.Device (Device (..), DeviceType (..))
import Torch.GraduallyTyped.Layout (Layout (..), LayoutType (..))
import Torch.GraduallyTyped.Prelude (Reverse, Seq)
import Torch.GraduallyTyped.RequiresGradient (Gradient (..), RequiresGradient (..))
import Torch.GraduallyTyped.Shape.Type (Dim (..), Name (..), Shape (..), Size (..))
import Torch.GraduallyTyped.Tensor.Type (Tensor)
import Torch.GraduallyTyped.Unify (type (<+>), type (<|>))
import Torch.Internal.Cast (cast2, cast3)
import qualified Torch.Internal.Managed.Native as ATen
import Type.Errors.Pretty (type (%), type (<>))
type family LinearWithBiasF (weightShape :: Shape [Dim (Name Symbol) (Size Nat)]) (biasShape :: Shape [Dim (Name Symbol) (Size Nat)]) (inputShape :: Shape [Dim (Name Symbol) (Size Nat)]) :: Shape [Dim (Name Symbol) (Size Nat)] where
LinearWithBiasF ('Shape '[]) _ _ = TypeError (LinearWeightDimsErrorMessage '[])
LinearWithBiasF ('Shape '[weightDim]) _ _ = TypeError (LinearWeightDimsErrorMessage '[weightDim])
LinearWithBiasF ('Shape (weightDim ': weightDim' ': weightDim'' ': weightDims)) _ _ = TypeError (LinearWeightDimsErrorMessage (weightDim ': weightDim' ': weightDim'' ': weightDims))
LinearWithBiasF _ ('Shape '[]) _ = TypeError (LinearBiasDimsErrorMessage '[])
LinearWithBiasF _ ('Shape (biasDim ': biasDim' ': biasDims)) _ = TypeError (LinearBiasDimsErrorMessage (biasDim ': biasDim' ': biasDims))
LinearWithBiasF _ _ ('Shape '[]) = TypeError LinearInputDimsErrorMessage
LinearWithBiasF ('Shape weightDims) ('Shape biasDims) ('Shape inputDims) = 'Shape (Reverse (LinearWithBiasDimsF weightDims biasDims (Reverse inputDims)))
LinearWithBiasF 'UncheckedShape _ _ = 'UncheckedShape
LinearWithBiasF _ 'UncheckedShape _ = 'UncheckedShape
LinearWithBiasF _ _ 'UncheckedShape = 'UncheckedShape
type family LinearWithBiasDimsF (weightDims :: [Dim (Name Symbol) (Size Nat)]) (biasDims :: [Dim (Name Symbol) (Size Nat)]) (reversedInputDims :: [Dim (Name Symbol) (Size Nat)]) :: [Dim (Name Symbol) (Size Nat)] where
LinearWithBiasDimsF '[outputDim, inputDim] '[outputDim'] (inputDim' ': reversedInputDims) = Seq (inputDim <+> inputDim') (outputDim <+> outputDim' ': reversedInputDims)
type LinearInputDimsErrorMessage =
"Cannot apply the linear transformation."
% "The input tensor does not have the minimum required number of dimensions."
% "At least one dimension is needed, but none were found."
type LinearBiasDimsErrorMessage (biasDims :: [Dim (Name Symbol) (Size Nat)]) =
"Cannot apply the linear transformation."
% "The bias tensor must have exactly one dimension,"
% "but the following dimensions were found:"
% ""
% " " <> biasDims <> "."
% ""
type LinearWeightDimsErrorMessage (weightDims :: [Dim (Name Symbol) (Size Nat)]) =
"Cannot apply the linear transformation."
% "The weight tensor must have exactly two dimensions,"
% "but the following dimensions were found:"
% ""
% " " <> weightDims <> "."
% ""
linearWithBias ::
forall gradient layout device dataType shape gradient' layout' device' dataType' shape' gradient'' layout'' device'' dataType'' shape''.
Tensor gradient layout device dataType shape ->
Tensor gradient' layout' device' dataType' shape' ->
Tensor gradient'' layout'' device'' dataType'' shape'' ->
Tensor
(gradient' <|> gradient'' <|> gradient'')
(layout <+> (layout' <+> layout''))
(device <+> (device' <+> device''))
(dataType <+> (dataType' <+> dataType''))
(LinearWithBiasF shape shape' shape'')
linearWithBias :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(gradient' :: Gradient RequiresGradient)
(layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
(dataType' :: DataType DType)
(shape' :: Shape [Dim (Name Symbol) (Size Nat)])
(gradient'' :: Gradient RequiresGradient)
(layout'' :: Layout LayoutType)
(device'' :: Device (DeviceType Nat))
(dataType'' :: DataType DType)
(shape'' :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> Tensor gradient'' layout'' device'' dataType'' shape''
-> Tensor
(gradient' <|> (gradient'' <|> gradient''))
(layout <+> (layout' <+> layout''))
(device <+> (device' <+> device''))
(dataType <+> (dataType' <+> dataType''))
(LinearWithBiasF shape shape' shape'')
linearWithBias Tensor gradient layout device dataType shape
weight Tensor gradient' layout' device' dataType' shape'
bias Tensor gradient'' layout'' device'' dataType'' shape''
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Tensor
-> ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.linear_ttt Tensor gradient'' layout'' device'' dataType'' shape''
input Tensor gradient layout device dataType shape
weight Tensor gradient' layout' device' dataType' shape'
bias
type family LinearWithoutBiasF (weightShape :: Shape [Dim (Name Symbol) (Size Nat)]) (inputShape :: Shape [Dim (Name Symbol) (Size Nat)]) :: Shape [Dim (Name Symbol) (Size Nat)] where
LinearWithoutBiasF ('Shape '[]) _ = TypeError (LinearWeightDimsErrorMessage '[])
LinearWithoutBiasF ('Shape '[weightDim]) _ = TypeError (LinearWeightDimsErrorMessage '[weightDim])
LinearWithoutBiasF ('Shape (weightDim ': weightDim' ': weightDim'' ': weightDims)) _ = TypeError (LinearWeightDimsErrorMessage (weightDim ': weightDim' ': weightDim'' ': weightDims))
LinearWithoutBiasF _ ('Shape '[]) = TypeError LinearInputDimsErrorMessage
LinearWithoutBiasF ('Shape weightDims) ('Shape inputDims) = 'Shape (Reverse (LinearWithoutBiasDimsF weightDims (Reverse inputDims)))
LinearWithoutBiasF 'UncheckedShape _ = 'UncheckedShape
LinearWithoutBiasF _ 'UncheckedShape = 'UncheckedShape
type family LinearWithoutBiasDimsF (weightDims :: [Dim (Name Symbol) (Size Nat)]) (reversedInputDims :: [Dim (Name Symbol) (Size Nat)]) :: [Dim (Name Symbol) (Size Nat)] where
LinearWithoutBiasDimsF '[outputDim, inputDim] (inputDim' ': reversedInputDims) = Seq (inputDim <+> inputDim') (outputDim ': reversedInputDims)
linearWithoutBias ::
forall gradient layout device dataType shape gradient' layout' device' dataType' shape'.
Tensor gradient layout device dataType shape ->
Tensor gradient' layout' device' dataType' shape' ->
Tensor
(gradient <|> gradient')
(layout <+> layout')
(device <+> device')
(dataType <+> dataType')
(LinearWithoutBiasF shape shape')
linearWithoutBias :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(gradient' :: Gradient RequiresGradient)
(layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
(dataType' :: DataType DType)
(shape' :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> Tensor
(gradient <|> gradient')
(layout <+> layout')
(device <+> device')
(dataType <+> dataType')
(LinearWithoutBiasF shape shape')
linearWithoutBias Tensor gradient layout device dataType shape
weight Tensor gradient' layout' device' dataType' shape'
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.linear_tt Tensor gradient' layout' device' dataType' shape'
input Tensor gradient layout device dataType shape
weight
testLinearWithoutBias ::
Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
'UncheckedDevice
('DataType 'Float)
('Shape '[ 'Dim ('Name "output") ('Size 2)])
testLinearWithoutBias :: Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
'UncheckedDevice
('DataType 'Float)
('Shape '[ 'Dim ('Name "output") ('Size 2)])
testLinearWithoutBias =
let weight :: Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
('DataType 'Float)
('Shape
'[ 'Dim ('Name "output") ('Size 2),
'Dim ('Name "input") ('Size 1)])
weight = forall a. HasCallStack => a
undefined :: Tensor ('Gradient 'WithGradient) ('Layout 'Dense) ('Device 'CPU) ('DataType 'Float) ('Shape '[ 'Dim ('Name "output") ('Size 2), 'Dim ('Name "input") ('Size 1)])
input :: Tensor
('Gradient 'WithoutGradient)
('Layout 'Dense)
'UncheckedDevice
('DataType 'Float)
('Shape '[ 'Dim ('Name "input") ('Size 1)])
input = forall a. HasCallStack => a
undefined :: Tensor ('Gradient 'WithoutGradient) ('Layout 'Dense) 'UncheckedDevice ('DataType 'Float) ('Shape '[ 'Dim ('Name "input") ('Size 1)])
in forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(gradient' :: Gradient RequiresGradient)
(layout' :: Layout LayoutType) (device' :: Device (DeviceType Nat))
(dataType' :: DataType DType)
(shape' :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient layout device dataType shape
-> Tensor gradient' layout' device' dataType' shape'
-> Tensor
(gradient <|> gradient')
(layout <+> layout')
(device <+> device')
(dataType <+> dataType')
(LinearWithoutBiasF shape shape')
linearWithoutBias Tensor
('Gradient 'WithGradient)
('Layout 'Dense)
('Device 'CPU)
('DataType 'Float)
('Shape
'[ 'Dim ('Name "output") ('Size 2),
'Dim ('Name "input") ('Size 1)])
weight Tensor
('Gradient 'WithoutGradient)
('Layout 'Dense)
'UncheckedDevice
('DataType 'Float)
('Shape '[ 'Dim ('Name "input") ('Size 1)])
input