{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RecordWildCards #-}

module Torch.Random
  ( mkGenerator,
    Generator,
    randn,
    randn',
    rand,
    rand',
    randint,
    randint',
    normal,
    normal',
  )
where

import Control.Concurrent
import Control.Concurrent.STM
import Control.Monad.IO.Class
import Control.Monad.STM
import Data.Int
import Data.Word
import Foreign.ForeignPtr
import System.IO.Unsafe
import Torch.Device
import Torch.Internal.Cast
import Torch.Internal.Class (Castable (..))
import qualified Torch.Internal.Const as ATen
import qualified Torch.Internal.Managed.TensorFactories as LibTorch
import qualified Torch.Internal.Managed.Type.Generator as ATen
import qualified Torch.Internal.Type as ATen
import Torch.Tensor
import Torch.TensorOptions

instance Show (TVar (Either (Word64, Device) (ForeignPtr ATen.Generator))) where
  show :: TVar (Either (Word64, Device) (ForeignPtr Generator)) -> String
show TVar (Either (Word64, Device) (ForeignPtr Generator))
_ = String
"_"

newtype Generator = UnsafeGenerator
  { Generator -> TVar (Either (Word64, Device) (ForeignPtr Generator))
unGenerator :: TVar (Either (Word64, Device) (ForeignPtr ATen.Generator))
  }
  deriving (Generator -> Generator -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Generator -> Generator -> Bool
$c/= :: Generator -> Generator -> Bool
== :: Generator -> Generator -> Bool
$c== :: Generator -> Generator -> Bool
Eq, Int -> Generator -> ShowS
[Generator] -> ShowS
Generator -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Generator] -> ShowS
$cshowList :: [Generator] -> ShowS
show :: Generator -> String
$cshow :: Generator -> String
showsPrec :: Int -> Generator -> ShowS
$cshowsPrec :: Int -> Generator -> ShowS
Show)

mkGenerator :: Device -> Word64 -> IO Generator
mkGenerator :: Device -> Word64 -> IO Generator
mkGenerator Device
device Word64
seed =
  case Device
device of
    Device DeviceType
CPU Int16
_ -> do
      ForeignPtr Generator
genPtr <- Word64 -> IO (ForeignPtr Generator)
ATen.newCPUGenerator Word64
seed
      TVar (Either (Word64, Device) (ForeignPtr Generator))
genenerator <- forall a. a -> IO (TVar a)
newTVarIO (forall a b. b -> Either a b
Right ForeignPtr Generator
genPtr)
      forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ TVar (Either (Word64, Device) (ForeignPtr Generator)) -> Generator
UnsafeGenerator TVar (Either (Word64, Device) (ForeignPtr Generator))
genenerator
    Device DeviceType
CUDA Int16
idx -> do
      ForeignPtr Generator
genPtr <- Word16 -> IO (ForeignPtr Generator)
ATen.newCUDAGenerator (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int16
idx)
      ForeignPtr Generator -> Word64 -> IO ()
ATen.generator_set_current_seed ForeignPtr Generator
genPtr Word64
seed
      TVar (Either (Word64, Device) (ForeignPtr Generator))
genenerator <- forall a. a -> IO (TVar a)
newTVarIO (forall a b. b -> Either a b
Right ForeignPtr Generator
genPtr)
      forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ TVar (Either (Word64, Device) (ForeignPtr Generator)) -> Generator
UnsafeGenerator TVar (Either (Word64, Device) (ForeignPtr Generator))
genenerator

type RandomGenFunc = ForeignPtr ATen.IntArray -> ForeignPtr ATen.Generator -> ForeignPtr ATen.TensorOptions -> IO (ForeignPtr ATen.Tensor)

generatorFactory :: RandomGenFunc -> [Int] -> TensorOptions -> Generator -> (Tensor, Generator)
generatorFactory :: RandomGenFunc
-> [Int] -> TensorOptions -> Generator -> (Tensor, Generator)
generatorFactory RandomGenFunc
func [Int]
size TensorOptions
options (UnsafeGenerator TVar (Either (Word64, Device) (ForeignPtr Generator))
generator) =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
    Either (Word64, Device) (ForeignPtr Generator)
mGenerator <- forall a. STM a -> IO a
atomically forall a b. (a -> b) -> a -> b
$ do
      Either (Word64, Device) (ForeignPtr Generator)
v <- forall a. TVar a -> STM a
readTVar TVar (Either (Word64, Device) (ForeignPtr Generator))
generator
      case Either (Word64, Device) (ForeignPtr Generator)
v of
        Right ForeignPtr Generator
v' -> do
          let device :: Device
device =
                if ForeignPtr Generator -> Bool
generatorIsCuda ForeignPtr Generator
v'
                  then Device {deviceType :: DeviceType
deviceType = DeviceType
CUDA, deviceIndex :: Int16
deviceIndex = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ ForeignPtr Generator -> Int
generatorDevice ForeignPtr Generator
v'}
                  else Device {deviceType :: DeviceType
deviceType = DeviceType
CPU, deviceIndex :: Int16
deviceIndex = Int16
0}
              seed :: Word64
seed = ForeignPtr Generator -> Word64
generatorSeed ForeignPtr Generator
v'
          forall a. TVar a -> a -> STM ()
writeTVar TVar (Either (Word64, Device) (ForeignPtr Generator))
generator forall a b. (a -> b) -> a -> b
$ Word64
seed seq :: forall a b. a -> b -> b
`seq` Device -> DeviceType
deviceType Device
device seq :: forall a b. a -> b -> b
`seq` Device -> Int16
deviceIndex Device
device seq :: forall a b. a -> b -> b
`seq` forall a b. a -> Either a b
Left (Word64
seed, Device
device)
          forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a b. b -> Either a b
Right ForeignPtr Generator
v'
        Left (Word64, Device)
v -> forall (m :: * -> *) a. Monad m => a -> m a
return (forall a b. a -> Either a b
Left (Word64, Device)
v)
    ForeignPtr Generator
genPtr <- case Either (Word64, Device) (ForeignPtr Generator)
mGenerator of
      Right ForeignPtr Generator
gen -> forall (m :: * -> *) a. Monad m => a -> m a
return ForeignPtr Generator
gen
      Left (Word64
seed, Device
device) -> case Device
device of
        Device DeviceType
CPU Int16
_ -> Word64 -> IO (ForeignPtr Generator)
ATen.newCPUGenerator Word64
seed
        Device DeviceType
CUDA Int16
idx -> do
          ForeignPtr Generator
gen <- Word16 -> IO (ForeignPtr Generator)
ATen.newCUDAGenerator (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int16
idx)
          ForeignPtr Generator -> Word64 -> IO ()
ATen.generator_set_current_seed ForeignPtr Generator
gen Word64
seed
          forall (m :: * -> *) a. Monad m => a -> m a
return ForeignPtr Generator
gen
    Tensor
tensor <- forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 RandomGenFunc
func [Int]
size ForeignPtr Generator
genPtr TensorOptions
options
    TVar (Either (Word64, Device) (ForeignPtr Generator))
nextGenenerator <- forall a. a -> IO (TVar a)
newTVarIO (forall a b. b -> Either a b
Right ForeignPtr Generator
genPtr)
    forall (m :: * -> *) a. Monad m => a -> m a
return (Tensor
tensor, TVar (Either (Word64, Device) (ForeignPtr Generator)) -> Generator
UnsafeGenerator TVar (Either (Word64, Device) (ForeignPtr Generator))
nextGenenerator)
  where
    generatorIsCpu :: ForeignPtr ATen.Generator -> Bool
    generatorIsCpu :: ForeignPtr Generator -> Bool
generatorIsCpu ForeignPtr Generator
gen = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Generator -> IO CBool
ATen.generator_is_cpu ForeignPtr Generator
gen

    generatorIsCuda :: ForeignPtr ATen.Generator -> Bool
    generatorIsCuda :: ForeignPtr Generator -> Bool
generatorIsCuda ForeignPtr Generator
gen = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Generator -> IO CBool
ATen.generator_is_cuda ForeignPtr Generator
gen

    generatorDevice :: ForeignPtr ATen.Generator -> Int
    generatorDevice :: ForeignPtr Generator -> Int
generatorDevice ForeignPtr Generator
gen = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Generator -> IO Int64
ATen.generator_get_device ForeignPtr Generator
gen

    generatorSeed :: ForeignPtr ATen.Generator -> Word64
    generatorSeed :: ForeignPtr Generator -> Word64
generatorSeed ForeignPtr Generator
gen = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Generator -> IO Word64
ATen.generator_current_seed ForeignPtr Generator
gen

randn ::
  -- | size
  [Int] ->
  -- | options
  TensorOptions ->
  -- | generator
  Generator ->
  -- | output
  (Tensor, Generator)
randn :: [Int] -> TensorOptions -> Generator -> (Tensor, Generator)
randn = RandomGenFunc
-> [Int] -> TensorOptions -> Generator -> (Tensor, Generator)
generatorFactory RandomGenFunc
LibTorch.randn_lGo

randn' ::
  -- | size
  [Int] ->
  -- | generator
  Generator ->
  -- | output
  (Tensor, Generator)
randn' :: [Int] -> Generator -> (Tensor, Generator)
randn' [Int]
size = [Int] -> TensorOptions -> Generator -> (Tensor, Generator)
randn [Int]
size TensorOptions
defaultOpts

rand ::
  -- | size
  [Int] ->
  -- | options
  TensorOptions ->
  -- | generator
  Generator ->
  -- | output
  (Tensor, Generator)
rand :: [Int] -> TensorOptions -> Generator -> (Tensor, Generator)
rand = RandomGenFunc
-> [Int] -> TensorOptions -> Generator -> (Tensor, Generator)
generatorFactory RandomGenFunc
LibTorch.rand_lGo

rand' ::
  -- | size
  [Int] ->
  -- | generator
  Generator ->
  -- | output
  (Tensor, Generator)
rand' :: [Int] -> Generator -> (Tensor, Generator)
rand' [Int]
size = [Int] -> TensorOptions -> Generator -> (Tensor, Generator)
rand [Int]
size TensorOptions
defaultOpts

randint ::
  -- | low
  Int ->
  -- | high
  Int ->
  -- | size
  [Int] ->
  -- | options
  TensorOptions ->
  -- | generator
  Generator ->
  -- | output
  (Tensor, Generator)
randint :: Int
-> Int
-> [Int]
-> TensorOptions
-> Generator
-> (Tensor, Generator)
randint Int
low Int
high = RandomGenFunc
-> [Int] -> TensorOptions -> Generator -> (Tensor, Generator)
generatorFactory (Int64 -> Int64 -> RandomGenFunc
LibTorch.randint_lllGo (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
low) (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
high))

randint' ::
  -- | low
  Int ->
  -- | high
  Int ->
  -- | size
  [Int] ->
  -- | generator
  Generator ->
  -- | output
  (Tensor, Generator)
randint' :: Int -> Int -> [Int] -> Generator -> (Tensor, Generator)
randint' Int
low Int
high [Int]
size = Int
-> Int
-> [Int]
-> TensorOptions
-> Generator
-> (Tensor, Generator)
randint Int
low Int
high [Int]
size TensorOptions
defaultOpts

normal ::
  -- | mean
  Double ->
  -- | std
  Double ->
  -- | size
  [Int] ->
  -- | options
  TensorOptions ->
  -- | generator
  Generator ->
  -- | output
  (Tensor, Generator)
normal :: Double
-> Double
-> [Int]
-> TensorOptions
-> Generator
-> (Tensor, Generator)
normal Double
mean Double
std = RandomGenFunc
-> [Int] -> TensorOptions -> Generator -> (Tensor, Generator)
generatorFactory (CDouble -> CDouble -> RandomGenFunc
LibTorch.normal_ddlGo (forall a b. (Real a, Fractional b) => a -> b
realToFrac Double
mean) (forall a b. (Real a, Fractional b) => a -> b
realToFrac Double
std))

normal' ::
  -- | mean
  Double ->
  -- | std
  Double ->
  -- | size
  [Int] ->
  -- | generator
  Generator ->
  -- | output
  (Tensor, Generator)
normal' :: Double -> Double -> [Int] -> Generator -> (Tensor, Generator)
normal' Double
mean Double
std [Int]
size = Double
-> Double
-> [Int]
-> TensorOptions
-> Generator
-> (Tensor, Generator)
normal Double
mean Double
std [Int]
size TensorOptions
defaultOpts