{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Torch.Dimname where

import Data.String
import Foreign.ForeignPtr
import System.IO.Unsafe
import Torch.Internal.Class (Castable (..))
import qualified Torch.Internal.Const as ATen
import qualified Torch.Internal.Managed.Type.Dimname as ATen
import qualified Torch.Internal.Managed.Type.StdString as ATen
import qualified Torch.Internal.Managed.Type.Symbol as ATen
import qualified Torch.Internal.Type as ATen

newtype Dimname = Dimname (ForeignPtr ATen.Dimname)

instance IsString Dimname where
  fromString :: String -> Dimname
fromString String
str = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
    ForeignPtr StdString
str' <- String -> IO (ForeignPtr StdString)
ATen.newStdString_s String
str
    ForeignPtr Symbol
symbol <- ForeignPtr StdString -> IO (ForeignPtr Symbol)
ATen.dimname_s ForeignPtr StdString
str'
    ForeignPtr Dimname
dimname <- ForeignPtr Symbol -> IO (ForeignPtr Dimname)
ATen.fromSymbol_s ForeignPtr Symbol
symbol
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ ForeignPtr Dimname -> Dimname
Dimname ForeignPtr Dimname
dimname

instance Castable Dimname (ForeignPtr ATen.Dimname) where
  cast :: forall r. Dimname -> (ForeignPtr Dimname -> IO r) -> IO r
cast (Dimname ForeignPtr Dimname
dname) ForeignPtr Dimname -> IO r
f = ForeignPtr Dimname -> IO r
f ForeignPtr Dimname
dname
  uncast :: forall r. ForeignPtr Dimname -> (Dimname -> IO r) -> IO r
uncast ForeignPtr Dimname
dname Dimname -> IO r
f = Dimname -> IO r
f forall a b. (a -> b) -> a -> b
$ ForeignPtr Dimname -> Dimname
Dimname ForeignPtr Dimname
dname

instance Castable [Dimname] (ForeignPtr ATen.DimnameList) where
  cast :: forall r. [Dimname] -> (ForeignPtr DimnameList -> IO r) -> IO r
cast [Dimname]
xs ForeignPtr DimnameList -> IO r
f = do
    [ForeignPtr Dimname]
ptr_list <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\Dimname
x -> forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast Dimname
x forall (m :: * -> *) a. Monad m => a -> m a
return :: IO (ForeignPtr ATen.Dimname)) [Dimname]
xs
    forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast (forall a b. (a -> b) -> [a] -> [b]
map ForeignPtr Dimname -> Dimname
Dimname [ForeignPtr Dimname]
ptr_list) ForeignPtr DimnameList -> IO r
f
  uncast :: forall r. ForeignPtr DimnameList -> ([Dimname] -> IO r) -> IO r
uncast ForeignPtr DimnameList
xs [Dimname] -> IO r
f = forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr DimnameList
xs forall a b. (a -> b) -> a -> b
$ \[Dimname]
ptr_list -> do
    [Dimname]
dname_list <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((\(ForeignPtr Dimname
x :: ForeignPtr ATen.Dimname) -> forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr Dimname
x forall (m :: * -> *) a. Monad m => a -> m a
return) forall b c a. (b -> c) -> (a -> b) -> a -> c
. (\(Dimname ForeignPtr Dimname
dname) -> ForeignPtr Dimname
dname)) [Dimname]
ptr_list
    [Dimname] -> IO r
f [Dimname]
dname_list