hasktorch-gradually-typed-0.2.0.0: experimental project for hasktorch
Safe HaskellSafe-Inferred
LanguageHaskell2010

Torch.GraduallyTyped.Tensor.IndexingSlicingJoining

Synopsis

Documentation

>>> import Torch.GraduallyTyped.Prelude.List (SList (..))
>>> import Torch.GraduallyTyped

class HasCat (selectDim :: SelectDim (By Symbol Nat)) k (c :: k -> Type) (a :: k) where Source #

Minimal complete definition

sCat

Associated Types

type CatF selectDim a c :: Type Source #

Methods

sCat :: forall m. MonadThrow m => SSelectDim selectDim -> c a -> m (CatF selectDim a c) Source #

Concatenates the given sequence of seq tensors in the given dimension. All tensors must either have the same shape (except in the concatenating dimension) or be empty.

>>> t <- ones @('Gradient 'WithGradient) @('Layout 'Dense) @('Device 'CPU) @('DataType 'Float) @('Shape '[ 'Dim ('Name "batch") ('Size 32), 'Dim ('Name "feature") ('Size 8)])
>>> :type cat @('SelectDim ('ByName "feature")) [t]
cat @('SelectDim ('ByName "feature")) [t]
  :: MonadThrow m =>
     m (Tensor
          ('Gradient 'WithGradient)
          ('Layout 'Dense)
          ('Device 'CPU)
          ('DataType 'Float)
          ('Shape
             '[ 'Dim ('Name "batch") ('Size 32),
                'Dim 'UncheckedName 'UncheckedSize]))
>>> :type cat @('SelectDim ( 'ByIndex 0)) [t]
cat @('SelectDim ( 'ByIndex 0)) [t]
  :: MonadThrow m =>
     m (Tensor
          ('Gradient 'WithGradient)
          ('Layout 'Dense)
          ('Device 'CPU)
          ('DataType 'Float)
          ('Shape
             '[ 'Dim 'UncheckedName 'UncheckedSize,
                'Dim ('Name "feature") ('Size 8)]))
>>> :type sCat (SUncheckedSelectDim (ByIndex 0)) [t]
sCat (SUncheckedSelectDim (ByIndex 0)) [t]
  :: MonadThrow m =>
     m (Tensor
          ('Gradient 'WithGradient)
          ('Layout 'Dense)
          ('Device 'CPU)
          ('DataType 'Float)
          'UncheckedShape)

cat :: forall m. (SingI selectDim, MonadThrow m) => c a -> m (CatF selectDim a c) Source #

Instances

Instances details
Castable (CatListF selectDim (Tensor gradient layout device dataType shape)) (ForeignPtr Tensor) => HasCat selectDim Type [] (Tensor gradient layout device dataType shape :: TYPE LiftedRep) Source # 
Instance details

Defined in Torch.GraduallyTyped.Tensor.IndexingSlicingJoining

Associated Types

type CatF selectDim (Tensor gradient layout device dataType shape) [] Source #

Methods

sCat :: MonadThrow m => SSelectDim selectDim -> [Tensor gradient layout device dataType shape] -> m (CatF selectDim (Tensor gradient layout device dataType shape) []) Source #

cat :: (SingI selectDim, MonadThrow m) => [Tensor gradient layout device dataType shape] -> m (CatF selectDim (Tensor gradient layout device dataType shape) []) Source #

(Castable (CatHListF selectDim tensors) (ForeignPtr Tensor), Castable (HList tensors) (ForeignPtr TensorList)) => HasCat selectDim [Type] (HList :: [Type] -> Type) (tensors :: [Type]) Source # 
Instance details

Defined in Torch.GraduallyTyped.Tensor.IndexingSlicingJoining

Associated Types

type CatF selectDim tensors HList Source #

Methods

sCat :: MonadThrow m => SSelectDim selectDim -> HList tensors -> m (CatF selectDim tensors HList) Source #

cat :: (SingI selectDim, MonadThrow m) => HList tensors -> m (CatF selectDim tensors HList) Source #

type family CatListImplF (selectDim :: SelectDim (By Symbol Nat)) (tensor :: Type) :: Maybe Type where ... Source #

Equations

CatListImplF 'UncheckedSelectDim (Tensor gradient layout device dataType _) = 'Just (Tensor gradient layout device dataType 'UncheckedShape) 
CatListImplF ('SelectDim _) (Tensor gradient layout device dataType 'UncheckedShape) = 'Just (Tensor gradient layout device dataType 'UncheckedShape) 
CatListImplF ('SelectDim by) (Tensor gradient layout device dataType ('Shape dims)) = MapMaybe (Tensor gradient layout device dataType) (MapMaybe 'Shape (ReplaceDimImplF by dims ('Dim 'UncheckedName 'UncheckedSize))) 

type CheckSpellingMessage = "Check the spelling of named dimensions, and make sure the number of dimensions is correct." Source #

type family CatListCheckF (selectDim :: SelectDim (By Symbol Nat)) (tensor :: Type) (result :: Maybe Type) :: Type where ... Source #

Equations

CatListCheckF selectDim (Tensor _ _ _ _ shape) 'Nothing = TypeError ("Cannot concatenate the dimension" % ("" % ((" " <> selectDim) % ("" % ("for tensors of shape" % ("" % (((" " <> shape) <> ".") % ("" % CheckSpellingMessage)))))))) 
CatListCheckF _ _ ('Just result) = result 

type CatListF selectDim tensor = CatListCheckF selectDim tensor (CatListImplF selectDim tensor) Source #

type family CatHListImplF selectDim tensors acc where ... Source #

Equations

CatHListImplF _ '[] 'Nothing = TypeError (ToErrorMessage "Cannot concatenate an empty list of tensors.") 
CatHListImplF _ '[] ('Just '(gradient, layout, device, dataType, shape)) = Tensor gradient layout device dataType shape 
CatHListImplF selectDim (Tensor gradient layout device dataType shape ': tensors) 'Nothing = CatHListImplF selectDim tensors ('Just '(gradient, layout, device, dataType, shape)) 
CatHListImplF selectDim (Tensor gradient layout device dataType shape ': tensors) ('Just '(gradient', layout', device', dataType', shape')) = CatHListImplF selectDim tensors ('Just '(gradient <|> gradient', layout <+> layout', device <+> device', dataType <+> dataType', ReplaceDimF selectDim (shape <+> ReplaceDimF selectDim shape' (GetDimF selectDim shape)) (AddDimF (GetDimF selectDim shape) (GetDimF selectDim shape')))) 
CatHListImplF _ (x ': _) _ = TypeError ("Cannot concatenate because" % ("" % (((" '" <> x) <> "'") % ("" % "is not a tensor type.")))) 

type CatHListF selectDim tensors = CatHListImplF selectDim tensors 'Nothing Source #

type ReshapeNumelMismatchMessage numel numel' shape shape' = "Cannot reshape the tensor. The original shape," % ("" % (((" '" <> shape) <> "',") % ("" % ("and the new shape," % ("" % (((" '" <> shape') <> "',") % ("" % ("have different total numbers of elements," % ("" % (((((" '" <> numel) <> "' versus '") <> numel') <> "',") % ("" % "respectively."))))))))))) Source #

type family ReshapeImplF numel numel' shape shape' where ... Source #

Equations

ReshapeImplF ('Just numel) ('Just numel) _ shape' = shape' 
ReshapeImplF ('Just numel) ('Just numel') shape shape' = TypeError (ReshapeNumelMismatchMessage numel numel' shape shape') 
ReshapeImplF 'Nothing _ _ _ = 'UncheckedShape 
ReshapeImplF _ 'Nothing _ _ = 'UncheckedShape 
ReshapeImplF _ _ 'UncheckedShape _ = 'UncheckedShape 
ReshapeImplF _ _ _ 'UncheckedShape = 'UncheckedShape 

type family ReshapeF shape shape' where ... Source #

Equations

ReshapeF shape shape' = ReshapeImplF (NumelF shape) (NumelF shape') shape shape' 

sReshape :: forall m shape' gradient layout device dataType shape shape''. MonadThrow m => (shape'' ~ ReshapeF shape shape', Catch shape'') => SShape shape' -> Tensor gradient layout device dataType shape -> m (Tensor gradient layout device dataType shape'') Source #

Returns a tensor with the same data and number of elements as the input tensor, but with the specified shape:

>>> g <- sMkGenerator (SDevice SCPU) 0
>>> (input, _) <- sRandn (TensorSpec (SGradient SWithGradient) (SLayout SDense) (SDevice SCPU) (SDataType SFloat) (SShape $ SName @"*" :&: SSize @4 :|: SNil)) g
>>> output <- sReshape (SShape $ SName @"*" :&: SSize @2 :|: SName @"*" :&: SSize @2 :|: SNil) input
>>> :type output
output
  :: Tensor
       ('Gradient 'WithGradient)
       ('Layout 'Dense)
       ('Device 'CPU)
       ('DataType 'Float)
       ('Shape '[ 'Dim ('Name "*") ('Size 2), 'Dim ('Name "*") ('Size 2)])

At the value level, a single dimension may be '-1', in which case it is inferred from the remaining dimensions and the number of elements in the input:

>>> output' <- sReshape (SShape $ SUncheckedName "*" :&: SUncheckedSize (-1) :|: SNil) output
>>> :type output'
output'
  :: Tensor
       ('Gradient 'WithGradient)
       ('Layout 'Dense)
       ('Device 'CPU)
       ('DataType 'Float)
       'UncheckedShape
>>> getDims output'
[Dim {dimName = "*", dimSize = 4}]

sSetShape :: forall m shape' gradient layout device dataType shape shape''. MonadThrow m => (shape'' ~ ReshapeF shape shape', Catch shape'') => SShape shape' -> Tensor gradient layout device dataType shape -> m (Tensor gradient layout device dataType shape'') Source #

Returns a tensor with the same data and number of elements as the input tensor, but with the specified shape:

>>> g <- sMkGenerator (SDevice SCPU) 0
>>> (input, _) <- sRandn (TensorSpec (SGradient SWithGradient) (SLayout SDense) (SDevice SCPU) (SDataType SFloat) (SShape $ SName @"*" :&: SSize @4 :|: SNil)) g
>>> output <- sReshape (SShape $ SName @"*" :&: SSize @2 :|: SName @"*" :&: SSize @2 :|: SNil) input
>>> :type output
output
  :: Tensor
       ('Gradient 'WithGradient)
       ('Layout 'Dense)
       ('Device 'CPU)
       ('DataType 'Float)
       ('Shape '[ 'Dim ('Name "*") ('Size 2), 'Dim ('Name "*") ('Size 2)])

At the value level, a single dimension may be '-1', in which case it is inferred from the remaining dimensions and the number of elements in the input:

>>> output' <- sReshape (SShape $ SUncheckedName "*" :&: SUncheckedSize (-1) :|: SNil) output
>>> :type output'
output'
  :: Tensor
       ('Gradient 'WithGradient)
       ('Layout 'Dense)
       ('Device 'CPU)
       ('DataType 'Float)
       'UncheckedShape
>>> getDims output'
[Dim {dimName = "*", dimSize = 4}]

type family AllDimSizesChecked shape where ... Source #

reshape :: forall m shape' gradient layout device dataType shape shape''. (shape'' ~ ReshapeF shape shape', Catch shape'', When (AllDimSizesChecked shape) (shape' ~ shape''), SingI shape', MonadThrow m) => Tensor gradient layout device dataType shape -> m (Tensor gradient layout device dataType shape'') Source #

type TransposeBy0Message by0 dims = "Cannot transpose the tensor with the dimensions" % ("" % (((" '" <> dims) <> "'") % ("" % ("because the specified source dimension" % ("" % (((" '" <> by0) <> "'") % ("" % "could not be found."))))))) Source #

type TransposeBy1Message by1 dims = "Cannot transpose the tensor with the dimensions" % ("" % (((" '" <> dims) <> "'") % ("" % ("because the specified target dimension" % ("" % (((" '" <> by1) <> "'") % ("" % "could not be found."))))))) Source #

type family TransposeF selectDim0 selectDim1 shape where ... Source #

Compute transposed shapes.

>>> type SelectBatch = 'SelectDim ('ByName "batch" :: By Symbol Nat)
>>> type SelectFeature = 'SelectDim ('ByName "feature" :: By Symbol Nat)
>>> type Dims = '[ 'Dim ('Name "batch") ('Size 10), 'Dim ('Name "feature") ('Size 8), 'Dim ('Name "anotherFeature") ('Size 12)]
>>> :kind! TransposeF SelectBatch SelectFeature ('Shape Dims)
TransposeF SelectBatch SelectFeature ('Shape Dims) :: Shape
                                                        [Dim (Name Symbol) (Size Natural)]
= 'Shape
    '[ 'Dim ('Name "feature") ('Size 8),
       'Dim ('Name "batch") ('Size 10),
       'Dim ('Name "anotherFeature") ('Size 12)]
>>> :kind! TransposeF SelectFeature SelectBatch ('Shape Dims)
TransposeF SelectFeature SelectBatch ('Shape Dims) :: Shape
                                                        [Dim (Name Symbol) (Size Natural)]
= 'Shape
    '[ 'Dim ('Name "feature") ('Size 8),
       'Dim ('Name "batch") ('Size 10),
       'Dim ('Name "anotherFeature") ('Size 12)]

Equations

TransposeF _ _ 'UncheckedShape = 'UncheckedShape 
TransposeF _ 'UncheckedSelectDim _ = 'UncheckedShape 
TransposeF 'UncheckedSelectDim _ _ = 'UncheckedShape 
TransposeF ('SelectDim ('ByName name0)) ('SelectDim ('ByName name1)) ('Shape dims) = 'Shape (TransposeIndexIndexDimsF (FromMaybe (TypeError (TransposeBy0Message ('ByName name0) dims)) (GetIndexByNameF name0 dims)) (FromMaybe (TypeError (TransposeBy1Message ('ByName name1) dims)) (GetIndexByNameF name1 dims)) dims) 
TransposeF ('SelectDim ('ByIndex index0)) ('SelectDim ('ByIndex index1)) ('Shape dims) = 'Shape (TransposeIndexIndexDimsF index0 index1 dims) 
TransposeF ('SelectDim by0) ('SelectDim by1) _ = TypeError ("Cannot transpose the tensor. " % ("" % ("The source and target dimensions must be selected either both by name or both by index, " % ("but mixed selectors were found: " % ("" % (((((" '" <> 'SelectDim by0) <> "' and '") <> 'SelectDim by1) <> "'.") % "")))))) 

type family TransposeIndexIndexDimsF index0 index1 dims where ... Source #

Equations

TransposeIndexIndexDimsF index0 index1 dims = FromMaybe (TypeError (TransposeBy1Message ('ByIndex index1) dims)) (ReplaceDimImplF ('ByIndex index1) (FromMaybe (TypeError (TransposeBy0Message ('ByIndex index0) dims)) (ReplaceDimImplF ('ByIndex index0) dims (FromMaybe (TypeError (TransposeBy1Message ('ByIndex index1) dims)) (GetDimImplF ('ByIndex index1) dims)))) (FromMaybe (TypeError (TransposeBy0Message ('ByIndex index0) dims)) (GetDimImplF ('ByIndex index0) dims))) 

sTranspose :: forall selectDim0 selectDim1 gradient layout device dataType shape shape' m. (shape' ~ TransposeF selectDim0 selectDim1 shape, Catch shape', MonadThrow m) => SSelectDim selectDim0 -> SSelectDim selectDim1 -> Tensor gradient layout device dataType shape -> m (Tensor gradient layout device dataType shape') Source #

Returns a tensor that is a transposed version of input. The selected dimensions selectDim0 and selectDim1 are swapped.

>>> g <- sMkGenerator (SDevice SCPU) 0
>>> (input, _) <- sRandn (TensorSpec (SGradient SWithGradient) (SLayout SDense) (SDevice SCPU) (SDataType SFloat) (SShape $ SName @"batch" :&: SSize @10 :|: SName @"feature" :&: SSize @5 :|: SNil)) g
>>> output <- sTranspose (SSelectDim (SByName @"batch")) (SSelectDim (SByName @"feature")) input
>>> :type output
output
  :: Tensor
       ('Gradient 'WithGradient)
       ('Layout 'Dense)
       ('Device 'CPU)
       ('DataType 'Float)
       ('Shape
          '[ 'Dim ('Name "feature") ('Size 5),
             'Dim ('Name "batch") ('Size 10)])
>>> output <- sTranspose (SUncheckedSelectDim (ByIndex 0)) (SSelectDim (SByIndex @1)) input
>>> :type output
output
  :: Tensor
       ('Gradient 'WithGradient)
       ('Layout 'Dense)
       ('Device 'CPU)
       ('DataType 'Float)
       'UncheckedShape
>>> getDims output
[Dim {dimName = "feature", dimSize = 5},Dim {dimName = "batch", dimSize = 10}]

transpose :: forall selectDim0 selectDim1 gradient layout device dataType shape shape' m. (shape' ~ TransposeF selectDim0 selectDim1 shape, Catch shape', SingI selectDim0, SingI selectDim1, MonadThrow m) => Tensor gradient layout device dataType shape -> m (Tensor gradient layout device dataType shape') Source #

type UnsqueezeByMessage (by :: By Symbol Nat) (dims :: [Dim (Name Symbol) (Size Nat)]) = "Cannot unsqueeze the tensor with the dimensions" % ("" % (((" '" <> dims) <> "'") % ("" % ("because the specified source dimension" % ("" % (((" '" <> by) <> "'") % ("" % "could not be found."))))))) Source #

type family UnsqueezeF (selectDim :: SelectDim (By Symbol Nat)) (shape :: Shape [Dim (Name Symbol) (Size Nat)]) :: Shape [Dim (Name Symbol) (Size Nat)] where ... Source #

type family UnsqueezeIndexDimsF (index :: Nat) (dims :: [Dim (Name Symbol) (Size Nat)]) :: [Dim (Name Symbol) (Size Nat)] where ... Source #

Equations

UnsqueezeIndexDimsF index dims = FromMaybe (TypeError (UnsqueezeByMessage ('ByIndex index) dims)) (InsertDimImplF ('ByIndex index) dims ('Dim ('Name "*") ('Size 1))) 

sUnsqueeze :: forall selectDim gradient layout device dataType shape shape' m. (shape' ~ UnsqueezeF selectDim shape, Catch shape', MonadThrow m) => SSelectDim selectDim -> Tensor gradient layout device dataType shape -> m (Tensor gradient layout device dataType shape') Source #

Unsqueezes a tensor with the specified dimension.

unsqueeze :: forall selectDim gradient layout device dataType shape shape' m. (shape' ~ UnsqueezeF selectDim shape, Catch shape', SingI selectDim, MonadThrow m) => Tensor gradient layout device dataType shape -> m (Tensor gradient layout device dataType shape') Source #

Unsqueezes a tensor with the specified dimension.

type family SqueezeAllDimsF dims where ... Source #

Equations

SqueezeAllDimsF '[] = 'Just '[] 
SqueezeAllDimsF ('Dim _ 'UncheckedSize ': dims) = 'Nothing 
SqueezeAllDimsF ('Dim _ ('Size 1) ': dims) = SqueezeAllDimsF dims 
SqueezeAllDimsF (dim ': dims) = PrependMaybe ('Just dim) (SqueezeAllDimsF dims) 

squeezeAll Source #

Arguments

:: forall gradient layout device dataType shape. Tensor gradient layout device dataType shape

input

-> Tensor gradient layout device dataType (SqueezeAllShapeF shape)

output

type family SqueezeDimByIndexF dimIndex dims where ... Source #

Equations

SqueezeDimByIndexF 0 (x ': xs) = If (UnifyCheck (Dim (Name Symbol) (Size Nat)) x ('Dim ('Name "*") ('Size 1))) ('Just xs) 'Nothing 
SqueezeDimByIndexF dimIndex (x ': xs) = PrependMaybe ('Just x) (SqueezeDimByIndexF (dimIndex - 1) xs) 
SqueezeDimByIndexF _ _ = 'Nothing 

type family SqueezeDimImplF by dims where ... Source #

Equations

SqueezeDimImplF ('ByName dimName) dims = SqueezeDimByNameF (GetIndexByNameF dimName dims) dims 
SqueezeDimImplF ('ByIndex dimIndex) dims = SqueezeDimByIndexF dimIndex dims 

type family SqueezeDimByNameF dimIndex dims where ... Source #

Equations

SqueezeDimByNameF 'Nothing dims = 'Nothing 
SqueezeDimByNameF ('Just dimIndex) dims = SqueezeDimByIndexF dimIndex dims 

type SqueezeDimMessage by dims = "Cannot squeeze the tensor with the dimensions" % ("" % (((" '" <> dims) <> "'") % ("" % ("at the dimension" % ("" % (((" '" <> by) <> "'.") % "")))))) Source #

type family SqueezeDimCheckF by dims result where ... Source #

Equations

SqueezeDimCheckF by dims 'Nothing = TypeError (SqueezeDimMessage by dims) 
SqueezeDimCheckF _ _ ('Just dims) = dims 

type family SqueezeDimF selectDim shape where ... Source #

Calculate the output shape of a squeeze along a given dimension

>>> :kind! SqueezeDimF ('SelectDim ('ByIndex 1)) ('Shape '[ 'Dim ('Name "*") ('Size 2), 'Dim ('Name "*") ('Size 1), 'Dim ('Name "*") ('Size 2)])
...
= 'Shape '[ 'Dim ('Name "*") ('Size 2), 'Dim ('Name "*") ('Size 2)]

sSqueezeDim :: forall selectDim gradient layout device dataType shape shape' m. (MonadThrow m, shape' ~ SqueezeDimF selectDim shape, Catch shape') => SSelectDim selectDim -> Tensor gradient layout device dataType shape -> m (Tensor gradient layout device dataType shape') Source #

Squeeze a particular dimension.

>>> t <- sOnes $ TensorSpec (SGradient SWithoutGradient) (SLayout SDense) (SDevice SCPU) (SDataType SFloat) (SShape $ SNoName :&: SSize @2 :|: SNoName :&: SSize @1 :|: SNoName :&: SSize @2 :|: SNoName :&: SSize @1 :|: SNoName :&: SSize @2 :|: SNil)
>>> result <- sSqueezeDim (SSelectDim $ SByIndex @1) t
>>> :t result
result
  :: Tensor
       ('Gradient 'WithoutGradient)
       ('Layout 'Dense)
       ('Device 'CPU)
       ('DataType 'Float)
       ('Shape
          '[ 'Dim ('Name "*") ('Size 2), 'Dim ('Name "*") ('Size 2),
             'Dim ('Name "*") ('Size 1), 'Dim ('Name "*") ('Size 2)])
>>> result
Tensor Float [2,2,1,2] [[[[ 1.0000   ,  1.0000   ]],
                         [[ 1.0000   ,  1.0000   ]]],
                        [[[ 1.0000   ,  1.0000   ]],
                         [[ 1.0000   ,  1.0000   ]]]]

sExpand Source #

Arguments

:: forall shape' shape'' gradient layout device dataType shape. (shape'' ~ BroadcastShapesF shape shape', Catch shape'') 
=> SShape shape'

new shape

-> Tensor gradient layout device dataType shape

input tensor

-> Tensor gradient layout device dataType shape''

output tensor

Expands a tensor to the specified shape.

expand :: forall shape' shape'' gradient layout device dataType shape. (SingI shape', shape'' ~ BroadcastShapesF shape shape', Catch shape'') => Tensor gradient layout device dataType shape -> Tensor gradient layout device dataType shape'' Source #

Expands a tensor to the specified shape.

sSelect :: forall selectDim index gradient layout device dataType shapeIn shapeOut m. (index `InRangeF` GetDimF selectDim shapeIn, shapeOut ~ RemoveDimF selectDim shapeIn, Catch shapeOut, SGetShape shapeIn, MonadThrow m) => SSelectDim selectDim -> SIndex index -> Tensor gradient layout device dataType shapeIn -> m (Tensor gradient layout device dataType shapeOut) Source #

Slices the self tensor along the selected dimension at the given index. This function returns a view of the original tensor with the given dimension removed.

>>> nats <- sArangeNaturals (SGradient SWithoutGradient) (SLayout SDense) (SDevice SCPU) (SDataType SInt32) (SSize @8)
>>> input <- sReshape (SShape $ SName @"*" :&: SSize @4 :|: SName @"*" :&: SSize @2 :|: SNil) nats
>>> input
Tensor Int32 [4,2] [[ 0,  1],
                    [ 2,  3],
                    [ 4,  5],
                    [ 6,  7]]

index can be provided at compile-time:

>>> sSelect (SSelectDim (SByIndex @0)) (SIndex @1) input
Tensor Int32 [2] [ 2,  3]

index can also be provided at runtime:

>>> sSelect (SSelectDim (SByIndex @0)) (SUncheckedIndex 1) input
Tensor Int32 [2] [ 2,  3]

It produces a runtime error if the index is too large:

>>> sSelect (SSelectDim (SByIndex @0)) (SUncheckedIndex 10) input
*** Exception: IndexOutOfBoundError {ioobeIndex = 10, ioobeDim = Dim {dimName = "*", dimSize = 4}}

select :: forall selectDim index gradient layout device dataType shapeIn shapeOut m. (SingI selectDim, SingI index, index `InRangeF` GetDimF selectDim shapeIn, shapeOut ~ RemoveDimF selectDim shapeIn, Catch shapeOut, SGetShape shapeIn, MonadThrow m) => Tensor gradient layout device dataType shapeIn -> m (Tensor gradient layout device dataType shapeOut) Source #

type family GatherDimImplF by indexDims inputDims where ... Source #

GatherDimImplF is a type-level helper function for sGatherDim.

>>> :kind! GatherDimImplF ('ByIndex 1) '[ 'Dim ('Name "batch") ('Size 2), 'Dim ('Name "sequence") ('Size 4), 'Dim ('Name "feature") ('Size 1)] '[ 'Dim ('Name "batch") ('Size 2), 'Dim ('Name "sequence") ('Size 1), 'Dim ('Name "feature") ('Size 1)]
...
= 'Just
    '[ 'Dim ('Name "batch") ('Size 2),
       'Dim ('Name "sequence") ('Size 4),
       'Dim ('Name "feature") ('Size 1)]
>>> :kind! GatherDimImplF ('ByIndex 1) '[ 'Dim ('Name "batch") ('Size 2), 'Dim ('Name "sequence") ('Size 4), 'Dim ('Name "feature") ('Size 1)] '[ 'Dim ('Name "*") ('Size 2), 'Dim ('Name "sequence") ('Size 1), 'Dim ('Name "*") ('Size 1)]
...
= 'Just
    '[ 'Dim ('Name "batch") ('Size 2),
       'Dim ('Name "sequence") ('Size 4),
       'Dim ('Name "feature") ('Size 1)]
>>> :kind! GatherDimImplF ('ByIndex 1) '[ 'Dim ('Name "batch") ('Size 2), 'Dim ('Name "sequence") ('Size 4), 'Dim ('Name "feature") ('Size 2)] '[ 'Dim ('Name "batch") ('Size 2), 'Dim ('Name "sequence") ('Size 1), 'Dim ('Name "feature") ('Size 1)]
...
= 'Nothing
>>> :kind! GatherDimImplF ('ByIndex 1) '[ 'Dim ('Name "batch") ('Size 2), 'Dim ('Name "sequence") ('Size 4), 'Dim ('Name "feature") ('Size 1)] '[ 'Dim ('Name "batch") ('Size 2), 'Dim ('Name "boo") ('Size 1), 'Dim ('Name "feature") ('Size 1)]
...
= 'Nothing
>>> :kind! GatherDimImplF ('ByIndex 1) '[ 'Dim ('Name "batch") ('Size 2), 'Dim ('Name "sequence") ('Size 0), 'Dim ('Name "feature") ('Size 1)] '[ 'Dim ('Name "batch") ('Size 2), 'Dim ('Name "sequence") ('Size 2), 'Dim ('Name "feature") ('Size 1)]
...
= 'Nothing
>>> :kind! GatherDimImplF ('ByIndex 1) '[ 'Dim ('Name "batch") ('Size 2), 'Dim ('Name "sequence") ('Size 1)] '[ 'Dim ('Name "batch") ('Size 2), 'Dim ('Name "sequence") ('Size 1), 'Dim ('Name "feature") ('Size 1)]
...
= 'Nothing
>>> :kind! GatherDimImplF ('ByIndex 2) '[ 'Dim ('Name "batch") ('Size 2), 'Dim ('Name "sequence") ('Size 1), 'Dim ('Name "feature") ('Size 3)] '[ 'Dim ('Name "batch") ('Size 2), 'Dim ('Name "sequence") ('Size 1), 'Dim ('Name "feature") ('Size 1)]
...
= 'Just
    '[ 'Dim ('Name "batch") ('Size 2),
       'Dim ('Name "sequence") ('Size 1),
       'Dim ('Name "feature") ('Size 3)]
>>> :kind! GatherDimImplF ('ByName "feature") '[ 'Dim ('Name "batch") ('Size 2), 'Dim ('Name "sequence") ('Size 1), 'Dim ('Name "feature") ('Size 3)] '[ 'Dim ('Name "batch") ('Size 2), 'Dim ('Name "sequence") ('Size 1), 'Dim ('Name "*") ('Size 1)]
...
= 'Just
    '[ 'Dim ('Name "batch") ('Size 2),
       'Dim ('Name "sequence") ('Size 1),
       'Dim ('Name "feature") ('Size 3)]

Equations

GatherDimImplF ('ByName dimName) indexDims inputDims = GatherDimByNameF (GetIndexByNameF dimName indexDims) (GetIndexByNameF dimName inputDims) indexDims inputDims 
GatherDimImplF ('ByIndex dimIndex) indexDims inputDims = GatherDimByIndexF dimIndex indexDims inputDims 

type family GatherDimByIndexF dimIndex indexDims inputDims where ... Source #

Equations

GatherDimByIndexF 0 ('Dim _ ('Size 0) ': _) _ = 'Nothing 
GatherDimByIndexF 0 ('Dim indexDimName indexDimSize ': indexDims) ('Dim inputDimName _ ': inputDims) = If (UnifyCheck (Name Symbol) indexDimName inputDimName && UnifyCheck [Dim (Name Symbol) (Size Nat)] indexDims inputDims) ('Just ('Dim (indexDimName <+> inputDimName) indexDimSize ': (indexDims <+> inputDims))) 'Nothing 
GatherDimByIndexF dimIndex (indexDim ': indexDims) (inputDim ': inputDims) = If (UnifyCheck (Dim (Name Symbol) (Size Nat)) indexDim inputDim) (PrependMaybe ('Just (indexDim <+> inputDim)) (GatherDimByIndexF (dimIndex - 1) indexDims inputDims)) 'Nothing 
GatherDimByIndexF _ _ _ = 'Nothing 

type family GatherDimByNameF dimIndex dimIndex' indexDims inputDims where ... Source #

Equations

GatherDimByNameF 'Nothing ('Just dimIndex') indexDims inputDims = GatherDimByIndexF dimIndex' indexDims inputDims 
GatherDimByNameF ('Just dimIndex) 'Nothing indexDims inputDims = GatherDimByIndexF dimIndex indexDims inputDims 
GatherDimByNameF ('Just dimIndex) ('Just dimIndex) indexDims inputDims = GatherDimByIndexF dimIndex indexDims inputDims 
GatherDimByNameF _ _ _ _ = 'Nothing 

type GatherDimMessage by indexDims inputDims = "Cannot gather the tensor with the dimensions" % ("" % (((" '" <> inputDims) <> "'") % ("" % ("at the dimension" % ("" % (((" '" <> by) <> "'") % ("" % ("using an index of shape" % ("" % (((" '" <> indexDims) <> "'.") % "")))))))))) Source #

type family GatherDimCheckF by indexDims inputDims result where ... Source #

Equations

GatherDimCheckF by indexDims inputDims 'Nothing = TypeError (GatherDimMessage by indexDims inputDims) 
GatherDimCheckF _ _ _ ('Just dims) = dims 

type family GatherDimF selectDim indexShape inputShape where ... Source #

Calculate the output shape of a gather operation for a given index shape along a given axis.

>>> :kind! GatherDimF ('SelectDim ('ByIndex 2)) ('Shape '[ 'Dim ('Name "*") ('Size 2), 'Dim ('Name "*") ('Size 1), 'Dim ('Name "*") ('Size 3)]) ('Shape '[ 'Dim ('Name "*") ('Size 2), 'Dim ('Name "*") ('Size 1), 'Dim ('Name "*") ('Size 1)])
...
= 'Shape
    '[ 'Dim ('Name "*") ('Size 2), 'Dim ('Name "*") ('Size 1),
       'Dim ('Name "*") ('Size 3)]

Equations

GatherDimF 'UncheckedSelectDim _ _ = 'UncheckedShape 
GatherDimF _ 'UncheckedShape _ = 'UncheckedShape 
GatherDimF _ _ 'UncheckedShape = 'UncheckedShape 
GatherDimF ('SelectDim by) ('Shape indexDims) ('Shape inputDims) = 'Shape (GatherDimCheckF by indexDims inputDims (GatherDimImplF by indexDims inputDims)) 

sGatherDim Source #

Arguments

:: forall selectDim indexGradient inputGradient indexLayout inputLayout indexDevice inputDevice indexDataType inputDataType indexShape inputShape outputShape m. (MonadThrow m, outputShape ~ GatherDimF selectDim indexShape inputShape, Catch outputShape, Catch (indexDataType <+> 'DataType 'Int64)) 
=> SSelectDim selectDim 
-> Tensor indexGradient indexLayout indexDevice indexDataType indexShape

the indices of elements to gather

-> Tensor inputGradient inputLayout inputDevice inputDataType inputShape

input

-> m (Tensor (indexGradient <|> inputGradient) (indexLayout <+> inputLayout) (indexDevice <+> inputDevice) inputDataType outputShape)

output

Gather values along an axis for a specified dimension.

>>> sToTensor' = sToTensor (SGradient SWithoutGradient) (SLayout SDense) (SDevice SCPU)
>>> t <- sToTensor' [[1 :: Float, 2], [3, 4]]
>>> idx <- sToTensor' [[0 :: Int, 0], [1, 0]]
>>> result <- sGatherDim (SSelectDim $ SByIndex @1) idx t
>>> :t result
result
  :: Tensor
       ('Gradient 'WithoutGradient)
       ('Layout 'Dense)
       ('Device 'CPU)
       ('DataType 'Float)
       ('Shape
          '[ 'Dim ('Name "*") 'UncheckedSize,
             'Dim ('Name "*") 'UncheckedSize])
>>> result
Tensor Float [2,2] [[ 1.0000   ,  1.0000   ],
                    [ 4.0000   ,  3.0000   ]]
>>> shape = SShape $ SNoName :&: SSize @2 :|: SNoName :&: SSize @2 :|: SNil
>>> t' <- sCheckedShape shape t
>>> idx' <- sCheckedShape shape idx
>>> result <- sGatherDim (SSelectDim $ SByIndex @1) idx' t'
>>> :t result
result
  :: Tensor
       ('Gradient 'WithoutGradient)
       ('Layout 'Dense)
       ('Device 'CPU)
       ('DataType 'Float)
       ('Shape '[ 'Dim ('Name "*") ('Size 2), 'Dim ('Name "*") ('Size 2)])
>>> result
Tensor Float [2,2] [[ 1.0000   ,  1.0000   ],
                    [ 4.0000   ,  3.0000   ]]