{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE UndecidableSuperClasses #-}
{-# LANGUAGE NoStarIsType #-}
module Torch.Typed.Factories where
import Control.Arrow ((&&&))
import Data.Default.Class
import Data.Finite
import Data.Kind (Constraint)
import Data.Proxy
import Data.Reflection
import GHC.TypeLits
import System.IO.Unsafe
import qualified Torch.DType as D
import qualified Torch.Device as D
import qualified Torch.Functional as D
import Torch.Internal.Cast
import qualified Torch.Scalar as D
import qualified Torch.Tensor as D
import qualified Torch.TensorFactories as D
import qualified Torch.TensorOptions as D
import Torch.Typed.Auxiliary
import Torch.Typed.Tensor
import Prelude hiding (sin)
instance (TensorOptions shape dtype device) => Default (Tensor device dtype shape) where
def :: Tensor device dtype shape
def = forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
Tensor device dtype shape
zeros
instance (TensorOptions shape' dtype device, shape' ~ ToNats shape) => Default (NamedTensor device dtype shape) where
def :: NamedTensor device dtype shape
def = forall t (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
fromUnnamed forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
Tensor device dtype shape
zeros
zeros ::
forall shape dtype device.
(TensorOptions shape dtype device) =>
Tensor device dtype shape
zeros :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
Tensor device dtype shape
zeros =
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$
[Int] -> TensorOptions -> Tensor
D.zeros
(forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
[Int]
optionsRuntimeShape @shape @dtype @device)
( Device -> TensorOptions -> TensorOptions
D.withDevice (forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
Device
optionsRuntimeDevice @shape @dtype @device)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DType -> TensorOptions -> TensorOptions
D.withDType (forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
DType
optionsRuntimeDType @shape @dtype @device)
forall a b. (a -> b) -> a -> b
$ TensorOptions
D.defaultOpts
)
full ::
forall shape dtype device a.
(TensorOptions shape dtype device, D.Scalar a) =>
a ->
Tensor device dtype shape
full :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)) a.
(TensorOptions shape dtype device, Scalar a) =>
a -> Tensor device dtype shape
full a
value =
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$
forall a. Scalar a => [Int] -> a -> TensorOptions -> Tensor
D.full
(forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
[Int]
optionsRuntimeShape @shape @dtype @device)
a
value
( Device -> TensorOptions -> TensorOptions
D.withDevice (forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
Device
optionsRuntimeDevice @shape @dtype @device)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DType -> TensorOptions -> TensorOptions
D.withDType (forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
DType
optionsRuntimeDType @shape @dtype @device)
forall a b. (a -> b) -> a -> b
$ TensorOptions
D.defaultOpts
)
ones ::
forall shape dtype device.
(TensorOptions shape dtype device) =>
Tensor device dtype shape
ones :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
Tensor device dtype shape
ones =
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$
[Int] -> TensorOptions -> Tensor
D.ones
(forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
[Int]
optionsRuntimeShape @shape @dtype @device)
( Device -> TensorOptions -> TensorOptions
D.withDevice (forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
Device
optionsRuntimeDevice @shape @dtype @device)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DType -> TensorOptions -> TensorOptions
D.withDType (forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
DType
optionsRuntimeDType @shape @dtype @device)
forall a b. (a -> b) -> a -> b
$ TensorOptions
D.defaultOpts
)
type family RandDTypeIsValid (device :: (D.DeviceType, Nat)) (dtype :: D.DType) :: Constraint where
RandDTypeIsValid '( 'D.CPU, 0) dtype =
( DTypeIsNotBool '( 'D.CPU, 0) dtype,
DTypeIsNotHalf '( 'D.CPU, 0) dtype
)
RandDTypeIsValid '( 'D.CUDA, _) dtype = ()
RandDTypeIsValid '(deviceType, _) dtype = UnsupportedDTypeForDevice deviceType dtype
rand ::
forall shape dtype device.
( TensorOptions shape dtype device,
RandDTypeIsValid device dtype
) =>
IO (Tensor device dtype shape)
rand :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(TensorOptions shape dtype device,
RandDTypeIsValid device dtype) =>
IO (Tensor device dtype shape)
rand =
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [Int] -> TensorOptions -> IO Tensor
D.randIO
(forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
[Int]
optionsRuntimeShape @shape @dtype @device)
( Device -> TensorOptions -> TensorOptions
D.withDevice (forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
Device
optionsRuntimeDevice @shape @dtype @device)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DType -> TensorOptions -> TensorOptions
D.withDType (forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
DType
optionsRuntimeDType @shape @dtype @device)
forall a b. (a -> b) -> a -> b
$ TensorOptions
D.defaultOpts
)
randn ::
forall shape dtype device.
( TensorOptions shape dtype device,
RandDTypeIsValid device dtype
) =>
IO (Tensor device dtype shape)
randn :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(TensorOptions shape dtype device,
RandDTypeIsValid device dtype) =>
IO (Tensor device dtype shape)
randn =
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [Int] -> TensorOptions -> IO Tensor
D.randnIO
(forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
[Int]
optionsRuntimeShape @shape @dtype @device)
( Device -> TensorOptions -> TensorOptions
D.withDevice (forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
Device
optionsRuntimeDevice @shape @dtype @device)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DType -> TensorOptions -> TensorOptions
D.withDType (forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
DType
optionsRuntimeDType @shape @dtype @device)
forall a b. (a -> b) -> a -> b
$ TensorOptions
D.defaultOpts
)
randint ::
forall shape dtype device.
( TensorOptions shape dtype device,
RandDTypeIsValid device dtype
) =>
Int ->
Int ->
IO (Tensor device dtype shape)
randint :: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(TensorOptions shape dtype device,
RandDTypeIsValid device dtype) =>
Int -> Int -> IO (Tensor device dtype shape)
randint Int
low Int
high =
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> (Int -> Int -> [Int] -> TensorOptions -> IO Tensor
D.randintIO Int
low Int
high)
(forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
[Int]
optionsRuntimeShape @shape @dtype @device)
( Device -> TensorOptions -> TensorOptions
D.withDevice (forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
Device
optionsRuntimeDevice @shape @dtype @device)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DType -> TensorOptions -> TensorOptions
D.withDType (forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
DType
optionsRuntimeDType @shape @dtype @device)
forall a b. (a -> b) -> a -> b
$ TensorOptions
D.defaultOpts
)
linspace ::
forall steps device start end.
( D.Scalar start,
D.Scalar end,
KnownNat steps,
TensorOptions '[steps] 'D.Float device
) =>
start ->
end ->
Tensor device 'D.Float '[steps]
linspace :: forall (steps :: Nat) (device :: (DeviceType, Nat)) start end.
(Scalar start, Scalar end, KnownNat steps,
TensorOptions '[steps] 'Float device) =>
start -> end -> Tensor device 'Float '[steps]
linspace start
start end
end =
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$
forall a b.
(Scalar a, Scalar b) =>
a -> b -> Int -> TensorOptions -> Tensor
D.linspace
start
start
end
end
(forall (n :: Nat). KnownNat n => Int
natValI @steps)
( Device -> TensorOptions -> TensorOptions
D.withDevice (forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
Device
optionsRuntimeDevice @'[steps] @D.Float @device)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DType -> TensorOptions -> TensorOptions
D.withDType (forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
DType
optionsRuntimeDType @'[steps] @D.Float @device)
forall a b. (a -> b) -> a -> b
$ TensorOptions
D.defaultOpts
)
eyeSquare ::
forall n dtype device.
( KnownNat n,
TensorOptions '[n, n] dtype device
) =>
Tensor device dtype '[n, n]
eyeSquare :: forall (n :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat n, TensorOptions '[n, n] dtype device) =>
Tensor device dtype '[n, n]
eyeSquare =
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$
Int -> TensorOptions -> Tensor
D.eyeSquare
(forall (n :: Nat). KnownNat n => Int
natValI @n)
( Device -> TensorOptions -> TensorOptions
D.withDevice (forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
Device
optionsRuntimeDevice @'[n, n] @dtype @device)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DType -> TensorOptions -> TensorOptions
D.withDType (forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
DType
optionsRuntimeDType @'[n, n] @dtype @device)
forall a b. (a -> b) -> a -> b
$ TensorOptions
D.defaultOpts
)