{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}

module Torch.Internal.Managed.Cast where

import Control.Exception.Safe (throwIO)
import Foreign.ForeignPtr
import Foreign.C.Types
import Data.Int
import Control.Monad

import Torch.Internal.Class
import Torch.Internal.Cast
import Torch.Internal.Type
import Torch.Internal.Managed.Type.IntArray
import Torch.Internal.Managed.Type.TensorList
import Torch.Internal.Managed.Type.C10List
import Torch.Internal.Managed.Type.IValueList
import Torch.Internal.Managed.Type.C10Tuple
import Torch.Internal.Managed.Type.C10Dict
import Torch.Internal.Managed.Type.StdVector

instance Castable Int (ForeignPtr IntArray) where
  cast :: forall r. Int -> (ForeignPtr IntArray -> IO r) -> IO r
cast Int
xs ForeignPtr IntArray -> IO r
f = do
    ForeignPtr IntArray
arr <- IO (ForeignPtr IntArray)
newIntArray
    ForeignPtr IntArray -> Int64 -> IO ()
intArray_push_back_l ForeignPtr IntArray
arr forall a b. (a -> b) -> a -> b
$ forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
xs
    ForeignPtr IntArray -> IO r
f ForeignPtr IntArray
arr
  uncast :: forall r. ForeignPtr IntArray -> (Int -> IO r) -> IO r
uncast ForeignPtr IntArray
xs Int -> IO r
f = do
    Int64
v <- ForeignPtr IntArray -> CSize -> IO Int64
intArray_at_s ForeignPtr IntArray
xs CSize
0
    Int -> IO r
f (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
v)

instance Castable [Int] (ForeignPtr IntArray) where
  cast :: forall r. [Int] -> (ForeignPtr IntArray -> IO r) -> IO r
cast [Int]
xs ForeignPtr IntArray -> IO r
f = do
    ForeignPtr IntArray
arr <- IO (ForeignPtr IntArray)
newIntArray
    ForeignPtr IntArray -> [Int64] -> IO ()
intArray_fromList ForeignPtr IntArray
arr (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (Integral a, Num b) => a -> b
fromIntegral [Int]
xs)
    ForeignPtr IntArray -> IO r
f ForeignPtr IntArray
arr
  uncast :: forall r. ForeignPtr IntArray -> ([Int] -> IO r) -> IO r
uncast ForeignPtr IntArray
xs [Int] -> IO r
f = do
    [Int64]
xs <- ForeignPtr IntArray -> IO [Int64]
intArray_toList ForeignPtr IntArray
xs
    [Int] -> IO r
f (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (Integral a, Num b) => a -> b
fromIntegral [Int64]
xs)

instance Castable [Double] (ForeignPtr (StdVector CDouble)) where
  cast :: forall r.
[Double] -> (ForeignPtr (StdVector CDouble) -> IO r) -> IO r
cast [Double]
xs ForeignPtr (StdVector CDouble) -> IO r
f = do
    ForeignPtr (StdVector CDouble)
arr <- IO (ForeignPtr (StdVector CDouble))
newStdVectorDouble
    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Double]
xs forall a b. (a -> b) -> a -> b
$ (ForeignPtr (StdVector CDouble) -> CDouble -> IO ()
stdVectorDouble_push_back ForeignPtr (StdVector CDouble)
arr) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Real a, Fractional b) => a -> b
realToFrac
    ForeignPtr (StdVector CDouble) -> IO r
f ForeignPtr (StdVector CDouble)
arr
  uncast :: forall r.
ForeignPtr (StdVector CDouble) -> ([Double] -> IO r) -> IO r
uncast ForeignPtr (StdVector CDouble)
xs [Double] -> IO r
f = do
    CSize
len <- ForeignPtr (StdVector CDouble) -> IO CSize
stdVectorDouble_size ForeignPtr (StdVector CDouble)
xs
    -- NB: This check is necessary, because len is unsigned and it will wrap around if
    --     we subtract 1 when it's 0.
    if CSize
len forall a. Eq a => a -> a -> Bool
== CSize
0
      then [Double] -> IO r
f []
      else [Double] -> IO r
f forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\CSize
i -> ForeignPtr (StdVector CDouble) -> CSize -> IO CDouble
stdVectorDouble_at ForeignPtr (StdVector CDouble)
xs CSize
i forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Real a, Fractional b) => a -> b
realToFrac) [CSize
0..(CSize
len forall a. Num a => a -> a -> a
- CSize
1)]

instance Castable [ForeignPtr Tensor] (ForeignPtr TensorList) where
  cast :: forall r.
[ForeignPtr Tensor] -> (ForeignPtr TensorList -> IO r) -> IO r
cast [ForeignPtr Tensor]
xs ForeignPtr TensorList -> IO r
f = do
    ForeignPtr TensorList
l <- IO (ForeignPtr TensorList)
newTensorList
    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [ForeignPtr Tensor]
xs forall a b. (a -> b) -> a -> b
$ (ForeignPtr TensorList -> ForeignPtr Tensor -> IO ()
tensorList_push_back_t ForeignPtr TensorList
l)
    ForeignPtr TensorList -> IO r
f ForeignPtr TensorList
l
  uncast :: forall r.
ForeignPtr TensorList -> ([ForeignPtr Tensor] -> IO r) -> IO r
uncast ForeignPtr TensorList
xs [ForeignPtr Tensor] -> IO r
f = do
    CSize
len <- ForeignPtr TensorList -> IO CSize
tensorList_size ForeignPtr TensorList
xs
    [ForeignPtr Tensor] -> IO r
f forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (ForeignPtr TensorList -> CSize -> IO (ForeignPtr Tensor)
tensorList_at_s ForeignPtr TensorList
xs) [CSize
0..(CSize
len forall a. Num a => a -> a -> a
- CSize
1)]

instance Castable [ForeignPtr Tensor] (ForeignPtr (C10List Tensor)) where
  cast :: forall r.
[ForeignPtr Tensor]
-> (ForeignPtr (C10List Tensor) -> IO r) -> IO r
cast [ForeignPtr Tensor]
xs ForeignPtr (C10List Tensor) -> IO r
f = do
    ForeignPtr (C10List Tensor)
l <- IO (ForeignPtr (C10List Tensor))
newC10ListTensor
    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [ForeignPtr Tensor]
xs forall a b. (a -> b) -> a -> b
$ (ForeignPtr (C10List Tensor) -> ForeignPtr Tensor -> IO ()
c10ListTensor_push_back ForeignPtr (C10List Tensor)
l)
    ForeignPtr (C10List Tensor) -> IO r
f ForeignPtr (C10List Tensor)
l
  uncast :: forall r.
ForeignPtr (C10List Tensor)
-> ([ForeignPtr Tensor] -> IO r) -> IO r
uncast ForeignPtr (C10List Tensor)
xs [ForeignPtr Tensor] -> IO r
f = do
    CSize
len <- ForeignPtr (C10List Tensor) -> IO CSize
c10ListTensor_size ForeignPtr (C10List Tensor)
xs
    [ForeignPtr Tensor] -> IO r
f forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (ForeignPtr (C10List Tensor) -> CSize -> IO (ForeignPtr Tensor)
c10ListTensor_at ForeignPtr (C10List Tensor)
xs) [CSize
0..(CSize
len forall a. Num a => a -> a -> a
- CSize
1)]

instance Castable [ForeignPtr Tensor] (ForeignPtr (C10List (C10Optional Tensor))) where
  cast :: forall r.
[ForeignPtr Tensor]
-> (ForeignPtr (C10List (C10Optional Tensor)) -> IO r) -> IO r
cast [ForeignPtr Tensor]
xs ForeignPtr (C10List (C10Optional Tensor)) -> IO r
f = do
    ForeignPtr (C10List (C10Optional Tensor))
l <- IO (ForeignPtr (C10List (C10Optional Tensor)))
newC10ListOptionalTensor
    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [ForeignPtr Tensor]
xs forall a b. (a -> b) -> a -> b
$ (ForeignPtr (C10List (C10Optional Tensor))
-> ForeignPtr Tensor -> IO ()
c10ListOptionalTensor_push_back ForeignPtr (C10List (C10Optional Tensor))
l)
    ForeignPtr (C10List (C10Optional Tensor)) -> IO r
f ForeignPtr (C10List (C10Optional Tensor))
l
  uncast :: forall r.
ForeignPtr (C10List (C10Optional Tensor))
-> ([ForeignPtr Tensor] -> IO r) -> IO r
uncast ForeignPtr (C10List (C10Optional Tensor))
xs [ForeignPtr Tensor] -> IO r
f = do
    CSize
len <- ForeignPtr (C10List (C10Optional Tensor)) -> IO CSize
c10ListOptionalTensor_size ForeignPtr (C10List (C10Optional Tensor))
xs
    [ForeignPtr Tensor] -> IO r
f forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (ForeignPtr (C10List (C10Optional Tensor))
-> CSize -> IO (ForeignPtr Tensor)
c10ListOptionalTensor_at ForeignPtr (C10List (C10Optional Tensor))
xs) [CSize
0..(CSize
len forall a. Num a => a -> a -> a
- CSize
1)]

instance Castable [CDouble] (ForeignPtr (C10List CDouble)) where
  cast :: forall r.
[CDouble] -> (ForeignPtr (C10List CDouble) -> IO r) -> IO r
cast [CDouble]
xs ForeignPtr (C10List CDouble) -> IO r
f = do
    ForeignPtr (C10List CDouble)
l <- IO (ForeignPtr (C10List CDouble))
newC10ListDouble
    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [CDouble]
xs forall a b. (a -> b) -> a -> b
$ (ForeignPtr (C10List CDouble) -> CDouble -> IO ()
c10ListDouble_push_back ForeignPtr (C10List CDouble)
l)
    ForeignPtr (C10List CDouble) -> IO r
f ForeignPtr (C10List CDouble)
l
  uncast :: forall r.
ForeignPtr (C10List CDouble) -> ([CDouble] -> IO r) -> IO r
uncast ForeignPtr (C10List CDouble)
xs [CDouble] -> IO r
f = do
    CSize
len <- ForeignPtr (C10List CDouble) -> IO CSize
c10ListDouble_size ForeignPtr (C10List CDouble)
xs
    [CDouble] -> IO r
f forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (ForeignPtr (C10List CDouble) -> CSize -> IO CDouble
c10ListDouble_at ForeignPtr (C10List CDouble)
xs) [CSize
0..(CSize
len forall a. Num a => a -> a -> a
- CSize
1)]

instance Castable [Int64] (ForeignPtr (C10List Int64)) where
  cast :: forall r. [Int64] -> (ForeignPtr (C10List Int64) -> IO r) -> IO r
cast [Int64]
xs ForeignPtr (C10List Int64) -> IO r
f = do
    ForeignPtr (C10List Int64)
l <- IO (ForeignPtr (C10List Int64))
newC10ListInt
    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int64]
xs forall a b. (a -> b) -> a -> b
$ (ForeignPtr (C10List Int64) -> Int64 -> IO ()
c10ListInt_push_back ForeignPtr (C10List Int64)
l)
    ForeignPtr (C10List Int64) -> IO r
f ForeignPtr (C10List Int64)
l
  uncast :: forall r. ForeignPtr (C10List Int64) -> ([Int64] -> IO r) -> IO r
uncast ForeignPtr (C10List Int64)
xs [Int64] -> IO r
f = do
    CSize
len <- ForeignPtr (C10List Int64) -> IO CSize
c10ListInt_size ForeignPtr (C10List Int64)
xs
    [Int64] -> IO r
f forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (ForeignPtr (C10List Int64) -> CSize -> IO Int64
c10ListInt_at ForeignPtr (C10List Int64)
xs) [CSize
0..(CSize
len forall a. Num a => a -> a -> a
- CSize
1)]

instance Castable [CBool] (ForeignPtr (C10List CBool)) where
  cast :: forall r. [CBool] -> (ForeignPtr (C10List CBool) -> IO r) -> IO r
cast [CBool]
xs ForeignPtr (C10List CBool) -> IO r
f = do
    ForeignPtr (C10List CBool)
l <- IO (ForeignPtr (C10List CBool))
newC10ListBool
    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [CBool]
xs forall a b. (a -> b) -> a -> b
$ (ForeignPtr (C10List CBool) -> CBool -> IO ()
c10ListBool_push_back ForeignPtr (C10List CBool)
l)
    ForeignPtr (C10List CBool) -> IO r
f ForeignPtr (C10List CBool)
l
  uncast :: forall r. ForeignPtr (C10List CBool) -> ([CBool] -> IO r) -> IO r
uncast ForeignPtr (C10List CBool)
xs [CBool] -> IO r
f = do
    CSize
len <- ForeignPtr (C10List CBool) -> IO CSize
c10ListBool_size ForeignPtr (C10List CBool)
xs
    [CBool] -> IO r
f forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (ForeignPtr (C10List CBool) -> CSize -> IO CBool
c10ListBool_at ForeignPtr (C10List CBool)
xs) [CSize
0..(CSize
len forall a. Num a => a -> a -> a
- CSize
1)]

instance Castable [ForeignPtr IValue] (ForeignPtr IValueList) where
  cast :: forall r.
[ForeignPtr IValue] -> (ForeignPtr IValueList -> IO r) -> IO r
cast [ForeignPtr IValue]
xs ForeignPtr IValueList -> IO r
f = do
    ForeignPtr IValueList
l <- IO (ForeignPtr IValueList)
newIValueList
    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [ForeignPtr IValue]
xs forall a b. (a -> b) -> a -> b
$ (ForeignPtr IValueList -> ForeignPtr IValue -> IO ()
ivalueList_push_back ForeignPtr IValueList
l)
    ForeignPtr IValueList -> IO r
f ForeignPtr IValueList
l
  uncast :: forall r.
ForeignPtr IValueList -> ([ForeignPtr IValue] -> IO r) -> IO r
uncast ForeignPtr IValueList
xs [ForeignPtr IValue] -> IO r
f = do
    CSize
len <- ForeignPtr IValueList -> IO CSize
ivalueList_size ForeignPtr IValueList
xs
    [ForeignPtr IValue] -> IO r
f forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (ForeignPtr IValueList -> CSize -> IO (ForeignPtr IValue)
ivalueList_at ForeignPtr IValueList
xs) [CSize
0..(CSize
len forall a. Num a => a -> a -> a
- CSize
1)]

instance Castable [ForeignPtr IValue] (ForeignPtr (C10Ptr IVTuple)) where
  cast :: forall r.
[ForeignPtr IValue]
-> (ForeignPtr (C10Ptr IVTuple) -> IO r) -> IO r
cast [ForeignPtr IValue]
xs ForeignPtr (C10Ptr IVTuple) -> IO r
f = do
    forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast [ForeignPtr IValue]
xs forall a b. (a -> b) -> a -> b
$ \ForeignPtr IValueList
ivalueList -> do
      ForeignPtr (C10Ptr IVTuple) -> IO r
f forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ForeignPtr IValueList -> IO (ForeignPtr (C10Ptr IVTuple))
newC10Tuple_tuple ForeignPtr IValueList
ivalueList
  uncast :: forall r.
ForeignPtr (C10Ptr IVTuple)
-> ([ForeignPtr IValue] -> IO r) -> IO r
uncast ForeignPtr (C10Ptr IVTuple)
xs [ForeignPtr IValue] -> IO r
f = do
    CSize
len <- ForeignPtr (C10Ptr IVTuple) -> IO CSize
c10Tuple_size ForeignPtr (C10Ptr IVTuple)
xs
    [ForeignPtr IValue] -> IO r
f forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (ForeignPtr (C10Ptr IVTuple) -> CSize -> IO (ForeignPtr IValue)
c10Tuple_at ForeignPtr (C10Ptr IVTuple)
xs) [CSize
0..(CSize
len forall a. Num a => a -> a -> a
- CSize
1)]

instance Castable [ForeignPtr IValue] (ForeignPtr (C10List IValue)) where
  cast :: forall r.
[ForeignPtr IValue]
-> (ForeignPtr (C10List IValue) -> IO r) -> IO r
cast [] ForeignPtr (C10List IValue) -> IO r
_ = forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO forall a b. (a -> b) -> a -> b
$ String -> IOError
userError String
"[ForeignPtr IValue]'s length must be one or more."
  cast [ForeignPtr IValue]
xs ForeignPtr (C10List IValue) -> IO r
f = do
    ForeignPtr (C10List IValue)
l <- ForeignPtr IValue -> IO (ForeignPtr (C10List IValue))
newC10ListIValue (forall a. [a] -> a
head [ForeignPtr IValue]
xs)
    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [ForeignPtr IValue]
xs forall a b. (a -> b) -> a -> b
$ (ForeignPtr (C10List IValue) -> ForeignPtr IValue -> IO ()
c10ListIValue_push_back ForeignPtr (C10List IValue)
l)
    ForeignPtr (C10List IValue) -> IO r
f ForeignPtr (C10List IValue)
l
  uncast :: forall r.
ForeignPtr (C10List IValue)
-> ([ForeignPtr IValue] -> IO r) -> IO r
uncast ForeignPtr (C10List IValue)
xs [ForeignPtr IValue] -> IO r
f = do
    CSize
len <- ForeignPtr (C10List IValue) -> IO CSize
c10ListIValue_size ForeignPtr (C10List IValue)
xs
    [ForeignPtr IValue] -> IO r
f forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (ForeignPtr (C10List IValue) -> CSize -> IO (ForeignPtr IValue)
c10ListIValue_at ForeignPtr (C10List IValue)
xs) [CSize
0..(CSize
len forall a. Num a => a -> a -> a
- CSize
1)]

instance Castable [(ForeignPtr IValue,ForeignPtr IValue)] (ForeignPtr (C10Dict '(IValue,IValue))) where
  cast :: forall r.
[(ForeignPtr IValue, ForeignPtr IValue)]
-> (ForeignPtr (C10Dict '(IValue, IValue)) -> IO r) -> IO r
cast [] ForeignPtr (C10Dict '(IValue, IValue)) -> IO r
_ = forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO forall a b. (a -> b) -> a -> b
$ String -> IOError
userError String
"[(ForeignPtr IValue,ForeignPtr IValue)]'s length must be one or more."
  cast [(ForeignPtr IValue, ForeignPtr IValue)]
xs ForeignPtr (C10Dict '(IValue, IValue)) -> IO r
f = do
    let (ForeignPtr IValue
k,ForeignPtr IValue
v) = (forall a. [a] -> a
head [(ForeignPtr IValue, ForeignPtr IValue)]
xs)
    ForeignPtr (C10Dict '(IValue, IValue))
l <- ForeignPtr IValue
-> ForeignPtr IValue -> IO (ForeignPtr (C10Dict '(IValue, IValue)))
newC10Dict ForeignPtr IValue
k ForeignPtr IValue
v
    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(ForeignPtr IValue, ForeignPtr IValue)]
xs forall a b. (a -> b) -> a -> b
$ \(ForeignPtr IValue
k,ForeignPtr IValue
v) -> (ForeignPtr (C10Dict '(IValue, IValue))
-> ForeignPtr IValue -> ForeignPtr IValue -> IO ()
c10Dict_insert ForeignPtr (C10Dict '(IValue, IValue))
l ForeignPtr IValue
k ForeignPtr IValue
v)
    ForeignPtr (C10Dict '(IValue, IValue)) -> IO r
f ForeignPtr (C10Dict '(IValue, IValue))
l
  uncast :: forall r.
ForeignPtr (C10Dict '(IValue, IValue))
-> ([(ForeignPtr IValue, ForeignPtr IValue)] -> IO r) -> IO r
uncast ForeignPtr (C10Dict '(IValue, IValue))
xs [(ForeignPtr IValue, ForeignPtr IValue)] -> IO r
f = [(ForeignPtr IValue, ForeignPtr IValue)] -> IO r
f forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ForeignPtr (C10Dict '(IValue, IValue))
-> IO [(ForeignPtr IValue, ForeignPtr IValue)]
c10Dict_toList ForeignPtr (C10Dict '(IValue, IValue))
xs