{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

module Torch.GraduallyTyped.Tensor.IndexingSlicingJoining where

import Control.Exception (Exception (..))
import Control.Monad.Catch (MonadThrow (throwM))
import Data.Bifunctor (bimap)
import Data.Coerce (coerce)
import Data.Kind (Type)
import Data.Singletons (SingI (..), SingKind (..), fromSing)
import Data.Type.Bool (If, type (&&))
import Data.Typeable (Typeable)
import Foreign.ForeignPtr (ForeignPtr)
import GHC.TypeLits (ErrorMessage, Nat, Symbol, TypeError, type (-))
import System.IO.Unsafe (unsafePerformIO)
import Torch.GraduallyTyped.DType (DType (..), DataType (..))
import Torch.GraduallyTyped.Device (Device (..), DeviceType (..))
import Torch.GraduallyTyped.Index.Class (InRangeF)
import Torch.GraduallyTyped.Index.Type (DemotedIndex (..), SIndex)
import Torch.GraduallyTyped.Layout (Layout (..), LayoutType (..))
import Torch.GraduallyTyped.Prelude (Catch, FromMaybe, MapMaybe, MaybeF, PrependMaybe, When, forgetIsChecked)
import Torch.GraduallyTyped.RequiresGradient (Gradient, RequiresGradient (..))
import Torch.GraduallyTyped.Shape.Class (AddDimF, BroadcastShapesF, GetDimF, GetDimImplF, GetIndexByNameF, InsertDimImplF, NumelF, RemoveDimF, ReplaceDimF, ReplaceDimImplF, sGetDimFromShape)
import Torch.GraduallyTyped.Shape.Type (By (..), Dim (..), Name (..), SSelectDim, SShape, SelectDim (..), Shape (..), Size (..), dimSize)
import Torch.GraduallyTyped.Tensor.Type (SGetShape (sGetShape), Tensor)
import Torch.GraduallyTyped.Unify (UnifyCheck, type (<+>), type (<|>))
import Torch.HList (HList)
import qualified Torch.Internal.Cast as ATen (cast1, cast2, cast3, cast4)
import qualified Torch.Internal.Class as ATen (Castable)
import Torch.Internal.GC (unsafeThrowableIO)
import qualified Torch.Internal.Managed.Native as ATen
import qualified Torch.Internal.Managed.Type.Tensor as ATen
import qualified Torch.Internal.Type as ATen
import Type.Errors.Pretty (ToErrorMessage, type (%), type (<>))

-- $setup
-- >>> import Torch.GraduallyTyped.Prelude.List (SList (..))
-- >>> import Torch.GraduallyTyped

class HasCat (selectDim :: SelectDim (By Symbol Nat)) k (c :: k -> Type) (a :: k) where
  type CatF selectDim a c :: Type

  -- | Concatenates the given sequence of seq tensors in the given dimension.
  -- All tensors must either have the same shape (except in the concatenating dimension) or be empty.
  --
  -- >>> t <- ones @('Gradient 'WithGradient) @('Layout 'Dense) @('Device 'CPU) @('DataType 'Float) @('Shape '[ 'Dim ('Name "batch") ('Size 32), 'Dim ('Name "feature") ('Size 8)])
  -- >>> :type cat @('SelectDim ('ByName "feature")) [t]
  -- cat @('SelectDim ('ByName "feature")) [t]
  --   :: MonadThrow m =>
  --      m (Tensor
  --           ('Gradient 'WithGradient)
  --           ('Layout 'Dense)
  --           ('Device 'CPU)
  --           ('DataType 'Float)
  --           ('Shape
  --              '[ 'Dim ('Name "batch") ('Size 32),
  --                 'Dim 'UncheckedName 'UncheckedSize]))
  -- >>> :type cat @('SelectDim ( 'ByIndex 0)) [t]
  -- cat @('SelectDim ( 'ByIndex 0)) [t]
  --   :: MonadThrow m =>
  --      m (Tensor
  --           ('Gradient 'WithGradient)
  --           ('Layout 'Dense)
  --           ('Device 'CPU)
  --           ('DataType 'Float)
  --           ('Shape
  --              '[ 'Dim 'UncheckedName 'UncheckedSize,
  --                 'Dim ('Name "feature") ('Size 8)]))
  -- >>> :type sCat (SUncheckedSelectDim (ByIndex 0)) [t]
  -- sCat (SUncheckedSelectDim (ByIndex 0)) [t]
  --   :: MonadThrow m =>
  --      m (Tensor
  --           ('Gradient 'WithGradient)
  --           ('Layout 'Dense)
  --           ('Device 'CPU)
  --           ('DataType 'Float)
  --           'UncheckedShape)
  sCat :: forall m. MonadThrow m => SSelectDim selectDim -> c a -> m (CatF selectDim a c)

  cat :: forall m. (SingI selectDim, MonadThrow m) => c a -> m (CatF selectDim a c)
  cat = forall (selectDim :: SelectDim (By Symbol Nat)) k (c :: k -> *)
       (a :: k) (m :: * -> *).
(HasCat selectDim k c a, MonadThrow m) =>
SSelectDim selectDim -> c a -> m (CatF selectDim a c)
sCat (forall {k} (a :: k). SingI a => Sing a
sing @selectDim)

type family CatListImplF (selectDim :: SelectDim (By Symbol Nat)) (tensor :: Type) :: Maybe Type where
  CatListImplF 'UncheckedSelectDim (Tensor gradient layout device dataType _) = 'Just (Tensor gradient layout device dataType 'UncheckedShape)
  CatListImplF ('SelectDim _) (Tensor gradient layout device dataType 'UncheckedShape) = 'Just (Tensor gradient layout device dataType 'UncheckedShape)
  CatListImplF ('SelectDim by) (Tensor gradient layout device dataType ('Shape dims)) = MapMaybe (Tensor gradient layout device dataType) (MapMaybe 'Shape (ReplaceDimImplF by dims ('Dim 'UncheckedName 'UncheckedSize)))

type CheckSpellingMessage = "Check the spelling of named dimensions, and make sure the number of dimensions is correct."

type family CatListCheckF (selectDim :: SelectDim (By Symbol Nat)) (tensor :: Type) (result :: Maybe Type) :: Type where
  CatListCheckF selectDim (Tensor _ _ _ _ shape) 'Nothing =
    TypeError
      ( "Cannot concatenate the dimension"
          % ""
          % "    " <> selectDim
          % ""
          % "for tensors of shape"
          % ""
          % "    " <> shape <> "."
          % ""
          % CheckSpellingMessage
      )
  CatListCheckF _ _ ('Just result) = result

type CatListF selectDim tensor = CatListCheckF selectDim tensor (CatListImplF selectDim tensor)

instance
  ATen.Castable (CatListF selectDim (Tensor gradient layout device dataType shape)) (ForeignPtr ATen.Tensor) =>
  HasCat selectDim Type [] (Tensor gradient layout device dataType shape)
  where
  type CatF selectDim (Tensor gradient layout device dataType shape) [] = CatListF selectDim (Tensor gradient layout device dataType shape)
  sCat :: forall (m :: * -> *).
MonadThrow m =>
SSelectDim selectDim
-> [Tensor gradient layout device dataType shape]
-> m (CatF
        selectDim (Tensor gradient layout device dataType shape) [])
sCat SSelectDim selectDim
selectDim [Tensor gradient layout device dataType shape]
tensors = do
    let by :: By String Integer
by = forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SSelectDim selectDim
selectDim
    forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ case By String Integer
by of
      ByName String
name -> forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr TensorList
-> ForeignPtr Dimname -> IO (ForeignPtr Tensor)
ATen.cat_ln [Tensor gradient layout device dataType shape]
tensors String
name
      ByIndex Integer
index -> forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr TensorList -> Int64 -> IO (ForeignPtr Tensor)
ATen.cat_ll [Tensor gradient layout device dataType shape]
tensors (forall a. Num a => Integer -> a
fromInteger Integer
index :: Int)

type CatHListImplF ::
  SelectDim (By Symbol Nat) ->
  [Type] ->
  Maybe (Gradient RequiresGradient, Layout LayoutType, Device (DeviceType Nat), DataType DType, Shape [Dim (Name Symbol) (Size Nat)]) ->
  Type
type family CatHListImplF selectDim tensors acc where
  CatHListImplF _ '[] 'Nothing = TypeError (ToErrorMessage "Cannot concatenate an empty list of tensors.")
  CatHListImplF _ '[] ('Just '(gradient, layout, device, dataType, shape)) = Tensor gradient layout device dataType shape
  CatHListImplF selectDim (Tensor gradient layout device dataType shape ': tensors) 'Nothing =
    CatHListImplF selectDim tensors ('Just '(gradient, layout, device, dataType, shape))
  CatHListImplF selectDim (Tensor gradient layout device dataType shape ': tensors) ('Just '(gradient', layout', device', dataType', shape')) =
    CatHListImplF
      selectDim
      tensors
      ( 'Just
          '( gradient <|> gradient',
             layout <+> layout',
             device <+> device',
             dataType <+> dataType',
             ReplaceDimF
               selectDim
               (shape <+> ReplaceDimF selectDim shape' (GetDimF selectDim shape))
               (AddDimF (GetDimF selectDim shape) (GetDimF selectDim shape'))
           )
      )
  CatHListImplF _ (x ': _) _ =
    TypeError
      ( "Cannot concatenate because"
          % ""
          % "    '" <> x <> "'"
          % ""
          % "is not a tensor type."
      )

type CatHListF selectDim tensors = CatHListImplF selectDim tensors 'Nothing

instance
  ( ATen.Castable (CatHListF selectDim tensors) (ForeignPtr ATen.Tensor),
    ATen.Castable (HList tensors) (ForeignPtr ATen.TensorList)
  ) =>
  HasCat selectDim [Type] HList tensors
  where
  type CatF selectDim tensors HList = CatHListF selectDim tensors
  sCat :: forall (m :: * -> *).
MonadThrow m =>
SSelectDim selectDim
-> HList tensors -> m (CatF selectDim tensors HList)
sCat SSelectDim selectDim
selectDim HList tensors
tensors = do
    let by :: By String Integer
by = forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SSelectDim selectDim
selectDim
    forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ case By String Integer
by of
      ByName String
name -> forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr TensorList
-> ForeignPtr Dimname -> IO (ForeignPtr Tensor)
ATen.cat_ln HList tensors
tensors String
name
      ByIndex Integer
index -> forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr TensorList -> Int64 -> IO (ForeignPtr Tensor)
ATen.cat_ll HList tensors
tensors (forall a. Num a => Integer -> a
fromInteger Integer
index :: Int)

type ReshapeNumelMismatchMessage ::
  Nat ->
  Nat ->
  Shape [Dim (Name Symbol) (Size Nat)] ->
  Shape [Dim (Name Symbol) (Size Nat)] ->
  ErrorMessage

type ReshapeNumelMismatchMessage numel numel' shape shape' =
  "Cannot reshape the tensor. The original shape,"
    % ""
    % "    '" <> shape <> "',"
    % ""
    % "and the new shape,"
    % ""
    % "    '" <> shape' <> "',"
    % ""
    % "have different total numbers of elements,"
    % ""
    % "    '" <> numel <> "' versus '" <> numel' <> "',"
    % ""
    % "respectively."

type ReshapeImplF ::
  Maybe Nat ->
  Maybe Nat ->
  Shape [Dim (Name Symbol) (Size Nat)] ->
  Shape [Dim (Name Symbol) (Size Nat)] ->
  Shape [Dim (Name Symbol) (Size Nat)]
type family ReshapeImplF numel numel' shape shape' where
  ReshapeImplF ('Just numel) ('Just numel) _ shape' = shape'
  ReshapeImplF ('Just numel) ('Just numel') shape shape' = TypeError (ReshapeNumelMismatchMessage numel numel' shape shape')
  ReshapeImplF 'Nothing _ _ _ = 'UncheckedShape
  ReshapeImplF _ 'Nothing _ _ = 'UncheckedShape
  ReshapeImplF _ _ 'UncheckedShape _ = 'UncheckedShape
  ReshapeImplF _ _ _ 'UncheckedShape = 'UncheckedShape

type ReshapeF :: Shape [Dim (Name Symbol) (Size Nat)] -> Shape [Dim (Name Symbol) (Size Nat)] -> Shape [Dim (Name Symbol) (Size Nat)]
type family ReshapeF shape shape' where
  ReshapeF shape shape' = ReshapeImplF (NumelF shape) (NumelF shape') shape shape'

-- | Returns a tensor with the same data and number of elements as the input tensor,
-- but with the specified shape:
--
-- >>> g <- sMkGenerator (SDevice SCPU) 0
-- >>> (input, _) <- sRandn (TensorSpec (SGradient SWithGradient) (SLayout SDense) (SDevice SCPU) (SDataType SFloat) (SShape $ SName @"*" :&: SSize @4 :|: SNil)) g
-- >>> output <- sReshape (SShape $ SName @"*" :&: SSize @2 :|: SName @"*" :&: SSize @2 :|: SNil) input
-- >>> :type output
-- output
--   :: Tensor
--        ('Gradient 'WithGradient)
--        ('Layout 'Dense)
--        ('Device 'CPU)
--        ('DataType 'Float)
--        ('Shape '[ 'Dim ('Name "*") ('Size 2), 'Dim ('Name "*") ('Size 2)])
--
-- At the value level, a single dimension may be '-1',
-- in which case it is inferred from the remaining dimensions and the number of elements in the input:
--
-- >>> output' <- sReshape (SShape $ SUncheckedName "*" :&: SUncheckedSize (-1) :|: SNil) output
-- >>> :type output'
-- output'
--   :: Tensor
--        ('Gradient 'WithGradient)
--        ('Layout 'Dense)
--        ('Device 'CPU)
--        ('DataType 'Float)
--        'UncheckedShape
-- >>> getDims output'
-- [Dim {dimName = "*", dimSize = 4}]
sReshape,
  sSetShape ::
    forall m shape' gradient layout device dataType shape shape''.
    MonadThrow m =>
    (shape'' ~ ReshapeF shape shape', Catch shape'') =>
    SShape shape' ->
    Tensor gradient layout device dataType shape ->
    m (Tensor gradient layout device dataType shape'')
sReshape :: forall (m :: * -> *)
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]).
(MonadThrow m, shape'' ~ ReshapeF shape shape', Catch shape'') =>
SShape shape'
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape'')
sReshape SShape shape'
shape' Tensor gradient layout device dataType shape
input = forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ do
  let dims :: [Dim (IsChecked String) (IsChecked Integer)]
dims = forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SShape shape'
shape'
  ForeignPtr Tensor
t :: ForeignPtr ATen.Tensor <- forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr IntArray -> IO (ForeignPtr Tensor)
ATen.reshape_tl Tensor gradient layout device dataType shape
input (forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall name size. Dim name size -> size
dimSize forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Dim (IsChecked String) (IsChecked Integer)]
dims)
  forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor
-> ForeignPtr DimnameList -> IO (ForeignPtr Tensor)
ATen.tensor_refine_names_N ForeignPtr Tensor
t (forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall name size. Dim name size -> name
dimName forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Dim (IsChecked String) (IsChecked Integer)]
dims)
sSetShape :: forall (m :: * -> *)
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]).
(MonadThrow m, shape'' ~ ReshapeF shape shape', Catch shape'') =>
SShape shape'
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape'')
sSetShape = forall (m :: * -> *)
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]).
(MonadThrow m, shape'' ~ ReshapeF shape shape', Catch shape'') =>
SShape shape'
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape'')
sReshape

type AllDimSizesChecked :: Shape [Dim (Name Symbol) (Size Nat)] -> Bool
type family AllDimSizesChecked shape where
  AllDimSizesChecked 'UncheckedShape = 'False
  AllDimSizesChecked ('Shape '[]) = 'True
  AllDimSizesChecked ('Shape ('Dim name ('Size size) ': xs)) = AllDimSizesChecked ('Shape xs)

reshape ::
  forall m shape' gradient layout device dataType shape shape''.
  ( shape'' ~ ReshapeF shape shape',
    Catch shape'',
    When (AllDimSizesChecked shape) (shape' ~ shape''),
    SingI shape',
    MonadThrow m
  ) =>
  Tensor gradient layout device dataType shape ->
  m (Tensor gradient layout device dataType shape'')
reshape :: forall (m :: * -> *)
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]).
(shape'' ~ ReshapeF shape shape', Catch shape'',
 When (AllDimSizesChecked shape) (shape' ~ shape''), SingI shape',
 MonadThrow m) =>
Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape'')
reshape = forall (m :: * -> *)
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape'' :: Shape [Dim (Name Symbol) (Size Nat)]).
(MonadThrow m, shape'' ~ ReshapeF shape shape', Catch shape'') =>
SShape shape'
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape'')
sReshape (forall {k} (a :: k). SingI a => Sing a
sing @shape')

type TransposeBy0Message :: By Symbol Nat -> [Dim (Name Symbol) (Size Nat)] -> ErrorMessage

type TransposeBy0Message by0 dims =
  "Cannot transpose the tensor with the dimensions"
    % ""
    % "    '" <> dims <> "'"
    % ""
    % "because the specified source dimension"
    % ""
    % "    '" <> by0 <> "'"
    % ""
    % "could not be found."

type TransposeBy1Message :: By Symbol Nat -> [Dim (Name Symbol) (Size Nat)] -> ErrorMessage

type TransposeBy1Message by1 dims =
  "Cannot transpose the tensor with the dimensions"
    % ""
    % "    '" <> dims <> "'"
    % ""
    % "because the specified target dimension"
    % ""
    % "    '" <> by1 <> "'"
    % ""
    % "could not be found."

-- | Compute transposed shapes.
--
-- >>> type SelectBatch = 'SelectDim ('ByName "batch" :: By Symbol Nat)
-- >>> type SelectFeature = 'SelectDim ('ByName "feature" :: By Symbol Nat)
-- >>> type Dims = '[ 'Dim ('Name "batch") ('Size 10), 'Dim ('Name "feature") ('Size 8), 'Dim ('Name "anotherFeature") ('Size 12)]
-- >>> :kind! TransposeF SelectBatch SelectFeature ('Shape Dims)
-- TransposeF SelectBatch SelectFeature ('Shape Dims) :: Shape
--                                                         [Dim (Name Symbol) (Size Natural)]
-- = 'Shape
--     '[ 'Dim ('Name "feature") ('Size 8),
--        'Dim ('Name "batch") ('Size 10),
--        'Dim ('Name "anotherFeature") ('Size 12)]
-- >>> :kind! TransposeF SelectFeature SelectBatch ('Shape Dims)
-- TransposeF SelectFeature SelectBatch ('Shape Dims) :: Shape
--                                                         [Dim (Name Symbol) (Size Natural)]
-- = 'Shape
--     '[ 'Dim ('Name "feature") ('Size 8),
--        'Dim ('Name "batch") ('Size 10),
--        'Dim ('Name "anotherFeature") ('Size 12)]
type TransposeF ::
  SelectDim (By Symbol Nat) ->
  SelectDim (By Symbol Nat) ->
  Shape [Dim (Name Symbol) (Size Nat)] ->
  Shape [Dim (Name Symbol) (Size Nat)]
type family TransposeF selectDim0 selectDim1 shape where
  TransposeF _ _ 'UncheckedShape = 'UncheckedShape
  TransposeF _ 'UncheckedSelectDim _ = 'UncheckedShape
  TransposeF 'UncheckedSelectDim _ _ = 'UncheckedShape
  TransposeF ('SelectDim ('ByName name0)) ('SelectDim ('ByName name1)) ('Shape dims) =
    'Shape
      ( TransposeIndexIndexDimsF
          ( FromMaybe
              (TypeError (TransposeBy0Message ('ByName name0) dims))
              (GetIndexByNameF name0 dims)
          )
          ( FromMaybe
              (TypeError (TransposeBy1Message ('ByName name1) dims))
              (GetIndexByNameF name1 dims)
          )
          dims
      )
  TransposeF ('SelectDim ('ByIndex index0)) ('SelectDim ('ByIndex index1)) ('Shape dims) = 'Shape (TransposeIndexIndexDimsF index0 index1 dims)
  TransposeF ('SelectDim by0) ('SelectDim by1) _ =
    TypeError
      ( "Cannot transpose the tensor. "
          % ""
          % "The source and target dimensions must be selected either both by name or both by index, "
          % "but mixed selectors were found: "
          % ""
          % "    '" <> 'SelectDim by0 <> "' and '" <> 'SelectDim by1 <> "'."
          % ""
      )

type TransposeIndexIndexDimsF :: Nat -> Nat -> [Dim (Name Symbol) (Size Nat)] -> [Dim (Name Symbol) (Size Nat)]
type family TransposeIndexIndexDimsF index0 index1 dims where
  TransposeIndexIndexDimsF index0 index1 dims =
    FromMaybe
      (TypeError (TransposeBy1Message ('ByIndex index1) dims))
      ( ReplaceDimImplF
          ('ByIndex index1)
          ( FromMaybe
              (TypeError (TransposeBy0Message ('ByIndex index0) dims))
              ( ReplaceDimImplF
                  ('ByIndex index0)
                  dims
                  ( FromMaybe
                      (TypeError (TransposeBy1Message ('ByIndex index1) dims))
                      (GetDimImplF ('ByIndex index1) dims)
                  )
              )
          )
          ( FromMaybe
              (TypeError (TransposeBy0Message ('ByIndex index0) dims))
              (GetDimImplF ('ByIndex index0) dims)
          )
      )

-- | Returns a tensor that is a transposed version of @input@.
-- The selected dimensions @selectDim0@ and @selectDim1@ are swapped.
--
-- >>> g <- sMkGenerator (SDevice SCPU) 0
-- >>> (input, _) <- sRandn (TensorSpec (SGradient SWithGradient) (SLayout SDense) (SDevice SCPU) (SDataType SFloat) (SShape $ SName @"batch" :&: SSize @10 :|: SName @"feature" :&: SSize @5 :|: SNil)) g
-- >>> output <- sTranspose (SSelectDim (SByName @"batch")) (SSelectDim (SByName @"feature")) input
-- >>> :type output
-- output
--   :: Tensor
--        ('Gradient 'WithGradient)
--        ('Layout 'Dense)
--        ('Device 'CPU)
--        ('DataType 'Float)
--        ('Shape
--           '[ 'Dim ('Name "feature") ('Size 5),
--              'Dim ('Name "batch") ('Size 10)])
-- >>> output <- sTranspose (SUncheckedSelectDim (ByIndex 0)) (SSelectDim (SByIndex @1)) input
-- >>> :type output
-- output
--   :: Tensor
--        ('Gradient 'WithGradient)
--        ('Layout 'Dense)
--        ('Device 'CPU)
--        ('DataType 'Float)
--        'UncheckedShape
-- >>> getDims output
-- [Dim {dimName = "feature", dimSize = 5},Dim {dimName = "batch", dimSize = 10}]
sTranspose ::
  forall selectDim0 selectDim1 gradient layout device dataType shape shape' m.
  ( shape' ~ TransposeF selectDim0 selectDim1 shape,
    Catch shape',
    MonadThrow m
  ) =>
  SSelectDim selectDim0 ->
  SSelectDim selectDim1 ->
  Tensor gradient layout device dataType shape ->
  m (Tensor gradient layout device dataType shape')
sTranspose :: forall (selectDim0 :: SelectDim (By Symbol Nat))
       (selectDim1 :: SelectDim (By Symbol Nat))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(shape' ~ TransposeF selectDim0 selectDim1 shape, Catch shape',
 MonadThrow m) =>
SSelectDim selectDim0
-> SSelectDim selectDim1
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
sTranspose SSelectDim selectDim0
selectDim0 SSelectDim selectDim1
selectDim1 Tensor gradient layout device dataType shape
input = do
  let by0 :: By String Integer
by0 = forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SSelectDim selectDim0
selectDim0
      by1 :: By String Integer
by1 = forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SSelectDim selectDim1
selectDim1
  forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ case (By String Integer
by0, By String Integer
by1) of
    (ByName String
name0, ByName String
name1) -> forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor
-> ForeignPtr Dimname
-> ForeignPtr Dimname
-> IO (ForeignPtr Tensor)
ATen.transpose_tnn Tensor gradient layout device dataType shape
input String
name0 String
name1
    (ByIndex Integer
index0, ByIndex Integer
index1) -> forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor -> Int64 -> Int64 -> IO (ForeignPtr Tensor)
ATen.transpose_tll Tensor gradient layout device dataType shape
input (forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
index0 :: Int) (forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
index1 :: Int)
    (By String Integer, By String Integer)
_ -> forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM forall a b. (a -> b) -> a -> b
$ By String Integer -> By String Integer -> TransposeError
TransposeMixedSelectorsError By String Integer
by0 By String Integer
by1

data TransposeError = TransposeMixedSelectorsError
  { TransposeError -> By String Integer
teBy0 :: By String Integer,
    TransposeError -> By String Integer
teBy1 :: By String Integer
  }
  deriving stock (Int -> TransposeError -> ShowS
[TransposeError] -> ShowS
TransposeError -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TransposeError] -> ShowS
$cshowList :: [TransposeError] -> ShowS
show :: TransposeError -> String
$cshow :: TransposeError -> String
showsPrec :: Int -> TransposeError -> ShowS
$cshowsPrec :: Int -> TransposeError -> ShowS
Show, Typeable)

instance Exception TransposeError where
  displayException :: TransposeError -> String
displayException TransposeMixedSelectorsError {By String Integer
teBy1 :: By String Integer
teBy0 :: By String Integer
teBy1 :: TransposeError -> By String Integer
teBy0 :: TransposeError -> By String Integer
..} =
    String
"Cannot transpose the tensor. "
      forall a. Semigroup a => a -> a -> a
<> String
"The source and target dimensions must be selected either both by name or both by index, "
      forall a. Semigroup a => a -> a -> a
<> String
"but mixed selectors were found: '"
      forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show By String Integer
teBy0
      forall a. Semigroup a => a -> a -> a
<> String
"' and '"
      forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show By String Integer
teBy1
      forall a. Semigroup a => a -> a -> a
<> String
"'."

transpose ::
  forall selectDim0 selectDim1 gradient layout device dataType shape shape' m.
  ( shape' ~ TransposeF selectDim0 selectDim1 shape,
    Catch shape',
    SingI selectDim0,
    SingI selectDim1,
    MonadThrow m
  ) =>
  Tensor gradient layout device dataType shape ->
  m (Tensor gradient layout device dataType shape')
transpose :: forall (selectDim0 :: SelectDim (By Symbol Nat))
       (selectDim1 :: SelectDim (By Symbol Nat))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(shape' ~ TransposeF selectDim0 selectDim1 shape, Catch shape',
 SingI selectDim0, SingI selectDim1, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
transpose = forall (selectDim0 :: SelectDim (By Symbol Nat))
       (selectDim1 :: SelectDim (By Symbol Nat))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(shape' ~ TransposeF selectDim0 selectDim1 shape, Catch shape',
 MonadThrow m) =>
SSelectDim selectDim0
-> SSelectDim selectDim1
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
sTranspose (forall {k} (a :: k). SingI a => Sing a
sing @selectDim0) (forall {k} (a :: k). SingI a => Sing a
sing @selectDim1)

type UnsqueezeByMessage (by :: By Symbol Nat) (dims :: [Dim (Name Symbol) (Size Nat)]) =
  "Cannot unsqueeze the tensor with the dimensions"
    % ""
    % "    '" <> dims <> "'"
    % ""
    % "because the specified source dimension"
    % ""
    % "    '" <> by <> "'"
    % ""
    % "could not be found."

type family UnsqueezeF (selectDim :: SelectDim (By Symbol Nat)) (shape :: Shape [Dim (Name Symbol) (Size Nat)]) :: Shape [Dim (Name Symbol) (Size Nat)] where
  UnsqueezeF _ 'UncheckedShape = 'UncheckedShape
  UnsqueezeF 'UncheckedSelectDim _ = 'UncheckedShape
  UnsqueezeF ('SelectDim ('ByName name)) ('Shape dims) =
    'Shape
      ( UnsqueezeIndexDimsF
          ( FromMaybe
              (TypeError (UnsqueezeByMessage ('ByName name) dims))
              (GetIndexByNameF name dims)
          )
          dims
      )
  UnsqueezeF ('SelectDim ('ByIndex index)) ('Shape dims) = 'Shape (UnsqueezeIndexDimsF index dims)

type family UnsqueezeIndexDimsF (index :: Nat) (dims :: [Dim (Name Symbol) (Size Nat)]) :: [Dim (Name Symbol) (Size Nat)] where
  UnsqueezeIndexDimsF index dims =
    FromMaybe
      (TypeError (UnsqueezeByMessage ('ByIndex index) dims))
      ( InsertDimImplF
          ('ByIndex index)
          dims
          ('Dim ('Name "*") ('Size 1))
      )

-- | Unsqueezes a tensor with the specified dimension.
sUnsqueeze ::
  forall selectDim gradient layout device dataType shape shape' m.
  ( shape' ~ UnsqueezeF selectDim shape,
    Catch shape',
    MonadThrow m
  ) =>
  SSelectDim selectDim ->
  Tensor gradient layout device dataType shape ->
  m (Tensor gradient layout device dataType shape')
sUnsqueeze :: forall (selectDim :: SelectDim (By Symbol Nat))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(shape' ~ UnsqueezeF selectDim shape, Catch shape',
 MonadThrow m) =>
SSelectDim selectDim
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
sUnsqueeze SSelectDim selectDim
selectDim Tensor gradient layout device dataType shape
input =
  let by :: By String Integer
by = forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SSelectDim selectDim
selectDim
   in case By String Integer
by of
        ByName String
_name -> forall a. HasCallStack => a
undefined
        ByIndex Integer
index -> forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> Int64 -> IO (ForeignPtr Tensor)
ATen.unsqueeze_tl Tensor gradient layout device dataType shape
input (forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
index :: Int)

-- | Unsqueezes a tensor with the specified dimension.
unsqueeze ::
  forall selectDim gradient layout device dataType shape shape' m.
  ( shape' ~ UnsqueezeF selectDim shape,
    Catch shape',
    SingI selectDim,
    MonadThrow m
  ) =>
  Tensor gradient layout device dataType shape ->
  m (Tensor gradient layout device dataType shape')
unsqueeze :: forall (selectDim :: SelectDim (By Symbol Nat))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(shape' ~ UnsqueezeF selectDim shape, Catch shape',
 SingI selectDim, MonadThrow m) =>
Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
unsqueeze = forall (selectDim :: SelectDim (By Symbol Nat))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(shape' ~ UnsqueezeF selectDim shape, Catch shape',
 MonadThrow m) =>
SSelectDim selectDim
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
sUnsqueeze (forall {k} (a :: k). SingI a => Sing a
sing @selectDim)

type SqueezeAllShapeF :: Shape [Dim (Name Symbol) (Size Nat)] -> Shape [Dim (Name Symbol) (Size Nat)]
type family SqueezeAllShapeF shape where
  SqueezeAllShapeF 'UncheckedShape = 'UncheckedShape
  SqueezeAllShapeF ('Shape dims) = MaybeF 'UncheckedShape 'Shape (SqueezeAllDimsF dims)

type SqueezeAllDimsF :: [Dim (Name Symbol) (Size Nat)] -> Maybe [Dim (Name Symbol) (Size Nat)]
type family SqueezeAllDimsF dims where
  SqueezeAllDimsF '[] = 'Just '[]
  SqueezeAllDimsF ('Dim _ 'UncheckedSize ': dims) = 'Nothing
  SqueezeAllDimsF ('Dim _ ('Size 1) ': dims) = SqueezeAllDimsF dims
  SqueezeAllDimsF (dim ': dims) = PrependMaybe ('Just dim) (SqueezeAllDimsF dims)

squeezeAll ::
  forall gradient layout device dataType shape.
  -- | input
  Tensor gradient layout device dataType shape ->
  -- | output
  Tensor gradient layout device dataType (SqueezeAllShapeF shape)
squeezeAll :: forall (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
Tensor gradient layout device dataType shape
-> Tensor gradient layout device dataType (SqueezeAllShapeF shape)
squeezeAll Tensor gradient layout device dataType shape
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.squeeze_t Tensor gradient layout device dataType shape
input

type SqueezeDimByIndexF :: Nat -> [Dim (Name Symbol) (Size Nat)] -> Maybe [Dim (Name Symbol) (Size Nat)]
type family SqueezeDimByIndexF dimIndex dims where
  SqueezeDimByIndexF 0 (x ': xs) =
    If
      (UnifyCheck (Dim (Name Symbol) (Size Nat)) x ('Dim ('Name "*") ('Size 1)))
      ('Just xs)
      'Nothing
  SqueezeDimByIndexF dimIndex (x ': xs) = PrependMaybe ('Just x) (SqueezeDimByIndexF (dimIndex - 1) xs)
  SqueezeDimByIndexF _ _ = 'Nothing

type SqueezeDimImplF :: By Symbol Nat -> [Dim (Name Symbol) (Size Nat)] -> Maybe [Dim (Name Symbol) (Size Nat)]
type family SqueezeDimImplF by dims where
  SqueezeDimImplF ('ByName dimName) dims = SqueezeDimByNameF (GetIndexByNameF dimName dims) dims
  SqueezeDimImplF ('ByIndex dimIndex) dims = SqueezeDimByIndexF dimIndex dims

type SqueezeDimByNameF :: Maybe Nat -> [Dim (Name Symbol) (Size Nat)] -> Maybe [Dim (Name Symbol) (Size Nat)]
type family SqueezeDimByNameF dimIndex dims where
  SqueezeDimByNameF 'Nothing dims = 'Nothing
  SqueezeDimByNameF ('Just dimIndex) dims = SqueezeDimByIndexF dimIndex dims

type SqueezeDimMessage :: By Symbol Nat -> [Dim (Name Symbol) (Size Nat)] -> ErrorMessage

type SqueezeDimMessage by dims =
  "Cannot squeeze the tensor with the dimensions"
    % ""
    % "    '" <> dims <> "'"
    % ""
    % "at the dimension"
    % ""
    % "    '" <> by <> "'."
    % ""

type SqueezeDimCheckF :: By Symbol Nat -> [Dim (Name Symbol) (Size Nat)] -> Maybe [Dim (Name Symbol) (Size Nat)] -> [Dim (Name Symbol) (Size Nat)]
type family SqueezeDimCheckF by dims result where
  SqueezeDimCheckF by dims 'Nothing = TypeError (SqueezeDimMessage by dims)
  SqueezeDimCheckF _ _ ('Just dims) = dims

-- | Calculate the output shape of a squeeze along a given dimension
--
-- >>> :kind! SqueezeDimF ('SelectDim ('ByIndex 1)) ('Shape '[ 'Dim ('Name "*") ('Size 2), 'Dim ('Name "*") ('Size 1), 'Dim ('Name "*") ('Size 2)])
-- ...
-- = 'Shape '[ 'Dim ('Name "*") ('Size 2), 'Dim ('Name "*") ('Size 2)]
type SqueezeDimF :: SelectDim (By Symbol Nat) -> Shape [Dim (Name Symbol) (Size Nat)] -> Shape [Dim (Name Symbol) (Size Nat)]
type family SqueezeDimF selectDim shape where
  SqueezeDimF 'UncheckedSelectDim _ = 'UncheckedShape
  SqueezeDimF _ 'UncheckedShape = 'UncheckedShape
  SqueezeDimF ('SelectDim by) ('Shape dims) = 'Shape (SqueezeDimCheckF by dims (SqueezeDimImplF by dims))

-- | Squeeze a particular dimension.
--
-- >>> t <- sOnes $ TensorSpec (SGradient SWithoutGradient) (SLayout SDense) (SDevice SCPU) (SDataType SFloat) (SShape $ SNoName :&: SSize @2 :|: SNoName :&: SSize @1 :|: SNoName :&: SSize @2 :|: SNoName :&: SSize @1 :|: SNoName :&: SSize @2 :|: SNil)
-- >>> result <- sSqueezeDim (SSelectDim $ SByIndex @1) t
-- >>> :t result
-- result
--   :: Tensor
--        ('Gradient 'WithoutGradient)
--        ('Layout 'Dense)
--        ('Device 'CPU)
--        ('DataType 'Float)
--        ('Shape
--           '[ 'Dim ('Name "*") ('Size 2), 'Dim ('Name "*") ('Size 2),
--              'Dim ('Name "*") ('Size 1), 'Dim ('Name "*") ('Size 2)])
-- >>> result
-- Tensor Float [2,2,1,2] [[[[ 1.0000   ,  1.0000   ]],
--                          [[ 1.0000   ,  1.0000   ]]],
--                         [[[ 1.0000   ,  1.0000   ]],
--                          [[ 1.0000   ,  1.0000   ]]]]
sSqueezeDim ::
  forall selectDim gradient layout device dataType shape shape' m.
  (MonadThrow m, shape' ~ SqueezeDimF selectDim shape, Catch shape') =>
  SSelectDim selectDim ->
  Tensor gradient layout device dataType shape ->
  m (Tensor gradient layout device dataType shape')
sSqueezeDim :: forall (selectDim :: SelectDim (By Symbol Nat))
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape' :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(MonadThrow m, shape' ~ SqueezeDimF selectDim shape,
 Catch shape') =>
SSelectDim selectDim
-> Tensor gradient layout device dataType shape
-> m (Tensor gradient layout device dataType shape')
sSqueezeDim SSelectDim selectDim
selectDim Tensor gradient layout device dataType shape
input =
  let by :: By String Integer
by = forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SSelectDim selectDim
selectDim
   in case By String Integer
by of
        ByName String
dimName -> forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> ForeignPtr Dimname -> IO (ForeignPtr Tensor)
ATen.squeeze_tn Tensor gradient layout device dataType shape
input String
dimName
        ByIndex Integer
dimIndex -> forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr Tensor -> Int64 -> IO (ForeignPtr Tensor)
ATen.squeeze_tl Tensor gradient layout device dataType shape
input (forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
dimIndex :: Int)

-- | Expands a tensor to the specified shape.
sExpand ::
  forall shape' shape'' gradient layout device dataType shape.
  (shape'' ~ BroadcastShapesF shape shape', Catch shape'') =>
  -- | new shape
  SShape shape' ->
  -- | input tensor
  Tensor gradient layout device dataType shape ->
  -- | output tensor
  Tensor gradient layout device dataType shape''
sExpand :: forall (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape'' :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(shape'' ~ BroadcastShapesF shape shape', Catch shape'') =>
SShape shape'
-> Tensor gradient layout device dataType shape
-> Tensor gradient layout device dataType shape''
sExpand SShape shape'
shape' Tensor gradient layout device dataType shape
input =
  let sizes' :: [Integer]
sizes' = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\(Dim IsChecked String
_ IsChecked Integer
size) -> forall a. IsChecked a -> a
forgetIsChecked IsChecked Integer
size) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IsChecked a -> a
forgetIsChecked forall a b. (a -> b) -> a -> b
$ forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing SShape shape'
shape'
   in forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor
-> ForeignPtr IntArray -> CBool -> IO (ForeignPtr Tensor)
ATen.tensor_expand_lb Tensor gradient layout device dataType shape
input [Integer]
sizes' Bool
True

-- | Expands a tensor to the specified shape.
expand ::
  forall shape' shape'' gradient layout device dataType shape.
  (SingI shape', shape'' ~ BroadcastShapesF shape shape', Catch shape'') =>
  Tensor gradient layout device dataType shape ->
  Tensor gradient layout device dataType shape''
expand :: forall (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape'' :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(SingI shape', shape'' ~ BroadcastShapesF shape shape',
 Catch shape'') =>
Tensor gradient layout device dataType shape
-> Tensor gradient layout device dataType shape''
expand = forall (shape' :: Shape [Dim (Name Symbol) (Size Nat)])
       (shape'' :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shape :: Shape [Dim (Name Symbol) (Size Nat)]).
(shape'' ~ BroadcastShapesF shape shape', Catch shape'') =>
SShape shape'
-> Tensor gradient layout device dataType shape
-> Tensor gradient layout device dataType shape''
sExpand (forall {k} (a :: k). SingI a => Sing a
sing @shape')

-- | Slices the self tensor along the selected dimension at the given index. This function returns a view of the original tensor with the given dimension removed.
--
-- >>> nats <- sArangeNaturals (SGradient SWithoutGradient) (SLayout SDense) (SDevice SCPU) (SDataType SInt32) (SSize @8)
-- >>> input <- sReshape (SShape $ SName @"*" :&: SSize @4 :|: SName @"*" :&: SSize @2 :|: SNil) nats
-- >>> input
-- Tensor Int32 [4,2] [[ 0,  1],
--                     [ 2,  3],
--                     [ 4,  5],
--                     [ 6,  7]]
--
-- 'index' can be provided at compile-time:
--
-- >>> sSelect (SSelectDim (SByIndex @0)) (SIndex @1) input
-- Tensor Int32 [2] [ 2,  3]
--
-- 'index' can also be provided at runtime:
--
-- >>> sSelect (SSelectDim (SByIndex @0)) (SUncheckedIndex 1) input
-- Tensor Int32 [2] [ 2,  3]
--
-- It produces a runtime error if the 'index' is too large:
--
-- >>> sSelect (SSelectDim (SByIndex @0)) (SUncheckedIndex 10) input
-- *** Exception: IndexOutOfBoundError {ioobeIndex = 10, ioobeDim = Dim {dimName = "*", dimSize = 4}}
sSelect ::
  forall selectDim index gradient layout device dataType shapeIn shapeOut m.
  ( index `InRangeF` GetDimF selectDim shapeIn,
    shapeOut ~ RemoveDimF selectDim shapeIn,
    Catch shapeOut,
    SGetShape shapeIn,
    MonadThrow m
  ) =>
  SSelectDim selectDim ->
  SIndex index ->
  Tensor gradient layout device dataType shapeIn ->
  m (Tensor gradient layout device dataType shapeOut)
sSelect :: forall (selectDim :: SelectDim (By Symbol Nat))
       (index :: Index Nat) (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shapeIn :: Shape [Dim (Name Symbol) (Size Nat)])
       (shapeOut :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(InRangeF index (GetDimF selectDim shapeIn),
 shapeOut ~ RemoveDimF selectDim shapeIn, Catch shapeOut,
 SGetShape shapeIn, MonadThrow m) =>
SSelectDim selectDim
-> SIndex index
-> Tensor gradient layout device dataType shapeIn
-> m (Tensor gradient layout device dataType shapeOut)
sSelect SSelectDim selectDim
sSelectDim SIndex index
sIndex Tensor gradient layout device dataType shapeIn
input = do
  SDim (GetDimF selectDim shapeIn)
sDim <- let inputShape :: SShape shapeIn
inputShape = forall (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType).
SGetShape shape =>
Tensor gradient layout device dataType shape -> SShape shape
sGetShape Tensor gradient layout device dataType shapeIn
input in forall (selectDim :: SelectDim (By Symbol Nat))
       (shape :: Shape [Dim (Name Symbol) (Size Nat)])
       (dim :: Dim (Name Symbol) (Size Nat)) (m :: * -> *).
(dim ~ GetDimF selectDim shape, MonadThrow m) =>
SSelectDim selectDim -> SShape shape -> m (SDim dim)
sGetDimFromShape SSelectDim selectDim
sSelectDim SShape shapeIn
inputShape
  let dim :: Dim String Integer
dim = forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap forall a. IsChecked a -> a
forgetIsChecked forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SDim (GetDimF selectDim shapeIn)
sDim
      index :: Integer
index = coerce :: forall a b. Coercible a b => a -> b
coerce forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SIndex index
sIndex
      selectDim :: By String Integer
selectDim = forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SSelectDim selectDim
sSelectDim
  if Integer
index forall a. Ord a => a -> a -> Bool
< forall name size. Dim name size -> size
dimSize Dim String Integer
dim
    then forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ case By String Integer
selectDim of
      ByName String
dimName -> forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor
-> ForeignPtr Dimname -> Int64 -> IO (ForeignPtr Tensor)
ATen.tensor_select_nl Tensor gradient layout device dataType shapeIn
input String
dimName (forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
index :: Int)
      ByIndex Integer
dimIndex -> forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
ATen.cast3 ForeignPtr Tensor -> Int64 -> Int64 -> IO (ForeignPtr Tensor)
ATen.tensor_select_ll Tensor gradient layout device dataType shapeIn
input (forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
dimIndex :: Int) (forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
index :: Int)
    else forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM forall a b. (a -> b) -> a -> b
$ Integer -> Dim String Integer -> IndexOutOfBoundError
IndexOutOfBoundError Integer
index Dim String Integer
dim

data IndexOutOfBoundError = IndexOutOfBoundError {IndexOutOfBoundError -> Integer
ioobeIndex :: Integer, IndexOutOfBoundError -> Dim String Integer
ioobeDim :: Dim String Integer}
  deriving stock (Int -> IndexOutOfBoundError -> ShowS
[IndexOutOfBoundError] -> ShowS
IndexOutOfBoundError -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [IndexOutOfBoundError] -> ShowS
$cshowList :: [IndexOutOfBoundError] -> ShowS
show :: IndexOutOfBoundError -> String
$cshow :: IndexOutOfBoundError -> String
showsPrec :: Int -> IndexOutOfBoundError -> ShowS
$cshowsPrec :: Int -> IndexOutOfBoundError -> ShowS
Show, Typeable)

instance Exception IndexOutOfBoundError where
  displayException :: IndexOutOfBoundError -> String
displayException IndexOutOfBoundError {Integer
Dim String Integer
ioobeDim :: Dim String Integer
ioobeIndex :: Integer
ioobeDim :: IndexOutOfBoundError -> Dim String Integer
ioobeIndex :: IndexOutOfBoundError -> Integer
..} =
    String
"Index `"
      forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Integer
ioobeIndex
      forall a. Semigroup a => a -> a -> a
<> String
"` is out of bounds for dimension `"
      forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Dim String Integer
ioobeDim
      forall a. Semigroup a => a -> a -> a
<> String
"`."

select ::
  forall selectDim index gradient layout device dataType shapeIn shapeOut m.
  ( SingI selectDim,
    SingI index,
    index `InRangeF` GetDimF selectDim shapeIn,
    shapeOut ~ RemoveDimF selectDim shapeIn,
    Catch shapeOut,
    SGetShape shapeIn,
    MonadThrow m
  ) =>
  Tensor gradient layout device dataType shapeIn ->
  m (Tensor gradient layout device dataType shapeOut)
select :: forall (selectDim :: SelectDim (By Symbol Nat))
       (index :: Index Nat) (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shapeIn :: Shape [Dim (Name Symbol) (Size Nat)])
       (shapeOut :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(SingI selectDim, SingI index,
 InRangeF index (GetDimF selectDim shapeIn),
 shapeOut ~ RemoveDimF selectDim shapeIn, Catch shapeOut,
 SGetShape shapeIn, MonadThrow m) =>
Tensor gradient layout device dataType shapeIn
-> m (Tensor gradient layout device dataType shapeOut)
select = forall (selectDim :: SelectDim (By Symbol Nat))
       (index :: Index Nat) (gradient :: Gradient RequiresGradient)
       (layout :: Layout LayoutType) (device :: Device (DeviceType Nat))
       (dataType :: DataType DType)
       (shapeIn :: Shape [Dim (Name Symbol) (Size Nat)])
       (shapeOut :: Shape [Dim (Name Symbol) (Size Nat)]) (m :: * -> *).
(InRangeF index (GetDimF selectDim shapeIn),
 shapeOut ~ RemoveDimF selectDim shapeIn, Catch shapeOut,
 SGetShape shapeIn, MonadThrow m) =>
SSelectDim selectDim
-> SIndex index
-> Tensor gradient layout device dataType shapeIn
-> m (Tensor gradient layout device dataType shapeOut)
sSelect (forall {k} (a :: k). SingI a => Sing a
sing @selectDim) (forall {k} (a :: k). SingI a => Sing a
sing @index)

-- | 'GatherDimImplF' is a type-level helper function for 'sGatherDim'.
--
-- >>> :kind! GatherDimImplF ('ByIndex 1) '[ 'Dim ('Name "batch") ('Size 2), 'Dim ('Name "sequence") ('Size 4), 'Dim ('Name "feature") ('Size 1)] '[ 'Dim ('Name "batch") ('Size 2), 'Dim ('Name "sequence") ('Size 1), 'Dim ('Name "feature") ('Size 1)]
-- ...
-- = 'Just
--     '[ 'Dim ('Name "batch") ('Size 2),
--        'Dim ('Name "sequence") ('Size 4),
--        'Dim ('Name "feature") ('Size 1)]
-- >>> :kind! GatherDimImplF ('ByIndex 1) '[ 'Dim ('Name "batch") ('Size 2), 'Dim ('Name "sequence") ('Size 4), 'Dim ('Name "feature") ('Size 1)] '[ 'Dim ('Name "*") ('Size 2), 'Dim ('Name "sequence") ('Size 1), 'Dim ('Name "*") ('Size 1)]
-- ...
-- = 'Just
--     '[ 'Dim ('Name "batch") ('Size 2),
--        'Dim ('Name "sequence") ('Size 4),
--        'Dim ('Name "feature") ('Size 1)]
-- >>> :kind! GatherDimImplF ('ByIndex 1) '[ 'Dim ('Name "batch") ('Size 2), 'Dim ('Name "sequence") ('Size 4), 'Dim ('Name "feature") ('Size 2)] '[ 'Dim ('Name "batch") ('Size 2), 'Dim ('Name "sequence") ('Size 1), 'Dim ('Name "feature") ('Size 1)]
-- ...
-- = 'Nothing
-- >>> :kind! GatherDimImplF ('ByIndex 1) '[ 'Dim ('Name "batch") ('Size 2), 'Dim ('Name "sequence") ('Size 4), 'Dim ('Name "feature") ('Size 1)] '[ 'Dim ('Name "batch") ('Size 2), 'Dim ('Name "boo") ('Size 1), 'Dim ('Name "feature") ('Size 1)]
-- ...
-- = 'Nothing
-- >>> :kind! GatherDimImplF ('ByIndex 1) '[ 'Dim ('Name "batch") ('Size 2), 'Dim ('Name "sequence") ('Size 0), 'Dim ('Name "feature") ('Size 1)] '[ 'Dim ('Name "batch") ('Size 2), 'Dim ('Name "sequence") ('Size 2), 'Dim ('Name "feature") ('Size 1)]
-- ...
-- = 'Nothing
-- >>> :kind! GatherDimImplF ('ByIndex 1) '[ 'Dim ('Name "batch") ('Size 2), 'Dim ('Name "sequence") ('Size 1)] '[ 'Dim ('Name "batch") ('Size 2), 'Dim ('Name "sequence") ('Size 1), 'Dim ('Name "feature") ('Size 1)]
-- ...
-- = 'Nothing
-- >>> :kind! GatherDimImplF ('ByIndex 2) '[ 'Dim ('Name "batch") ('Size 2), 'Dim ('Name "sequence") ('Size 1), 'Dim ('Name "feature") ('Size 3)] '[ 'Dim ('Name "batch") ('Size 2), 'Dim ('Name "sequence") ('Size 1), 'Dim ('Name "feature") ('Size 1)]
-- ...
-- = 'Just
--     '[ 'Dim ('Name "batch") ('Size 2),
--        'Dim ('Name "sequence") ('Size 1),
--        'Dim ('Name "feature") ('Size 3)]
-- >>> :kind! GatherDimImplF ('ByName "feature") '[ 'Dim ('Name "batch") ('Size 2), 'Dim ('Name "sequence") ('Size 1), 'Dim ('Name "feature") ('Size 3)] '[ 'Dim ('Name "batch") ('Size 2), 'Dim ('Name "sequence") ('Size 1), 'Dim ('Name "*") ('Size 1)]
-- ...
-- = 'Just
--     '[ 'Dim ('Name "batch") ('Size 2),
--        'Dim ('Name "sequence") ('Size 1),
--        'Dim ('Name "feature") ('Size 3)]
type GatherDimImplF :: By Symbol Nat -> [Dim (Name Symbol) (Size Nat)] -> [Dim (Name Symbol) (Size Nat)] -> Maybe [Dim (Name Symbol) (Size Nat)]
type family GatherDimImplF by indexDims inputDims where
  GatherDimImplF ('ByName dimName) indexDims inputDims = GatherDimByNameF (GetIndexByNameF dimName indexDims) (GetIndexByNameF dimName inputDims) indexDims inputDims
  GatherDimImplF ('ByIndex dimIndex) indexDims inputDims = GatherDimByIndexF dimIndex indexDims inputDims

type GatherDimByIndexF :: Nat -> [Dim (Name Symbol) (Size Nat)] -> [Dim (Name Symbol) (Size Nat)] -> Maybe [Dim (Name Symbol) (Size Nat)]
type family GatherDimByIndexF dimIndex indexDims inputDims where
  GatherDimByIndexF 0 ('Dim _ ('Size 0) ': _) _ = 'Nothing
  GatherDimByIndexF 0 ('Dim indexDimName indexDimSize ': indexDims) ('Dim inputDimName _ ': inputDims) =
    If
      (UnifyCheck (Name Symbol) indexDimName inputDimName && UnifyCheck [Dim (Name Symbol) (Size Nat)] indexDims inputDims)
      ('Just ('Dim (indexDimName <+> inputDimName) indexDimSize ': (indexDims <+> inputDims)))
      'Nothing
  GatherDimByIndexF dimIndex (indexDim ': indexDims) (inputDim ': inputDims) =
    If
      (UnifyCheck (Dim (Name Symbol) (Size Nat)) indexDim inputDim)
      (PrependMaybe ('Just (indexDim <+> inputDim)) (GatherDimByIndexF (dimIndex - 1) indexDims inputDims))
      'Nothing
  GatherDimByIndexF _ _ _ = 'Nothing

type GatherDimByNameF :: Maybe Nat -> Maybe Nat -> [Dim (Name Symbol) (Size Nat)] -> [Dim (Name Symbol) (Size Nat)] -> Maybe [Dim (Name Symbol) (Size Nat)]
type family GatherDimByNameF dimIndex dimIndex' indexDims inputDims where
  GatherDimByNameF 'Nothing ('Just dimIndex') indexDims inputDims = GatherDimByIndexF dimIndex' indexDims inputDims
  GatherDimByNameF ('Just dimIndex) 'Nothing indexDims inputDims = GatherDimByIndexF dimIndex indexDims inputDims
  GatherDimByNameF ('Just dimIndex) ('Just dimIndex) indexDims inputDims = GatherDimByIndexF dimIndex indexDims inputDims
  GatherDimByNameF _ _ _ _ = 'Nothing

type GatherDimMessage :: By Symbol Nat -> [Dim (Name Symbol) (Size Nat)] -> [Dim (Name Symbol) (Size Nat)] -> ErrorMessage

type GatherDimMessage by indexDims inputDims =
  "Cannot gather the tensor with the dimensions"
    % ""
    % "    '" <> inputDims <> "'"
    % ""
    % "at the dimension"
    % ""
    % "    '" <> by <> "'"
    % ""
    % "using an index of shape"
    % ""
    % "    '" <> indexDims <> "'."
    % ""

type GatherDimCheckF ::
  By Symbol Nat ->
  [Dim (Name Symbol) (Size Nat)] ->
  [Dim (Name Symbol) (Size Nat)] ->
  Maybe [Dim (Name Symbol) (Size Nat)] ->
  [Dim (Name Symbol) (Size Nat)]
type family GatherDimCheckF by indexDims inputDims result where
  GatherDimCheckF by indexDims inputDims 'Nothing = TypeError (GatherDimMessage by indexDims inputDims)
  GatherDimCheckF _ _ _ ('Just dims) = dims

-- | Calculate the output shape of a gather operation for a given index shape along a given axis.
--
-- >>> :kind! GatherDimF ('SelectDim ('ByIndex 2)) ('Shape '[ 'Dim ('Name "*") ('Size 2), 'Dim ('Name "*") ('Size 1), 'Dim ('Name "*") ('Size 3)]) ('Shape '[ 'Dim ('Name "*") ('Size 2), 'Dim ('Name "*") ('Size 1), 'Dim ('Name "*") ('Size 1)])
-- ...
-- = 'Shape
--     '[ 'Dim ('Name "*") ('Size 2), 'Dim ('Name "*") ('Size 1),
--        'Dim ('Name "*") ('Size 3)]
type GatherDimF :: SelectDim (By Symbol Nat) -> Shape [Dim (Name Symbol) (Size Nat)] -> Shape [Dim (Name Symbol) (Size Nat)] -> Shape [Dim (Name Symbol) (Size Nat)]
type family GatherDimF selectDim indexShape inputShape where
  GatherDimF 'UncheckedSelectDim _ _ = 'UncheckedShape
  GatherDimF _ 'UncheckedShape _ = 'UncheckedShape
  GatherDimF _ _ 'UncheckedShape = 'UncheckedShape
  GatherDimF ('SelectDim by) ('Shape indexDims) ('Shape inputDims) = 'Shape (GatherDimCheckF by indexDims inputDims (GatherDimImplF by indexDims inputDims))

-- | Gather values along an axis for a specified dimension.
--
-- >>> sToTensor' = sToTensor (SGradient SWithoutGradient) (SLayout SDense) (SDevice SCPU)
-- >>> t <- sToTensor' [[1 :: Float, 2], [3, 4]]
-- >>> idx <- sToTensor' [[0 :: Int, 0], [1, 0]]
-- >>> result <- sGatherDim (SSelectDim $ SByIndex @1) idx t
-- >>> :t result
-- result
--   :: Tensor
--        ('Gradient 'WithoutGradient)
--        ('Layout 'Dense)
--        ('Device 'CPU)
--        ('DataType 'Float)
--        ('Shape
--           '[ 'Dim ('Name "*") 'UncheckedSize,
--              'Dim ('Name "*") 'UncheckedSize])
-- >>> result
-- Tensor Float [2,2] [[ 1.0000   ,  1.0000   ],
--                     [ 4.0000   ,  3.0000   ]]
-- >>> shape = SShape $ SNoName :&: SSize @2 :|: SNoName :&: SSize @2 :|: SNil
-- >>> t' <- sCheckedShape shape t
-- >>> idx' <- sCheckedShape shape idx
-- >>> result <- sGatherDim (SSelectDim $ SByIndex @1) idx' t'
-- >>> :t result
-- result
--   :: Tensor
--        ('Gradient 'WithoutGradient)
--        ('Layout 'Dense)
--        ('Device 'CPU)
--        ('DataType 'Float)
--        ('Shape '[ 'Dim ('Name "*") ('Size 2), 'Dim ('Name "*") ('Size 2)])
-- >>> result
-- Tensor Float [2,2] [[ 1.0000   ,  1.0000   ],
--                     [ 4.0000   ,  3.0000   ]]
sGatherDim ::
  forall selectDim indexGradient inputGradient indexLayout inputLayout indexDevice inputDevice indexDataType inputDataType indexShape inputShape outputShape m.
  (MonadThrow m, outputShape ~ GatherDimF selectDim indexShape inputShape, Catch outputShape, Catch (indexDataType <+> 'DataType 'Int64)) =>
  SSelectDim selectDim ->
  -- | the indices of elements to gather
  Tensor indexGradient indexLayout indexDevice indexDataType indexShape ->
  -- | input
  Tensor inputGradient inputLayout inputDevice inputDataType inputShape ->
  -- | output
  m
    ( Tensor
        (indexGradient <|> inputGradient)
        (indexLayout <+> inputLayout)
        (indexDevice <+> inputDevice)
        inputDataType
        outputShape
    )
sGatherDim :: forall (selectDim :: SelectDim (By Symbol Nat))
       (indexGradient :: Gradient RequiresGradient)
       (inputGradient :: Gradient RequiresGradient)
       (indexLayout :: Layout LayoutType)
       (inputLayout :: Layout LayoutType)
       (indexDevice :: Device (DeviceType Nat))
       (inputDevice :: Device (DeviceType Nat))
       (indexDataType :: DataType DType) (inputDataType :: DataType DType)
       (indexShape :: Shape [Dim (Name Symbol) (Size Nat)])
       (inputShape :: Shape [Dim (Name Symbol) (Size Nat)])
       (outputShape :: Shape [Dim (Name Symbol) (Size Nat)])
       (m :: * -> *).
(MonadThrow m,
 outputShape ~ GatherDimF selectDim indexShape inputShape,
 Catch outputShape, Catch (indexDataType <+> 'DataType 'Int64)) =>
SSelectDim selectDim
-> Tensor
     indexGradient indexLayout indexDevice indexDataType indexShape
-> Tensor
     inputGradient inputLayout inputDevice inputDataType inputShape
-> m (Tensor
        (indexGradient <|> inputGradient)
        (indexLayout <+> inputLayout)
        (indexDevice <+> inputDevice)
        inputDataType
        outputShape)
sGatherDim SSelectDim selectDim
selectDim Tensor
  indexGradient indexLayout indexDevice indexDataType indexShape
index Tensor
  inputGradient inputLayout inputDevice inputDataType inputShape
input =
  let by :: By String Integer
by = forall a. IsChecked a -> a
forgetIsChecked forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (a :: k). SingKind k => Sing a -> Demote k
fromSing forall a b. (a -> b) -> a -> b
$ SSelectDim selectDim
selectDim
   in case By String Integer
by of
        ByName String
dimName -> forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
ATen.cast4 ForeignPtr Tensor
-> ForeignPtr Dimname
-> ForeignPtr Tensor
-> CBool
-> IO (ForeignPtr Tensor)
ATen.gather_tntb Tensor
  inputGradient inputLayout inputDevice inputDataType inputShape
input String
dimName Tensor
  indexGradient indexLayout indexDevice indexDataType indexShape
index Bool
False
        ByIndex Integer
dimIndex -> forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
ATen.cast4 ForeignPtr Tensor
-> Int64 -> ForeignPtr Tensor -> CBool -> IO (ForeignPtr Tensor)
ATen.gather_tltb Tensor
  inputGradient inputLayout inputDevice inputDataType inputShape
input (forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
dimIndex :: Int) Tensor
  indexGradient indexLayout indexDevice indexDataType indexShape
index Bool
False