{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeSynonymInstances #-}
{-# LANGUAGE UndecidableInstances #-}
module Torch.Tensor where
import Control.Exception.Safe (throwIO)
import Control.Monad (forM, forM_)
import Numeric.Half
import Data.Complex
import Data.Int (Int16, Int64)
import Data.List (intercalate)
import Data.Proxy
import Data.Reflection
import qualified Data.Vector as V
import Data.Word (Word8)
import Foreign.C.Types
import Foreign.ForeignPtr
import Foreign.Ptr
import Foreign.Storable
import GHC.Generics
import Numeric
import System.IO.Unsafe
import Torch.DType
import Torch.Device
import Torch.Internal.Cast
import Torch.Internal.Class (Castable (..), CppTuple2 (..), CppTuple3 (..), CppTuple4 (..))
import qualified Torch.Internal.Const as ATen
import qualified Torch.Internal.Managed.Cast as ATen
import qualified Torch.Internal.Managed.Native as ATen
import qualified Torch.Internal.Managed.TensorFactories as LibTorch
import qualified Torch.Internal.Managed.Type.Context as ATen
import qualified Torch.Internal.Managed.Type.StdArray as ATen
import qualified Torch.Internal.Managed.Type.StdString as ATen
import qualified Torch.Internal.Managed.Type.Tensor as ATen
import qualified Torch.Internal.Managed.Type.TensorIndex as ATen
import qualified Torch.Internal.Managed.Type.TensorOptions as ATen
import qualified Torch.Internal.Managed.Type.Extra as ATen
import qualified Torch.Internal.Type as ATen
import qualified Torch.Internal.Unmanaged.Type.Tensor as Unmanaged (tensor_data_ptr)
import Torch.Lens
import Torch.TensorOptions
type ATenTensor = ForeignPtr ATen.Tensor
newtype Tensor = Unsafe ATenTensor
instance Castable Tensor ATenTensor where
cast :: forall r. Tensor -> (ForeignPtr Tensor -> IO r) -> IO r
cast (Unsafe ForeignPtr Tensor
aten_tensor) ForeignPtr Tensor -> IO r
f = ForeignPtr Tensor -> IO r
f ForeignPtr Tensor
aten_tensor
uncast :: forall r. ForeignPtr Tensor -> (Tensor -> IO r) -> IO r
uncast ForeignPtr Tensor
aten_tensor Tensor -> IO r
f = Tensor -> IO r
f forall a b. (a -> b) -> a -> b
$ ForeignPtr Tensor -> Tensor
Unsafe ForeignPtr Tensor
aten_tensor
newtype MutableTensor = MutableTensor Tensor deriving Int -> MutableTensor -> ShowS
[MutableTensor] -> ShowS
MutableTensor -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [MutableTensor] -> ShowS
$cshowList :: [MutableTensor] -> ShowS
show :: MutableTensor -> [Char]
$cshow :: MutableTensor -> [Char]
showsPrec :: Int -> MutableTensor -> ShowS
$cshowsPrec :: Int -> MutableTensor -> ShowS
Show
newMutableTensor :: Tensor -> IO MutableTensor
newMutableTensor :: Tensor -> IO MutableTensor
newMutableTensor Tensor
tensor = Tensor -> MutableTensor
MutableTensor forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.detach_t Tensor
tensor
toImmutable :: MutableTensor -> IO Tensor
toImmutable :: MutableTensor -> IO Tensor
toImmutable (MutableTensor Tensor
tensor) = forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.detach_t Tensor
tensor
numel ::
Tensor ->
Int
numel :: Tensor -> Int
numel Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO Int64
ATen.tensor_numel forall a b. (a -> b) -> a -> b
$ Tensor
t
size ::
Int ->
Tensor ->
Int
size :: Int -> Tensor -> Int
size Int
dim Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ (forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> Int64 -> IO Int64
ATen.tensor_size_l) Tensor
t Int
dim
shape ::
Tensor ->
[Int]
shape :: Tensor -> [Int]
shape Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr IntArray)
ATen.tensor_sizes) Tensor
t
dim ::
Tensor ->
Int
dim :: Tensor -> Int
dim Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO Int64
ATen.tensor_dim) Tensor
t
dimUnsafe ::
Tensor ->
Int
dimUnsafe :: Tensor -> Int
dimUnsafe Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO Int64
ATen.tensor_dim_unsafe) Tensor
t
dimCUnsafe ::
Tensor ->
Int
dimCUnsafe :: Tensor -> Int
dimCUnsafe Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO Int64
ATen.tensor_dim_c_unsafe) Tensor
t
device ::
Tensor ->
Device
device :: Tensor -> Device
device Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
Bool
hasCUDA <- forall a ca. Castable a ca => IO ca -> IO a
cast0 IO CBool
ATen.hasCUDA :: IO Bool
if Bool
hasCUDA
then do
Bool
isCUDA <- forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO CBool
ATen.tensor_is_cuda Tensor
t :: IO Bool
if Bool
isCUDA then Int -> Device
cuda forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO Int64
ATen.tensor_get_device Tensor
t else forall (f :: * -> *) a. Applicative f => a -> f a
pure Device
cpu
else forall (f :: * -> *) a. Applicative f => a -> f a
pure Device
cpu
where
cpu :: Device
cpu = Device {deviceType :: DeviceType
deviceType = DeviceType
CPU, deviceIndex :: Int16
deviceIndex = Int16
0}
cuda :: Int -> Device
cuda :: Int -> Device
cuda Int
di = Device {deviceType :: DeviceType
deviceType = DeviceType
CUDA, deviceIndex :: Int16
deviceIndex = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
di}
dtype ::
Tensor ->
DType
dtype :: Tensor -> DType
dtype Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO ScalarType
ATen.tensor_scalar_type Tensor
t
toComplex :: Tensor -> Complex Double
toComplex :: Tensor -> Complex Double
toComplex Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
case Tensor -> DType
dtype Tensor
t of
DType
ComplexHalf -> do
Half
r :+ Half
i <- forall a. Tensor -> (Ptr () -> IO a) -> IO a
withTensor Tensor
t forall a b. (a -> b) -> a -> b
$ \Ptr ()
ptr -> forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
ptr) Int
0 :: IO (Complex Half)
forall (m :: * -> *) a. Monad m => a -> m a
return (forall a b. (Real a, Fractional b) => a -> b
realToFrac Half
r forall a. a -> a -> Complex a
:+ forall a b. (Real a, Fractional b) => a -> b
realToFrac Half
i)
DType
ComplexFloat -> do
Float
r :+ Float
i <- forall a. Tensor -> (Ptr () -> IO a) -> IO a
withTensor Tensor
t forall a b. (a -> b) -> a -> b
$ \Ptr ()
ptr -> forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
ptr) Int
0 :: IO (Complex Float)
forall (m :: * -> *) a. Monad m => a -> m a
return (forall a b. (Real a, Fractional b) => a -> b
realToFrac Float
r forall a. a -> a -> Complex a
:+ forall a b. (Real a, Fractional b) => a -> b
realToFrac Float
i)
DType
ComplexDouble -> forall a. Tensor -> (Ptr () -> IO a) -> IO a
withTensor Tensor
t forall a b. (a -> b) -> a -> b
$ \Ptr ()
ptr -> forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
ptr) Int
0 :: IO (Complex Double)
DType
_ -> (forall a. a -> a -> Complex a
:+ Double
0) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO CDouble
ATen.tensor_item_double Tensor
t
toDouble :: Tensor -> Double
toDouble :: Tensor -> Double
toDouble Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO CDouble
ATen.tensor_item_double Tensor
t
toInt :: Tensor -> Int
toInt :: Tensor -> Int
toInt Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO Int64
ATen.tensor_item_int64_t Tensor
t
_toType ::
DType ->
Tensor ->
Tensor
_toType :: DType -> Tensor -> Tensor
_toType DType
dtype Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ScalarType -> IO (ForeignPtr Tensor)
ATen.tensor_toType_s Tensor
t DType
dtype
instance HasTypes Tensor Tensor where
types_ :: Traversal' Tensor Tensor
types_ = forall a. a -> a
id
instance HasTypes (a -> a) Tensor where
types_ :: Traversal' (a -> a) Tensor
types_ Tensor -> f Tensor
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure
instance HasTypes Int Tensor where
types_ :: Traversal' Int Tensor
types_ Tensor -> f Tensor
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure
instance HasTypes Double Tensor where
types_ :: Traversal' Double Tensor
types_ Tensor -> f Tensor
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure
instance HasTypes Float Tensor where
types_ :: Traversal' Float Tensor
types_ Tensor -> f Tensor
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure
instance HasTypes Bool Tensor where
types_ :: Traversal' Bool Tensor
types_ Tensor -> f Tensor
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure
instance HasTypes Int Int where
types_ :: Traversal' Int Int
types_ = forall a. a -> a
id
instance HasTypes Float Float where
types_ :: Traversal' Float Float
types_ = forall a. a -> a
id
instance HasTypes Double Double where
types_ :: Traversal' Double Double
types_ = forall a. a -> a
id
instance HasTypes Bool Bool where
types_ :: Traversal' Bool Bool
types_ = forall a. a -> a
id
toType :: forall a. HasTypes a Tensor => DType -> a -> a
toType :: forall a. HasTypes a Tensor => DType -> a -> a
toType DType
dtype a
t = forall s a. Traversal' s a -> (a -> a) -> s -> s
over (forall a s. HasTypes s a => Traversal' s a
types @Tensor @a) (DType -> Tensor -> Tensor
_toType DType
dtype) a
t
toDevice :: forall a. HasTypes a Tensor => Device -> a -> a
toDevice :: forall a. HasTypes a Tensor => Device -> a -> a
toDevice Device
device' a
t = forall s a. Traversal' s a -> (a -> a) -> s -> s
over (forall a s. HasTypes s a => Traversal' s a
types @Tensor @a) (Device -> Tensor -> Tensor
_toDevice Device
device') a
t
_toDevice ::
Device ->
Tensor ->
Tensor
_toDevice :: Device -> Tensor -> Tensor
_toDevice Device
device' Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
Bool
hasCUDA <- forall a ca. Castable a ca => IO ca -> IO a
cast0 IO CBool
ATen.hasCUDA :: IO Bool
let device :: Device
device = Tensor -> Device
Torch.Tensor.device Tensor
t
Tensor
t' <-
DeviceType -> DeviceType -> Int16 -> Int16 -> Bool -> IO Tensor
toDevice'
(Device -> DeviceType
deviceType Device
device)
(Device -> DeviceType
deviceType Device
device')
(Device -> Int16
deviceIndex Device
device)
(Device -> Int16
deviceIndex Device
device')
Bool
hasCUDA
forall {a} {a} {f :: * -> *}.
(Eq a, Eq a, Applicative f, Show a, Show a) =>
a -> a -> a -> a -> f ()
check
(Device -> DeviceType
deviceType Device
device')
(Device -> DeviceType
deviceType forall a b. (a -> b) -> a -> b
$ Tensor -> Device
Torch.Tensor.device Tensor
t')
(Device -> Int16
deviceIndex Device
device')
(Device -> Int16
deviceIndex forall a b. (a -> b) -> a -> b
$ Tensor -> Device
Torch.Tensor.device Tensor
t')
forall (f :: * -> *) a. Applicative f => a -> f a
pure Tensor
t'
where
toDevice' :: DeviceType -> DeviceType -> Int16 -> Int16 -> Bool -> IO Tensor
toDevice' DeviceType
dt DeviceType
dt' Int16
di Int16
di' Bool
_ | DeviceType
dt forall a. Eq a => a -> a -> Bool
== DeviceType
dt' Bool -> Bool -> Bool
&& Int16
di forall a. Eq a => a -> a -> Bool
== Int16
di' = forall (f :: * -> *) a. Applicative f => a -> f a
pure Tensor
t
toDevice' DeviceType
CUDA DeviceType
CUDA Int16
di Int16
di' Bool
True | Int16
di forall a. Eq a => a -> a -> Bool
/= Int16
di' = Tensor -> IO TensorOptions
getOpts Tensor
t forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Int16 -> TensorOptions -> IO TensorOptions
withDeviceIndex Int16
di' forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Tensor -> TensorOptions -> IO Tensor
to Tensor
t
toDevice' DeviceType
CPU DeviceType
CUDA Int16
0 Int16
di' Bool
True | Int16
di' forall a. Ord a => a -> a -> Bool
>= Int16
0 = Tensor -> IO TensorOptions
getOpts Tensor
t forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Int16 -> TensorOptions -> IO TensorOptions
withDeviceIndex Int16
di' forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Tensor -> TensorOptions -> IO Tensor
to Tensor
t
toDevice' DeviceType
CUDA DeviceType
CPU Int16
di Int16
0 Bool
True | Int16
di forall a. Ord a => a -> a -> Bool
>= Int16
0 = Tensor -> IO TensorOptions
getOpts Tensor
t forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= DeviceType -> TensorOptions -> IO TensorOptions
withDeviceType DeviceType
CPU forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Tensor -> TensorOptions -> IO Tensor
to Tensor
t
toDevice' DeviceType
dt DeviceType
dt' Int16
di Int16
di' Bool
_ =
forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$
[Char]
"cannot move tensor from \""
forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> [Char]
show DeviceType
dt
forall a. Semigroup a => a -> a -> a
<> [Char]
":"
forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> [Char]
show Int16
di
forall a. Semigroup a => a -> a -> a
<> [Char]
"\" to \""
forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> [Char]
show DeviceType
dt'
forall a. Semigroup a => a -> a -> a
<> [Char]
":"
forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> [Char]
show Int16
di'
forall a. Semigroup a => a -> a -> a
<> [Char]
"\""
getOpts :: Tensor -> IO TensorOptions
getOpts :: Tensor -> IO TensorOptions
getOpts = forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr TensorOptions)
ATen.tensor_options
withDeviceType :: DeviceType -> TensorOptions -> IO TensorOptions
withDeviceType :: DeviceType -> TensorOptions -> IO TensorOptions
withDeviceType DeviceType
dt TensorOptions
opts = forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr TensorOptions -> Int16 -> IO (ForeignPtr TensorOptions)
ATen.tensorOptions_device_D TensorOptions
opts DeviceType
dt
withDeviceIndex :: Int16 -> TensorOptions -> IO TensorOptions
withDeviceIndex :: Int16 -> TensorOptions -> IO TensorOptions
withDeviceIndex Int16
di TensorOptions
opts = forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr TensorOptions -> Int16 -> IO (ForeignPtr TensorOptions)
ATen.tensorOptions_device_index_s TensorOptions
opts Int16
di
to :: Tensor -> TensorOptions -> IO Tensor
to :: Tensor -> TensorOptions -> IO Tensor
to Tensor
t TensorOptions
opts = forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
cast4 ForeignPtr Tensor
-> ForeignPtr TensorOptions
-> CBool
-> CBool
-> IO (ForeignPtr Tensor)
ATen.tensor_to_obb Tensor
t TensorOptions
opts Bool
nonBlocking Bool
copy
where
nonBlocking :: Bool
nonBlocking = Bool
False
copy :: Bool
copy = Bool
False
check :: a -> a -> a -> a -> f ()
check a
dt a
dt' a
di a
di' | a
dt forall a. Eq a => a -> a -> Bool
== a
dt' Bool -> Bool -> Bool
&& a
di forall a. Eq a => a -> a -> Bool
== a
di' = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
check a
dt a
dt' a
di a
di' =
forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$
[Char]
"moving of tensor failed: device should have been \""
forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> [Char]
show a
dt
forall a. Semigroup a => a -> a -> a
<> [Char]
":"
forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> [Char]
show a
di
forall a. Semigroup a => a -> a -> a
<> [Char]
"\" but is \""
forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> [Char]
show a
dt'
forall a. Semigroup a => a -> a -> a
<> [Char]
":"
forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> [Char]
show a
di'
forall a. Semigroup a => a -> a -> a
<> [Char]
"\""
toDeviceWithTensor :: Tensor -> Tensor -> Tensor
toDeviceWithTensor :: Tensor -> Tensor -> Tensor
toDeviceWithTensor Tensor
reference Tensor
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.tensor_to_device Tensor
reference Tensor
input
select ::
Int ->
Int ->
Tensor ->
Tensor
select :: Int -> Int -> Tensor -> Tensor
select Int
dim Int
idx Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Tensor -> Int64 -> Int64 -> IO (ForeignPtr Tensor)
ATen.tensor_select_ll Tensor
t Int
dim Int
idx
indexSelect ::
Int ->
Tensor ->
Tensor ->
Tensor
indexSelect :: Int -> Tensor -> Tensor -> Tensor
indexSelect Int
dim Tensor
indexTensor Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ (forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Tensor
-> Int64 -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.index_select_tlt) Tensor
t Int
dim Tensor
indexTensor
indexSelect' ::
Int ->
[Int] ->
Tensor ->
Tensor
indexSelect' :: Int -> [Int] -> Tensor -> Tensor
indexSelect' Int
dim [Int]
indexList Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ (forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Tensor
-> Int64 -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.index_select_tlt) Tensor
t Int
dim (forall a opt.
(TensorLike a, TensorOptionLike opt) =>
a -> opt -> Tensor
asTensor' [Int]
indexList Tensor
t)
sliceDim ::
Int ->
Int ->
Int ->
Int ->
Tensor ->
Tensor
sliceDim :: Int -> Int -> Int -> Int -> Tensor -> Tensor
sliceDim Int
_dim Int
_start Int
_end Int
_step Tensor
_self = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ (forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable x4 cx4, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> IO y
cast5 ForeignPtr Tensor
-> Int64 -> Int64 -> Int64 -> Int64 -> IO (ForeignPtr Tensor)
ATen.slice_tllll) Tensor
_self Int
_dim Int
_start Int
_end Int
_step
isContiguous ::
Tensor ->
Bool
isContiguous :: Tensor -> Bool
isContiguous Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO CBool
ATen.tensor_is_contiguous) Tensor
t
contiguous ::
Tensor ->
Tensor
contiguous :: Tensor -> Tensor
contiguous Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.tensor_contiguous) Tensor
t
reshape ::
[Int] ->
Tensor ->
Tensor
reshape :: [Int] -> Tensor -> Tensor
reshape [Int]
shape Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr IntArray -> IO (ForeignPtr Tensor)
ATen.reshape_tl Tensor
t [Int]
shape
toSparse :: Tensor -> Tensor
toSparse :: Tensor -> Tensor
toSparse Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.tensor_to_sparse) Tensor
t
toDense :: Tensor -> Tensor
toDense :: Tensor -> Tensor
toDense Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.tensor_to_dense) Tensor
t
toMKLDNN :: Tensor -> Tensor
toMKLDNN :: Tensor -> Tensor
toMKLDNN Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.tensor_to_mkldnn) Tensor
t
toCPU :: Tensor -> Tensor
toCPU :: Tensor -> Tensor
toCPU Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.tensor_cpu) Tensor
t
toCUDA :: Tensor -> Tensor
toCUDA :: Tensor -> Tensor
toCUDA Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ (forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.tensor_cuda) Tensor
t
newtype RawTensorIndexList = RawTensorIndexList (ForeignPtr (ATen.StdVector ATen.TensorIndex))
newtype RawTensorIndex = RawTensorIndex (ForeignPtr ATen.TensorIndex)
(!) :: TensorIndex a => Tensor -> a -> Tensor
(Unsafe ForeignPtr Tensor
t) ! :: forall a. TensorIndex a => Tensor -> a -> Tensor
! a
idx = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
let idxs :: [RawTensorIndex]
idxs = forall a.
TensorIndex a =>
[RawTensorIndex] -> a -> [RawTensorIndex]
pushIndex [] a
idx
ForeignPtr (StdVector TensorIndex)
vec <- IO (ForeignPtr (StdVector TensorIndex))
ATen.newTensorIndexList
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [RawTensorIndex]
idxs forall a b. (a -> b) -> a -> b
$ \(RawTensorIndex ForeignPtr TensorIndex
i) -> do
ForeignPtr (StdVector TensorIndex)
-> ForeignPtr TensorIndex -> IO ()
ATen.tensorIndexList_push_back ForeignPtr (StdVector TensorIndex)
vec ForeignPtr TensorIndex
i
ForeignPtr Tensor
-> ForeignPtr (StdVector TensorIndex) -> IO (ForeignPtr Tensor)
ATen.index ForeignPtr Tensor
t ForeignPtr (StdVector TensorIndex)
vec forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. ForeignPtr Tensor -> Tensor
Unsafe)
maskedFill :: (TensorIndex a, TensorLike t) => Tensor -> a -> t -> Tensor
maskedFill :: forall a t.
(TensorIndex a, TensorLike t) =>
Tensor -> a -> t -> Tensor
maskedFill (Unsafe ForeignPtr Tensor
t') a
idx t
v' = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
let idxs :: [RawTensorIndex]
idxs = forall a.
TensorIndex a =>
[RawTensorIndex] -> a -> [RawTensorIndex]
pushIndex [] a
idx
(Unsafe ForeignPtr Tensor
v) = forall a. TensorLike a => a -> Tensor
asTensor t
v'
ForeignPtr Tensor
t <- ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.clone_t ForeignPtr Tensor
t'
ForeignPtr (StdVector TensorIndex)
vec <- IO (ForeignPtr (StdVector TensorIndex))
ATen.newTensorIndexList
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [RawTensorIndex]
idxs forall a b. (a -> b) -> a -> b
$ \(RawTensorIndex ForeignPtr TensorIndex
i) -> do
ForeignPtr (StdVector TensorIndex)
-> ForeignPtr TensorIndex -> IO ()
ATen.tensorIndexList_push_back ForeignPtr (StdVector TensorIndex)
vec ForeignPtr TensorIndex
i
ForeignPtr Tensor
-> ForeignPtr (StdVector TensorIndex)
-> ForeignPtr Tensor
-> IO (ForeignPtr Tensor)
ATen.index_put_ ForeignPtr Tensor
t ForeignPtr (StdVector TensorIndex)
vec ForeignPtr Tensor
v
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ ForeignPtr Tensor -> Tensor
Unsafe ForeignPtr Tensor
t
data None = None
deriving (Int -> None -> ShowS
[None] -> ShowS
None -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [None] -> ShowS
$cshowList :: [None] -> ShowS
show :: None -> [Char]
$cshow :: None -> [Char]
showsPrec :: Int -> None -> ShowS
$cshowsPrec :: Int -> None -> ShowS
Show, None -> None -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: None -> None -> Bool
$c/= :: None -> None -> Bool
== :: None -> None -> Bool
$c== :: None -> None -> Bool
Eq)
data Ellipsis = Ellipsis
deriving (Int -> Ellipsis -> ShowS
[Ellipsis] -> ShowS
Ellipsis -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [Ellipsis] -> ShowS
$cshowList :: [Ellipsis] -> ShowS
show :: Ellipsis -> [Char]
$cshow :: Ellipsis -> [Char]
showsPrec :: Int -> Ellipsis -> ShowS
$cshowsPrec :: Int -> Ellipsis -> ShowS
Show, Ellipsis -> Ellipsis -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Ellipsis -> Ellipsis -> Bool
$c/= :: Ellipsis -> Ellipsis -> Bool
== :: Ellipsis -> Ellipsis -> Bool
$c== :: Ellipsis -> Ellipsis -> Bool
Eq)
newtype Slice a = Slice a
deriving (Int -> Slice a -> ShowS
forall a. Show a => Int -> Slice a -> ShowS
forall a. Show a => [Slice a] -> ShowS
forall a. Show a => Slice a -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [Slice a] -> ShowS
$cshowList :: forall a. Show a => [Slice a] -> ShowS
show :: Slice a -> [Char]
$cshow :: forall a. Show a => Slice a -> [Char]
showsPrec :: Int -> Slice a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> Slice a -> ShowS
Show, Slice a -> Slice a -> Bool
forall a. Eq a => Slice a -> Slice a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Slice a -> Slice a -> Bool
$c/= :: forall a. Eq a => Slice a -> Slice a -> Bool
== :: Slice a -> Slice a -> Bool
$c== :: forall a. Eq a => Slice a -> Slice a -> Bool
Eq)
instance Castable RawTensorIndex (ForeignPtr ATen.TensorIndex) where
cast :: forall r.
RawTensorIndex -> (ForeignPtr TensorIndex -> IO r) -> IO r
cast (RawTensorIndex ForeignPtr TensorIndex
obj) ForeignPtr TensorIndex -> IO r
f = ForeignPtr TensorIndex -> IO r
f ForeignPtr TensorIndex
obj
uncast :: forall r.
ForeignPtr TensorIndex -> (RawTensorIndex -> IO r) -> IO r
uncast ForeignPtr TensorIndex
obj RawTensorIndex -> IO r
f = RawTensorIndex -> IO r
f forall a b. (a -> b) -> a -> b
$ ForeignPtr TensorIndex -> RawTensorIndex
RawTensorIndex ForeignPtr TensorIndex
obj
class TensorIndex a where
pushIndex :: [RawTensorIndex] -> a -> [RawTensorIndex]
toLens :: TensorLike b => a -> Lens' Tensor b
default toLens :: TensorLike b => a -> Lens' Tensor b
toLens a
idx b -> f b
func Tensor
s = forall a t.
(TensorIndex a, TensorLike t) =>
Tensor -> a -> t -> Tensor
maskedFill Tensor
s a
idx forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall a. TensorLike a => a -> Tensor
asTensor forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> b -> f b
func (forall a. TensorLike a => Tensor -> a
asValue (Tensor
s forall a. TensorIndex a => Tensor -> a -> Tensor
! a
idx)))
instance {-# OVERLAPS #-} TensorIndex None where
pushIndex :: [RawTensorIndex] -> None -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec None
_ = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
ForeignPtr TensorIndex
idx <- IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithNone
forall (m :: * -> *) a. Monad m => a -> m a
return ((ForeignPtr TensorIndex -> RawTensorIndex
RawTensorIndex ForeignPtr TensorIndex
idx) forall a. a -> [a] -> [a]
: [RawTensorIndex]
vec)
instance {-# OVERLAPS #-} TensorIndex Ellipsis where
pushIndex :: [RawTensorIndex] -> Ellipsis -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec Ellipsis
_ = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
ForeignPtr TensorIndex
idx <- IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithEllipsis
forall (m :: * -> *) a. Monad m => a -> m a
return ((ForeignPtr TensorIndex -> RawTensorIndex
RawTensorIndex ForeignPtr TensorIndex
idx) forall a. a -> [a] -> [a]
: [RawTensorIndex]
vec)
instance {-# OVERLAPS #-} TensorIndex Bool where
pushIndex :: [RawTensorIndex] -> Bool -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec Bool
b = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
ForeignPtr TensorIndex
idx <- CBool -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithBool (if Bool
b then CBool
1 else CBool
0)
forall (m :: * -> *) a. Monad m => a -> m a
return ((ForeignPtr TensorIndex -> RawTensorIndex
RawTensorIndex ForeignPtr TensorIndex
idx) forall a. a -> [a] -> [a]
: [RawTensorIndex]
vec)
instance {-# OVERLAPS #-} (Integral a) => TensorIndex (Slice (a, a)) where
pushIndex :: [RawTensorIndex] -> Slice (a, a) -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec (Slice (a
start, a
end)) = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
ForeignPtr TensorIndex
idx <- CInt -> CInt -> CInt -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithSlice (forall a b. (Integral a, Num b) => a -> b
fromIntegral a
start :: CInt) (forall a b. (Integral a, Num b) => a -> b
fromIntegral a
end :: CInt) CInt
1
forall (m :: * -> *) a. Monad m => a -> m a
return ((ForeignPtr TensorIndex -> RawTensorIndex
RawTensorIndex ForeignPtr TensorIndex
idx) forall a. a -> [a] -> [a]
: [RawTensorIndex]
vec)
instance {-# OVERLAPS #-} (Integral a) => TensorIndex (Slice (a, a, a)) where
pushIndex :: [RawTensorIndex] -> Slice (a, a, a) -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec (Slice (a
start, a
end, a
step)) = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
ForeignPtr TensorIndex
idx <- CInt -> CInt -> CInt -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithSlice (forall a b. (Integral a, Num b) => a -> b
fromIntegral a
start :: CInt) (forall a b. (Integral a, Num b) => a -> b
fromIntegral a
end :: CInt) (forall a b. (Integral a, Num b) => a -> b
fromIntegral a
step :: CInt)
forall (m :: * -> *) a. Monad m => a -> m a
return ((ForeignPtr TensorIndex -> RawTensorIndex
RawTensorIndex ForeignPtr TensorIndex
idx) forall a. a -> [a] -> [a]
: [RawTensorIndex]
vec)
instance {-# OVERLAPS #-} (Integral a) => TensorIndex (Slice (None, None, a)) where
pushIndex :: [RawTensorIndex] -> Slice (None, None, a) -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec (Slice (None
_, None
_, a
step)) = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
ForeignPtr TensorIndex
idx <- CInt -> CInt -> CInt -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithSlice CInt
0 (forall a. Bounded a => a
maxBound :: CInt) (forall a b. (Integral a, Num b) => a -> b
fromIntegral a
step :: CInt)
forall (m :: * -> *) a. Monad m => a -> m a
return ((ForeignPtr TensorIndex -> RawTensorIndex
RawTensorIndex ForeignPtr TensorIndex
idx) forall a. a -> [a] -> [a]
: [RawTensorIndex]
vec)
instance {-# OVERLAPS #-} (Integral a) => TensorIndex (Slice a) where
pushIndex :: [RawTensorIndex] -> Slice a -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec (Slice a
start) = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
ForeignPtr TensorIndex
idx <- CInt -> CInt -> CInt -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithSlice (forall a b. (Integral a, Num b) => a -> b
fromIntegral a
start :: CInt) (forall a. Bounded a => a
maxBound :: CInt) CInt
1
forall (m :: * -> *) a. Monad m => a -> m a
return ((ForeignPtr TensorIndex -> RawTensorIndex
RawTensorIndex ForeignPtr TensorIndex
idx) forall a. a -> [a] -> [a]
: [RawTensorIndex]
vec)
instance {-# OVERLAPS #-} (Integral a) => TensorIndex (Slice (a, None)) where
pushIndex :: [RawTensorIndex] -> Slice (a, None) -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec (Slice (a
start, None
_)) = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
ForeignPtr TensorIndex
idx <- CInt -> CInt -> CInt -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithSlice (forall a b. (Integral a, Num b) => a -> b
fromIntegral a
start :: CInt) (forall a. Bounded a => a
maxBound :: CInt) CInt
1
forall (m :: * -> *) a. Monad m => a -> m a
return ((ForeignPtr TensorIndex -> RawTensorIndex
RawTensorIndex ForeignPtr TensorIndex
idx) forall a. a -> [a] -> [a]
: [RawTensorIndex]
vec)
instance {-# OVERLAPS #-} (Integral a) => TensorIndex (Slice (a, None, a)) where
pushIndex :: [RawTensorIndex] -> Slice (a, None, a) -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec (Slice (a
start, None
_, a
step)) = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
ForeignPtr TensorIndex
idx <- CInt -> CInt -> CInt -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithSlice (forall a b. (Integral a, Num b) => a -> b
fromIntegral a
start :: CInt) (forall a. Bounded a => a
maxBound :: CInt) (forall a b. (Integral a, Num b) => a -> b
fromIntegral a
step :: CInt)
forall (m :: * -> *) a. Monad m => a -> m a
return ((ForeignPtr TensorIndex -> RawTensorIndex
RawTensorIndex ForeignPtr TensorIndex
idx) forall a. a -> [a] -> [a]
: [RawTensorIndex]
vec)
instance {-# OVERLAPS #-} (Integral a) => TensorIndex (Slice (None, a, a)) where
pushIndex :: [RawTensorIndex] -> Slice (None, a, a) -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec (Slice (None
_, a
end, a
step)) = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
ForeignPtr TensorIndex
idx <- CInt -> CInt -> CInt -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithSlice CInt
0 (forall a b. (Integral a, Num b) => a -> b
fromIntegral a
end :: CInt) (forall a b. (Integral a, Num b) => a -> b
fromIntegral a
step :: CInt)
forall (m :: * -> *) a. Monad m => a -> m a
return ((ForeignPtr TensorIndex -> RawTensorIndex
RawTensorIndex ForeignPtr TensorIndex
idx) forall a. a -> [a] -> [a]
: [RawTensorIndex]
vec)
instance {-# OVERLAPS #-} (Integral a) => TensorIndex (Slice (None, a)) where
pushIndex :: [RawTensorIndex] -> Slice (None, a) -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec (Slice (None
_, a
end)) = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
ForeignPtr TensorIndex
idx <- CInt -> CInt -> CInt -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithSlice CInt
0 (forall a b. (Integral a, Num b) => a -> b
fromIntegral a
end :: CInt) CInt
1
forall (m :: * -> *) a. Monad m => a -> m a
return ((ForeignPtr TensorIndex -> RawTensorIndex
RawTensorIndex ForeignPtr TensorIndex
idx) forall a. a -> [a] -> [a]
: [RawTensorIndex]
vec)
instance {-# OVERLAPS #-} TensorIndex (Slice ()) where
pushIndex :: [RawTensorIndex] -> Slice () -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec (Slice ()) = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
ForeignPtr TensorIndex
idx <- CInt -> CInt -> CInt -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithSlice CInt
0 (forall a. Bounded a => a
maxBound :: CInt) CInt
1
forall (m :: * -> *) a. Monad m => a -> m a
return ((ForeignPtr TensorIndex -> RawTensorIndex
RawTensorIndex ForeignPtr TensorIndex
idx) forall a. a -> [a] -> [a]
: [RawTensorIndex]
vec)
instance TensorIndex Int where
pushIndex :: [RawTensorIndex] -> Int -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec Int
v = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
ForeignPtr TensorIndex
idx <- CInt -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithInt (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
v :: CInt)
forall (m :: * -> *) a. Monad m => a -> m a
return ((ForeignPtr TensorIndex -> RawTensorIndex
RawTensorIndex ForeignPtr TensorIndex
idx) forall a. a -> [a] -> [a]
: [RawTensorIndex]
vec)
instance TensorIndex Integer where
pushIndex :: [RawTensorIndex] -> Integer -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec Integer
v = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
ForeignPtr TensorIndex
idx <- CInt -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithInt (forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
v :: CInt)
forall (m :: * -> *) a. Monad m => a -> m a
return ((ForeignPtr TensorIndex -> RawTensorIndex
RawTensorIndex ForeignPtr TensorIndex
idx) forall a. a -> [a] -> [a]
: [RawTensorIndex]
vec)
instance TensorIndex Tensor where
pushIndex :: [RawTensorIndex] -> Tensor -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec Tensor
v = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
RawTensorIndex
idx <- forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithTensor Tensor
v
forall (m :: * -> *) a. Monad m => a -> m a
return (RawTensorIndex
idx forall a. a -> [a] -> [a]
: [RawTensorIndex]
vec)
instance TensorIndex () where
pushIndex :: [RawTensorIndex] -> () -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec ()
_ = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
ForeignPtr TensorIndex
idx <- CInt -> CInt -> CInt -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithSlice CInt
0 (forall a. Bounded a => a
maxBound :: CInt) CInt
1
forall (m :: * -> *) a. Monad m => a -> m a
return ((ForeignPtr TensorIndex -> RawTensorIndex
RawTensorIndex ForeignPtr TensorIndex
idx) forall a. a -> [a] -> [a]
: [RawTensorIndex]
vec)
instance (TensorIndex a, TensorIndex b) => TensorIndex (a, b) where
pushIndex :: [RawTensorIndex] -> (a, b) -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec (a
a, b
b) = (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a.
TensorIndex a =>
[RawTensorIndex] -> a -> [RawTensorIndex]
pushIndex a
a) forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a.
TensorIndex a =>
[RawTensorIndex] -> a -> [RawTensorIndex]
pushIndex b
b) forall a b. (a -> b) -> a -> b
$ [RawTensorIndex]
vec
instance (TensorIndex a, TensorIndex b, TensorIndex c) => TensorIndex (a, b, c) where
pushIndex :: [RawTensorIndex] -> (a, b, c) -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec (a
a, b
b, c
c) = (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a.
TensorIndex a =>
[RawTensorIndex] -> a -> [RawTensorIndex]
pushIndex a
a) forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a.
TensorIndex a =>
[RawTensorIndex] -> a -> [RawTensorIndex]
pushIndex b
b) forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a.
TensorIndex a =>
[RawTensorIndex] -> a -> [RawTensorIndex]
pushIndex c
c) forall a b. (a -> b) -> a -> b
$ [RawTensorIndex]
vec
instance (TensorIndex a, TensorIndex b, TensorIndex c, TensorIndex d) => TensorIndex (a, b, c, d) where
pushIndex :: [RawTensorIndex] -> (a, b, c, d) -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec (a
a, b
b, c
c, d
d) = (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a.
TensorIndex a =>
[RawTensorIndex] -> a -> [RawTensorIndex]
pushIndex a
a) forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a.
TensorIndex a =>
[RawTensorIndex] -> a -> [RawTensorIndex]
pushIndex b
b) forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a.
TensorIndex a =>
[RawTensorIndex] -> a -> [RawTensorIndex]
pushIndex c
c) forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a.
TensorIndex a =>
[RawTensorIndex] -> a -> [RawTensorIndex]
pushIndex d
d) forall a b. (a -> b) -> a -> b
$ [RawTensorIndex]
vec
instance (TensorIndex a, TensorIndex b, TensorIndex c, TensorIndex d, TensorIndex e) => TensorIndex (a, b, c, d, e) where
pushIndex :: [RawTensorIndex] -> (a, b, c, d, e) -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec (a
a, b
b, c
c, d
d, e
e) = (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a.
TensorIndex a =>
[RawTensorIndex] -> a -> [RawTensorIndex]
pushIndex a
a) forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a.
TensorIndex a =>
[RawTensorIndex] -> a -> [RawTensorIndex]
pushIndex b
b) forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a.
TensorIndex a =>
[RawTensorIndex] -> a -> [RawTensorIndex]
pushIndex c
c) forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a.
TensorIndex a =>
[RawTensorIndex] -> a -> [RawTensorIndex]
pushIndex d
d) forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a.
TensorIndex a =>
[RawTensorIndex] -> a -> [RawTensorIndex]
pushIndex e
e) forall a b. (a -> b) -> a -> b
$ [RawTensorIndex]
vec
asValue :: TensorLike a => Tensor -> a
asValue :: forall a. TensorLike a => Tensor -> a
asValue Tensor
t =
let cpuTensor :: Tensor
cpuTensor = if Tensor -> Device
device Tensor
t forall a. Eq a => a -> a -> Bool
== DeviceType -> Int16 -> Device
Device DeviceType
CPU Int16
0 then Tensor
t else Tensor -> Tensor
toCPU Tensor
t
contTensor :: Tensor
contTensor = if Tensor -> Bool
isContiguous Tensor
cpuTensor then Tensor
cpuTensor else Tensor -> Tensor
contiguous Tensor
cpuTensor
in forall a. TensorLike a => Tensor -> a
_asValue Tensor
contTensor
class TensorOptionLike a where
withTensorOptions :: Tensor -> a -> Tensor
instance TensorOptionLike TensorOptions where
withTensorOptions :: Tensor -> TensorOptions -> Tensor
withTensorOptions Tensor
t TensorOptions
opts = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
cast4 ForeignPtr Tensor
-> ForeignPtr TensorOptions
-> CBool
-> CBool
-> IO (ForeignPtr Tensor)
ATen.tensor_to_obb Tensor
t TensorOptions
opts Bool
nonBlocking Bool
copy
where
nonBlocking :: Bool
nonBlocking = Bool
False
copy :: Bool
copy = Bool
False
instance TensorOptionLike Tensor where
withTensorOptions :: Tensor -> Tensor -> Tensor
withTensorOptions Tensor
t Tensor
opts = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
cast4 ForeignPtr Tensor
-> ForeignPtr Tensor -> CBool -> CBool -> IO (ForeignPtr Tensor)
ATen.tensor_to_tbb Tensor
t Tensor
opts Bool
nonBlocking Bool
copy
where
nonBlocking :: Bool
nonBlocking = Bool
False
copy :: Bool
copy = Bool
False
class TensorLike a where
asTensor' :: TensorOptionLike opt => a -> opt -> Tensor
asTensor' a
v opt
opts = forall a. TensorOptionLike a => Tensor -> a -> Tensor
withTensorOptions (forall a. TensorLike a => a -> Tensor
asTensor a
v) opt
opts
asTensor :: a -> Tensor
_asValue :: Tensor -> a
_dtype :: DType
_dims :: a -> [Int]
_deepDims :: a -> Maybe [Int]
_peekElemOff :: Ptr () -> Int -> [Int] -> IO a
_pokeElemOff :: Ptr () -> Int -> a -> IO ()
bool_opts :: TensorOptions
bool_opts = DType -> TensorOptions -> TensorOptions
withDType DType
Bool TensorOptions
defaultOpts
uint8_opts :: TensorOptions
uint8_opts = DType -> TensorOptions -> TensorOptions
withDType DType
UInt8 TensorOptions
defaultOpts
int64_opts :: TensorOptions
int64_opts = DType -> TensorOptions -> TensorOptions
withDType DType
Int64 TensorOptions
defaultOpts
float_opts :: TensorOptions
float_opts = DType -> TensorOptions -> TensorOptions
withDType DType
Float TensorOptions
defaultOpts
double_opts :: TensorOptions
double_opts = DType -> TensorOptions -> TensorOptions
withDType DType
Double TensorOptions
defaultOpts
withTensor :: Tensor -> (Ptr () -> IO a) -> IO a
withTensor :: forall a. Tensor -> (Ptr () -> IO a) -> IO a
withTensor Tensor
t Ptr () -> IO a
fn =
let tensor :: Tensor
tensor = if Tensor -> Bool
isContiguous Tensor
t then Tensor
t else Tensor -> Tensor
contiguous Tensor
t
in forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast Tensor
tensor forall a b. (a -> b) -> a -> b
$ \ForeignPtr Tensor
t' -> forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Tensor
t' forall a b. (a -> b) -> a -> b
$ \Ptr Tensor
tensor_ptr -> Ptr Tensor -> IO (Ptr ())
Unmanaged.tensor_data_ptr Ptr Tensor
tensor_ptr forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Ptr () -> IO a
fn
_withTensor :: Tensor -> (Ptr () -> IO a) -> IO a
_withTensor :: forall a. Tensor -> (Ptr () -> IO a) -> IO a
_withTensor Tensor
t Ptr () -> IO a
fn =
forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast Tensor
t forall a b. (a -> b) -> a -> b
$ \ForeignPtr Tensor
t' -> forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Tensor
t' forall a b. (a -> b) -> a -> b
$ \Ptr Tensor
tensor_ptr -> Ptr Tensor -> IO (Ptr ())
Unmanaged.tensor_data_ptr Ptr Tensor
tensor_ptr forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Ptr () -> IO a
fn
instance {-# OVERLAPPING #-} (Reifies a DType, Storable a) => TensorLike a where
asTensor :: a -> Tensor
asTensor a
v = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
Tensor
t <- ((forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 [Int] -> ForeignPtr TensorOptions -> IO (ForeignPtr Tensor)
ATen.new_empty_tensor) :: [Int] -> TensorOptions -> IO Tensor) [] forall a b. (a -> b) -> a -> b
$ DType -> TensorOptions -> TensorOptions
withDType (forall a. TensorLike a => DType
_dtype @a) TensorOptions
defaultOpts
forall a. Tensor -> (Ptr () -> IO a) -> IO a
_withTensor Tensor
t forall a b. (a -> b) -> a -> b
$ \Ptr ()
ptr -> do
forall a. TensorLike a => Ptr () -> Int -> a -> IO ()
_pokeElemOff Ptr ()
ptr Int
0 a
v
forall (m :: * -> *) a. Monad m => a -> m a
return Tensor
t
_asValue :: Tensor -> a
_asValue Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
if forall a. TensorLike a => DType
_dtype @a forall a. Eq a => a -> a -> Bool
== Tensor -> DType
dtype Tensor
t
then do
forall a. Tensor -> (Ptr () -> IO a) -> IO a
withTensor Tensor
t forall a b. (a -> b) -> a -> b
$ \Ptr ()
ptr -> do
forall a. TensorLike a => Ptr () -> Int -> [Int] -> IO a
_peekElemOff Ptr ()
ptr Int
0 []
else forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO forall a b. (a -> b) -> a -> b
$ [Char] -> IOError
userError forall a b. (a -> b) -> a -> b
$ [Char]
"The infered DType of asValue is " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show (forall a. TensorLike a => DType
_dtype @a) forall a. [a] -> [a] -> [a]
++ [Char]
", but the DType of tensor on memory is " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show (Tensor -> DType
dtype Tensor
t) forall a. [a] -> [a] -> [a]
++ [Char]
"."
_dtype :: DType
_dtype = forall {k} (s :: k) a (proxy :: k -> *).
Reifies s a =>
proxy s -> a
reflect (forall {k} (t :: k). Proxy t
Proxy :: Proxy a)
_dims :: a -> [Int]
_dims a
_ = []
_deepDims :: a -> Maybe [Int]
_deepDims a
_ = forall a. a -> Maybe a
Just []
_peekElemOff :: Ptr () -> Int -> [Int] -> IO a
_peekElemOff Ptr ()
ptr Int
offset [Int]
_ = forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
ptr) Int
offset
_pokeElemOff :: Ptr () -> Int -> a -> IO ()
_pokeElemOff Ptr ()
ptr Int
offset a
v = forall a. Storable a => Ptr a -> Int -> a -> IO ()
pokeElemOff (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
ptr) Int
offset a
v
instance {-# OVERLAPPING #-} TensorLike Bool where
asTensor :: Bool -> Tensor
asTensor Bool
v = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
Tensor
t <- ((forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 [Int] -> ForeignPtr TensorOptions -> IO (ForeignPtr Tensor)
ATen.new_empty_tensor) :: [Int] -> TensorOptions -> IO Tensor) [] forall a b. (a -> b) -> a -> b
$ DType -> TensorOptions -> TensorOptions
withDType (forall a. TensorLike a => DType
_dtype @Bool) TensorOptions
defaultOpts
forall a. Tensor -> (Ptr () -> IO a) -> IO a
_withTensor Tensor
t forall a b. (a -> b) -> a -> b
$ \Ptr ()
ptr -> do
forall a. TensorLike a => Ptr () -> Int -> a -> IO ()
_pokeElemOff Ptr ()
ptr Int
0 Bool
v
forall (m :: * -> *) a. Monad m => a -> m a
return Tensor
t
_asValue :: Tensor -> Bool
_asValue Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
if forall a. TensorLike a => DType
_dtype @Bool forall a. Eq a => a -> a -> Bool
== Tensor -> DType
dtype Tensor
t
then do
forall a. Tensor -> (Ptr () -> IO a) -> IO a
withTensor Tensor
t forall a b. (a -> b) -> a -> b
$ \Ptr ()
ptr -> do
forall a. TensorLike a => Ptr () -> Int -> [Int] -> IO a
_peekElemOff Ptr ()
ptr Int
0 []
else forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO forall a b. (a -> b) -> a -> b
$ [Char] -> IOError
userError forall a b. (a -> b) -> a -> b
$ [Char]
"The infered DType of asValue is " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show (forall a. TensorLike a => DType
_dtype @Bool) forall a. [a] -> [a] -> [a]
++ [Char]
", but the DType of tensor on memory is " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show (Tensor -> DType
dtype Tensor
t) forall a. [a] -> [a] -> [a]
++ [Char]
"."
_dtype :: DType
_dtype = forall {k} (s :: k) a (proxy :: k -> *).
Reifies s a =>
proxy s -> a
reflect (forall {k} (t :: k). Proxy t
Proxy :: Proxy Bool)
_dims :: Bool -> [Int]
_dims Bool
_ = []
_deepDims :: Bool -> Maybe [Int]
_deepDims Bool
_ = forall a. a -> Maybe a
Just []
_peekElemOff :: Ptr () -> Int -> [Int] -> IO Bool
_peekElemOff Ptr ()
ptr Int
offset [Int]
_ = (forall a. Eq a => a -> a -> Bool
/= Word8
0) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
ptr) Int
offset :: IO Word8)
_pokeElemOff :: Ptr () -> Int -> Bool -> IO ()
_pokeElemOff Ptr ()
ptr Int
offset Bool
v = forall a. Storable a => Ptr a -> Int -> a -> IO ()
pokeElemOff (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
ptr) Int
offset ((if Bool
v then Word8
1 else Word8
0) :: Word8)
instance {-# OVERLAPPING #-} TensorLike Tensor where
asTensor' :: forall a. TensorOptionLike a => Tensor -> a -> Tensor
asTensor' Tensor
v opt
opts = forall a. TensorOptionLike a => Tensor -> a -> Tensor
withTensorOptions Tensor
v opt
opts
asTensor :: Tensor -> Tensor
asTensor = forall a. a -> a
id
_asValue :: Tensor -> Tensor
_asValue = forall a. a -> a
id
_dtype :: DType
_dtype = forall a. HasCallStack => [Char] -> a
error [Char]
"Not implemented for Tensor-type"
_dims :: Tensor -> [Int]
_dims Tensor
v = forall a. HasCallStack => [Char] -> a
error [Char]
"Not implemented for Tensor-type"
_deepDims :: Tensor -> Maybe [Int]
_deepDims Tensor
v = forall a. HasCallStack => [Char] -> a
error [Char]
"Not implemented for Tensor-type"
_peekElemOff :: Ptr () -> Int -> [Int] -> IO Tensor
_peekElemOff = forall a. HasCallStack => [Char] -> a
error [Char]
"Not implemented for Tensor-type"
_pokeElemOff :: Ptr () -> Int -> Tensor -> IO ()
_pokeElemOff = forall a. HasCallStack => [Char] -> a
error [Char]
"Not implemented for Tensor-type"
instance {-# OVERLAPPING #-} TensorLike a => TensorLike (a, a) where
asTensor :: (a, a) -> Tensor
asTensor (a
a, a
b) = forall a. TensorLike a => a -> Tensor
asTensor [a
a, a
b]
_asValue :: Tensor -> (a, a)
_asValue Tensor
v =
let [a
a, a
b] = forall a. TensorLike a => Tensor -> a
_asValue Tensor
v
in (a
a, a
b)
_dtype :: DType
_dtype = forall a. HasCallStack => [Char] -> a
error [Char]
"Not implemented for tuple-type"
_dims :: (a, a) -> [Int]
_dims (a, a)
v = forall a. HasCallStack => [Char] -> a
error [Char]
"Not implemented for tuple-type"
_deepDims :: (a, a) -> Maybe [Int]
_deepDims (a, a)
v = forall a. HasCallStack => [Char] -> a
error [Char]
"Not implemented for tuple-type"
_peekElemOff :: Ptr () -> Int -> [Int] -> IO (a, a)
_peekElemOff = forall a. HasCallStack => [Char] -> a
error [Char]
"Not implemented for tuple-type"
_pokeElemOff :: Ptr () -> Int -> (a, a) -> IO ()
_pokeElemOff = forall a. HasCallStack => [Char] -> a
error [Char]
"Not implemented for tuple-type"
instance {-# OVERLAPPING #-} TensorLike a => TensorLike [a] where
asTensor :: [a] -> Tensor
asTensor [a]
v = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
Tensor
t <- ((forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 [Int] -> ForeignPtr TensorOptions -> IO (ForeignPtr Tensor)
ATen.new_empty_tensor) :: [Int] -> TensorOptions -> IO Tensor) (forall a. TensorLike a => a -> [Int]
_dims [a]
v) forall a b. (a -> b) -> a -> b
$ DType -> TensorOptions -> TensorOptions
withDType (forall a. TensorLike a => DType
_dtype @a) TensorOptions
defaultOpts
forall a. Tensor -> (Ptr () -> IO a) -> IO a
_withTensor Tensor
t forall a b. (a -> b) -> a -> b
$ \Ptr ()
ptr -> do
forall a. TensorLike a => Ptr () -> Int -> a -> IO ()
_pokeElemOff Ptr ()
ptr Int
0 [a]
v
forall (m :: * -> *) a. Monad m => a -> m a
return Tensor
t
_asValue :: Tensor -> [a]
_asValue Tensor
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
if forall a. TensorLike a => DType
_dtype @a forall a. Eq a => a -> a -> Bool
== Tensor -> DType
dtype Tensor
t
then do
forall a. Tensor -> (Ptr () -> IO a) -> IO a
withTensor Tensor
t forall a b. (a -> b) -> a -> b
$ \Ptr ()
ptr -> do
forall a. TensorLike a => Ptr () -> Int -> [Int] -> IO a
_peekElemOff Ptr ()
ptr Int
0 (Tensor -> [Int]
shape Tensor
t)
else forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO forall a b. (a -> b) -> a -> b
$ [Char] -> IOError
userError forall a b. (a -> b) -> a -> b
$ [Char]
"The infered DType of asValue is " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show (forall a. TensorLike a => DType
_dtype @a) forall a. [a] -> [a] -> [a]
++ [Char]
", but the DType of tensor on memory is " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show (Tensor -> DType
dtype Tensor
t) forall a. [a] -> [a] -> [a]
++ [Char]
"."
_dtype :: DType
_dtype = forall a. TensorLike a => DType
_dtype @a
_dims :: [a] -> [Int]
_dims [] = [Int
0]
_dims v :: [a]
v@(a
x : [a]
_) = (forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
v) forall a. a -> [a] -> [a]
: (forall a. TensorLike a => a -> [Int]
_dims a
x)
_deepDims :: [a] -> Maybe [Int]
_deepDims [] = forall a. a -> Maybe a
Just [Int
0]
_deepDims v :: [a]
v@(a
x : [a]
xs) = do
[Int]
deepDimsX <- forall a. TensorLike a => a -> Maybe [Int]
_deepDims a
x
[[Int]]
deepDimsXs <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall a. TensorLike a => a -> Maybe [Int]
_deepDims [a]
xs
if forall (t :: * -> *). Foldable t => t Bool -> Bool
and forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([Int]
deepDimsX forall a. Eq a => a -> a -> Bool
==) [[Int]]
deepDimsXs
then forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
v forall a. a -> [a] -> [a]
: [Int]
deepDimsX
else forall a. Maybe a
Nothing
_peekElemOff :: Ptr () -> Int -> [Int] -> IO [a]
_peekElemOff Ptr ()
ptr Int
offset [] = forall (m :: * -> *) a. Monad m => a -> m a
return []
_peekElemOff Ptr ()
ptr Int
offset (Int
d : [Int]
dims) =
let width :: Int
width = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Int]
dims
in forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int
0 .. (Int
d forall a. Num a => a -> a -> a
-Int
1)] forall a b. (a -> b) -> a -> b
$ \Int
i ->
forall a. TensorLike a => Ptr () -> Int -> [Int] -> IO a
_peekElemOff Ptr ()
ptr (Int
offset forall a. Num a => a -> a -> a
+ Int
i forall a. Num a => a -> a -> a
* Int
width) [Int]
dims
_pokeElemOff :: Ptr () -> Int -> [a] -> IO ()
_pokeElemOff Ptr ()
ptr Int
offset [] = forall (m :: * -> *) a. Monad m => a -> m a
return ()
_pokeElemOff Ptr ()
ptr Int
offset v :: [a]
v@(a
x : [a]
_) =
let width :: Int
width = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product (forall a. TensorLike a => a -> [Int]
_dims a
x)
in forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0 ..] [a]
v) forall a b. (a -> b) -> a -> b
$ \(Int
i, a
d) ->
if forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product (forall a. TensorLike a => a -> [Int]
_dims a
d) forall a. Eq a => a -> a -> Bool
== Int
width
then (forall a. TensorLike a => Ptr () -> Int -> a -> IO ()
_pokeElemOff @a) Ptr ()
ptr (Int
offset forall a. Num a => a -> a -> a
+ Int
i forall a. Num a => a -> a -> a
* Int
width) a
d
else forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO forall a b. (a -> b) -> a -> b
$ [Char] -> IOError
userError forall a b. (a -> b) -> a -> b
$ [Char]
"There are lists having different length."
class AsTensors as where
toTensors :: as -> V.Vector Tensor
default toTensors :: (Generic as, GAsTensors (Rep as)) => as -> V.Vector Tensor
toTensors as
a = forall (record :: * -> *) as.
GAsTensors record =>
record as -> Vector Tensor
gToTensors forall a b. (a -> b) -> a -> b
$ forall a x. Generic a => a -> Rep a x
from as
a
instance TensorLike a => AsTensors a where
toTensors :: a -> Vector Tensor
toTensors = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. TensorLike a => a -> Tensor
asTensor
class GAsTensors record where
gToTensors :: record as -> V.Vector Tensor
instance (GAsTensors ls, GAsTensors rs) => GAsTensors (ls :*: rs) where
gToTensors :: forall as. (:*:) ls rs as -> Vector Tensor
gToTensors (ls as
g :*: rs as
d) = forall (record :: * -> *) as.
GAsTensors record =>
record as -> Vector Tensor
gToTensors ls as
g forall a. Vector a -> Vector a -> Vector a
V.++ forall (record :: * -> *) as.
GAsTensors record =>
record as -> Vector Tensor
gToTensors rs as
d
instance (GAsTensors ls, GAsTensors rs) => GAsTensors (ls :+: rs) where
gToTensors :: forall as. (:+:) ls rs as -> Vector Tensor
gToTensors (L1 ls as
g) = forall (record :: * -> *) as.
GAsTensors record =>
record as -> Vector Tensor
gToTensors ls as
g
gToTensors (R1 rs as
g) = forall (record :: * -> *) as.
GAsTensors record =>
record as -> Vector Tensor
gToTensors rs as
g
instance (GAsTensors ls) => GAsTensors (M1 i c ls) where
gToTensors :: forall as. M1 i c ls as -> Vector Tensor
gToTensors (M1 ls as
g) = forall (record :: * -> *) as.
GAsTensors record =>
record as -> Vector Tensor
gToTensors ls as
g
instance (TensorLike ls) => GAsTensors (K1 i ls) where
gToTensors :: forall as. K1 i ls as -> Vector Tensor
gToTensors (K1 ls
g) = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. TensorLike a => a -> Tensor
asTensor ls
g
instance Show Tensor where
show :: Tensor -> [Char]
show Tensor
t' =
case (Tensor -> Int
dim Tensor
t) of
Int
0 -> [Char]
details forall a. [a] -> [a] -> [a]
++ Tensor -> [Char]
show0d Tensor
t
Int
1 -> [Char]
details forall a. [a] -> [a] -> [a]
++ Tensor -> [Char]
show1d Tensor
t
Int
n -> [Char]
details forall a. [a] -> [a] -> [a]
++ Int -> Int -> Tensor -> [Char]
shownd Int
n Int
0 Tensor
t
where
t :: Tensor
t = if Tensor -> Device
device Tensor
t' forall a. Eq a => a -> a -> Bool
== DeviceType -> Int16 -> Device
Device DeviceType
CPU Int16
0 then Tensor
t' else Tensor -> Tensor
toCPU Tensor
t'
showElems :: (Tensor -> [Char]) -> [Char] -> Tensor -> [Char]
showElems Tensor -> [Char]
elemShow [Char]
sep Tensor
t = [Char]
"[" forall a. [a] -> [a] -> [a]
++ (forall a. [a] -> [[a]] -> [a]
intercalate [Char]
sep forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map Tensor -> [Char]
elemShow [Tensor
t forall a. TensorIndex a => Tensor -> a -> Tensor
! Int
i | Int
i <- [Int
0 .. ((Int -> Tensor -> Int
size Int
0 Tensor
t) forall a. Num a => a -> a -> a
- Int
1)]]) forall a. [a] -> [a] -> [a]
++ [Char]
"]"
padPositive :: a -> ShowS
padPositive a
x [Char]
s = if a
x forall a. Ord a => a -> a -> Bool
>= a
0 then [Char]
" " forall a. [a] -> [a] -> [a]
++ [Char]
s else [Char]
s
padLarge :: a -> ShowS
padLarge a
x [Char]
s = if (forall a. Num a => a -> a
abs a
x) forall a. Ord a => a -> a -> Bool
>= a
0.1 then [Char]
s forall a. [a] -> [a] -> [a]
++ [Char]
" " else [Char]
s
show0d :: Tensor -> [Char]
show0d Tensor
x =
if DType -> Bool
isIntegral (Tensor -> DType
dtype Tensor
t)
then forall {a}. (Ord a, Num a) => a -> ShowS
padPositive (Tensor -> Int
toInt Tensor
x) forall a b. (a -> b) -> a -> b
$ forall a. Show a => a -> [Char]
show forall a b. (a -> b) -> a -> b
$ Tensor -> Int
toInt Tensor
x
else
if DType -> Bool
isComplex (Tensor -> DType
dtype Tensor
t)
then
let Double
r :+ Double
i = Tensor -> Complex Double
toComplex Tensor
x
in (forall {a}. (Ord a, Fractional a) => a -> ShowS
padLarge Double
r forall a b. (a -> b) -> a -> b
$ forall {a}. (Ord a, Num a) => a -> ShowS
padPositive Double
r forall a b. (a -> b) -> a -> b
$ forall a. RealFloat a => Maybe Int -> a -> ShowS
showGFloat (forall a. a -> Maybe a
Just Int
4) Double
r [Char]
"") forall a. [a] -> [a] -> [a]
++ [Char]
" + i" forall a. [a] -> [a] -> [a]
++
(forall {a}. (Ord a, Fractional a) => a -> ShowS
padLarge Double
i forall a b. (a -> b) -> a -> b
$ forall {a}. (Ord a, Num a) => a -> ShowS
padPositive Double
i forall a b. (a -> b) -> a -> b
$ forall a. RealFloat a => Maybe Int -> a -> ShowS
showGFloat (forall a. a -> Maybe a
Just Int
4) Double
i [Char]
"")
else forall {a}. (Ord a, Fractional a) => a -> ShowS
padLarge (Tensor -> Double
toDouble Tensor
x) forall a b. (a -> b) -> a -> b
$ forall {a}. (Ord a, Num a) => a -> ShowS
padPositive (Tensor -> Double
toDouble Tensor
x) forall a b. (a -> b) -> a -> b
$ forall a. RealFloat a => Maybe Int -> a -> ShowS
showGFloat (forall a. a -> Maybe a
Just Int
4) (Tensor -> Double
toDouble Tensor
x) [Char]
""
show1d :: Tensor -> [Char]
show1d = (Tensor -> [Char]) -> [Char] -> Tensor -> [Char]
showElems Tensor -> [Char]
show0d [Char]
", "
shownd :: Int -> Int -> Tensor -> [Char]
shownd Int
n Int
offset =
case Int
n of
Int
2 -> (Tensor -> [Char]) -> [Char] -> Tensor -> [Char]
showElems Tensor -> [Char]
show1d ([Char]
",\n " forall a. [a] -> [a] -> [a]
++ [Char]
padding forall a. [a] -> [a] -> [a]
++ forall a. Int -> a -> [a]
replicate Int
offset Char
' ')
Int
_ -> (Tensor -> [Char]) -> [Char] -> Tensor -> [Char]
showElems (Int -> Int -> Tensor -> [Char]
shownd (Int
n forall a. Num a => a -> a -> a
-Int
1) (Int
offset forall a. Num a => a -> a -> a
+ Int
1)) ([Char]
",\n " forall a. [a] -> [a] -> [a]
++ [Char]
padding forall a. [a] -> [a] -> [a]
++ forall a. Int -> a -> [a]
replicate Int
offset Char
' ')
details :: [Char]
details = [Char]
"Tensor " forall a. [a] -> [a] -> [a]
++ (forall a. Show a => a -> [Char]
show forall a b. (a -> b) -> a -> b
$ Tensor -> DType
dtype Tensor
t) forall a. [a] -> [a] -> [a]
++ [Char]
" " forall a. [a] -> [a] -> [a]
++ (forall a. Show a => a -> [Char]
show forall a b. (a -> b) -> a -> b
$ Tensor -> [Int]
shape Tensor
t) forall a. [a] -> [a] -> [a]
++ [Char]
" "
padding :: [Char]
padding = forall a b. (a -> b) -> [a] -> [b]
map (forall a b. a -> b -> a
const Char
' ') [Char]
details
instance Castable [Tensor] (ForeignPtr ATen.TensorList) where
cast :: forall r. [Tensor] -> (ForeignPtr TensorList -> IO r) -> IO r
cast [Tensor]
xs ForeignPtr TensorList -> IO r
f = do
[ForeignPtr Tensor]
ptr_list <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\Tensor
x -> (forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast Tensor
x forall (m :: * -> *) a. Monad m => a -> m a
return :: IO (ForeignPtr ATen.Tensor))) [Tensor]
xs
forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast [ForeignPtr Tensor]
ptr_list ForeignPtr TensorList -> IO r
f
uncast :: forall r. ForeignPtr TensorList -> ([Tensor] -> IO r) -> IO r
uncast ForeignPtr TensorList
xs [Tensor] -> IO r
f = forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr TensorList
xs forall a b. (a -> b) -> a -> b
$ \[ForeignPtr Tensor]
ptr_list -> do
[Tensor]
tensor_list <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\(ForeignPtr Tensor
x :: ForeignPtr ATen.Tensor) -> forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr Tensor
x forall (m :: * -> *) a. Monad m => a -> m a
return) [ForeignPtr Tensor]
ptr_list
[Tensor] -> IO r
f [Tensor]
tensor_list
instance Castable [Tensor] (ForeignPtr (ATen.C10List ATen.Tensor)) where
cast :: forall r. [Tensor] -> (ForeignPtr (C10List Tensor) -> IO r) -> IO r
cast [Tensor]
xs ForeignPtr (C10List Tensor) -> IO r
f = do
[ForeignPtr Tensor]
ptr_list <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\Tensor
x -> (forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast Tensor
x forall (m :: * -> *) a. Monad m => a -> m a
return :: IO (ForeignPtr ATen.Tensor))) [Tensor]
xs
forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast [ForeignPtr Tensor]
ptr_list ForeignPtr (C10List Tensor) -> IO r
f
uncast :: forall r. ForeignPtr (C10List Tensor) -> ([Tensor] -> IO r) -> IO r
uncast ForeignPtr (C10List Tensor)
xs [Tensor] -> IO r
f = forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr (C10List Tensor)
xs forall a b. (a -> b) -> a -> b
$ \[ForeignPtr Tensor]
ptr_list -> do
[Tensor]
tensor_list <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\(ForeignPtr Tensor
x :: ForeignPtr ATen.Tensor) -> forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr Tensor
x forall (m :: * -> *) a. Monad m => a -> m a
return) [ForeignPtr Tensor]
ptr_list
[Tensor] -> IO r
f [Tensor]
tensor_list
instance Castable [Tensor] (ForeignPtr (ATen.C10List (ATen.C10Optional ATen.Tensor))) where
cast :: forall r.
[Tensor]
-> (ForeignPtr (C10List (C10Optional Tensor)) -> IO r) -> IO r
cast [Tensor]
xs ForeignPtr (C10List (C10Optional Tensor)) -> IO r
f = do
[ForeignPtr Tensor]
ptr_list <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\Tensor
x -> (forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast Tensor
x forall (m :: * -> *) a. Monad m => a -> m a
return :: IO (ForeignPtr ATen.Tensor))) [Tensor]
xs
forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast [ForeignPtr Tensor]
ptr_list ForeignPtr (C10List (C10Optional Tensor)) -> IO r
f
uncast :: forall r.
ForeignPtr (C10List (C10Optional Tensor))
-> ([Tensor] -> IO r) -> IO r
uncast ForeignPtr (C10List (C10Optional Tensor))
xs [Tensor] -> IO r
f = forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr (C10List (C10Optional Tensor))
xs forall a b. (a -> b) -> a -> b
$ \[ForeignPtr Tensor]
ptr_list -> do
[Tensor]
tensor_list <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\(ForeignPtr Tensor
x :: ForeignPtr ATen.Tensor) -> forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr Tensor
x forall (m :: * -> *) a. Monad m => a -> m a
return) [ForeignPtr Tensor]
ptr_list
[Tensor] -> IO r
f [Tensor]
tensor_list