{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeSynonymInstances #-}
{-# LANGUAGE UndecidableInstances #-}

module Torch.Tensor where

import Control.Exception.Safe (throwIO)
import Control.Monad (forM, forM_)
import Numeric.Half
import Data.Complex
import Data.Int (Int16, Int64)
import Data.List (intercalate)
import Data.Proxy
import Data.Reflection
import qualified Data.Vector as V
import Data.Word (Word8)
import Foreign.C.Types
import Foreign.ForeignPtr
import Foreign.Ptr
import Foreign.Storable
import GHC.Generics
import Numeric
import System.IO.Unsafe
import Torch.DType
import Torch.Device
import Torch.Internal.Cast
import Torch.Internal.Class (Castable (..), CppTuple2 (..), CppTuple3 (..), CppTuple4 (..))
import qualified Torch.Internal.Const as ATen
import qualified Torch.Internal.Managed.Cast as ATen
import qualified Torch.Internal.Managed.Native as ATen
import qualified Torch.Internal.Managed.TensorFactories as LibTorch
import qualified Torch.Internal.Managed.Type.Context as ATen
import qualified Torch.Internal.Managed.Type.StdArray as ATen
import qualified Torch.Internal.Managed.Type.StdString as ATen
import qualified Torch.Internal.Managed.Type.Tensor as ATen
import qualified Torch.Internal.Managed.Type.TensorIndex as ATen
import qualified Torch.Internal.Managed.Type.TensorOptions as ATen
import qualified Torch.Internal.Managed.Type.Extra as ATen
import qualified Torch.Internal.Type as ATen
import qualified Torch.Internal.Unmanaged.Type.Tensor as Unmanaged (tensor_data_ptr)
import Torch.Lens
import Torch.TensorOptions

type ATenTensor = ForeignPtr ATen.Tensor

-- do not use the constructor
newtype Tensor = Unsafe ATenTensor

instance Castable Tensor ATenTensor where
  cast :: forall r. Tensor -> (ForeignPtr Tensor -> IO r) -> IO r
cast (Unsafe ForeignPtr Tensor
aten_tensor) ForeignPtr Tensor -> IO r
f = ForeignPtr Tensor -> IO r
f ForeignPtr Tensor
aten_tensor
  uncast :: forall r. ForeignPtr Tensor -> (Tensor -> IO r) -> IO r
uncast ForeignPtr Tensor
aten_tensor Tensor -> IO r
f = Tensor -> IO r
f forall a b. (a -> b) -> a -> b
$ ForeignPtr Tensor -> Tensor
Unsafe ForeignPtr Tensor
aten_tensor

newtype MutableTensor = MutableTensor Tensor deriving Int -> MutableTensor -> ShowS
[MutableTensor] -> ShowS
MutableTensor -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [MutableTensor] -> ShowS
$cshowList :: [MutableTensor] -> ShowS
show :: MutableTensor -> [Char]
$cshow :: MutableTensor -> [Char]
showsPrec :: Int -> MutableTensor -> ShowS
$cshowsPrec :: Int -> MutableTensor -> ShowS
Show

newMutableTensor :: Tensor -> IO MutableTensor
newMutableTensor :: Tensor -> IO MutableTensor
newMutableTensor Tensor
tensor = Tensor -> MutableTensor
MutableTensor forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.detach_t Tensor
tensor

toImmutable :: MutableTensor -> IO Tensor
toImmutable :: MutableTensor -> IO Tensor
toImmutable (MutableTensor Tensor
tensor) = forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.detach_t Tensor
tensor

--------------------------------------------------------------------------------
-- Basic tensor properties
--------------------------------------------------------------------------------

-- | Returns the total number of elements in the input tensor.
numel ::
  -- | input
  Tensor ->
  -- | number of elements in tensor
  Int
numel :: Tensor -> Int
numel Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO Int64
ATen.tensor_numel forall a b. (a -> b) -> a -> b
$ Tensor
t

-- | Returns the size of a given dimension of the input tensor.
size ::
  -- | dimension
  Int ->
  -- | input
  Tensor ->
  Int
size :: Int -> Tensor -> Int
size Int
dim Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ (forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> Int64 -> IO Int64
ATen.tensor_size_l) Tensor
t Int
dim

-- | Returns the shape of the tensor
shape ::
  -- | input
  Tensor ->
  -- | list of integers representing the shape of the tensor
  [Int]
shape :: Tensor -> [Int]
shape Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr IntArray)
ATen.tensor_sizes) Tensor
t

-- | Returns the dimensions of the input tensor
dim ::
  -- | input
  Tensor ->
  -- | output
  Int
dim :: Tensor -> Int
dim Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO Int64
ATen.tensor_dim) Tensor
t

-- | Returns the dimensions of the input tensor
dimUnsafe ::
  -- | input
  Tensor ->
  -- | output
  Int
dimUnsafe :: Tensor -> Int
dimUnsafe Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO Int64
ATen.tensor_dim_unsafe) Tensor
t

-- | Returns the dimensions of the input tensor
dimCUnsafe ::
  -- | input
  Tensor ->
  -- | output
  Int
dimCUnsafe :: Tensor -> Int
dimCUnsafe Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO Int64
ATen.tensor_dim_c_unsafe) Tensor
t

-- | Returns the device on which the tensor is currently allocated
device ::
  -- | input
  Tensor ->
  -- | object representing the device
  Device
device :: Tensor -> Device
device Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
  Bool
hasCUDA <- forall a ca. Castable a ca => IO ca -> IO a
cast0 IO CBool
ATen.hasCUDA :: IO Bool
  if Bool
hasCUDA
    then do
      Bool
isCUDA <- forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO CBool
ATen.tensor_is_cuda Tensor
t :: IO Bool
      if Bool
isCUDA then Int -> Device
cuda forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO Int64
ATen.tensor_get_device Tensor
t else forall (f :: * -> *) a. Applicative f => a -> f a
pure Device
cpu
    else forall (f :: * -> *) a. Applicative f => a -> f a
pure Device
cpu
  where
    cpu :: Device
cpu = Device {deviceType :: DeviceType
deviceType = DeviceType
CPU, deviceIndex :: Int16
deviceIndex = Int16
0}
    cuda :: Int -> Device
    cuda :: Int -> Device
cuda Int
di = Device {deviceType :: DeviceType
deviceType = DeviceType
CUDA, deviceIndex :: Int16
deviceIndex = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
di}

-- | Returns the data type of the input tensor
dtype ::
  -- | input
  Tensor ->
  -- | data type of the input tensor
  DType
dtype :: Tensor -> DType
dtype Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO ScalarType
ATen.tensor_scalar_type Tensor
t

toComplex :: Tensor -> Complex Double
toComplex :: Tensor -> Complex Double
toComplex Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    case Tensor -> DType
dtype Tensor
t of
      DType
ComplexHalf -> do
        Half
r :+ Half
i  <- forall a. Tensor -> (Ptr () -> IO a) -> IO a
withTensor Tensor
t forall a b. (a -> b) -> a -> b
$ \Ptr ()
ptr -> forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
ptr) Int
0 :: IO (Complex Half)
        forall (m :: * -> *) a. Monad m => a -> m a
return (forall a b. (Real a, Fractional b) => a -> b
realToFrac Half
r forall a. a -> a -> Complex a
:+ forall a b. (Real a, Fractional b) => a -> b
realToFrac Half
i)
      DType
ComplexFloat -> do
        Float
r :+ Float
i  <- forall a. Tensor -> (Ptr () -> IO a) -> IO a
withTensor Tensor
t forall a b. (a -> b) -> a -> b
$ \Ptr ()
ptr -> forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
ptr) Int
0 :: IO (Complex Float)
        forall (m :: * -> *) a. Monad m => a -> m a
return (forall a b. (Real a, Fractional b) => a -> b
realToFrac Float
r forall a. a -> a -> Complex a
:+ forall a b. (Real a, Fractional b) => a -> b
realToFrac Float
i)
      DType
ComplexDouble -> forall a. Tensor -> (Ptr () -> IO a) -> IO a
withTensor Tensor
t forall a b. (a -> b) -> a -> b
$ \Ptr ()
ptr -> forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
ptr) Int
0 :: IO (Complex Double)
      DType
_ -> (forall a. a -> a -> Complex a
:+ Double
0) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO CDouble
ATen.tensor_item_double Tensor
t

toDouble :: Tensor -> Double
toDouble :: Tensor -> Double
toDouble Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO CDouble
ATen.tensor_item_double Tensor
t

toInt :: Tensor -> Int
toInt :: Tensor -> Int
toInt Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO Int64
ATen.tensor_item_int64_t Tensor
t

-- | Casts the input tensor to the given data type
_toType ::
  -- | data type to cast input to
  DType ->
  -- | input
  Tensor ->
  -- | output
  Tensor
_toType :: DType -> Tensor -> Tensor
_toType DType
dtype Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ScalarType -> IO (ForeignPtr Tensor)
ATen.tensor_toType_s Tensor
t DType
dtype

instance HasTypes Tensor Tensor where
  types_ :: Traversal' Tensor Tensor
types_ = forall a. a -> a
id

instance HasTypes (a -> a) Tensor where
  types_ :: Traversal' (a -> a) Tensor
types_ Tensor -> f Tensor
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure

instance HasTypes Int Tensor where
  types_ :: Traversal' Int Tensor
types_ Tensor -> f Tensor
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure

instance HasTypes Double Tensor where
  types_ :: Traversal' Double Tensor
types_ Tensor -> f Tensor
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure

instance HasTypes Float Tensor where
  types_ :: Traversal' Float Tensor
types_ Tensor -> f Tensor
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure

instance HasTypes Bool Tensor where
  types_ :: Traversal' Bool Tensor
types_ Tensor -> f Tensor
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure

instance HasTypes Int Int where
  types_ :: Traversal' Int Int
types_ = forall a. a -> a
id

instance HasTypes Float Float where
  types_ :: Traversal' Float Float
types_ = forall a. a -> a
id

instance HasTypes Double Double where
  types_ :: Traversal' Double Double
types_ = forall a. a -> a
id

instance HasTypes Bool Bool where
  types_ :: Traversal' Bool Bool
types_ = forall a. a -> a
id

toType :: forall a. HasTypes a Tensor => DType -> a -> a
toType :: forall a. HasTypes a Tensor => DType -> a -> a
toType DType
dtype a
t = forall s a. Traversal' s a -> (a -> a) -> s -> s
over (forall a s. HasTypes s a => Traversal' s a
types @Tensor @a) (DType -> Tensor -> Tensor
_toType DType
dtype) a
t

toDevice :: forall a. HasTypes a Tensor => Device -> a -> a
toDevice :: forall a. HasTypes a Tensor => Device -> a -> a
toDevice Device
device' a
t = forall s a. Traversal' s a -> (a -> a) -> s -> s
over (forall a s. HasTypes s a => Traversal' s a
types @Tensor @a) (Device -> Tensor -> Tensor
_toDevice Device
device') a
t

-- | Casts the input tensor to given device
_toDevice ::
  -- | device to cast input to
  Device ->
  -- | input
  Tensor ->
  -- | output
  Tensor
_toDevice :: Device -> Tensor -> Tensor
_toDevice Device
device' Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
  Bool
hasCUDA <- forall a ca. Castable a ca => IO ca -> IO a
cast0 IO CBool
ATen.hasCUDA :: IO Bool
  let device :: Device
device = Tensor -> Device
Torch.Tensor.device Tensor
t
  Tensor
t' <-
    DeviceType -> DeviceType -> Int16 -> Int16 -> Bool -> IO Tensor
toDevice'
      (Device -> DeviceType
deviceType Device
device)
      (Device -> DeviceType
deviceType Device
device')
      (Device -> Int16
deviceIndex Device
device)
      (Device -> Int16
deviceIndex Device
device')
      Bool
hasCUDA
  forall {a} {a} {f :: * -> *}.
(Eq a, Eq a, Applicative f, Show a, Show a) =>
a -> a -> a -> a -> f ()
check
    (Device -> DeviceType
deviceType Device
device')
    (Device -> DeviceType
deviceType forall a b. (a -> b) -> a -> b
$ Tensor -> Device
Torch.Tensor.device Tensor
t')
    (Device -> Int16
deviceIndex Device
device')
    (Device -> Int16
deviceIndex forall a b. (a -> b) -> a -> b
$ Tensor -> Device
Torch.Tensor.device Tensor
t')
  forall (f :: * -> *) a. Applicative f => a -> f a
pure Tensor
t'
  where
    toDevice' :: DeviceType -> DeviceType -> Int16 -> Int16 -> Bool -> IO Tensor
toDevice' DeviceType
dt DeviceType
dt' Int16
di Int16
di' Bool
_ | DeviceType
dt forall a. Eq a => a -> a -> Bool
== DeviceType
dt' Bool -> Bool -> Bool
&& Int16
di forall a. Eq a => a -> a -> Bool
== Int16
di' = forall (f :: * -> *) a. Applicative f => a -> f a
pure Tensor
t -- do nothing
    toDevice' DeviceType
CUDA DeviceType
CUDA Int16
di Int16
di' Bool
True | Int16
di forall a. Eq a => a -> a -> Bool
/= Int16
di' = Tensor -> IO TensorOptions
getOpts Tensor
t forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Int16 -> TensorOptions -> IO TensorOptions
withDeviceIndex Int16
di' forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Tensor -> TensorOptions -> IO Tensor
to Tensor
t -- copy from di to di'
    toDevice' DeviceType
CPU DeviceType
CUDA Int16
0 Int16
di' Bool
True | Int16
di' forall a. Ord a => a -> a -> Bool
>= Int16
0 = Tensor -> IO TensorOptions
getOpts Tensor
t forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Int16 -> TensorOptions -> IO TensorOptions
withDeviceIndex Int16
di' forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Tensor -> TensorOptions -> IO Tensor
to Tensor
t -- copy from cpu:0 to cuda:di'
    toDevice' DeviceType
CUDA DeviceType
CPU Int16
di Int16
0 Bool
True | Int16
di forall a. Ord a => a -> a -> Bool
>= Int16
0 = Tensor -> IO TensorOptions
getOpts Tensor
t forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= DeviceType -> TensorOptions -> IO TensorOptions
withDeviceType DeviceType
CPU forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Tensor -> TensorOptions -> IO Tensor
to Tensor
t -- copy from cuda:di to cpu:0
    toDevice' DeviceType
dt DeviceType
dt' Int16
di Int16
di' Bool
_ =
      forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$
        [Char]
"cannot move tensor from \""
          forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> [Char]
show DeviceType
dt
          forall a. Semigroup a => a -> a -> a
<> [Char]
":"
          forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> [Char]
show Int16
di
          forall a. Semigroup a => a -> a -> a
<> [Char]
"\" to \""
          forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> [Char]
show DeviceType
dt'
          forall a. Semigroup a => a -> a -> a
<> [Char]
":"
          forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> [Char]
show Int16
di'
          forall a. Semigroup a => a -> a -> a
<> [Char]
"\""
    getOpts :: Tensor -> IO TensorOptions
    getOpts :: Tensor -> IO TensorOptions
getOpts = forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr TensorOptions)
ATen.tensor_options
    withDeviceType :: DeviceType -> TensorOptions -> IO TensorOptions
    withDeviceType :: DeviceType -> TensorOptions -> IO TensorOptions
withDeviceType DeviceType
dt TensorOptions
opts = forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr TensorOptions -> Int16 -> IO (ForeignPtr TensorOptions)
ATen.tensorOptions_device_D TensorOptions
opts DeviceType
dt
    withDeviceIndex :: Int16 -> TensorOptions -> IO TensorOptions
    withDeviceIndex :: Int16 -> TensorOptions -> IO TensorOptions
withDeviceIndex Int16
di TensorOptions
opts = forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr TensorOptions -> Int16 -> IO (ForeignPtr TensorOptions)
ATen.tensorOptions_device_index_s TensorOptions
opts Int16
di -- careful, setting the device index implies setting the device type to CUDA!
    to :: Tensor -> TensorOptions -> IO Tensor
    to :: Tensor -> TensorOptions -> IO Tensor
to Tensor
t TensorOptions
opts = forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
cast4 ForeignPtr Tensor
-> ForeignPtr TensorOptions
-> CBool
-> CBool
-> IO (ForeignPtr Tensor)
ATen.tensor_to_obb Tensor
t TensorOptions
opts Bool
nonBlocking Bool
copy
      where
        nonBlocking :: Bool
nonBlocking = Bool
False
        copy :: Bool
copy = Bool
False
    check :: a -> a -> a -> a -> f ()
check a
dt a
dt' a
di a
di' | a
dt forall a. Eq a => a -> a -> Bool
== a
dt' Bool -> Bool -> Bool
&& a
di forall a. Eq a => a -> a -> Bool
== a
di' = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    check a
dt a
dt' a
di a
di' =
      forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$
        [Char]
"moving of tensor failed: device should have been \""
          forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> [Char]
show a
dt
          forall a. Semigroup a => a -> a -> a
<> [Char]
":"
          forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> [Char]
show a
di
          forall a. Semigroup a => a -> a -> a
<> [Char]
"\" but is \""
          forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> [Char]
show a
dt'
          forall a. Semigroup a => a -> a -> a
<> [Char]
":"
          forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> [Char]
show a
di'
          forall a. Semigroup a => a -> a -> a
<> [Char]
"\""

toDeviceWithTensor :: Tensor -> Tensor -> Tensor
toDeviceWithTensor :: Tensor -> Tensor -> Tensor
toDeviceWithTensor Tensor
reference Tensor
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.tensor_to_device Tensor
reference Tensor
input

-- | Slices the input tensor along the selected dimension at the given index.
select ::
  -- | dimension to slice along
  Int ->
  -- | index in the given dimension
  Int ->
  -- | input
  Tensor ->
  -- | output
  Tensor
select :: Int -> Int -> Tensor -> Tensor
select Int
dim Int
idx Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Tensor -> Int64 -> Int64 -> IO (ForeignPtr Tensor)
ATen.tensor_select_ll Tensor
t Int
dim Int
idx

-- | Returns a new tensor which indexes the input tensor along dimension dim using the entries in index which is a LongTensor.
indexSelect ::
  -- | dim
  Int ->
  -- | indexTensor
  Tensor ->
  -- | input
  Tensor ->
  -- | output
  Tensor
indexSelect :: Int -> Tensor -> Tensor -> Tensor
indexSelect Int
dim Tensor
indexTensor Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ (forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Tensor
-> Int64 -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.index_select_tlt) Tensor
t Int
dim Tensor
indexTensor

indexSelect' ::
  -- | dim
  Int ->
  -- | indexList
  [Int] ->
  -- | input
  Tensor ->
  -- | output
  Tensor
indexSelect' :: Int -> [Int] -> Tensor -> Tensor
indexSelect' Int
dim [Int]
indexList Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ (forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Tensor
-> Int64 -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.index_select_tlt) Tensor
t Int
dim (forall a opt.
(TensorLike a, TensorOptionLike opt) =>
a -> opt -> Tensor
asTensor' [Int]
indexList Tensor
t)

-- | Slices the input tensor along the selected dimension at the given range.
sliceDim ::
  -- | dim
  Int ->
  -- | start
  Int ->
  -- | end
  Int ->
  -- | step
  Int ->
  -- | input
  Tensor ->
  Tensor
sliceDim :: Int -> Int -> Int -> Int -> Tensor -> Tensor
sliceDim Int
_dim Int
_start Int
_end Int
_step Tensor
_self = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ (forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> IO y
cast5 ForeignPtr Tensor
-> Int64 -> Int64 -> Int64 -> Int64 -> IO (ForeignPtr Tensor)
ATen.slice_tllll) Tensor
_self Int
_dim Int
_start Int
_end Int
_step

isContiguous ::
  Tensor ->
  Bool
isContiguous :: Tensor -> Bool
isContiguous Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO CBool
ATen.tensor_is_contiguous) Tensor
t

contiguous ::
  Tensor ->
  Tensor
contiguous :: Tensor -> Tensor
contiguous Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.tensor_contiguous) Tensor
t

-- | Returns a tensor with the same data and number of elements as input, but with the specified shape.
reshape ::
  [Int] ->
  Tensor ->
  Tensor
reshape :: [Int] -> Tensor -> Tensor
reshape [Int]
shape Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr IntArray -> IO (ForeignPtr Tensor)
ATen.reshape_tl Tensor
t [Int]
shape

--------------------------------------------------------------------------------
-- Move backend
--------------------------------------------------------------------------------

toSparse :: Tensor -> Tensor
toSparse :: Tensor -> Tensor
toSparse Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.tensor_to_sparse) Tensor
t

toDense :: Tensor -> Tensor
toDense :: Tensor -> Tensor
toDense Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.tensor_to_dense) Tensor
t

toMKLDNN :: Tensor -> Tensor
toMKLDNN :: Tensor -> Tensor
toMKLDNN Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.tensor_to_mkldnn) Tensor
t

toCPU :: Tensor -> Tensor
toCPU :: Tensor -> Tensor
toCPU Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.tensor_cpu) Tensor
t

toCUDA :: Tensor -> Tensor
toCUDA :: Tensor -> Tensor
toCUDA Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.tensor_cuda) Tensor
t

--------------------------------------------------------------------------------
-- Indexing support
--------------------------------------------------------------------------------

-- TensorIndex is the same as slice of pytorch.
--
-- There is one-to-one correspondence between Pytorch and Hasktorch tensor index types:
-- Pytorch                 | Hasktorch
-- -----------------------------------------------------
-- `None`                  | `None`
-- `Ellipsis`              | `Ellipsis`
-- `...`                   | `Ellipsis`
-- `123`                   | `123`
-- `True` / `False`        | `True` / `False`
-- `:`                     | `Slice ()`
-- `::`                    | `Slice ()`
-- `1:`                    | `Slice (1, None)`
-- `1::`                   | `Slice (1, None)`
-- `:3`                    | `Slice (None, 3)`
-- `:3:`                   | `Slice (None, 3)`
-- `::2`                   | `Slice (None, None, 2)`
-- `1:3`                   | `Slice (1, 3)`
-- `1::2`                  | `Slice (1, None, 2)`
-- `:3:2`                  | `Slice (None, 3, 2)`
-- `1:3:2`                 | `Slice (1, 3, 2)`
-- `torch.tensor([1, 2])`) | `asTensor([1, 2 ::Int])`

newtype RawTensorIndexList = RawTensorIndexList (ForeignPtr (ATen.StdVector ATen.TensorIndex))

newtype RawTensorIndex = RawTensorIndex (ForeignPtr ATen.TensorIndex)

(!) :: TensorIndex a => Tensor -> a -> Tensor
(Unsafe ForeignPtr Tensor
t) ! :: forall a. TensorIndex a => Tensor -> a -> Tensor
! a
idx = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
  let idxs :: [RawTensorIndex]
idxs = forall a.
TensorIndex a =>
[RawTensorIndex] -> a -> [RawTensorIndex]
pushIndex [] a
idx
  ForeignPtr (StdVector TensorIndex)
vec <- IO (ForeignPtr (StdVector TensorIndex))
ATen.newTensorIndexList
  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [RawTensorIndex]
idxs forall a b. (a -> b) -> a -> b
$ \(RawTensorIndex ForeignPtr TensorIndex
i) -> do
    ForeignPtr (StdVector TensorIndex)
-> ForeignPtr TensorIndex -> IO ()
ATen.tensorIndexList_push_back ForeignPtr (StdVector TensorIndex)
vec ForeignPtr TensorIndex
i
  ForeignPtr Tensor
-> ForeignPtr (StdVector TensorIndex) -> IO (ForeignPtr Tensor)
ATen.index ForeignPtr Tensor
t ForeignPtr (StdVector TensorIndex)
vec forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. ForeignPtr Tensor -> Tensor
Unsafe)

maskedFill :: (TensorIndex a, TensorLike t) => Tensor -> a -> t -> Tensor
maskedFill :: forall a t.
(TensorIndex a, TensorLike t) =>
Tensor -> a -> t -> Tensor
maskedFill (Unsafe ForeignPtr Tensor
t') a
idx t
v' = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
  let idxs :: [RawTensorIndex]
idxs = forall a.
TensorIndex a =>
[RawTensorIndex] -> a -> [RawTensorIndex]
pushIndex [] a
idx
      (Unsafe ForeignPtr Tensor
v) = forall a. TensorLike a => a -> Tensor
asTensor t
v'
  ForeignPtr Tensor
t <- ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.clone_t ForeignPtr Tensor
t'
  ForeignPtr (StdVector TensorIndex)
vec <- IO (ForeignPtr (StdVector TensorIndex))
ATen.newTensorIndexList
  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [RawTensorIndex]
idxs forall a b. (a -> b) -> a -> b
$ \(RawTensorIndex ForeignPtr TensorIndex
i) -> do
    ForeignPtr (StdVector TensorIndex)
-> ForeignPtr TensorIndex -> IO ()
ATen.tensorIndexList_push_back ForeignPtr (StdVector TensorIndex)
vec ForeignPtr TensorIndex
i
  ForeignPtr Tensor
-> ForeignPtr (StdVector TensorIndex)
-> ForeignPtr Tensor
-> IO (ForeignPtr Tensor)
ATen.index_put_ ForeignPtr Tensor
t ForeignPtr (StdVector TensorIndex)
vec ForeignPtr Tensor
v
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ ForeignPtr Tensor -> Tensor
Unsafe ForeignPtr Tensor
t

data None = None
  deriving (Int -> None -> ShowS
[None] -> ShowS
None -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [None] -> ShowS
$cshowList :: [None] -> ShowS
show :: None -> [Char]
$cshow :: None -> [Char]
showsPrec :: Int -> None -> ShowS
$cshowsPrec :: Int -> None -> ShowS
Show, None -> None -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: None -> None -> Bool
$c/= :: None -> None -> Bool
== :: None -> None -> Bool
$c== :: None -> None -> Bool
Eq)

data Ellipsis = Ellipsis
  deriving (Int -> Ellipsis -> ShowS
[Ellipsis] -> ShowS
Ellipsis -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [Ellipsis] -> ShowS
$cshowList :: [Ellipsis] -> ShowS
show :: Ellipsis -> [Char]
$cshow :: Ellipsis -> [Char]
showsPrec :: Int -> Ellipsis -> ShowS
$cshowsPrec :: Int -> Ellipsis -> ShowS
Show, Ellipsis -> Ellipsis -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Ellipsis -> Ellipsis -> Bool
$c/= :: Ellipsis -> Ellipsis -> Bool
== :: Ellipsis -> Ellipsis -> Bool
$c== :: Ellipsis -> Ellipsis -> Bool
Eq)

newtype Slice a = Slice a
  deriving (Int -> Slice a -> ShowS
forall a. Show a => Int -> Slice a -> ShowS
forall a. Show a => [Slice a] -> ShowS
forall a. Show a => Slice a -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [Slice a] -> ShowS
$cshowList :: forall a. Show a => [Slice a] -> ShowS
show :: Slice a -> [Char]
$cshow :: forall a. Show a => Slice a -> [Char]
showsPrec :: Int -> Slice a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> Slice a -> ShowS
Show, Slice a -> Slice a -> Bool
forall a. Eq a => Slice a -> Slice a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Slice a -> Slice a -> Bool
$c/= :: forall a. Eq a => Slice a -> Slice a -> Bool
== :: Slice a -> Slice a -> Bool
$c== :: forall a. Eq a => Slice a -> Slice a -> Bool
Eq)

instance Castable RawTensorIndex (ForeignPtr ATen.TensorIndex) where
  cast :: forall r.
RawTensorIndex -> (ForeignPtr TensorIndex -> IO r) -> IO r
cast (RawTensorIndex ForeignPtr TensorIndex
obj) ForeignPtr TensorIndex -> IO r
f = ForeignPtr TensorIndex -> IO r
f ForeignPtr TensorIndex
obj
  uncast :: forall r.
ForeignPtr TensorIndex -> (RawTensorIndex -> IO r) -> IO r
uncast ForeignPtr TensorIndex
obj RawTensorIndex -> IO r
f = RawTensorIndex -> IO r
f forall a b. (a -> b) -> a -> b
$ ForeignPtr TensorIndex -> RawTensorIndex
RawTensorIndex ForeignPtr TensorIndex
obj

class TensorIndex a where
  pushIndex :: [RawTensorIndex] -> a -> [RawTensorIndex]
  toLens :: TensorLike b => a -> Lens' Tensor b
  default toLens :: TensorLike b => a -> Lens' Tensor b
  toLens a
idx b -> f b
func Tensor
s = forall a t.
(TensorIndex a, TensorLike t) =>
Tensor -> a -> t -> Tensor
maskedFill Tensor
s a
idx forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall a. TensorLike a => a -> Tensor
asTensor forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> b -> f b
func (forall a. TensorLike a => Tensor -> a
asValue (Tensor
s forall a. TensorIndex a => Tensor -> a -> Tensor
! a
idx)))

instance {-# OVERLAPS #-} TensorIndex None where
  pushIndex :: [RawTensorIndex] -> None -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec None
_ = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
    ForeignPtr TensorIndex
idx <- IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithNone
    forall (m :: * -> *) a. Monad m => a -> m a
return ((ForeignPtr TensorIndex -> RawTensorIndex
RawTensorIndex ForeignPtr TensorIndex
idx) forall a. a -> [a] -> [a]
: [RawTensorIndex]
vec)

instance {-# OVERLAPS #-} TensorIndex Ellipsis where
  pushIndex :: [RawTensorIndex] -> Ellipsis -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec Ellipsis
_ = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
    ForeignPtr TensorIndex
idx <- IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithEllipsis
    forall (m :: * -> *) a. Monad m => a -> m a
return ((ForeignPtr TensorIndex -> RawTensorIndex
RawTensorIndex ForeignPtr TensorIndex
idx) forall a. a -> [a] -> [a]
: [RawTensorIndex]
vec)

instance {-# OVERLAPS #-} TensorIndex Bool where
  pushIndex :: [RawTensorIndex] -> Bool -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec Bool
b = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
    ForeignPtr TensorIndex
idx <- CBool -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithBool (if Bool
b then CBool
1 else CBool
0)
    forall (m :: * -> *) a. Monad m => a -> m a
return ((ForeignPtr TensorIndex -> RawTensorIndex
RawTensorIndex ForeignPtr TensorIndex
idx) forall a. a -> [a] -> [a]
: [RawTensorIndex]
vec)

instance {-# OVERLAPS #-} (Integral a) => TensorIndex (Slice (a, a)) where
  pushIndex :: [RawTensorIndex] -> Slice (a, a) -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec (Slice (a
start, a
end)) = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
    ForeignPtr TensorIndex
idx <- CInt -> CInt -> CInt -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithSlice (forall a b. (Integral a, Num b) => a -> b
fromIntegral a
start :: CInt) (forall a b. (Integral a, Num b) => a -> b
fromIntegral a
end :: CInt) CInt
1
    forall (m :: * -> *) a. Monad m => a -> m a
return ((ForeignPtr TensorIndex -> RawTensorIndex
RawTensorIndex ForeignPtr TensorIndex
idx) forall a. a -> [a] -> [a]
: [RawTensorIndex]
vec)

instance {-# OVERLAPS #-} (Integral a) => TensorIndex (Slice (a, a, a)) where
  pushIndex :: [RawTensorIndex] -> Slice (a, a, a) -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec (Slice (a
start, a
end, a
step)) = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
    ForeignPtr TensorIndex
idx <- CInt -> CInt -> CInt -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithSlice (forall a b. (Integral a, Num b) => a -> b
fromIntegral a
start :: CInt) (forall a b. (Integral a, Num b) => a -> b
fromIntegral a
end :: CInt) (forall a b. (Integral a, Num b) => a -> b
fromIntegral a
step :: CInt)
    forall (m :: * -> *) a. Monad m => a -> m a
return ((ForeignPtr TensorIndex -> RawTensorIndex
RawTensorIndex ForeignPtr TensorIndex
idx) forall a. a -> [a] -> [a]
: [RawTensorIndex]
vec)

instance {-# OVERLAPS #-} (Integral a) => TensorIndex (Slice (None, None, a)) where
  pushIndex :: [RawTensorIndex] -> Slice (None, None, a) -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec (Slice (None
_, None
_, a
step)) = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
    ForeignPtr TensorIndex
idx <- CInt -> CInt -> CInt -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithSlice CInt
0 (forall a. Bounded a => a
maxBound :: CInt) (forall a b. (Integral a, Num b) => a -> b
fromIntegral a
step :: CInt)
    forall (m :: * -> *) a. Monad m => a -> m a
return ((ForeignPtr TensorIndex -> RawTensorIndex
RawTensorIndex ForeignPtr TensorIndex
idx) forall a. a -> [a] -> [a]
: [RawTensorIndex]
vec)

instance {-# OVERLAPS #-} (Integral a) => TensorIndex (Slice a) where
  pushIndex :: [RawTensorIndex] -> Slice a -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec (Slice a
start) = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
    ForeignPtr TensorIndex
idx <- CInt -> CInt -> CInt -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithSlice (forall a b. (Integral a, Num b) => a -> b
fromIntegral a
start :: CInt) (forall a. Bounded a => a
maxBound :: CInt) CInt
1
    forall (m :: * -> *) a. Monad m => a -> m a
return ((ForeignPtr TensorIndex -> RawTensorIndex
RawTensorIndex ForeignPtr TensorIndex
idx) forall a. a -> [a] -> [a]
: [RawTensorIndex]
vec)

instance {-# OVERLAPS #-} (Integral a) => TensorIndex (Slice (a, None)) where
  pushIndex :: [RawTensorIndex] -> Slice (a, None) -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec (Slice (a
start, None
_)) = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
    ForeignPtr TensorIndex
idx <- CInt -> CInt -> CInt -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithSlice (forall a b. (Integral a, Num b) => a -> b
fromIntegral a
start :: CInt) (forall a. Bounded a => a
maxBound :: CInt) CInt
1
    forall (m :: * -> *) a. Monad m => a -> m a
return ((ForeignPtr TensorIndex -> RawTensorIndex
RawTensorIndex ForeignPtr TensorIndex
idx) forall a. a -> [a] -> [a]
: [RawTensorIndex]
vec)

instance {-# OVERLAPS #-} (Integral a) => TensorIndex (Slice (a, None, a)) where
  pushIndex :: [RawTensorIndex] -> Slice (a, None, a) -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec (Slice (a
start, None
_, a
step)) = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
    ForeignPtr TensorIndex
idx <- CInt -> CInt -> CInt -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithSlice (forall a b. (Integral a, Num b) => a -> b
fromIntegral a
start :: CInt) (forall a. Bounded a => a
maxBound :: CInt) (forall a b. (Integral a, Num b) => a -> b
fromIntegral a
step :: CInt)
    forall (m :: * -> *) a. Monad m => a -> m a
return ((ForeignPtr TensorIndex -> RawTensorIndex
RawTensorIndex ForeignPtr TensorIndex
idx) forall a. a -> [a] -> [a]
: [RawTensorIndex]
vec)

instance {-# OVERLAPS #-} (Integral a) => TensorIndex (Slice (None, a, a)) where
  pushIndex :: [RawTensorIndex] -> Slice (None, a, a) -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec (Slice (None
_, a
end, a
step)) = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
    ForeignPtr TensorIndex
idx <- CInt -> CInt -> CInt -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithSlice CInt
0 (forall a b. (Integral a, Num b) => a -> b
fromIntegral a
end :: CInt) (forall a b. (Integral a, Num b) => a -> b
fromIntegral a
step :: CInt)
    forall (m :: * -> *) a. Monad m => a -> m a
return ((ForeignPtr TensorIndex -> RawTensorIndex
RawTensorIndex ForeignPtr TensorIndex
idx) forall a. a -> [a] -> [a]
: [RawTensorIndex]
vec)

instance {-# OVERLAPS #-} (Integral a) => TensorIndex (Slice (None, a)) where
  pushIndex :: [RawTensorIndex] -> Slice (None, a) -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec (Slice (None
_, a
end)) = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
    ForeignPtr TensorIndex
idx <- CInt -> CInt -> CInt -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithSlice CInt
0 (forall a b. (Integral a, Num b) => a -> b
fromIntegral a
end :: CInt) CInt
1
    forall (m :: * -> *) a. Monad m => a -> m a
return ((ForeignPtr TensorIndex -> RawTensorIndex
RawTensorIndex ForeignPtr TensorIndex
idx) forall a. a -> [a] -> [a]
: [RawTensorIndex]
vec)

instance {-# OVERLAPS #-} TensorIndex (Slice ()) where
  pushIndex :: [RawTensorIndex] -> Slice () -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec (Slice ()) = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
    ForeignPtr TensorIndex
idx <- CInt -> CInt -> CInt -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithSlice CInt
0 (forall a. Bounded a => a
maxBound :: CInt) CInt
1
    forall (m :: * -> *) a. Monad m => a -> m a
return ((ForeignPtr TensorIndex -> RawTensorIndex
RawTensorIndex ForeignPtr TensorIndex
idx) forall a. a -> [a] -> [a]
: [RawTensorIndex]
vec)

instance TensorIndex Int where
  pushIndex :: [RawTensorIndex] -> Int -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec Int
v = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
    ForeignPtr TensorIndex
idx <- CInt -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithInt (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
v :: CInt)
    forall (m :: * -> *) a. Monad m => a -> m a
return ((ForeignPtr TensorIndex -> RawTensorIndex
RawTensorIndex ForeignPtr TensorIndex
idx) forall a. a -> [a] -> [a]
: [RawTensorIndex]
vec)

instance TensorIndex Integer where
  pushIndex :: [RawTensorIndex] -> Integer -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec Integer
v = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
    ForeignPtr TensorIndex
idx <- CInt -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithInt (forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
v :: CInt)
    forall (m :: * -> *) a. Monad m => a -> m a
return ((ForeignPtr TensorIndex -> RawTensorIndex
RawTensorIndex ForeignPtr TensorIndex
idx) forall a. a -> [a] -> [a]
: [RawTensorIndex]
vec)

instance TensorIndex Tensor where
  pushIndex :: [RawTensorIndex] -> Tensor -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec Tensor
v = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
    RawTensorIndex
idx <- forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithTensor Tensor
v
    forall (m :: * -> *) a. Monad m => a -> m a
return (RawTensorIndex
idx forall a. a -> [a] -> [a]
: [RawTensorIndex]
vec)

instance TensorIndex () where
  pushIndex :: [RawTensorIndex] -> () -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec ()
_ = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
    ForeignPtr TensorIndex
idx <- CInt -> CInt -> CInt -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithSlice CInt
0 (forall a. Bounded a => a
maxBound :: CInt) CInt
1
    forall (m :: * -> *) a. Monad m => a -> m a
return ((ForeignPtr TensorIndex -> RawTensorIndex
RawTensorIndex ForeignPtr TensorIndex
idx) forall a. a -> [a] -> [a]
: [RawTensorIndex]
vec)

instance (TensorIndex a, TensorIndex b) => TensorIndex (a, b) where
  pushIndex :: [RawTensorIndex] -> (a, b) -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec (a
a, b
b) = (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a.
TensorIndex a =>
[RawTensorIndex] -> a -> [RawTensorIndex]
pushIndex a
a) forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a.
TensorIndex a =>
[RawTensorIndex] -> a -> [RawTensorIndex]
pushIndex b
b) forall a b. (a -> b) -> a -> b
$ [RawTensorIndex]
vec

instance (TensorIndex a, TensorIndex b, TensorIndex c) => TensorIndex (a, b, c) where
  pushIndex :: [RawTensorIndex] -> (a, b, c) -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec (a
a, b
b, c
c) = (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a.
TensorIndex a =>
[RawTensorIndex] -> a -> [RawTensorIndex]
pushIndex a
a) forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a.
TensorIndex a =>
[RawTensorIndex] -> a -> [RawTensorIndex]
pushIndex b
b) forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a.
TensorIndex a =>
[RawTensorIndex] -> a -> [RawTensorIndex]
pushIndex c
c) forall a b. (a -> b) -> a -> b
$ [RawTensorIndex]
vec

instance (TensorIndex a, TensorIndex b, TensorIndex c, TensorIndex d) => TensorIndex (a, b, c, d) where
  pushIndex :: [RawTensorIndex] -> (a, b, c, d) -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec (a
a, b
b, c
c, d
d) = (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a.
TensorIndex a =>
[RawTensorIndex] -> a -> [RawTensorIndex]
pushIndex a
a) forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a.
TensorIndex a =>
[RawTensorIndex] -> a -> [RawTensorIndex]
pushIndex b
b) forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a.
TensorIndex a =>
[RawTensorIndex] -> a -> [RawTensorIndex]
pushIndex c
c) forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a.
TensorIndex a =>
[RawTensorIndex] -> a -> [RawTensorIndex]
pushIndex d
d) forall a b. (a -> b) -> a -> b
$ [RawTensorIndex]
vec

instance (TensorIndex a, TensorIndex b, TensorIndex c, TensorIndex d, TensorIndex e) => TensorIndex (a, b, c, d, e) where
  pushIndex :: [RawTensorIndex] -> (a, b, c, d, e) -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec (a
a, b
b, c
c, d
d, e
e) = (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a.
TensorIndex a =>
[RawTensorIndex] -> a -> [RawTensorIndex]
pushIndex a
a) forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a.
TensorIndex a =>
[RawTensorIndex] -> a -> [RawTensorIndex]
pushIndex b
b) forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a.
TensorIndex a =>
[RawTensorIndex] -> a -> [RawTensorIndex]
pushIndex c
c) forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a.
TensorIndex a =>
[RawTensorIndex] -> a -> [RawTensorIndex]
pushIndex d
d) forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a.
TensorIndex a =>
[RawTensorIndex] -> a -> [RawTensorIndex]
pushIndex e
e) forall a b. (a -> b) -> a -> b
$ [RawTensorIndex]
vec

--------------------------------------------------------------------------------
-- Scalar <-> Tensor promotion
--------------------------------------------------------------------------------

asValue :: TensorLike a => Tensor -> a
asValue :: forall a. TensorLike a => Tensor -> a
asValue Tensor
t =
  let cpuTensor :: Tensor
cpuTensor = if Tensor -> Device
device Tensor
t forall a. Eq a => a -> a -> Bool
== DeviceType -> Int16 -> Device
Device DeviceType
CPU Int16
0 then Tensor
t else Tensor -> Tensor
toCPU Tensor
t
      contTensor :: Tensor
contTensor = if Tensor -> Bool
isContiguous Tensor
cpuTensor then Tensor
cpuTensor else Tensor -> Tensor
contiguous Tensor
cpuTensor
   in forall a. TensorLike a => Tensor -> a
_asValue Tensor
contTensor

class TensorOptionLike a where
  withTensorOptions :: Tensor -> a -> Tensor

instance  TensorOptionLike TensorOptions where
  withTensorOptions :: Tensor -> TensorOptions -> Tensor
withTensorOptions Tensor
t TensorOptions
opts = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
cast4 ForeignPtr Tensor
-> ForeignPtr TensorOptions
-> CBool
-> CBool
-> IO (ForeignPtr Tensor)
ATen.tensor_to_obb Tensor
t TensorOptions
opts Bool
nonBlocking Bool
copy
    where
      nonBlocking :: Bool
nonBlocking = Bool
False
      copy :: Bool
copy = Bool
False

instance  TensorOptionLike Tensor where
  withTensorOptions :: Tensor -> Tensor -> Tensor
withTensorOptions Tensor
t Tensor
opts = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
cast4 ForeignPtr Tensor
-> ForeignPtr Tensor -> CBool -> CBool -> IO (ForeignPtr Tensor)
ATen.tensor_to_tbb Tensor
t Tensor
opts Bool
nonBlocking Bool
copy
    where
      nonBlocking :: Bool
nonBlocking = Bool
False
      copy :: Bool
copy = Bool
False

class TensorLike a where
  asTensor' :: TensorOptionLike opt => a -> opt -> Tensor
  asTensor' a
v opt
opts = forall a. TensorOptionLike a => Tensor -> a -> Tensor
withTensorOptions (forall a. TensorLike a => a -> Tensor
asTensor a
v) opt
opts
  asTensor :: a -> Tensor
  _asValue :: Tensor -> a

  -- Internal functions(like "_xxx") are below. Do not use them directly.
  _dtype :: DType
  _dims :: a -> [Int]
  _deepDims :: a -> Maybe [Int]
  _peekElemOff :: Ptr () -> Int -> [Int] -> IO a
  _pokeElemOff :: Ptr () -> Int -> a -> IO ()

bool_opts :: TensorOptions
bool_opts = DType -> TensorOptions -> TensorOptions
withDType DType
Bool TensorOptions
defaultOpts

uint8_opts :: TensorOptions
uint8_opts = DType -> TensorOptions -> TensorOptions
withDType DType
UInt8 TensorOptions
defaultOpts

int64_opts :: TensorOptions
int64_opts = DType -> TensorOptions -> TensorOptions
withDType DType
Int64 TensorOptions
defaultOpts

float_opts :: TensorOptions
float_opts = DType -> TensorOptions -> TensorOptions
withDType DType
Float TensorOptions
defaultOpts

double_opts :: TensorOptions
double_opts = DType -> TensorOptions -> TensorOptions
withDType DType
Double TensorOptions
defaultOpts

withTensor :: Tensor -> (Ptr () -> IO a) -> IO a
withTensor :: forall a. Tensor -> (Ptr () -> IO a) -> IO a
withTensor Tensor
t Ptr () -> IO a
fn =
  let tensor :: Tensor
tensor = if Tensor -> Bool
isContiguous Tensor
t then Tensor
t else Tensor -> Tensor
contiguous Tensor
t
   in forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast Tensor
tensor forall a b. (a -> b) -> a -> b
$ \ForeignPtr Tensor
t' -> forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Tensor
t' forall a b. (a -> b) -> a -> b
$ \Ptr Tensor
tensor_ptr -> Ptr Tensor -> IO (Ptr ())
Unmanaged.tensor_data_ptr Ptr Tensor
tensor_ptr forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Ptr () -> IO a
fn

-- | The internal function of withTensor. It does not check contiguous memory-layout.
_withTensor :: Tensor -> (Ptr () -> IO a) -> IO a
_withTensor :: forall a. Tensor -> (Ptr () -> IO a) -> IO a
_withTensor Tensor
t Ptr () -> IO a
fn =
  forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast Tensor
t forall a b. (a -> b) -> a -> b
$ \ForeignPtr Tensor
t' -> forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Tensor
t' forall a b. (a -> b) -> a -> b
$ \Ptr Tensor
tensor_ptr -> Ptr Tensor -> IO (Ptr ())
Unmanaged.tensor_data_ptr Ptr Tensor
tensor_ptr forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Ptr () -> IO a
fn

instance {-# OVERLAPPING #-} (Reifies a DType, Storable a) => TensorLike a where
  asTensor :: a -> Tensor
asTensor a
v = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
    Tensor
t <- ((forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 [Int] -> ForeignPtr TensorOptions -> IO (ForeignPtr Tensor)
ATen.new_empty_tensor) :: [Int] -> TensorOptions -> IO Tensor) [] forall a b. (a -> b) -> a -> b
$ DType -> TensorOptions -> TensorOptions
withDType (forall a. TensorLike a => DType
_dtype @a) TensorOptions
defaultOpts
    forall a. Tensor -> (Ptr () -> IO a) -> IO a
_withTensor Tensor
t forall a b. (a -> b) -> a -> b
$ \Ptr ()
ptr -> do
      forall a. TensorLike a => Ptr () -> Int -> a -> IO ()
_pokeElemOff Ptr ()
ptr Int
0 a
v
    forall (m :: * -> *) a. Monad m => a -> m a
return Tensor
t

  _asValue :: Tensor -> a
_asValue Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
    if forall a. TensorLike a => DType
_dtype @a forall a. Eq a => a -> a -> Bool
== Tensor -> DType
dtype Tensor
t
      then do
        forall a. Tensor -> (Ptr () -> IO a) -> IO a
withTensor Tensor
t forall a b. (a -> b) -> a -> b
$ \Ptr ()
ptr -> do
          forall a. TensorLike a => Ptr () -> Int -> [Int] -> IO a
_peekElemOff Ptr ()
ptr Int
0 []
      else forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO forall a b. (a -> b) -> a -> b
$ [Char] -> IOError
userError forall a b. (a -> b) -> a -> b
$ [Char]
"The infered DType of asValue is " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show (forall a. TensorLike a => DType
_dtype @a) forall a. [a] -> [a] -> [a]
++ [Char]
", but the DType of tensor on memory is " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show (Tensor -> DType
dtype Tensor
t) forall a. [a] -> [a] -> [a]
++ [Char]
"."

  _dtype :: DType
_dtype = forall {k} (s :: k) a (proxy :: k -> *).
Reifies s a =>
proxy s -> a
reflect (forall {k} (t :: k). Proxy t
Proxy :: Proxy a)
  _dims :: a -> [Int]
_dims a
_ = []
  _deepDims :: a -> Maybe [Int]
_deepDims a
_ = forall a. a -> Maybe a
Just []
  _peekElemOff :: Ptr () -> Int -> [Int] -> IO a
_peekElemOff Ptr ()
ptr Int
offset [Int]
_ = forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
ptr) Int
offset
  _pokeElemOff :: Ptr () -> Int -> a -> IO ()
_pokeElemOff Ptr ()
ptr Int
offset a
v = forall a. Storable a => Ptr a -> Int -> a -> IO ()
pokeElemOff (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
ptr) Int
offset a
v

instance {-# OVERLAPPING #-} TensorLike Bool where
  asTensor :: Bool -> Tensor
asTensor Bool
v = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
    Tensor
t <- ((forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 [Int] -> ForeignPtr TensorOptions -> IO (ForeignPtr Tensor)
ATen.new_empty_tensor) :: [Int] -> TensorOptions -> IO Tensor) [] forall a b. (a -> b) -> a -> b
$ DType -> TensorOptions -> TensorOptions
withDType (forall a. TensorLike a => DType
_dtype @Bool) TensorOptions
defaultOpts
    forall a. Tensor -> (Ptr () -> IO a) -> IO a
_withTensor Tensor
t forall a b. (a -> b) -> a -> b
$ \Ptr ()
ptr -> do
      forall a. TensorLike a => Ptr () -> Int -> a -> IO ()
_pokeElemOff Ptr ()
ptr Int
0 Bool
v
    forall (m :: * -> *) a. Monad m => a -> m a
return Tensor
t

  _asValue :: Tensor -> Bool
_asValue Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
    if forall a. TensorLike a => DType
_dtype @Bool forall a. Eq a => a -> a -> Bool
== Tensor -> DType
dtype Tensor
t
      then do
        forall a. Tensor -> (Ptr () -> IO a) -> IO a
withTensor Tensor
t forall a b. (a -> b) -> a -> b
$ \Ptr ()
ptr -> do
          forall a. TensorLike a => Ptr () -> Int -> [Int] -> IO a
_peekElemOff Ptr ()
ptr Int
0 []
      else forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO forall a b. (a -> b) -> a -> b
$ [Char] -> IOError
userError forall a b. (a -> b) -> a -> b
$ [Char]
"The infered DType of asValue is " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show (forall a. TensorLike a => DType
_dtype @Bool) forall a. [a] -> [a] -> [a]
++ [Char]
", but the DType of tensor on memory is " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show (Tensor -> DType
dtype Tensor
t) forall a. [a] -> [a] -> [a]
++ [Char]
"."

  _dtype :: DType
_dtype = forall {k} (s :: k) a (proxy :: k -> *).
Reifies s a =>
proxy s -> a
reflect (forall {k} (t :: k). Proxy t
Proxy :: Proxy Bool)
  _dims :: Bool -> [Int]
_dims Bool
_ = []
  _deepDims :: Bool -> Maybe [Int]
_deepDims Bool
_ = forall a. a -> Maybe a
Just []
  _peekElemOff :: Ptr () -> Int -> [Int] -> IO Bool
_peekElemOff Ptr ()
ptr Int
offset [Int]
_ = (forall a. Eq a => a -> a -> Bool
/= Word8
0) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
ptr) Int
offset :: IO Word8)
  _pokeElemOff :: Ptr () -> Int -> Bool -> IO ()
_pokeElemOff Ptr ()
ptr Int
offset Bool
v = forall a. Storable a => Ptr a -> Int -> a -> IO ()
pokeElemOff (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
ptr) Int
offset ((if Bool
v then Word8
1 else Word8
0) :: Word8)

instance {-# OVERLAPPING #-} TensorLike Tensor where
  asTensor' :: forall a. TensorOptionLike a => Tensor -> a -> Tensor
asTensor' Tensor
v opt
opts = forall a. TensorOptionLike a => Tensor -> a -> Tensor
withTensorOptions Tensor
v opt
opts
  asTensor :: Tensor -> Tensor
asTensor = forall a. a -> a
id
  _asValue :: Tensor -> Tensor
_asValue = forall a. a -> a
id
  _dtype :: DType
_dtype = forall a. HasCallStack => [Char] -> a
error [Char]
"Not implemented for Tensor-type"
  _dims :: Tensor -> [Int]
_dims Tensor
v = forall a. HasCallStack => [Char] -> a
error [Char]
"Not implemented for Tensor-type"
  _deepDims :: Tensor -> Maybe [Int]
_deepDims Tensor
v = forall a. HasCallStack => [Char] -> a
error [Char]
"Not implemented for Tensor-type"
  _peekElemOff :: Ptr () -> Int -> [Int] -> IO Tensor
_peekElemOff = forall a. HasCallStack => [Char] -> a
error [Char]
"Not implemented for Tensor-type"
  _pokeElemOff :: Ptr () -> Int -> Tensor -> IO ()
_pokeElemOff = forall a. HasCallStack => [Char] -> a
error [Char]
"Not implemented for Tensor-type"

instance {-# OVERLAPPING #-} TensorLike a => TensorLike (a, a) where
  asTensor :: (a, a) -> Tensor
asTensor (a
a, a
b) = forall a. TensorLike a => a -> Tensor
asTensor [a
a, a
b]
  _asValue :: Tensor -> (a, a)
_asValue Tensor
v =
    let [a
a, a
b] = forall a. TensorLike a => Tensor -> a
_asValue Tensor
v
     in (a
a, a
b)
  _dtype :: DType
_dtype = forall a. HasCallStack => [Char] -> a
error [Char]
"Not implemented for tuple-type"
  _dims :: (a, a) -> [Int]
_dims (a, a)
v = forall a. HasCallStack => [Char] -> a
error [Char]
"Not implemented for tuple-type"
  _deepDims :: (a, a) -> Maybe [Int]
_deepDims (a, a)
v = forall a. HasCallStack => [Char] -> a
error [Char]
"Not implemented for tuple-type"
  _peekElemOff :: Ptr () -> Int -> [Int] -> IO (a, a)
_peekElemOff = forall a. HasCallStack => [Char] -> a
error [Char]
"Not implemented for tuple-type"
  _pokeElemOff :: Ptr () -> Int -> (a, a) -> IO ()
_pokeElemOff = forall a. HasCallStack => [Char] -> a
error [Char]
"Not implemented for tuple-type"

instance {-# OVERLAPPING #-} TensorLike a => TensorLike [a] where
  asTensor :: [a] -> Tensor
asTensor [a]
v = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
    Tensor
t <- ((forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 [Int] -> ForeignPtr TensorOptions -> IO (ForeignPtr Tensor)
ATen.new_empty_tensor) :: [Int] -> TensorOptions -> IO Tensor) (forall a. TensorLike a => a -> [Int]
_dims [a]
v) forall a b. (a -> b) -> a -> b
$ DType -> TensorOptions -> TensorOptions
withDType (forall a. TensorLike a => DType
_dtype @a) TensorOptions
defaultOpts
    forall a. Tensor -> (Ptr () -> IO a) -> IO a
_withTensor Tensor
t forall a b. (a -> b) -> a -> b
$ \Ptr ()
ptr -> do
      forall a. TensorLike a => Ptr () -> Int -> a -> IO ()
_pokeElemOff Ptr ()
ptr Int
0 [a]
v
    forall (m :: * -> *) a. Monad m => a -> m a
return Tensor
t

  _asValue :: Tensor -> [a]
_asValue Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
    if forall a. TensorLike a => DType
_dtype @a forall a. Eq a => a -> a -> Bool
== Tensor -> DType
dtype Tensor
t
      then do
        forall a. Tensor -> (Ptr () -> IO a) -> IO a
withTensor Tensor
t forall a b. (a -> b) -> a -> b
$ \Ptr ()
ptr -> do
          forall a. TensorLike a => Ptr () -> Int -> [Int] -> IO a
_peekElemOff Ptr ()
ptr Int
0 (Tensor -> [Int]
shape Tensor
t)
      else forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO forall a b. (a -> b) -> a -> b
$ [Char] -> IOError
userError forall a b. (a -> b) -> a -> b
$ [Char]
"The infered DType of asValue is " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show (forall a. TensorLike a => DType
_dtype @a) forall a. [a] -> [a] -> [a]
++ [Char]
", but the DType of tensor on memory is " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show (Tensor -> DType
dtype Tensor
t) forall a. [a] -> [a] -> [a]
++ [Char]
"."

  _dtype :: DType
_dtype = forall a. TensorLike a => DType
_dtype @a

  _dims :: [a] -> [Int]
_dims [] = [Int
0]
  _dims v :: [a]
v@(a
x : [a]
_) = (forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
v) forall a. a -> [a] -> [a]
: (forall a. TensorLike a => a -> [Int]
_dims a
x)

  _deepDims :: [a] -> Maybe [Int]
_deepDims [] = forall a. a -> Maybe a
Just [Int
0]
  _deepDims v :: [a]
v@(a
x : [a]
xs) = do
    [Int]
deepDimsX <- forall a. TensorLike a => a -> Maybe [Int]
_deepDims a
x
    [[Int]]
deepDimsXs <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall a. TensorLike a => a -> Maybe [Int]
_deepDims [a]
xs
    if forall (t :: * -> *). Foldable t => t Bool -> Bool
and forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([Int]
deepDimsX forall a. Eq a => a -> a -> Bool
==) [[Int]]
deepDimsXs
      then forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
v forall a. a -> [a] -> [a]
: [Int]
deepDimsX
      else forall a. Maybe a
Nothing

  _peekElemOff :: Ptr () -> Int -> [Int] -> IO [a]
_peekElemOff Ptr ()
ptr Int
offset [] = forall (m :: * -> *) a. Monad m => a -> m a
return []
  _peekElemOff Ptr ()
ptr Int
offset (Int
d : [Int]
dims) =
    let width :: Int
width = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Int]
dims
     in forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int
0 .. (Int
d forall a. Num a => a -> a -> a
-Int
1)] forall a b. (a -> b) -> a -> b
$ \Int
i ->
          forall a. TensorLike a => Ptr () -> Int -> [Int] -> IO a
_peekElemOff Ptr ()
ptr (Int
offset forall a. Num a => a -> a -> a
+ Int
i forall a. Num a => a -> a -> a
* Int
width) [Int]
dims

  _pokeElemOff :: Ptr () -> Int -> [a] -> IO ()
_pokeElemOff Ptr ()
ptr Int
offset [] = forall (m :: * -> *) a. Monad m => a -> m a
return ()
  _pokeElemOff Ptr ()
ptr Int
offset v :: [a]
v@(a
x : [a]
_) =
    let width :: Int
width = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product (forall a. TensorLike a => a -> [Int]
_dims a
x)
     in forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0 ..] [a]
v) forall a b. (a -> b) -> a -> b
$ \(Int
i, a
d) ->
          if forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product (forall a. TensorLike a => a -> [Int]
_dims a
d) forall a. Eq a => a -> a -> Bool
== Int
width -- This validation may be slow.
            then (forall a. TensorLike a => Ptr () -> Int -> a -> IO ()
_pokeElemOff @a) Ptr ()
ptr (Int
offset forall a. Num a => a -> a -> a
+ Int
i forall a. Num a => a -> a -> a
* Int
width) a
d
            else forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO forall a b. (a -> b) -> a -> b
$ [Char] -> IOError
userError forall a b. (a -> b) -> a -> b
$ [Char]
"There are lists having different length."

class AsTensors as where
  toTensors :: as -> V.Vector Tensor
  default toTensors :: (Generic as, GAsTensors (Rep as)) => as -> V.Vector Tensor
  toTensors as
a = forall (record :: * -> *) as.
GAsTensors record =>
record as -> Vector Tensor
gToTensors forall a b. (a -> b) -> a -> b
$ forall a x. Generic a => a -> Rep a x
from as
a

instance TensorLike a => AsTensors a where
  toTensors :: a -> Vector Tensor
toTensors = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. TensorLike a => a -> Tensor
asTensor

class GAsTensors record where
  gToTensors :: record as -> V.Vector Tensor

instance (GAsTensors ls, GAsTensors rs) => GAsTensors (ls :*: rs) where
  gToTensors :: forall as. (:*:) ls rs as -> Vector Tensor
gToTensors (ls as
g :*: rs as
d) = forall (record :: * -> *) as.
GAsTensors record =>
record as -> Vector Tensor
gToTensors ls as
g forall a. Vector a -> Vector a -> Vector a
V.++ forall (record :: * -> *) as.
GAsTensors record =>
record as -> Vector Tensor
gToTensors rs as
d

instance (GAsTensors ls, GAsTensors rs) => GAsTensors (ls :+: rs) where
  gToTensors :: forall as. (:+:) ls rs as -> Vector Tensor
gToTensors (L1 ls as
g) = forall (record :: * -> *) as.
GAsTensors record =>
record as -> Vector Tensor
gToTensors ls as
g
  gToTensors (R1 rs as
g) = forall (record :: * -> *) as.
GAsTensors record =>
record as -> Vector Tensor
gToTensors rs as
g

instance (GAsTensors ls) => GAsTensors (M1 i c ls) where
  gToTensors :: forall as. M1 i c ls as -> Vector Tensor
gToTensors (M1 ls as
g) = forall (record :: * -> *) as.
GAsTensors record =>
record as -> Vector Tensor
gToTensors ls as
g

instance (TensorLike ls) => GAsTensors (K1 i ls) where
  gToTensors :: forall as. K1 i ls as -> Vector Tensor
gToTensors (K1 ls
g) = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. TensorLike a => a -> Tensor
asTensor ls
g

--------------------------------------------------------------------------------
-- Show
--------------------------------------------------------------------------------

instance Show Tensor where
  show :: Tensor -> [Char]
show Tensor
t' =
    case (Tensor -> Int
dim Tensor
t) of
      Int
0 -> [Char]
details forall a. [a] -> [a] -> [a]
++ Tensor -> [Char]
show0d Tensor
t
      Int
1 -> [Char]
details forall a. [a] -> [a] -> [a]
++ Tensor -> [Char]
show1d Tensor
t
      Int
n -> [Char]
details forall a. [a] -> [a] -> [a]
++ Int -> Int -> Tensor -> [Char]
shownd Int
n Int
0 Tensor
t
    where
      t :: Tensor
t = if Tensor -> Device
device Tensor
t' forall a. Eq a => a -> a -> Bool
== DeviceType -> Int16 -> Device
Device DeviceType
CPU Int16
0 then Tensor
t' else Tensor -> Tensor
toCPU Tensor
t'
      -- TODO: this is obviously not the right way to do it,
      -- and will be terribly slow, so please fix it.
      showElems :: (Tensor -> [Char]) -> [Char] -> Tensor -> [Char]
showElems Tensor -> [Char]
elemShow [Char]
sep Tensor
t = [Char]
"[" forall a. [a] -> [a] -> [a]
++ (forall a. [a] -> [[a]] -> [a]
intercalate [Char]
sep forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map Tensor -> [Char]
elemShow [Tensor
t forall a. TensorIndex a => Tensor -> a -> Tensor
! Int
i | Int
i <- [Int
0 .. ((Int -> Tensor -> Int
size Int
0 Tensor
t) forall a. Num a => a -> a -> a
- Int
1)]]) forall a. [a] -> [a] -> [a]
++ [Char]
"]"
      padPositive :: a -> ShowS
padPositive a
x [Char]
s = if a
x forall a. Ord a => a -> a -> Bool
>= a
0 then [Char]
" " forall a. [a] -> [a] -> [a]
++ [Char]
s else [Char]
s
      -- TODO: this assumes that scientific notation only uses one-digit exponents, which is not
      --       true in general
      padLarge :: a -> ShowS
padLarge a
x [Char]
s = if (forall a. Num a => a -> a
abs a
x) forall a. Ord a => a -> a -> Bool
>= a
0.1 then [Char]
s forall a. [a] -> [a] -> [a]
++ [Char]
"   " else [Char]
s
      show0d :: Tensor -> [Char]
show0d Tensor
x =
        if DType -> Bool
isIntegral (Tensor -> DType
dtype Tensor
t)
          then forall {a}. (Ord a, Num a) => a -> ShowS
padPositive (Tensor -> Int
toInt Tensor
x) forall a b. (a -> b) -> a -> b
$ forall a. Show a => a -> [Char]
show forall a b. (a -> b) -> a -> b
$ Tensor -> Int
toInt Tensor
x
          else
            if DType -> Bool
isComplex (Tensor -> DType
dtype Tensor
t)
               then
                 let Double
r :+ Double
i = Tensor -> Complex Double
toComplex Tensor
x
                 in (forall {a}. (Ord a, Fractional a) => a -> ShowS
padLarge Double
r forall a b. (a -> b) -> a -> b
$ forall {a}. (Ord a, Num a) => a -> ShowS
padPositive Double
r forall a b. (a -> b) -> a -> b
$ forall a. RealFloat a => Maybe Int -> a -> ShowS
showGFloat (forall a. a -> Maybe a
Just Int
4) Double
r [Char]
"") forall a. [a] -> [a] -> [a]
++ [Char]
" + i" forall a. [a] -> [a] -> [a]
++
                    (forall {a}. (Ord a, Fractional a) => a -> ShowS
padLarge Double
i forall a b. (a -> b) -> a -> b
$ forall {a}. (Ord a, Num a) => a -> ShowS
padPositive Double
i forall a b. (a -> b) -> a -> b
$ forall a. RealFloat a => Maybe Int -> a -> ShowS
showGFloat (forall a. a -> Maybe a
Just Int
4) Double
i [Char]
"")
               else forall {a}. (Ord a, Fractional a) => a -> ShowS
padLarge (Tensor -> Double
toDouble Tensor
x) forall a b. (a -> b) -> a -> b
$ forall {a}. (Ord a, Num a) => a -> ShowS
padPositive (Tensor -> Double
toDouble Tensor
x) forall a b. (a -> b) -> a -> b
$ forall a. RealFloat a => Maybe Int -> a -> ShowS
showGFloat (forall a. a -> Maybe a
Just Int
4) (Tensor -> Double
toDouble Tensor
x) [Char]
""
      show1d :: Tensor -> [Char]
show1d = (Tensor -> [Char]) -> [Char] -> Tensor -> [Char]
showElems Tensor -> [Char]
show0d [Char]
", "
      shownd :: Int -> Int -> Tensor -> [Char]
shownd Int
n Int
offset =
        case Int
n of
          Int
2 -> (Tensor -> [Char]) -> [Char] -> Tensor -> [Char]
showElems Tensor -> [Char]
show1d ([Char]
",\n " forall a. [a] -> [a] -> [a]
++ [Char]
padding forall a. [a] -> [a] -> [a]
++ forall a. Int -> a -> [a]
replicate Int
offset Char
' ')
          Int
_ -> (Tensor -> [Char]) -> [Char] -> Tensor -> [Char]
showElems (Int -> Int -> Tensor -> [Char]
shownd (Int
n forall a. Num a => a -> a -> a
-Int
1) (Int
offset forall a. Num a => a -> a -> a
+ Int
1)) ([Char]
",\n " forall a. [a] -> [a] -> [a]
++ [Char]
padding forall a. [a] -> [a] -> [a]
++ forall a. Int -> a -> [a]
replicate Int
offset Char
' ')
      details :: [Char]
details = [Char]
"Tensor " forall a. [a] -> [a] -> [a]
++ (forall a. Show a => a -> [Char]
show forall a b. (a -> b) -> a -> b
$ Tensor -> DType
dtype Tensor
t) forall a. [a] -> [a] -> [a]
++ [Char]
" " forall a. [a] -> [a] -> [a]
++ (forall a. Show a => a -> [Char]
show forall a b. (a -> b) -> a -> b
$ Tensor -> [Int]
shape Tensor
t) forall a. [a] -> [a] -> [a]
++ [Char]
" "
      padding :: [Char]
padding = forall a b. (a -> b) -> [a] -> [b]
map (forall a b. a -> b -> a
const Char
' ') [Char]
details

--------------------------------------------------------------------------------

-- Castable instances
--------------------------------------------------------------------------------

-- NB: ATen only defines Castable [ForeignPtr ATen.Tensor] (ForeignPtr ATen.TensorList)
instance Castable [Tensor] (ForeignPtr ATen.TensorList) where
  cast :: forall r. [Tensor] -> (ForeignPtr TensorList -> IO r) -> IO r
cast [Tensor]
xs ForeignPtr TensorList -> IO r
f = do
    [ForeignPtr Tensor]
ptr_list <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\Tensor
x -> (forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast Tensor
x forall (m :: * -> *) a. Monad m => a -> m a
return :: IO (ForeignPtr ATen.Tensor))) [Tensor]
xs
    forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast [ForeignPtr Tensor]
ptr_list ForeignPtr TensorList -> IO r
f
  uncast :: forall r. ForeignPtr TensorList -> ([Tensor] -> IO r) -> IO r
uncast ForeignPtr TensorList
xs [Tensor] -> IO r
f = forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr TensorList
xs forall a b. (a -> b) -> a -> b
$ \[ForeignPtr Tensor]
ptr_list -> do
    [Tensor]
tensor_list <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\(ForeignPtr Tensor
x :: ForeignPtr ATen.Tensor) -> forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr Tensor
x forall (m :: * -> *) a. Monad m => a -> m a
return) [ForeignPtr Tensor]
ptr_list
    [Tensor] -> IO r
f [Tensor]
tensor_list

instance Castable [Tensor] (ForeignPtr (ATen.C10List ATen.Tensor)) where
  cast :: forall r. [Tensor] -> (ForeignPtr (C10List Tensor) -> IO r) -> IO r
cast [Tensor]
xs ForeignPtr (C10List Tensor) -> IO r
f = do
    [ForeignPtr Tensor]
ptr_list <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\Tensor
x -> (forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast Tensor
x forall (m :: * -> *) a. Monad m => a -> m a
return :: IO (ForeignPtr ATen.Tensor))) [Tensor]
xs
    forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast [ForeignPtr Tensor]
ptr_list ForeignPtr (C10List Tensor) -> IO r
f
  uncast :: forall r. ForeignPtr (C10List Tensor) -> ([Tensor] -> IO r) -> IO r
uncast ForeignPtr (C10List Tensor)
xs [Tensor] -> IO r
f = forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr (C10List Tensor)
xs forall a b. (a -> b) -> a -> b
$ \[ForeignPtr Tensor]
ptr_list -> do
    [Tensor]
tensor_list <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\(ForeignPtr Tensor
x :: ForeignPtr ATen.Tensor) -> forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr Tensor
x forall (m :: * -> *) a. Monad m => a -> m a
return) [ForeignPtr Tensor]
ptr_list
    [Tensor] -> IO r
f [Tensor]
tensor_list

instance Castable [Tensor] (ForeignPtr (ATen.C10List (ATen.C10Optional ATen.Tensor))) where
  cast :: forall r.
[Tensor]
-> (ForeignPtr (C10List (C10Optional Tensor)) -> IO r) -> IO r
cast [Tensor]
xs ForeignPtr (C10List (C10Optional Tensor)) -> IO r
f = do
    [ForeignPtr Tensor]
ptr_list <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\Tensor
x -> (forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast Tensor
x forall (m :: * -> *) a. Monad m => a -> m a
return :: IO (ForeignPtr ATen.Tensor))) [Tensor]
xs
    forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast [ForeignPtr Tensor]
ptr_list ForeignPtr (C10List (C10Optional Tensor)) -> IO r
f
  uncast :: forall r.
ForeignPtr (C10List (C10Optional Tensor))
-> ([Tensor] -> IO r) -> IO r
uncast ForeignPtr (C10List (C10Optional Tensor))
xs [Tensor] -> IO r
f = forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr (C10List (C10Optional Tensor))
xs forall a b. (a -> b) -> a -> b
$ \[ForeignPtr Tensor]
ptr_list -> do
    [Tensor]
tensor_list <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\(ForeignPtr Tensor
x :: ForeignPtr ATen.Tensor) -> forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr Tensor
x forall (m :: * -> *) a. Monad m => a -> m a
return) [ForeignPtr Tensor]
ptr_list
    [Tensor] -> IO r
f [Tensor]
tensor_list