Safe Haskell | Safe-Inferred |
---|---|
Language | Haskell2010 |
Synopsis
- class HasCat (selectDim :: SelectDim (By Symbol Nat)) k (c :: k -> Type) (a :: k) where
- type CatF selectDim a c :: Type
- sCat :: forall m. MonadThrow m => SSelectDim selectDim -> c a -> m (CatF selectDim a c)
- cat :: forall m. (SingI selectDim, MonadThrow m) => c a -> m (CatF selectDim a c)
- type family CatListImplF (selectDim :: SelectDim (By Symbol Nat)) (tensor :: Type) :: Maybe Type where ...
- type CheckSpellingMessage = "Check the spelling of named dimensions, and make sure the number of dimensions is correct."
- type family CatListCheckF (selectDim :: SelectDim (By Symbol Nat)) (tensor :: Type) (result :: Maybe Type) :: Type where ...
- type CatListF selectDim tensor = CatListCheckF selectDim tensor (CatListImplF selectDim tensor)
- type family CatHListImplF selectDim tensors acc where ...
- type CatHListF selectDim tensors = CatHListImplF selectDim tensors 'Nothing
- type ReshapeNumelMismatchMessage numel numel' shape shape' = "Cannot reshape the tensor. The original shape," % ("" % (((" '" <> shape) <> "',") % ("" % ("and the new shape," % ("" % (((" '" <> shape') <> "',") % ("" % ("have different total numbers of elements," % ("" % (((((" '" <> numel) <> "' versus '") <> numel') <> "',") % ("" % "respectively.")))))))))))
- type family ReshapeImplF numel numel' shape shape' where ...
- type family ReshapeF shape shape' where ...
- sReshape :: forall m shape' gradient layout device dataType shape shape''. MonadThrow m => (shape'' ~ ReshapeF shape shape', Catch shape'') => SShape shape' -> Tensor gradient layout device dataType shape -> m (Tensor gradient layout device dataType shape'')
- sSetShape :: forall m shape' gradient layout device dataType shape shape''. MonadThrow m => (shape'' ~ ReshapeF shape shape', Catch shape'') => SShape shape' -> Tensor gradient layout device dataType shape -> m (Tensor gradient layout device dataType shape'')
- type family AllDimSizesChecked shape where ...
- reshape :: forall m shape' gradient layout device dataType shape shape''. (shape'' ~ ReshapeF shape shape', Catch shape'', When (AllDimSizesChecked shape) (shape' ~ shape''), SingI shape', MonadThrow m) => Tensor gradient layout device dataType shape -> m (Tensor gradient layout device dataType shape'')
- type TransposeBy0Message by0 dims = "Cannot transpose the tensor with the dimensions" % ("" % (((" '" <> dims) <> "'") % ("" % ("because the specified source dimension" % ("" % (((" '" <> by0) <> "'") % ("" % "could not be found.")))))))
- type TransposeBy1Message by1 dims = "Cannot transpose the tensor with the dimensions" % ("" % (((" '" <> dims) <> "'") % ("" % ("because the specified target dimension" % ("" % (((" '" <> by1) <> "'") % ("" % "could not be found.")))))))
- type family TransposeF selectDim0 selectDim1 shape where ...
- type family TransposeIndexIndexDimsF index0 index1 dims where ...
- sTranspose :: forall selectDim0 selectDim1 gradient layout device dataType shape shape' m. (shape' ~ TransposeF selectDim0 selectDim1 shape, Catch shape', MonadThrow m) => SSelectDim selectDim0 -> SSelectDim selectDim1 -> Tensor gradient layout device dataType shape -> m (Tensor gradient layout device dataType shape')
- data TransposeError = TransposeMixedSelectorsError {}
- transpose :: forall selectDim0 selectDim1 gradient layout device dataType shape shape' m. (shape' ~ TransposeF selectDim0 selectDim1 shape, Catch shape', SingI selectDim0, SingI selectDim1, MonadThrow m) => Tensor gradient layout device dataType shape -> m (Tensor gradient layout device dataType shape')
- type UnsqueezeByMessage (by :: By Symbol Nat) (dims :: [Dim (Name Symbol) (Size Nat)]) = "Cannot unsqueeze the tensor with the dimensions" % ("" % (((" '" <> dims) <> "'") % ("" % ("because the specified source dimension" % ("" % (((" '" <> by) <> "'") % ("" % "could not be found.")))))))
- type family UnsqueezeF (selectDim :: SelectDim (By Symbol Nat)) (shape :: Shape [Dim (Name Symbol) (Size Nat)]) :: Shape [Dim (Name Symbol) (Size Nat)] where ...
- type family UnsqueezeIndexDimsF (index :: Nat) (dims :: [Dim (Name Symbol) (Size Nat)]) :: [Dim (Name Symbol) (Size Nat)] where ...
- sUnsqueeze :: forall selectDim gradient layout device dataType shape shape' m. (shape' ~ UnsqueezeF selectDim shape, Catch shape', MonadThrow m) => SSelectDim selectDim -> Tensor gradient layout device dataType shape -> m (Tensor gradient layout device dataType shape')
- unsqueeze :: forall selectDim gradient layout device dataType shape shape' m. (shape' ~ UnsqueezeF selectDim shape, Catch shape', SingI selectDim, MonadThrow m) => Tensor gradient layout device dataType shape -> m (Tensor gradient layout device dataType shape')
- type family SqueezeAllShapeF shape where ...
- type family SqueezeAllDimsF dims where ...
- squeezeAll :: forall gradient layout device dataType shape. Tensor gradient layout device dataType shape -> Tensor gradient layout device dataType (SqueezeAllShapeF shape)
- type family SqueezeDimByIndexF dimIndex dims where ...
- type family SqueezeDimImplF by dims where ...
- type family SqueezeDimByNameF dimIndex dims where ...
- type SqueezeDimMessage by dims = "Cannot squeeze the tensor with the dimensions" % ("" % (((" '" <> dims) <> "'") % ("" % ("at the dimension" % ("" % (((" '" <> by) <> "'.") % ""))))))
- type family SqueezeDimCheckF by dims result where ...
- type family SqueezeDimF selectDim shape where ...
- sSqueezeDim :: forall selectDim gradient layout device dataType shape shape' m. (MonadThrow m, shape' ~ SqueezeDimF selectDim shape, Catch shape') => SSelectDim selectDim -> Tensor gradient layout device dataType shape -> m (Tensor gradient layout device dataType shape')
- sExpand :: forall shape' shape'' gradient layout device dataType shape. (shape'' ~ BroadcastShapesF shape shape', Catch shape'') => SShape shape' -> Tensor gradient layout device dataType shape -> Tensor gradient layout device dataType shape''
- expand :: forall shape' shape'' gradient layout device dataType shape. (SingI shape', shape'' ~ BroadcastShapesF shape shape', Catch shape'') => Tensor gradient layout device dataType shape -> Tensor gradient layout device dataType shape''
- sSelect :: forall selectDim index gradient layout device dataType shapeIn shapeOut m. (index `InRangeF` GetDimF selectDim shapeIn, shapeOut ~ RemoveDimF selectDim shapeIn, Catch shapeOut, SGetShape shapeIn, MonadThrow m) => SSelectDim selectDim -> SIndex index -> Tensor gradient layout device dataType shapeIn -> m (Tensor gradient layout device dataType shapeOut)
- data IndexOutOfBoundError = IndexOutOfBoundError {}
- select :: forall selectDim index gradient layout device dataType shapeIn shapeOut m. (SingI selectDim, SingI index, index `InRangeF` GetDimF selectDim shapeIn, shapeOut ~ RemoveDimF selectDim shapeIn, Catch shapeOut, SGetShape shapeIn, MonadThrow m) => Tensor gradient layout device dataType shapeIn -> m (Tensor gradient layout device dataType shapeOut)
- type family GatherDimImplF by indexDims inputDims where ...
- type family GatherDimByIndexF dimIndex indexDims inputDims where ...
- type family GatherDimByNameF dimIndex dimIndex' indexDims inputDims where ...
- type GatherDimMessage by indexDims inputDims = "Cannot gather the tensor with the dimensions" % ("" % (((" '" <> inputDims) <> "'") % ("" % ("at the dimension" % ("" % (((" '" <> by) <> "'") % ("" % ("using an index of shape" % ("" % (((" '" <> indexDims) <> "'.") % ""))))))))))
- type family GatherDimCheckF by indexDims inputDims result where ...
- type family GatherDimF selectDim indexShape inputShape where ...
- sGatherDim :: forall selectDim indexGradient inputGradient indexLayout inputLayout indexDevice inputDevice indexDataType inputDataType indexShape inputShape outputShape m. (MonadThrow m, outputShape ~ GatherDimF selectDim indexShape inputShape, Catch outputShape, Catch (indexDataType <+> 'DataType 'Int64)) => SSelectDim selectDim -> Tensor indexGradient indexLayout indexDevice indexDataType indexShape -> Tensor inputGradient inputLayout inputDevice inputDataType inputShape -> m (Tensor (indexGradient <|> inputGradient) (indexLayout <+> inputLayout) (indexDevice <+> inputDevice) inputDataType outputShape)
Documentation
>>>
import Torch.GraduallyTyped.Prelude.List (SList (..))
>>>
import Torch.GraduallyTyped
class HasCat (selectDim :: SelectDim (By Symbol Nat)) k (c :: k -> Type) (a :: k) where Source #
sCat :: forall m. MonadThrow m => SSelectDim selectDim -> c a -> m (CatF selectDim a c) Source #
Concatenates the given sequence of seq tensors in the given dimension. All tensors must either have the same shape (except in the concatenating dimension) or be empty.
>>>
t <- ones @('Gradient 'WithGradient) @('Layout 'Dense) @('Device 'CPU) @('DataType 'Float) @('Shape '[ 'Dim ('Name "batch") ('Size 32), 'Dim ('Name "feature") ('Size 8)])
>>>
:type cat @('SelectDim ('ByName "feature")) [t]
cat @('SelectDim ('ByName "feature")) [t] :: MonadThrow m => m (Tensor ('Gradient 'WithGradient) ('Layout 'Dense) ('Device 'CPU) ('DataType 'Float) ('Shape '[ 'Dim ('Name "batch") ('Size 32), 'Dim 'UncheckedName 'UncheckedSize]))>>>
:type cat @('SelectDim ( 'ByIndex 0)) [t]
cat @('SelectDim ( 'ByIndex 0)) [t] :: MonadThrow m => m (Tensor ('Gradient 'WithGradient) ('Layout 'Dense) ('Device 'CPU) ('DataType 'Float) ('Shape '[ 'Dim 'UncheckedName 'UncheckedSize, 'Dim ('Name "feature") ('Size 8)]))>>>
:type sCat (SUncheckedSelectDim (ByIndex 0)) [t]
sCat (SUncheckedSelectDim (ByIndex 0)) [t] :: MonadThrow m => m (Tensor ('Gradient 'WithGradient) ('Layout 'Dense) ('Device 'CPU) ('DataType 'Float) 'UncheckedShape)
cat :: forall m. (SingI selectDim, MonadThrow m) => c a -> m (CatF selectDim a c) Source #
Instances
Castable (CatListF selectDim (Tensor gradient layout device dataType shape)) (ForeignPtr Tensor) => HasCat selectDim Type [] (Tensor gradient layout device dataType shape :: TYPE LiftedRep) Source # | |
Defined in Torch.GraduallyTyped.Tensor.IndexingSlicingJoining sCat :: MonadThrow m => SSelectDim selectDim -> [Tensor gradient layout device dataType shape] -> m (CatF selectDim (Tensor gradient layout device dataType shape) []) Source # cat :: (SingI selectDim, MonadThrow m) => [Tensor gradient layout device dataType shape] -> m (CatF selectDim (Tensor gradient layout device dataType shape) []) Source # | |
(Castable (CatHListF selectDim tensors) (ForeignPtr Tensor), Castable (HList tensors) (ForeignPtr TensorList)) => HasCat selectDim [Type] (HList :: [Type] -> Type) (tensors :: [Type]) Source # | |
type family CatListImplF (selectDim :: SelectDim (By Symbol Nat)) (tensor :: Type) :: Maybe Type where ... Source #
CatListImplF 'UncheckedSelectDim (Tensor gradient layout device dataType _) = 'Just (Tensor gradient layout device dataType 'UncheckedShape) | |
CatListImplF ('SelectDim _) (Tensor gradient layout device dataType 'UncheckedShape) = 'Just (Tensor gradient layout device dataType 'UncheckedShape) | |
CatListImplF ('SelectDim by) (Tensor gradient layout device dataType ('Shape dims)) = MapMaybe (Tensor gradient layout device dataType) (MapMaybe 'Shape (ReplaceDimImplF by dims ('Dim 'UncheckedName 'UncheckedSize))) |
type CheckSpellingMessage = "Check the spelling of named dimensions, and make sure the number of dimensions is correct." Source #
type family CatListCheckF (selectDim :: SelectDim (By Symbol Nat)) (tensor :: Type) (result :: Maybe Type) :: Type where ... Source #
type CatListF selectDim tensor = CatListCheckF selectDim tensor (CatListImplF selectDim tensor) Source #
type family CatHListImplF selectDim tensors acc where ... Source #
CatHListImplF _ '[] 'Nothing = TypeError (ToErrorMessage "Cannot concatenate an empty list of tensors.") | |
CatHListImplF _ '[] ('Just '(gradient, layout, device, dataType, shape)) = Tensor gradient layout device dataType shape | |
CatHListImplF selectDim (Tensor gradient layout device dataType shape ': tensors) 'Nothing = CatHListImplF selectDim tensors ('Just '(gradient, layout, device, dataType, shape)) | |
CatHListImplF selectDim (Tensor gradient layout device dataType shape ': tensors) ('Just '(gradient', layout', device', dataType', shape')) = CatHListImplF selectDim tensors ('Just '(gradient <|> gradient', layout <+> layout', device <+> device', dataType <+> dataType', ReplaceDimF selectDim (shape <+> ReplaceDimF selectDim shape' (GetDimF selectDim shape)) (AddDimF (GetDimF selectDim shape) (GetDimF selectDim shape')))) | |
CatHListImplF _ (x ': _) _ = TypeError ("Cannot concatenate because" % ("" % (((" '" <> x) <> "'") % ("" % "is not a tensor type.")))) |
type CatHListF selectDim tensors = CatHListImplF selectDim tensors 'Nothing Source #
type ReshapeNumelMismatchMessage numel numel' shape shape' = "Cannot reshape the tensor. The original shape," % ("" % (((" '" <> shape) <> "',") % ("" % ("and the new shape," % ("" % (((" '" <> shape') <> "',") % ("" % ("have different total numbers of elements," % ("" % (((((" '" <> numel) <> "' versus '") <> numel') <> "',") % ("" % "respectively."))))))))))) Source #
type family ReshapeImplF numel numel' shape shape' where ... Source #
ReshapeImplF ('Just numel) ('Just numel) _ shape' = shape' | |
ReshapeImplF ('Just numel) ('Just numel') shape shape' = TypeError (ReshapeNumelMismatchMessage numel numel' shape shape') | |
ReshapeImplF 'Nothing _ _ _ = 'UncheckedShape | |
ReshapeImplF _ 'Nothing _ _ = 'UncheckedShape | |
ReshapeImplF _ _ 'UncheckedShape _ = 'UncheckedShape | |
ReshapeImplF _ _ _ 'UncheckedShape = 'UncheckedShape |
type family ReshapeF shape shape' where ... Source #
ReshapeF shape shape' = ReshapeImplF (NumelF shape) (NumelF shape') shape shape' |
sReshape :: forall m shape' gradient layout device dataType shape shape''. MonadThrow m => (shape'' ~ ReshapeF shape shape', Catch shape'') => SShape shape' -> Tensor gradient layout device dataType shape -> m (Tensor gradient layout device dataType shape'') Source #
Returns a tensor with the same data and number of elements as the input tensor, but with the specified shape:
>>>
g <- sMkGenerator (SDevice SCPU) 0
>>>
(input, _) <- sRandn (TensorSpec (SGradient SWithGradient) (SLayout SDense) (SDevice SCPU) (SDataType SFloat) (SShape $ SName @"*" :&: SSize @4 :|: SNil)) g
>>>
output <- sReshape (SShape $ SName @"*" :&: SSize @2 :|: SName @"*" :&: SSize @2 :|: SNil) input
>>>
:type output
output :: Tensor ('Gradient 'WithGradient) ('Layout 'Dense) ('Device 'CPU) ('DataType 'Float) ('Shape '[ 'Dim ('Name "*") ('Size 2), 'Dim ('Name "*") ('Size 2)])
At the value level, a single dimension may be '-1', in which case it is inferred from the remaining dimensions and the number of elements in the input:
>>>
output' <- sReshape (SShape $ SUncheckedName "*" :&: SUncheckedSize (-1) :|: SNil) output
>>>
:type output'
output' :: Tensor ('Gradient 'WithGradient) ('Layout 'Dense) ('Device 'CPU) ('DataType 'Float) 'UncheckedShape>>>
getDims output'
[Dim {dimName = "*", dimSize = 4}]
sSetShape :: forall m shape' gradient layout device dataType shape shape''. MonadThrow m => (shape'' ~ ReshapeF shape shape', Catch shape'') => SShape shape' -> Tensor gradient layout device dataType shape -> m (Tensor gradient layout device dataType shape'') Source #
Returns a tensor with the same data and number of elements as the input tensor, but with the specified shape:
>>>
g <- sMkGenerator (SDevice SCPU) 0
>>>
(input, _) <- sRandn (TensorSpec (SGradient SWithGradient) (SLayout SDense) (SDevice SCPU) (SDataType SFloat) (SShape $ SName @"*" :&: SSize @4 :|: SNil)) g
>>>
output <- sReshape (SShape $ SName @"*" :&: SSize @2 :|: SName @"*" :&: SSize @2 :|: SNil) input
>>>
:type output
output :: Tensor ('Gradient 'WithGradient) ('Layout 'Dense) ('Device 'CPU) ('DataType 'Float) ('Shape '[ 'Dim ('Name "*") ('Size 2), 'Dim ('Name "*") ('Size 2)])
At the value level, a single dimension may be '-1', in which case it is inferred from the remaining dimensions and the number of elements in the input:
>>>
output' <- sReshape (SShape $ SUncheckedName "*" :&: SUncheckedSize (-1) :|: SNil) output
>>>
:type output'
output' :: Tensor ('Gradient 'WithGradient) ('Layout 'Dense) ('Device 'CPU) ('DataType 'Float) 'UncheckedShape>>>
getDims output'
[Dim {dimName = "*", dimSize = 4}]
type family AllDimSizesChecked shape where ... Source #
AllDimSizesChecked 'UncheckedShape = 'False | |
AllDimSizesChecked ('Shape '[]) = 'True | |
AllDimSizesChecked ('Shape ('Dim name ('Size size) ': xs)) = AllDimSizesChecked ('Shape xs) |
reshape :: forall m shape' gradient layout device dataType shape shape''. (shape'' ~ ReshapeF shape shape', Catch shape'', When (AllDimSizesChecked shape) (shape' ~ shape''), SingI shape', MonadThrow m) => Tensor gradient layout device dataType shape -> m (Tensor gradient layout device dataType shape'') Source #
type TransposeBy0Message by0 dims = "Cannot transpose the tensor with the dimensions" % ("" % (((" '" <> dims) <> "'") % ("" % ("because the specified source dimension" % ("" % (((" '" <> by0) <> "'") % ("" % "could not be found."))))))) Source #
type TransposeBy1Message by1 dims = "Cannot transpose the tensor with the dimensions" % ("" % (((" '" <> dims) <> "'") % ("" % ("because the specified target dimension" % ("" % (((" '" <> by1) <> "'") % ("" % "could not be found."))))))) Source #
type family TransposeF selectDim0 selectDim1 shape where ... Source #
Compute transposed shapes.
>>>
type SelectBatch = 'SelectDim ('ByName "batch" :: By Symbol Nat)
>>>
type SelectFeature = 'SelectDim ('ByName "feature" :: By Symbol Nat)
>>>
type Dims = '[ 'Dim ('Name "batch") ('Size 10), 'Dim ('Name "feature") ('Size 8), 'Dim ('Name "anotherFeature") ('Size 12)]
>>>
:kind! TransposeF SelectBatch SelectFeature ('Shape Dims)
TransposeF SelectBatch SelectFeature ('Shape Dims) :: Shape [Dim (Name Symbol) (Size Natural)] = 'Shape '[ 'Dim ('Name "feature") ('Size 8), 'Dim ('Name "batch") ('Size 10), 'Dim ('Name "anotherFeature") ('Size 12)]>>>
:kind! TransposeF SelectFeature SelectBatch ('Shape Dims)
TransposeF SelectFeature SelectBatch ('Shape Dims) :: Shape [Dim (Name Symbol) (Size Natural)] = 'Shape '[ 'Dim ('Name "feature") ('Size 8), 'Dim ('Name "batch") ('Size 10), 'Dim ('Name "anotherFeature") ('Size 12)]
TransposeF _ _ 'UncheckedShape = 'UncheckedShape | |
TransposeF _ 'UncheckedSelectDim _ = 'UncheckedShape | |
TransposeF 'UncheckedSelectDim _ _ = 'UncheckedShape | |
TransposeF ('SelectDim ('ByName name0)) ('SelectDim ('ByName name1)) ('Shape dims) = 'Shape (TransposeIndexIndexDimsF (FromMaybe (TypeError (TransposeBy0Message ('ByName name0) dims)) (GetIndexByNameF name0 dims)) (FromMaybe (TypeError (TransposeBy1Message ('ByName name1) dims)) (GetIndexByNameF name1 dims)) dims) | |
TransposeF ('SelectDim ('ByIndex index0)) ('SelectDim ('ByIndex index1)) ('Shape dims) = 'Shape (TransposeIndexIndexDimsF index0 index1 dims) | |
TransposeF ('SelectDim by0) ('SelectDim by1) _ = TypeError ("Cannot transpose the tensor. " % ("" % ("The source and target dimensions must be selected either both by name or both by index, " % ("but mixed selectors were found: " % ("" % (((((" '" <> 'SelectDim by0) <> "' and '") <> 'SelectDim by1) <> "'.") % "")))))) |
type family TransposeIndexIndexDimsF index0 index1 dims where ... Source #
TransposeIndexIndexDimsF index0 index1 dims = FromMaybe (TypeError (TransposeBy1Message ('ByIndex index1) dims)) (ReplaceDimImplF ('ByIndex index1) (FromMaybe (TypeError (TransposeBy0Message ('ByIndex index0) dims)) (ReplaceDimImplF ('ByIndex index0) dims (FromMaybe (TypeError (TransposeBy1Message ('ByIndex index1) dims)) (GetDimImplF ('ByIndex index1) dims)))) (FromMaybe (TypeError (TransposeBy0Message ('ByIndex index0) dims)) (GetDimImplF ('ByIndex index0) dims))) |
sTranspose :: forall selectDim0 selectDim1 gradient layout device dataType shape shape' m. (shape' ~ TransposeF selectDim0 selectDim1 shape, Catch shape', MonadThrow m) => SSelectDim selectDim0 -> SSelectDim selectDim1 -> Tensor gradient layout device dataType shape -> m (Tensor gradient layout device dataType shape') Source #
Returns a tensor that is a transposed version of input
.
The selected dimensions selectDim0
and selectDim1
are swapped.
>>>
g <- sMkGenerator (SDevice SCPU) 0
>>>
(input, _) <- sRandn (TensorSpec (SGradient SWithGradient) (SLayout SDense) (SDevice SCPU) (SDataType SFloat) (SShape $ SName @"batch" :&: SSize @10 :|: SName @"feature" :&: SSize @5 :|: SNil)) g
>>>
output <- sTranspose (SSelectDim (SByName @"batch")) (SSelectDim (SByName @"feature")) input
>>>
:type output
output :: Tensor ('Gradient 'WithGradient) ('Layout 'Dense) ('Device 'CPU) ('DataType 'Float) ('Shape '[ 'Dim ('Name "feature") ('Size 5), 'Dim ('Name "batch") ('Size 10)])>>>
output <- sTranspose (SUncheckedSelectDim (ByIndex 0)) (SSelectDim (SByIndex @1)) input
>>>
:type output
output :: Tensor ('Gradient 'WithGradient) ('Layout 'Dense) ('Device 'CPU) ('DataType 'Float) 'UncheckedShape>>>
getDims output
[Dim {dimName = "feature", dimSize = 5},Dim {dimName = "batch", dimSize = 10}]
data TransposeError Source #
Instances
transpose :: forall selectDim0 selectDim1 gradient layout device dataType shape shape' m. (shape' ~ TransposeF selectDim0 selectDim1 shape, Catch shape', SingI selectDim0, SingI selectDim1, MonadThrow m) => Tensor gradient layout device dataType shape -> m (Tensor gradient layout device dataType shape') Source #
type UnsqueezeByMessage (by :: By Symbol Nat) (dims :: [Dim (Name Symbol) (Size Nat)]) = "Cannot unsqueeze the tensor with the dimensions" % ("" % (((" '" <> dims) <> "'") % ("" % ("because the specified source dimension" % ("" % (((" '" <> by) <> "'") % ("" % "could not be found."))))))) Source #
type family UnsqueezeF (selectDim :: SelectDim (By Symbol Nat)) (shape :: Shape [Dim (Name Symbol) (Size Nat)]) :: Shape [Dim (Name Symbol) (Size Nat)] where ... Source #
UnsqueezeF _ 'UncheckedShape = 'UncheckedShape | |
UnsqueezeF 'UncheckedSelectDim _ = 'UncheckedShape | |
UnsqueezeF ('SelectDim ('ByName name)) ('Shape dims) = 'Shape (UnsqueezeIndexDimsF (FromMaybe (TypeError (UnsqueezeByMessage ('ByName name) dims)) (GetIndexByNameF name dims)) dims) | |
UnsqueezeF ('SelectDim ('ByIndex index)) ('Shape dims) = 'Shape (UnsqueezeIndexDimsF index dims) |
type family UnsqueezeIndexDimsF (index :: Nat) (dims :: [Dim (Name Symbol) (Size Nat)]) :: [Dim (Name Symbol) (Size Nat)] where ... Source #
UnsqueezeIndexDimsF index dims = FromMaybe (TypeError (UnsqueezeByMessage ('ByIndex index) dims)) (InsertDimImplF ('ByIndex index) dims ('Dim ('Name "*") ('Size 1))) |
sUnsqueeze :: forall selectDim gradient layout device dataType shape shape' m. (shape' ~ UnsqueezeF selectDim shape, Catch shape', MonadThrow m) => SSelectDim selectDim -> Tensor gradient layout device dataType shape -> m (Tensor gradient layout device dataType shape') Source #
Unsqueezes a tensor with the specified dimension.
unsqueeze :: forall selectDim gradient layout device dataType shape shape' m. (shape' ~ UnsqueezeF selectDim shape, Catch shape', SingI selectDim, MonadThrow m) => Tensor gradient layout device dataType shape -> m (Tensor gradient layout device dataType shape') Source #
Unsqueezes a tensor with the specified dimension.
type family SqueezeAllShapeF shape where ... Source #
SqueezeAllShapeF 'UncheckedShape = 'UncheckedShape | |
SqueezeAllShapeF ('Shape dims) = MaybeF 'UncheckedShape 'Shape (SqueezeAllDimsF dims) |
type family SqueezeAllDimsF dims where ... Source #
SqueezeAllDimsF '[] = 'Just '[] | |
SqueezeAllDimsF ('Dim _ 'UncheckedSize ': dims) = 'Nothing | |
SqueezeAllDimsF ('Dim _ ('Size 1) ': dims) = SqueezeAllDimsF dims | |
SqueezeAllDimsF (dim ': dims) = PrependMaybe ('Just dim) (SqueezeAllDimsF dims) |
:: forall gradient layout device dataType shape. Tensor gradient layout device dataType shape | input |
-> Tensor gradient layout device dataType (SqueezeAllShapeF shape) | output |
type family SqueezeDimByIndexF dimIndex dims where ... Source #
SqueezeDimByIndexF 0 (x ': xs) = If (UnifyCheck (Dim (Name Symbol) (Size Nat)) x ('Dim ('Name "*") ('Size 1))) ('Just xs) 'Nothing | |
SqueezeDimByIndexF dimIndex (x ': xs) = PrependMaybe ('Just x) (SqueezeDimByIndexF (dimIndex - 1) xs) | |
SqueezeDimByIndexF _ _ = 'Nothing |
type family SqueezeDimImplF by dims where ... Source #
SqueezeDimImplF ('ByName dimName) dims = SqueezeDimByNameF (GetIndexByNameF dimName dims) dims | |
SqueezeDimImplF ('ByIndex dimIndex) dims = SqueezeDimByIndexF dimIndex dims |
type family SqueezeDimByNameF dimIndex dims where ... Source #
SqueezeDimByNameF 'Nothing dims = 'Nothing | |
SqueezeDimByNameF ('Just dimIndex) dims = SqueezeDimByIndexF dimIndex dims |
type SqueezeDimMessage by dims = "Cannot squeeze the tensor with the dimensions" % ("" % (((" '" <> dims) <> "'") % ("" % ("at the dimension" % ("" % (((" '" <> by) <> "'.") % "")))))) Source #
type family SqueezeDimCheckF by dims result where ... Source #
SqueezeDimCheckF by dims 'Nothing = TypeError (SqueezeDimMessage by dims) | |
SqueezeDimCheckF _ _ ('Just dims) = dims |
type family SqueezeDimF selectDim shape where ... Source #
Calculate the output shape of a squeeze along a given dimension
>>>
:kind! SqueezeDimF ('SelectDim ('ByIndex 1)) ('Shape '[ 'Dim ('Name "*") ('Size 2), 'Dim ('Name "*") ('Size 1), 'Dim ('Name "*") ('Size 2)])
... = 'Shape '[ 'Dim ('Name "*") ('Size 2), 'Dim ('Name "*") ('Size 2)]
SqueezeDimF 'UncheckedSelectDim _ = 'UncheckedShape | |
SqueezeDimF _ 'UncheckedShape = 'UncheckedShape | |
SqueezeDimF ('SelectDim by) ('Shape dims) = 'Shape (SqueezeDimCheckF by dims (SqueezeDimImplF by dims)) |
sSqueezeDim :: forall selectDim gradient layout device dataType shape shape' m. (MonadThrow m, shape' ~ SqueezeDimF selectDim shape, Catch shape') => SSelectDim selectDim -> Tensor gradient layout device dataType shape -> m (Tensor gradient layout device dataType shape') Source #
Squeeze a particular dimension.
>>>
t <- sOnes $ TensorSpec (SGradient SWithoutGradient) (SLayout SDense) (SDevice SCPU) (SDataType SFloat) (SShape $ SNoName :&: SSize @2 :|: SNoName :&: SSize @1 :|: SNoName :&: SSize @2 :|: SNoName :&: SSize @1 :|: SNoName :&: SSize @2 :|: SNil)
>>>
result <- sSqueezeDim (SSelectDim $ SByIndex @1) t
>>>
:t result
result :: Tensor ('Gradient 'WithoutGradient) ('Layout 'Dense) ('Device 'CPU) ('DataType 'Float) ('Shape '[ 'Dim ('Name "*") ('Size 2), 'Dim ('Name "*") ('Size 2), 'Dim ('Name "*") ('Size 1), 'Dim ('Name "*") ('Size 2)])>>>
result
Tensor Float [2,2,1,2] [[[[ 1.0000 , 1.0000 ]], [[ 1.0000 , 1.0000 ]]], [[[ 1.0000 , 1.0000 ]], [[ 1.0000 , 1.0000 ]]]]
:: forall shape' shape'' gradient layout device dataType shape. (shape'' ~ BroadcastShapesF shape shape', Catch shape'') | |
=> SShape shape' | new shape |
-> Tensor gradient layout device dataType shape | input tensor |
-> Tensor gradient layout device dataType shape'' | output tensor |
Expands a tensor to the specified shape.
expand :: forall shape' shape'' gradient layout device dataType shape. (SingI shape', shape'' ~ BroadcastShapesF shape shape', Catch shape'') => Tensor gradient layout device dataType shape -> Tensor gradient layout device dataType shape'' Source #
Expands a tensor to the specified shape.
sSelect :: forall selectDim index gradient layout device dataType shapeIn shapeOut m. (index `InRangeF` GetDimF selectDim shapeIn, shapeOut ~ RemoveDimF selectDim shapeIn, Catch shapeOut, SGetShape shapeIn, MonadThrow m) => SSelectDim selectDim -> SIndex index -> Tensor gradient layout device dataType shapeIn -> m (Tensor gradient layout device dataType shapeOut) Source #
Slices the self tensor along the selected dimension at the given index. This function returns a view of the original tensor with the given dimension removed.
>>>
nats <- sArangeNaturals (SGradient SWithoutGradient) (SLayout SDense) (SDevice SCPU) (SDataType SInt32) (SSize @8)
>>>
input <- sReshape (SShape $ SName @"*" :&: SSize @4 :|: SName @"*" :&: SSize @2 :|: SNil) nats
>>>
input
Tensor Int32 [4,2] [[ 0, 1], [ 2, 3], [ 4, 5], [ 6, 7]]
index
can be provided at compile-time:
>>>
sSelect (SSelectDim (SByIndex @0)) (SIndex @1) input
Tensor Int32 [2] [ 2, 3]
index
can also be provided at runtime:
>>>
sSelect (SSelectDim (SByIndex @0)) (SUncheckedIndex 1) input
Tensor Int32 [2] [ 2, 3]
It produces a runtime error if the index
is too large:
>>>
sSelect (SSelectDim (SByIndex @0)) (SUncheckedIndex 10) input
*** Exception: IndexOutOfBoundError {ioobeIndex = 10, ioobeDim = Dim {dimName = "*", dimSize = 4}}
data IndexOutOfBoundError Source #
select :: forall selectDim index gradient layout device dataType shapeIn shapeOut m. (SingI selectDim, SingI index, index `InRangeF` GetDimF selectDim shapeIn, shapeOut ~ RemoveDimF selectDim shapeIn, Catch shapeOut, SGetShape shapeIn, MonadThrow m) => Tensor gradient layout device dataType shapeIn -> m (Tensor gradient layout device dataType shapeOut) Source #
type family GatherDimImplF by indexDims inputDims where ... Source #
GatherDimImplF
is a type-level helper function for sGatherDim
.
>>>
:kind! GatherDimImplF ('ByIndex 1) '[ 'Dim ('Name "batch") ('Size 2), 'Dim ('Name "sequence") ('Size 4), 'Dim ('Name "feature") ('Size 1)] '[ 'Dim ('Name "batch") ('Size 2), 'Dim ('Name "sequence") ('Size 1), 'Dim ('Name "feature") ('Size 1)]
... = 'Just '[ 'Dim ('Name "batch") ('Size 2), 'Dim ('Name "sequence") ('Size 4), 'Dim ('Name "feature") ('Size 1)]>>>
:kind! GatherDimImplF ('ByIndex 1) '[ 'Dim ('Name "batch") ('Size 2), 'Dim ('Name "sequence") ('Size 4), 'Dim ('Name "feature") ('Size 1)] '[ 'Dim ('Name "*") ('Size 2), 'Dim ('Name "sequence") ('Size 1), 'Dim ('Name "*") ('Size 1)]
... = 'Just '[ 'Dim ('Name "batch") ('Size 2), 'Dim ('Name "sequence") ('Size 4), 'Dim ('Name "feature") ('Size 1)]>>>
:kind! GatherDimImplF ('ByIndex 1) '[ 'Dim ('Name "batch") ('Size 2), 'Dim ('Name "sequence") ('Size 4), 'Dim ('Name "feature") ('Size 2)] '[ 'Dim ('Name "batch") ('Size 2), 'Dim ('Name "sequence") ('Size 1), 'Dim ('Name "feature") ('Size 1)]
... = 'Nothing>>>
:kind! GatherDimImplF ('ByIndex 1) '[ 'Dim ('Name "batch") ('Size 2), 'Dim ('Name "sequence") ('Size 4), 'Dim ('Name "feature") ('Size 1)] '[ 'Dim ('Name "batch") ('Size 2), 'Dim ('Name "boo") ('Size 1), 'Dim ('Name "feature") ('Size 1)]
... = 'Nothing>>>
:kind! GatherDimImplF ('ByIndex 1) '[ 'Dim ('Name "batch") ('Size 2), 'Dim ('Name "sequence") ('Size 0), 'Dim ('Name "feature") ('Size 1)] '[ 'Dim ('Name "batch") ('Size 2), 'Dim ('Name "sequence") ('Size 2), 'Dim ('Name "feature") ('Size 1)]
... = 'Nothing>>>
:kind! GatherDimImplF ('ByIndex 1) '[ 'Dim ('Name "batch") ('Size 2), 'Dim ('Name "sequence") ('Size 1)] '[ 'Dim ('Name "batch") ('Size 2), 'Dim ('Name "sequence") ('Size 1), 'Dim ('Name "feature") ('Size 1)]
... = 'Nothing>>>
:kind! GatherDimImplF ('ByIndex 2) '[ 'Dim ('Name "batch") ('Size 2), 'Dim ('Name "sequence") ('Size 1), 'Dim ('Name "feature") ('Size 3)] '[ 'Dim ('Name "batch") ('Size 2), 'Dim ('Name "sequence") ('Size 1), 'Dim ('Name "feature") ('Size 1)]
... = 'Just '[ 'Dim ('Name "batch") ('Size 2), 'Dim ('Name "sequence") ('Size 1), 'Dim ('Name "feature") ('Size 3)]>>>
:kind! GatherDimImplF ('ByName "feature") '[ 'Dim ('Name "batch") ('Size 2), 'Dim ('Name "sequence") ('Size 1), 'Dim ('Name "feature") ('Size 3)] '[ 'Dim ('Name "batch") ('Size 2), 'Dim ('Name "sequence") ('Size 1), 'Dim ('Name "*") ('Size 1)]
... = 'Just '[ 'Dim ('Name "batch") ('Size 2), 'Dim ('Name "sequence") ('Size 1), 'Dim ('Name "feature") ('Size 3)]
GatherDimImplF ('ByName dimName) indexDims inputDims = GatherDimByNameF (GetIndexByNameF dimName indexDims) (GetIndexByNameF dimName inputDims) indexDims inputDims | |
GatherDimImplF ('ByIndex dimIndex) indexDims inputDims = GatherDimByIndexF dimIndex indexDims inputDims |
type family GatherDimByIndexF dimIndex indexDims inputDims where ... Source #
GatherDimByIndexF 0 ('Dim _ ('Size 0) ': _) _ = 'Nothing | |
GatherDimByIndexF 0 ('Dim indexDimName indexDimSize ': indexDims) ('Dim inputDimName _ ': inputDims) = If (UnifyCheck (Name Symbol) indexDimName inputDimName && UnifyCheck [Dim (Name Symbol) (Size Nat)] indexDims inputDims) ('Just ('Dim (indexDimName <+> inputDimName) indexDimSize ': (indexDims <+> inputDims))) 'Nothing | |
GatherDimByIndexF dimIndex (indexDim ': indexDims) (inputDim ': inputDims) = If (UnifyCheck (Dim (Name Symbol) (Size Nat)) indexDim inputDim) (PrependMaybe ('Just (indexDim <+> inputDim)) (GatherDimByIndexF (dimIndex - 1) indexDims inputDims)) 'Nothing | |
GatherDimByIndexF _ _ _ = 'Nothing |
type family GatherDimByNameF dimIndex dimIndex' indexDims inputDims where ... Source #
GatherDimByNameF 'Nothing ('Just dimIndex') indexDims inputDims = GatherDimByIndexF dimIndex' indexDims inputDims | |
GatherDimByNameF ('Just dimIndex) 'Nothing indexDims inputDims = GatherDimByIndexF dimIndex indexDims inputDims | |
GatherDimByNameF ('Just dimIndex) ('Just dimIndex) indexDims inputDims = GatherDimByIndexF dimIndex indexDims inputDims | |
GatherDimByNameF _ _ _ _ = 'Nothing |
type GatherDimMessage by indexDims inputDims = "Cannot gather the tensor with the dimensions" % ("" % (((" '" <> inputDims) <> "'") % ("" % ("at the dimension" % ("" % (((" '" <> by) <> "'") % ("" % ("using an index of shape" % ("" % (((" '" <> indexDims) <> "'.") % "")))))))))) Source #
type family GatherDimCheckF by indexDims inputDims result where ... Source #
GatherDimCheckF by indexDims inputDims 'Nothing = TypeError (GatherDimMessage by indexDims inputDims) | |
GatherDimCheckF _ _ _ ('Just dims) = dims |
type family GatherDimF selectDim indexShape inputShape where ... Source #
Calculate the output shape of a gather operation for a given index shape along a given axis.
>>>
:kind! GatherDimF ('SelectDim ('ByIndex 2)) ('Shape '[ 'Dim ('Name "*") ('Size 2), 'Dim ('Name "*") ('Size 1), 'Dim ('Name "*") ('Size 3)]) ('Shape '[ 'Dim ('Name "*") ('Size 2), 'Dim ('Name "*") ('Size 1), 'Dim ('Name "*") ('Size 1)])
... = 'Shape '[ 'Dim ('Name "*") ('Size 2), 'Dim ('Name "*") ('Size 1), 'Dim ('Name "*") ('Size 3)]
GatherDimF 'UncheckedSelectDim _ _ = 'UncheckedShape | |
GatherDimF _ 'UncheckedShape _ = 'UncheckedShape | |
GatherDimF _ _ 'UncheckedShape = 'UncheckedShape | |
GatherDimF ('SelectDim by) ('Shape indexDims) ('Shape inputDims) = 'Shape (GatherDimCheckF by indexDims inputDims (GatherDimImplF by indexDims inputDims)) |
:: forall selectDim indexGradient inputGradient indexLayout inputLayout indexDevice inputDevice indexDataType inputDataType indexShape inputShape outputShape m. (MonadThrow m, outputShape ~ GatherDimF selectDim indexShape inputShape, Catch outputShape, Catch (indexDataType <+> 'DataType 'Int64)) | |
=> SSelectDim selectDim | |
-> Tensor indexGradient indexLayout indexDevice indexDataType indexShape | the indices of elements to gather |
-> Tensor inputGradient inputLayout inputDevice inputDataType inputShape | input |
-> m (Tensor (indexGradient <|> inputGradient) (indexLayout <+> inputLayout) (indexDevice <+> inputDevice) inputDataType outputShape) | output |
Gather values along an axis for a specified dimension.
>>>
sToTensor' = sToTensor (SGradient SWithoutGradient) (SLayout SDense) (SDevice SCPU)
>>>
t <- sToTensor' [[1 :: Float, 2], [3, 4]]
>>>
idx <- sToTensor' [[0 :: Int, 0], [1, 0]]
>>>
result <- sGatherDim (SSelectDim $ SByIndex @1) idx t
>>>
:t result
result :: Tensor ('Gradient 'WithoutGradient) ('Layout 'Dense) ('Device 'CPU) ('DataType 'Float) ('Shape '[ 'Dim ('Name "*") 'UncheckedSize, 'Dim ('Name "*") 'UncheckedSize])>>>
result
Tensor Float [2,2] [[ 1.0000 , 1.0000 ], [ 4.0000 , 3.0000 ]]>>>
shape = SShape $ SNoName :&: SSize @2 :|: SNoName :&: SSize @2 :|: SNil
>>>
t' <- sCheckedShape shape t
>>>
idx' <- sCheckedShape shape idx
>>>
result <- sGatherDim (SSelectDim $ SByIndex @1) idx' t'
>>>
:t result
result :: Tensor ('Gradient 'WithoutGradient) ('Layout 'Dense) ('Device 'CPU) ('DataType 'Float) ('Shape '[ 'Dim ('Name "*") ('Size 2), 'Dim ('Name "*") ('Size 2)])>>>
result
Tensor Float [2,2] [[ 1.0000 , 1.0000 ], [ 4.0000 , 3.0000 ]]