{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module Torch.GraduallyTyped.Tensor.IndexingSlicingJoining where
import Control.Exception (Exception (..))
import Control.Monad.Catch (MonadThrow (throwM))
import Data.Bifunctor (bimap)
import Data.Coerce (coerce)
import Data.Kind (Type)
import Data.Singletons (SingI (..), SingKind (..), fromSing)
import Data.Type.Bool (If, type (&&))
import Data.Typeable (Typeable)
import Foreign.ForeignPtr (ForeignPtr)
import GHC.TypeLits (ErrorMessage, Nat, Symbol, TypeError, type (-))
import System.IO.Unsafe (unsafePerformIO)
import Torch.GraduallyTyped.DType (DType (..), DataType (..))
import Torch.GraduallyTyped.Device (Device (..), DeviceType (..))
import Torch.GraduallyTyped.Index.Class (InRangeF)
import Torch.GraduallyTyped.Index.Type (DemotedIndex (..), SIndex)
import Torch.GraduallyTyped.Layout (Layout (..), LayoutType (..))
import Torch.GraduallyTyped.Prelude (Catch, FromMaybe, MapMaybe, MaybeF, PrependMaybe, When, forgetIsChecked)
import Torch.GraduallyTyped.RequiresGradient (Gradient, RequiresGradient (..))
import Torch.GraduallyTyped.Shape.Class (AddDimF, BroadcastShapesF, GetDimF, GetDimImplF, GetIndexByNameF, InsertDimImplF, NumelF, RemoveDimF, ReplaceDimF, ReplaceDimImplF, sGetDimFromShape)
import Torch.GraduallyTyped.Shape.Type (By (..), Dim (..), Name (..), SSelectDim, SShape, SelectDim (..), Shape (..), Size (..), dimSize)
import Torch.GraduallyTyped.Tensor.Type (SGetShape (sGetShape), Tensor)
import Torch.GraduallyTyped.Unify (UnifyCheck, type (<+>), type (<|>))
import Torch.HList (HList)
import qualified Torch.Internal.Cast as ATen (cast1, cast2, cast3, cast4)
import qualified Torch.Internal.Class as ATen (Castable)
import Torch.Internal.GC (unsafeThrowableIO)
import qualified Torch.Internal.Managed.Native as ATen
import qualified Torch.Internal.Managed.Type.Tensor as ATen
import qualified Torch.Internal.Type as ATen
import Type.Errors.Pretty (ToErrorMessage, type (%), type (<>))
class HasCat (selectDim :: SelectDim (By Symbol Nat)) k (c :: k -> Type) (a :: k) where
type CatF selectDim a c :: Type
sCat :: forall m. MonadThrow m => SSelectDim selectDim -> c a -> m (CatF selectDim a c)
cat :: forall m. (SingI selectDim, MonadThrow m) => c a -> m (CatF selectDim a c)
cat = forall (selectDim :: SelectDim (By Symbol Nat)) k (c :: k -> *)
(a :: k) (m :: * -> *).
(HasCat selectDim k c a, MonadThrow m) =>
SSelectDim selectDim -> c a -> m (CatF selectDim a c)
sCat (forall {k} (a :: k). SingI a => Sing a
sing @selectDim)
type family CatListImplF (selectDim :: SelectDim (By Symbol Nat)) (tensor :: Type) :: Maybe Type where
CatListImplF 'UncheckedSelectDim (Tensor gradient layout device dataType _) = 'Just (Tensor gradient layout device dataType 'UncheckedShape)
CatListImplF ('SelectDim _) (Tensor gradient layout device dataType 'UncheckedShape) = 'Just (Tensor gradient layout device dataType 'UncheckedShape)
CatListImplF ('SelectDim by) (Tensor gradient layout device dataType ('Shape dims)) = MapMaybe (Tensor gradient layout device dataType) (MapMaybe 'Shape (ReplaceDimImplF by dims ('Dim 'UncheckedName 'UncheckedSize)))
type CheckSpellingMessage = "Check the spelling of named dimensions, and make sure the number of dimensions is correct."
type family CatListCheckF (selectDim :: SelectDim (By Symbol Nat)) (tensor :: Type) (result :: Maybe Type) :: Type where
CatListCheckF selectDim (Tensor _ _ _ _ shape) 'Nothing =
TypeError
( "Cannot concatenate the dimension"
% ""
% " " <> selectDim
% ""
% "for tensors of shape"
% ""
% " " <> shape <> "."
% ""
% CheckSpellingMessage
)
CatListCheckF _ _ ('Just result) = result
type CatListF selectDim tensor = CatListCheckF selectDim tensor (CatListImplF selectDim tensor)
instance
ATen.Castable (CatListF selectDim (Tensor gradient layout device dataType shape)) (ForeignPtr ATen.Tensor) =>
HasCat selectDim Type [] (Tensor gradient layout device dataType shape)
where
type CatF selectDim (Tensor gradient layout device dataType shape) [] = CatListF selectDim (Tensor gradient layout device dataType shape)
sCat :: forall (m :: * -> *).
MonadThrow m =>
SSelectDim selectDim
-> [Tensor gradient layout device dataType shape]
-> m (CatF
selectDim (Tensor gradient layout device dataType shape) [])
sCat SSelectDim selectDim
selectDim [Tensor gradient layout device dataType shape]
tensors = do
let by :: By String Integer
by = forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SSelectDim selectDim
selectDim
forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ case By String Integer
by of
ByName String
name -> forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr TensorList
-> ForeignPtr Dimname -> IO (ForeignPtr Tensor)
ATen.cat_ln [Tensor gradient layout device dataType shape]
tensors String
name
ByIndex Integer
index -> forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr TensorList -> Int64 -> IO (ForeignPtr Tensor)
ATen.cat_ll [Tensor gradient layout device dataType shape]
tensors (forall a. Num a => Integer -> a
fromInteger Integer
index :: Int)
type CatHListImplF ::
SelectDim (By Symbol Nat) ->
[Type] ->
Maybe (Gradient RequiresGradient, Layout LayoutType, Device (DeviceType Nat), DataType DType, Shape [Dim (Name Symbol) (Size Nat)]) ->
Type
type family CatHListImplF selectDim tensors acc where
CatHListImplF _ '[] 'Nothing = TypeError (ToErrorMessage "Cannot concatenate an empty list of tensors.")
CatHListImplF _ '[] ('Just '(gradient, layout, device, dataType, shape)) = Tensor gradient layout device dataType shape
CatHListImplF selectDim (Tensor gradient layout device dataType shape ': tensors) 'Nothing =
CatHListImplF selectDim tensors ('Just '(gradient, layout, device, dataType, shape))
CatHListImplF selectDim (Tensor gradient layout device dataType shape ': tensors) ('Just '(gradient', layout', device', dataType', shape')) =
CatHListImplF
selectDim
tensors
( 'Just
'( gradient <|> gradient',
layout <+> layout',
device <+> device',
dataType <+> dataType',
ReplaceDimF
selectDim
(shape <+> ReplaceDimF selectDim shape' (GetDimF selectDim shape))
(AddDimF (GetDimF selectDim shape) (GetDimF selectDim shape'))
)
)
CatHListImplF _ (x ': _) _ =
TypeError
( "Cannot concatenate because"
% ""
% " '" <> x <> "'"
% ""
% "is not a tensor type."
)
type CatHListF selectDim tensors = CatHListImplF selectDim tensors 'Nothing
instance
( ATen.Castable (CatHListF selectDim tensors) (ForeignPtr ATen.Tensor),
ATen.Castable (HList tensors) (ForeignPtr ATen.TensorList)
) =>
HasCat selectDim [Type] HList tensors
where
type CatF selectDim tensors HList = CatHListF selectDim tensors
sCat :: forall (m :: * -> *).
MonadThrow m =>
SSelectDim selectDim
-> HList tensors -> m (CatF selectDim tensors HList)
sCat SSelectDim selectDim
selectDim HList tensors
tensors = do
let by :: By String Integer
by = forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SSelectDim selectDim
selectDim
forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ case By String Integer
by of
ByName String
name -> forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr TensorList
-> ForeignPtr Dimname -> IO (ForeignPtr Tensor)
ATen.cat_ln HList tensors
tensors String
name
ByIndex Integer
index -> forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr TensorList -> Int64 -> IO (ForeignPtr Tensor)
ATen.cat_ll HList tensors
tensors (forall a. Num a => Integer -> a
fromInteger Integer
index :: Int)
type ReshapeNumelMismatchMessage ::
Nat ->
Nat ->
Shape [Dim (Name Symbol) (Size Nat)] ->
Shape [Dim (Name Symbol) (Size Nat)] ->
ErrorMessage
type ReshapeNumelMismatchMessage numel numel' shape shape' =
"Cannot reshape the tensor. The original shape,"
% ""
% " '" <> shape <> "',"
% ""
% "and the new shape,"
% ""
% " '" <> shape' <> "',"
% ""
% "have different total numbers of elements,"
% ""
% " '" <> numel <> "' versus '" <> numel' <> "',"
% ""
% "respectively."
type ReshapeImplF ::
Maybe Nat ->
Maybe Nat ->
Shape [Dim (Name Symbol) (Size Nat)] ->
Shape [Dim (Name Symbol) (Size Nat)] ->
Shape [Dim (Name Symbol) (Size Nat)]
type family ReshapeImplF numel numel' shape shape' where
ReshapeImplF ('Just numel) ('Just numel) _ shape' = shape'
ReshapeImplF ('Just numel) ('Just numel') shape shape' = TypeError (ReshapeNumelMismatchMessage numel numel' shape shape')
ReshapeImplF 'Nothing _ _ _ = 'UncheckedShape
ReshapeImplF _ 'Nothing _ _ = 'UncheckedShape
ReshapeImplF _ _ 'UncheckedShape _ = 'UncheckedShape
ReshapeImplF _ _ _ 'UncheckedShape = 'UncheckedShape
type ReshapeF :: Shape [Dim (Name Symbol) (Size Nat)] -> Shape [Dim (Name Symbol) (Size Nat)] -> Shape [Dim (Name Symbol) (Size Nat)]
type family ReshapeF shape shape' where
ReshapeF shape shape' = ReshapeImplF (NumelF shape) (NumelF shape') shape shape'
sReshape,
sSetShape ::
forall m shape' gradient layout device dataType shape shape''.
MonadThrow m =>
(shape'' ~ ReshapeF shape shape', Catch shape'') =>
SShape shape' ->
Tensor gradient layout device dataType shape ->
m (Tensor gradient layout device dataType shape'')
sReshape :: forall (m :: * -> *)
(shape' :: Shape [Dim (Name Symbol) (Size Nat)])
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(shape'' :: Shape [Dim (Name Symbol) (Size Nat)]).
(MonadThrow m, shape'' ~ ReshapeF shape shape', Catch shape'') =>
SShape shape'
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape'')
sReshape SShape shape'
shape' Tensor gradient layout device dataType shape
input = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ do
let dims :: [Dim (IsChecked String) (IsChecked Integer)]
dims = forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SShape shape'
shape'
ForeignPtr Tensor
t :: ForeignPtr ATen.Tensor <- forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr IntArray -> IO (ForeignPtr Tensor)
ATen.reshape_tl Tensor gradient layout device dataType shape
input (forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall name size. Dim name size -> size
dimSize forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Dim (IsChecked String) (IsChecked Integer)]
dims)
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor
-> ForeignPtr DimnameList -> IO (ForeignPtr Tensor)
ATen.tensor_refine_names_N ForeignPtr Tensor
t (forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall name size. Dim name size -> name
dimName forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Dim (IsChecked String) (IsChecked Integer)]
dims)
sSetShape :: forall (m :: * -> *)
(shape' :: Shape [Dim (Name Symbol) (Size Nat)])
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(shape'' :: Shape [Dim (Name Symbol) (Size Nat)]).
(MonadThrow m, shape'' ~ ReshapeF shape shape', Catch shape'') =>
SShape shape'
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape'')
sSetShape = forall (m :: * -> *)
(shape' :: Shape [Dim (Name Symbol) (Size Nat)])
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(shape'' :: Shape [Dim (Name Symbol) (Size Nat)]).
(MonadThrow m, shape'' ~ ReshapeF shape shape', Catch shape'') =>
SShape shape'
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape'')
sReshape
type AllDimSizesChecked :: Shape [Dim (Name Symbol) (Size Nat)] -> Bool
type family AllDimSizesChecked shape where
AllDimSizesChecked 'UncheckedShape = 'False
AllDimSizesChecked ('Shape '[]) = 'True
AllDimSizesChecked ('Shape ('Dim name ('Size size) ': xs)) = AllDimSizesChecked ('Shape xs)
reshape ::
forall m shape' gradient layout device dataType shape shape''.
( shape'' ~ ReshapeF shape shape',
Catch shape'',
When (AllDimSizesChecked shape) (shape' ~ shape''),
SingI shape',
MonadThrow m
) =>
Tensor gradient layout device dataType shape ->
m (Tensor gradient layout device dataType shape'')
reshape :: forall (m :: * -> *)
(shape' :: Shape [Dim (Name Symbol) (Size Nat)])
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(shape'' :: Shape [Dim (Name Symbol) (Size Nat)]).
(shape'' ~ ReshapeF shape shape', Catch shape'',
When (AllDimSizesChecked shape) (shape' ~ shape''), SingI shape',
MonadThrow m) =>
Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape'')
reshape = forall (m :: * -> *)
(shape' :: Shape [Dim (Name Symbol) (Size Nat)])
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(shape'' :: Shape [Dim (Name Symbol) (Size Nat)]).
(MonadThrow m, shape'' ~ ReshapeF shape shape', Catch shape'') =>
SShape shape'
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape'')
sReshape (forall {k} (a :: k). SingI a => Sing a
sing @shape')
type TransposeBy0Message :: By Symbol Nat -> [Dim (Name Symbol) (Size Nat)] -> ErrorMessage
type TransposeBy0Message by0 dims =
"Cannot transpose the tensor with the dimensions"
% ""
% " '" <> dims <> "'"
% ""
% "because the specified source dimension"
% ""
% " '" <> by0 <> "'"
% ""
% "could not be found."
type TransposeBy1Message :: By Symbol Nat -> [Dim (Name Symbol) (Size Nat)] -> ErrorMessage
type TransposeBy1Message by1 dims =
"Cannot transpose the tensor with the dimensions"
% ""
% " '" <> dims <> "'"
% ""
% "because the specified target dimension"
% ""
% " '" <> by1 <> "'"
% ""
% "could not be found."
type TransposeF ::
SelectDim (By Symbol Nat) ->
SelectDim (By Symbol Nat) ->
Shape [Dim (Name Symbol) (Size Nat)] ->
Shape [Dim (Name Symbol) (Size Nat)]
type family TransposeF selectDim0 selectDim1 shape where
TransposeF _ _ 'UncheckedShape = 'UncheckedShape
TransposeF _ 'UncheckedSelectDim _ = 'UncheckedShape
TransposeF 'UncheckedSelectDim _ _ = 'UncheckedShape
TransposeF ('SelectDim ('ByName name0)) ('SelectDim ('ByName name1)) ('Shape dims) =
'Shape
( TransposeIndexIndexDimsF
( FromMaybe
(TypeError (TransposeBy0Message ('ByName name0) dims))
(GetIndexByNameF name0 dims)
)
( FromMaybe
(TypeError (TransposeBy1Message ('ByName name1) dims))
(GetIndexByNameF name1 dims)
)
dims
)
TransposeF ('SelectDim ('ByIndex index0)) ('SelectDim ('ByIndex index1)) ('Shape dims) = 'Shape (TransposeIndexIndexDimsF index0 index1 dims)
TransposeF ('SelectDim by0) ('SelectDim by1) _ =
TypeError
( "Cannot transpose the tensor. "
% ""
% "The source and target dimensions must be selected either both by name or both by index, "
% "but mixed selectors were found: "
% ""
% " '" <> 'SelectDim by0 <> "' and '" <> 'SelectDim by1 <> "'."
% ""
)
type TransposeIndexIndexDimsF :: Nat -> Nat -> [Dim (Name Symbol) (Size Nat)] -> [Dim (Name Symbol) (Size Nat)]
type family TransposeIndexIndexDimsF index0 index1 dims where
TransposeIndexIndexDimsF index0 index1 dims =
FromMaybe
(TypeError (TransposeBy1Message ('ByIndex index1) dims))
( ReplaceDimImplF
('ByIndex index1)
( FromMaybe
(TypeError (TransposeBy0Message ('ByIndex index0) dims))
( ReplaceDimImplF
('ByIndex index0)
dims
( FromMaybe
(TypeError (TransposeBy1Message ('ByIndex index1) dims))
(GetDimImplF ('ByIndex index1) dims)
)
)
)
( FromMaybe
(TypeError (TransposeBy0Message ('ByIndex index0) dims))
(GetDimImplF ('ByIndex index0) dims)
)
)
sTranspose ::
forall selectDim0 selectDim1 gradient layout device dataType shape shape' m.
( shape' ~ TransposeF selectDim0 selectDim1 shape,
Catch shape',
MonadThrow m
) =>
SSelectDim selectDim0 ->
SSelectDim selectDim1 ->
Tensor gradient layout device dataType shape ->
m (Tensor gradient layout device dataType shape')
sTranspose :: forall (selectDim0 :: SelectDim (By Symbol Nat))
(selectDim1 :: SelectDim (By Symbol Nat))
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(shape' ~ TransposeF selectDim0 selectDim1 shape, Catch shape',
MonadThrow m) =>
SSelectDim selectDim0
-> SSelectDim selectDim1
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
sTranspose SSelectDim selectDim0
selectDim0 SSelectDim selectDim1
selectDim1 Tensor gradient layout device dataType shape
input = do
let by0 :: By String Integer
by0 = forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SSelectDim selectDim0
selectDim0
by1 :: By String Integer
by1 = forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SSelectDim selectDim1
selectDim1
forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ case (By String Integer
by0, By String Integer
by1) of
(ByName String
name0, ByName String
name1) -> forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor
-> ForeignPtr Dimname
-> ForeignPtr Dimname
-> IO (ForeignPtr Tensor)
ATen.transpose_tnn Tensor gradient layout device dataType shape
input String
name0 String
name1
(ByIndex Integer
index0, ByIndex Integer
index1) -> forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor -> Int64 -> Int64 -> IO (ForeignPtr Tensor)
ATen.transpose_tll Tensor gradient layout device dataType shape
input (forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
index0 :: Int) (forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
index1 :: Int)
(By String Integer, By String Integer)
_ -> forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM forall a b. (a -> b) -> a -> b
$ By String Integer -> By String Integer -> TransposeError
TransposeMixedSelectorsError By String Integer
by0 By String Integer
by1
data TransposeError = TransposeMixedSelectorsError
{ TransposeError -> By String Integer
teBy0 :: By String Integer,
TransposeError -> By String Integer
teBy1 :: By String Integer
}
deriving stock (Int -> TransposeError -> ShowS
[TransposeError] -> ShowS
TransposeError -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TransposeError] -> ShowS
$cshowList :: [TransposeError] -> ShowS
show :: TransposeError -> String
$cshow :: TransposeError -> String
showsPrec :: Int -> TransposeError -> ShowS
$cshowsPrec :: Int -> TransposeError -> ShowS
Show, Typeable)
instance Exception TransposeError where
displayException :: TransposeError -> String
displayException TransposeMixedSelectorsError {By String Integer
teBy1 :: By String Integer
teBy0 :: By String Integer
teBy1 :: TransposeError -> By String Integer
teBy0 :: TransposeError -> By String Integer
..} =
String
"Cannot transpose the tensor. "
forall a. Semigroup a => a -> a -> a
<> String
"The source and target dimensions must be selected either both by name or both by index, "
forall a. Semigroup a => a -> a -> a
<> String
"but mixed selectors were found: '"
forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show By String Integer
teBy0
forall a. Semigroup a => a -> a -> a
<> String
"' and '"
forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show By String Integer
teBy1
forall a. Semigroup a => a -> a -> a
<> String
"'."
transpose ::
forall selectDim0 selectDim1 gradient layout device dataType shape shape' m.
( shape' ~ TransposeF selectDim0 selectDim1 shape,
Catch shape',
SingI selectDim0,
SingI selectDim1,
MonadThrow m
) =>
Tensor gradient layout device dataType shape ->
m (Tensor gradient layout device dataType shape')
transpose :: forall (selectDim0 :: SelectDim (By Symbol Nat))
(selectDim1 :: SelectDim (By Symbol Nat))
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(shape' ~ TransposeF selectDim0 selectDim1 shape, Catch shape',
SingI selectDim0, SingI selectDim1, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
transpose = forall (selectDim0 :: SelectDim (By Symbol Nat))
(selectDim1 :: SelectDim (By Symbol Nat))
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(shape' ~ TransposeF selectDim0 selectDim1 shape, Catch shape',
MonadThrow m) =>
SSelectDim selectDim0
-> SSelectDim selectDim1
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
sTranspose (forall {k} (a :: k). SingI a => Sing a
sing @selectDim0) (forall {k} (a :: k). SingI a => Sing a
sing @selectDim1)
type UnsqueezeByMessage (by :: By Symbol Nat) (dims :: [Dim (Name Symbol) (Size Nat)]) =
"Cannot unsqueeze the tensor with the dimensions"
% ""
% " '" <> dims <> "'"
% ""
% "because the specified source dimension"
% ""
% " '" <> by <> "'"
% ""
% "could not be found."
type family UnsqueezeF (selectDim :: SelectDim (By Symbol Nat)) (shape :: Shape [Dim (Name Symbol) (Size Nat)]) :: Shape [Dim (Name Symbol) (Size Nat)] where
UnsqueezeF _ 'UncheckedShape = 'UncheckedShape
UnsqueezeF 'UncheckedSelectDim _ = 'UncheckedShape
UnsqueezeF ('SelectDim ('ByName name)) ('Shape dims) =
'Shape
( UnsqueezeIndexDimsF
( FromMaybe
(TypeError (UnsqueezeByMessage ('ByName name) dims))
(GetIndexByNameF name dims)
)
dims
)
UnsqueezeF ('SelectDim ('ByIndex index)) ('Shape dims) = 'Shape (UnsqueezeIndexDimsF index dims)
type family UnsqueezeIndexDimsF (index :: Nat) (dims :: [Dim (Name Symbol) (Size Nat)]) :: [Dim (Name Symbol) (Size Nat)] where
UnsqueezeIndexDimsF index dims =
FromMaybe
(TypeError (UnsqueezeByMessage ('ByIndex index) dims))
( InsertDimImplF
('ByIndex index)
dims
('Dim ('Name "*") ('Size 1))
)
sUnsqueeze ::
forall selectDim gradient layout device dataType shape shape' m.
( shape' ~ UnsqueezeF selectDim shape,
Catch shape',
MonadThrow m
) =>
SSelectDim selectDim ->
Tensor gradient layout device dataType shape ->
m (Tensor gradient layout device dataType shape')
sUnsqueeze :: forall (selectDim :: SelectDim (By Symbol Nat))
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(shape' ~ UnsqueezeF selectDim shape, Catch shape',
MonadThrow m) =>
SSelectDim selectDim
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
sUnsqueeze SSelectDim selectDim
selectDim Tensor gradient layout device dataType shape
input =
let by :: By String Integer
by = forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SSelectDim selectDim
selectDim
in case By String Integer
by of
ByName String
_name -> forall a. HasCallStack => a
undefined
ByIndex Integer
index -> forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> Int64 -> IO (ForeignPtr Tensor)
ATen.unsqueeze_tl Tensor gradient layout device dataType shape
input (forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
index :: Int)
unsqueeze ::
forall selectDim gradient layout device dataType shape shape' m.
( shape' ~ UnsqueezeF selectDim shape,
Catch shape',
SingI selectDim,
MonadThrow m
) =>
Tensor gradient layout device dataType shape ->
m (Tensor gradient layout device dataType shape')
unsqueeze :: forall (selectDim :: SelectDim (By Symbol Nat))
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(shape' ~ UnsqueezeF selectDim shape, Catch shape',
SingI selectDim, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
unsqueeze = forall (selectDim :: SelectDim (By Symbol Nat))
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(shape' ~ UnsqueezeF selectDim shape, Catch shape',
MonadThrow m) =>
SSelectDim selectDim
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
sUnsqueeze (forall {k} (a :: k). SingI a => Sing a
sing @selectDim)
type SqueezeAllShapeF :: Shape [Dim (Name Symbol) (Size Nat)] -> Shape [Dim (Name Symbol) (Size Nat)]
type family SqueezeAllShapeF shape where
SqueezeAllShapeF 'UncheckedShape = 'UncheckedShape
SqueezeAllShapeF ('Shape dims) = MaybeF 'UncheckedShape 'Shape (SqueezeAllDimsF dims)
type SqueezeAllDimsF :: [Dim (Name Symbol) (Size Nat)] -> Maybe [Dim (Name Symbol) (Size Nat)]
type family SqueezeAllDimsF dims where
SqueezeAllDimsF '[] = 'Just '[]
SqueezeAllDimsF ('Dim _ 'UncheckedSize ': dims) = 'Nothing
SqueezeAllDimsF ('Dim _ ('Size 1) ': dims) = SqueezeAllDimsF dims
SqueezeAllDimsF (dim ': dims) = PrependMaybe ('Just dim) (SqueezeAllDimsF dims)
squeezeAll ::
forall gradient layout device dataType shape.
Tensor gradient layout device dataType shape ->
Tensor gradient layout device dataType (SqueezeAllShapeF shape)
squeezeAll :: forall (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient layout device dataType shape
-> Tensor gradient layout device dataType (SqueezeAllShapeF shape)
squeezeAll Tensor gradient layout device dataType shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.squeeze_t Tensor gradient layout device dataType shape
input
type SqueezeDimByIndexF :: Nat -> [Dim (Name Symbol) (Size Nat)] -> Maybe [Dim (Name Symbol) (Size Nat)]
type family SqueezeDimByIndexF dimIndex dims where
SqueezeDimByIndexF 0 (x ': xs) =
If
(UnifyCheck (Dim (Name Symbol) (Size Nat)) x ('Dim ('Name "*") ('Size 1)))
('Just xs)
'Nothing
SqueezeDimByIndexF dimIndex (x ': xs) = PrependMaybe ('Just x) (SqueezeDimByIndexF (dimIndex - 1) xs)
SqueezeDimByIndexF _ _ = 'Nothing
type SqueezeDimImplF :: By Symbol Nat -> [Dim (Name Symbol) (Size Nat)] -> Maybe [Dim (Name Symbol) (Size Nat)]
type family SqueezeDimImplF by dims where
SqueezeDimImplF ('ByName dimName) dims = SqueezeDimByNameF (GetIndexByNameF dimName dims) dims
SqueezeDimImplF ('ByIndex dimIndex) dims = SqueezeDimByIndexF dimIndex dims
type SqueezeDimByNameF :: Maybe Nat -> [Dim (Name Symbol) (Size Nat)] -> Maybe [Dim (Name Symbol) (Size Nat)]
type family SqueezeDimByNameF dimIndex dims where
SqueezeDimByNameF 'Nothing dims = 'Nothing
SqueezeDimByNameF ('Just dimIndex) dims = SqueezeDimByIndexF dimIndex dims
type SqueezeDimMessage :: By Symbol Nat -> [Dim (Name Symbol) (Size Nat)] -> ErrorMessage
type SqueezeDimMessage by dims =
"Cannot squeeze the tensor with the dimensions"
% ""
% " '" <> dims <> "'"
% ""
% "at the dimension"
% ""
% " '" <> by <> "'."
% ""
type SqueezeDimCheckF :: By Symbol Nat -> [Dim (Name Symbol) (Size Nat)] -> Maybe [Dim (Name Symbol) (Size Nat)] -> [Dim (Name Symbol) (Size Nat)]
type family SqueezeDimCheckF by dims result where
SqueezeDimCheckF by dims 'Nothing = TypeError (SqueezeDimMessage by dims)
SqueezeDimCheckF _ _ ('Just dims) = dims
type SqueezeDimF :: SelectDim (By Symbol Nat) -> Shape [Dim (Name Symbol) (Size Nat)] -> Shape [Dim (Name Symbol) (Size Nat)]
type family SqueezeDimF selectDim shape where
SqueezeDimF 'UncheckedSelectDim _ = 'UncheckedShape
SqueezeDimF _ 'UncheckedShape = 'UncheckedShape
SqueezeDimF ('SelectDim by) ('Shape dims) = 'Shape (SqueezeDimCheckF by dims (SqueezeDimImplF by dims))
sSqueezeDim ::
forall selectDim gradient layout device dataType shape shape' m.
(MonadThrow m, shape' ~ SqueezeDimF selectDim shape, Catch shape') =>
SSelectDim selectDim ->
Tensor gradient layout device dataType shape ->
m (Tensor gradient layout device dataType shape')
sSqueezeDim :: forall (selectDim :: SelectDim (By Symbol Nat))
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, shape' ~ SqueezeDimF selectDim shape,
Catch shape') =>
SSelectDim selectDim
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
sSqueezeDim SSelectDim selectDim
selectDim Tensor gradient layout device dataType shape
input =
let by :: By String Integer
by = forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SSelectDim selectDim
selectDim
in case By String Integer
by of
ByName String
dimName -> forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr Dimname -> IO (ForeignPtr Tensor)
ATen.squeeze_tn Tensor gradient layout device dataType shape
input String
dimName
ByIndex Integer
dimIndex -> forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> Int64 -> IO (ForeignPtr Tensor)
ATen.squeeze_tl Tensor gradient layout device dataType shape
input (forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
dimIndex :: Int)
sExpand ::
forall shape' shape'' gradient layout device dataType shape.
(shape'' ~ BroadcastShapesF shape shape', Catch shape'') =>
SShape shape' ->
Tensor gradient layout device dataType shape ->
Tensor gradient layout device dataType shape''
sExpand :: forall (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
(shape'' :: Shape [Dim (Name Symbol) (Size Nat)])
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(shape'' ~ BroadcastShapesF shape shape', Catch shape'') =>
SShape shape'
-> Tensor gradient layout device dataType shape
-> Tensor gradient layout device dataType shape''
sExpand SShape shape'
shape' Tensor gradient layout device dataType shape
input =
let sizes' :: [Integer]
sizes' = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\(Dim IsChecked String
_ IsChecked Integer
size) -> forall a. IsChecked a -> a
forgetIsChecked IsChecked Integer
size) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IsChecked a -> a
forgetIsChecked forall a b. (a -> b) -> a -> b
$ forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SShape shape'
shape'
in forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor
-> ForeignPtr IntArray -> CBool -> IO (ForeignPtr Tensor)
ATen.tensor_expand_lb Tensor gradient layout device dataType shape
input [Integer]
sizes' Bool
True
expand ::
forall shape' shape'' gradient layout device dataType shape.
(SingI shape', shape'' ~ BroadcastShapesF shape shape', Catch shape'') =>
Tensor gradient layout device dataType shape ->
Tensor gradient layout device dataType shape''
expand :: forall (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
(shape'' :: Shape [Dim (Name Symbol) (Size Nat)])
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(SingI shape', shape'' ~ BroadcastShapesF shape shape',
Catch shape'') =>
Tensor gradient layout device dataType shape
-> Tensor gradient layout device dataType shape''
expand = forall (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
(shape'' :: Shape [Dim (Name Symbol) (Size Nat)])
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(shape'' ~ BroadcastShapesF shape shape', Catch shape'') =>
SShape shape'
-> Tensor gradient layout device dataType shape
-> Tensor gradient layout device dataType shape''
sExpand (forall {k} (a :: k). SingI a => Sing a
sing @shape')
sSelect ::
forall selectDim index gradient layout device dataType shapeIn shapeOut m.
( index `InRangeF` GetDimF selectDim shapeIn,
shapeOut ~ RemoveDimF selectDim shapeIn,
Catch shapeOut,
SGetShape shapeIn,
MonadThrow m
) =>
SSelectDim selectDim ->
SIndex index ->
Tensor gradient layout device dataType shapeIn ->
m (Tensor gradient layout device dataType shapeOut)
sSelect :: forall (selectDim :: SelectDim (By Symbol Nat))
(index :: Index Nat) (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shapeIn :: Shape [Dim (Name Symbol) (Size Nat)])
(shapeOut :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(InRangeF index (GetDimF selectDim shapeIn),
shapeOut ~ RemoveDimF selectDim shapeIn, Catch shapeOut,
SGetShape shapeIn, MonadThrow m) =>
SSelectDim selectDim
-> SIndex index
-> Tensor gradient layout device dataType shapeIn
-> m (Tensor gradient layout device dataType shapeOut)
sSelect SSelectDim selectDim
sSelectDim SIndex index
sIndex Tensor gradient layout device dataType shapeIn
input = do
SDim (GetDimF selectDim shapeIn)
sDim <- let inputShape :: SShape shapeIn
inputShape = forall (shape :: Shape [Dim (Name Symbol) (Size Nat)])
(gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType).
SGetShape shape =>
Tensor gradient layout device dataType shape -> SShape shape
sGetShape Tensor gradient layout device dataType shapeIn
input in forall (selectDim :: SelectDim (By Symbol Nat))
(shape :: Shape [Dim (Name Symbol) (Size Nat)])
(dim :: Dim (Name Symbol) (Size Nat)) (m :: * -> *).
(dim ~ GetDimF selectDim shape, MonadThrow m) =>
SSelectDim selectDim -> SShape shape -> m (SDim dim)
sGetDimFromShape SSelectDim selectDim
sSelectDim SShape shapeIn
inputShape
let dim :: Dim String Integer
dim = forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap forall a. IsChecked a -> a
forgetIsChecked forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SDim (GetDimF selectDim shapeIn)
sDim
index :: Integer
index = coerce :: forall a b. Coercible a b => a -> b
coerce forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SIndex index
sIndex
selectDim :: By String Integer
selectDim = forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SSelectDim selectDim
sSelectDim
if Integer
index forall a. Ord a => a -> a -> Bool
< forall name size. Dim name size -> size
dimSize Dim String Integer
dim
then forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ case By String Integer
selectDim of
ByName String
dimName -> forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor
-> ForeignPtr Dimname -> Int64 -> IO (ForeignPtr Tensor)
ATen.tensor_select_nl Tensor gradient layout device dataType shapeIn
input String
dimName (forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
index :: Int)
ByIndex Integer
dimIndex -> forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor -> Int64 -> Int64 -> IO (ForeignPtr Tensor)
ATen.tensor_select_ll Tensor gradient layout device dataType shapeIn
input (forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
dimIndex :: Int) (forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
index :: Int)
else forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM forall a b. (a -> b) -> a -> b
$ Integer -> Dim String Integer -> IndexOutOfBoundError
IndexOutOfBoundError Integer
index Dim String Integer
dim
data IndexOutOfBoundError = IndexOutOfBoundError {IndexOutOfBoundError -> Integer
ioobeIndex :: Integer, IndexOutOfBoundError -> Dim String Integer
ioobeDim :: Dim String Integer}
deriving stock (Int -> IndexOutOfBoundError -> ShowS
[IndexOutOfBoundError] -> ShowS
IndexOutOfBoundError -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [IndexOutOfBoundError] -> ShowS
$cshowList :: [IndexOutOfBoundError] -> ShowS
show :: IndexOutOfBoundError -> String
$cshow :: IndexOutOfBoundError -> String
showsPrec :: Int -> IndexOutOfBoundError -> ShowS
$cshowsPrec :: Int -> IndexOutOfBoundError -> ShowS
Show, Typeable)
instance Exception IndexOutOfBoundError where
displayException :: IndexOutOfBoundError -> String
displayException IndexOutOfBoundError {Integer
Dim String Integer
ioobeDim :: Dim String Integer
ioobeIndex :: Integer
ioobeDim :: IndexOutOfBoundError -> Dim String Integer
ioobeIndex :: IndexOutOfBoundError -> Integer
..} =
String
"Index `"
forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Integer
ioobeIndex
forall a. Semigroup a => a -> a -> a
<> String
"` is out of bounds for dimension `"
forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Dim String Integer
ioobeDim
forall a. Semigroup a => a -> a -> a
<> String
"`."
select ::
forall selectDim index gradient layout device dataType shapeIn shapeOut m.
( SingI selectDim,
SingI index,
index `InRangeF` GetDimF selectDim shapeIn,
shapeOut ~ RemoveDimF selectDim shapeIn,
Catch shapeOut,
SGetShape shapeIn,
MonadThrow m
) =>
Tensor gradient layout device dataType shapeIn ->
m (Tensor gradient layout device dataType shapeOut)
select :: forall (selectDim :: SelectDim (By Symbol Nat))
(index :: Index Nat) (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shapeIn :: Shape [Dim (Name Symbol) (Size Nat)])
(shapeOut :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(SingI selectDim, SingI index,
InRangeF index (GetDimF selectDim shapeIn),
shapeOut ~ RemoveDimF selectDim shapeIn, Catch shapeOut,
SGetShape shapeIn, MonadThrow m) =>
Tensor gradient layout device dataType shapeIn
-> m (Tensor gradient layout device dataType shapeOut)
select = forall (selectDim :: SelectDim (By Symbol Nat))
(index :: Index Nat) (gradient :: Gradient RequiresGradient)
(layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
(dataType :: DataType DType)
(shapeIn :: Shape [Dim (Name Symbol) (Size Nat)])
(shapeOut :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(InRangeF index (GetDimF selectDim shapeIn),
shapeOut ~ RemoveDimF selectDim shapeIn, Catch shapeOut,
SGetShape shapeIn, MonadThrow m) =>
SSelectDim selectDim
-> SIndex index
-> Tensor gradient layout device dataType shapeIn
-> m (Tensor gradient layout device dataType shapeOut)
sSelect (forall {k} (a :: k). SingI a => Sing a
sing @selectDim) (forall {k} (a :: k). SingI a => Sing a
sing @index)
type GatherDimImplF :: By Symbol Nat -> [Dim (Name Symbol) (Size Nat)] -> [Dim (Name Symbol) (Size Nat)] -> Maybe [Dim (Name Symbol) (Size Nat)]
type family GatherDimImplF by indexDims inputDims where
GatherDimImplF ('ByName dimName) indexDims inputDims = GatherDimByNameF (GetIndexByNameF dimName indexDims) (GetIndexByNameF dimName inputDims) indexDims inputDims
GatherDimImplF ('ByIndex dimIndex) indexDims inputDims = GatherDimByIndexF dimIndex indexDims inputDims
type GatherDimByIndexF :: Nat -> [Dim (Name Symbol) (Size Nat)] -> [Dim (Name Symbol) (Size Nat)] -> Maybe [Dim (Name Symbol) (Size Nat)]
type family GatherDimByIndexF dimIndex indexDims inputDims where
GatherDimByIndexF 0 ('Dim _ ('Size 0) ': _) _ = 'Nothing
GatherDimByIndexF 0 ('Dim indexDimName indexDimSize ': indexDims) ('Dim inputDimName _ ': inputDims) =
If
(UnifyCheck (Name Symbol) indexDimName inputDimName && UnifyCheck [Dim (Name Symbol) (Size Nat)] indexDims inputDims)
('Just ('Dim (indexDimName <+> inputDimName) indexDimSize ': (indexDims <+> inputDims)))
'Nothing
GatherDimByIndexF dimIndex (indexDim ': indexDims) (inputDim ': inputDims) =
If
(UnifyCheck (Dim (Name Symbol) (Size Nat)) indexDim inputDim)
(PrependMaybe ('Just (indexDim <+> inputDim)) (GatherDimByIndexF (dimIndex - 1) indexDims inputDims))
'Nothing
GatherDimByIndexF _ _ _ = 'Nothing
type GatherDimByNameF :: Maybe Nat -> Maybe Nat -> [Dim (Name Symbol) (Size Nat)] -> [Dim (Name Symbol) (Size Nat)] -> Maybe [Dim (Name Symbol) (Size Nat)]
type family GatherDimByNameF dimIndex dimIndex' indexDims inputDims where
GatherDimByNameF 'Nothing ('Just dimIndex') indexDims inputDims = GatherDimByIndexF dimIndex' indexDims inputDims
GatherDimByNameF ('Just dimIndex) 'Nothing indexDims inputDims = GatherDimByIndexF dimIndex indexDims inputDims
GatherDimByNameF ('Just dimIndex) ('Just dimIndex) indexDims inputDims = GatherDimByIndexF dimIndex indexDims inputDims
GatherDimByNameF _ _ _ _ = 'Nothing
type GatherDimMessage :: By Symbol Nat -> [Dim (Name Symbol) (Size Nat)] -> [Dim (Name Symbol) (Size Nat)] -> ErrorMessage
type GatherDimMessage by indexDims inputDims =
"Cannot gather the tensor with the dimensions"
% ""
% " '" <> inputDims <> "'"
% ""
% "at the dimension"
% ""
% " '" <> by <> "'"
% ""
% "using an index of shape"
% ""
% " '" <> indexDims <> "'."
% ""
type GatherDimCheckF ::
By Symbol Nat ->
[Dim (Name Symbol) (Size Nat)] ->
[Dim (Name Symbol) (Size Nat)] ->
Maybe [Dim (Name Symbol) (Size Nat)] ->
[Dim (Name Symbol) (Size Nat)]
type family GatherDimCheckF by indexDims inputDims result where
GatherDimCheckF by indexDims inputDims 'Nothing = TypeError (GatherDimMessage by indexDims inputDims)
GatherDimCheckF _ _ _ ('Just dims) = dims
type GatherDimF :: SelectDim (By Symbol Nat) -> Shape [Dim (Name Symbol) (Size Nat)] -> Shape [Dim (Name Symbol) (Size Nat)] -> Shape [Dim (Name Symbol) (Size Nat)]
type family GatherDimF selectDim indexShape inputShape where
GatherDimF 'UncheckedSelectDim _ _ = 'UncheckedShape
GatherDimF _ 'UncheckedShape _ = 'UncheckedShape
GatherDimF _ _ 'UncheckedShape = 'UncheckedShape
GatherDimF ('SelectDim by) ('Shape indexDims) ('Shape inputDims) = 'Shape (GatherDimCheckF by indexDims inputDims (GatherDimImplF by indexDims inputDims))
sGatherDim ::
forall selectDim indexGradient inputGradient indexLayout inputLayout indexDevice inputDevice indexDataType inputDataType indexShape inputShape outputShape m.
(MonadThrow m, outputShape ~ GatherDimF selectDim indexShape inputShape, Catch outputShape, Catch (indexDataType <+> 'DataType 'Int64)) =>
SSelectDim selectDim ->
Tensor indexGradient indexLayout indexDevice indexDataType indexShape ->
Tensor inputGradient inputLayout inputDevice inputDataType inputShape ->
m
( Tensor
(indexGradient <|> inputGradient)
(indexLayout <+> inputLayout)
(indexDevice <+> inputDevice)
inputDataType
outputShape
)
sGatherDim :: forall (selectDim :: SelectDim (By Symbol Nat))
(indexGradient :: Gradient RequiresGradient)
(inputGradient :: Gradient RequiresGradient)
(indexLayout :: Layout LayoutType)
(inputLayout :: Layout LayoutType)
(indexDevice :: Device (DeviceType Nat))
(inputDevice :: Device (DeviceType Nat))
(indexDataType :: DataType DType) (inputDataType :: DataType DType)
(indexShape :: Shape [Dim (Name Symbol) (Size Nat)])
(inputShape :: Shape [Dim (Name Symbol) (Size Nat)])
(outputShape :: Shape [Dim (Name Symbol) (Size Nat)])
(m :: * -> *).
(MonadThrow m,
outputShape ~ GatherDimF selectDim indexShape inputShape,
Catch outputShape, Catch (indexDataType <+> 'DataType 'Int64)) =>
SSelectDim selectDim
-> Tensor
indexGradient indexLayout indexDevice indexDataType indexShape
-> Tensor
inputGradient inputLayout inputDevice inputDataType inputShape
-> m (Tensor
(indexGradient <|> inputGradient)
(indexLayout <+> inputLayout)
(indexDevice <+> inputDevice)
inputDataType
outputShape)
sGatherDim SSelectDim selectDim
selectDim Tensor
indexGradient indexLayout indexDevice indexDataType indexShape
index Tensor
inputGradient inputLayout inputDevice inputDataType inputShape
input =
let by :: By String Integer
by = forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SSelectDim selectDim
selectDim
in case By String Integer
by of
ByName String
dimName -> forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
ATen.cast4 ForeignPtr Tensor
-> ForeignPtr Dimname
-> ForeignPtr Tensor
-> CBool
-> IO (ForeignPtr Tensor)
ATen.gather_tntb Tensor
inputGradient inputLayout inputDevice inputDataType inputShape
input String
dimName Tensor
indexGradient indexLayout indexDevice indexDataType indexShape
index Bool
False
ByIndex Integer
dimIndex -> forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
ATen.cast4 ForeignPtr Tensor
-> Int64 -> ForeignPtr Tensor -> CBool -> IO (ForeignPtr Tensor)
ATen.gather_tltb Tensor
inputGradient inputLayout inputDevice inputDataType inputShape
input (forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
dimIndex :: Int) Tensor
indexGradient indexLayout indexDevice indexDataType indexShape
index Bool
False