{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE UndecidableSuperClasses #-}
{-# LANGUAGE NoStarIsType #-}

module Torch.Typed.Factories where

import Control.Arrow ((&&&))
import Data.Default.Class
import Data.Finite
import Data.Kind (Constraint)
import Data.Proxy
import Data.Reflection
import GHC.TypeLits
import System.IO.Unsafe
import qualified Torch.DType as D
import qualified Torch.Device as D
import qualified Torch.Functional as D
import Torch.Internal.Cast
import qualified Torch.Scalar as D
import qualified Torch.Tensor as D
import qualified Torch.TensorFactories as D
import qualified Torch.TensorOptions as D
import Torch.Typed.Auxiliary
import Torch.Typed.Tensor
import Prelude hiding (sin)

instance (TensorOptions shape dtype device) => Default (Tensor device dtype shape) where
  def :: Tensor device dtype shape
def = forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
Tensor device dtype shape
zeros

instance (TensorOptions shape' dtype device, shape' ~ ToNats shape) => Default (NamedTensor device dtype shape) where
  def :: NamedTensor device dtype shape
def = forall t (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
fromUnnamed forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
Tensor device dtype shape
zeros

zeros ::
  forall shape dtype device.
  (TensorOptions shape dtype device) =>
  Tensor device dtype shape
zeros :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
Tensor device dtype shape
zeros =
  forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$
    [Int] -> TensorOptions -> Tensor
D.zeros
      (forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
[Int]
optionsRuntimeShape @shape @dtype @device)
      ( Device -> TensorOptions -> TensorOptions
D.withDevice (forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
Device
optionsRuntimeDevice @shape @dtype @device)
          forall b c a. (b -> c) -> (a -> b) -> a -> c
. DType -> TensorOptions -> TensorOptions
D.withDType (forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
DType
optionsRuntimeDType @shape @dtype @device)
          forall a b. (a -> b) -> a -> b
$ TensorOptions
D.defaultOpts
      )

full ::
  forall shape dtype device a.
  (TensorOptions shape dtype device, D.Scalar a) =>
  a ->
  Tensor device dtype shape
full :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)) a.
(TensorOptions shape dtype device, Scalar a) =>
a -> Tensor device dtype shape
full a
value =
  forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$
    forall a. Scalar a => [Int] -> a -> TensorOptions -> Tensor
D.full
      (forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
[Int]
optionsRuntimeShape @shape @dtype @device)
      a
value
      ( Device -> TensorOptions -> TensorOptions
D.withDevice (forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
Device
optionsRuntimeDevice @shape @dtype @device)
          forall b c a. (b -> c) -> (a -> b) -> a -> c
. DType -> TensorOptions -> TensorOptions
D.withDType (forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
DType
optionsRuntimeDType @shape @dtype @device)
          forall a b. (a -> b) -> a -> b
$ TensorOptions
D.defaultOpts
      )

ones ::
  forall shape dtype device.
  (TensorOptions shape dtype device) =>
  Tensor device dtype shape
ones :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
Tensor device dtype shape
ones =
  forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$
    [Int] -> TensorOptions -> Tensor
D.ones
      (forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
[Int]
optionsRuntimeShape @shape @dtype @device)
      ( Device -> TensorOptions -> TensorOptions
D.withDevice (forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
Device
optionsRuntimeDevice @shape @dtype @device)
          forall b c a. (b -> c) -> (a -> b) -> a -> c
. DType -> TensorOptions -> TensorOptions
D.withDType (forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
DType
optionsRuntimeDType @shape @dtype @device)
          forall a b. (a -> b) -> a -> b
$ TensorOptions
D.defaultOpts
      )

type family RandDTypeIsValid (device :: (D.DeviceType, Nat)) (dtype :: D.DType) :: Constraint where
  RandDTypeIsValid '( 'D.CPU, 0) dtype =
    ( DTypeIsNotBool '( 'D.CPU, 0) dtype,
      DTypeIsNotHalf '( 'D.CPU, 0) dtype
    )
  RandDTypeIsValid '( 'D.CUDA, _) dtype = ()
  RandDTypeIsValid '(deviceType, _) dtype = UnsupportedDTypeForDevice deviceType dtype

rand ::
  forall shape dtype device.
  ( TensorOptions shape dtype device,
    RandDTypeIsValid device dtype
  ) =>
  IO (Tensor device dtype shape)
rand :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(TensorOptions shape dtype device,
 RandDTypeIsValid device dtype) =>
IO (Tensor device dtype shape)
rand =
  forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor
    forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [Int] -> TensorOptions -> IO Tensor
D.randIO
      (forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
[Int]
optionsRuntimeShape @shape @dtype @device)
      ( Device -> TensorOptions -> TensorOptions
D.withDevice (forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
Device
optionsRuntimeDevice @shape @dtype @device)
          forall b c a. (b -> c) -> (a -> b) -> a -> c
. DType -> TensorOptions -> TensorOptions
D.withDType (forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
DType
optionsRuntimeDType @shape @dtype @device)
          forall a b. (a -> b) -> a -> b
$ TensorOptions
D.defaultOpts
      )

randn ::
  forall shape dtype device.
  ( TensorOptions shape dtype device,
    RandDTypeIsValid device dtype
  ) =>
  IO (Tensor device dtype shape)
randn :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(TensorOptions shape dtype device,
 RandDTypeIsValid device dtype) =>
IO (Tensor device dtype shape)
randn =
  forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor
    forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [Int] -> TensorOptions -> IO Tensor
D.randnIO
      (forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
[Int]
optionsRuntimeShape @shape @dtype @device)
      ( Device -> TensorOptions -> TensorOptions
D.withDevice (forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
Device
optionsRuntimeDevice @shape @dtype @device)
          forall b c a. (b -> c) -> (a -> b) -> a -> c
. DType -> TensorOptions -> TensorOptions
D.withDType (forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
DType
optionsRuntimeDType @shape @dtype @device)
          forall a b. (a -> b) -> a -> b
$ TensorOptions
D.defaultOpts
      )

randint ::
  forall shape dtype device.
  ( TensorOptions shape dtype device,
    RandDTypeIsValid device dtype
  ) =>
  Int ->
  Int ->
  IO (Tensor device dtype shape)
randint :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(TensorOptions shape dtype device,
 RandDTypeIsValid device dtype) =>
Int -> Int -> IO (Tensor device dtype shape)
randint Int
low Int
high =
  forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor
    forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> (Int -> Int -> [Int] -> TensorOptions -> IO Tensor
D.randintIO Int
low Int
high)
      (forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
[Int]
optionsRuntimeShape @shape @dtype @device)
      ( Device -> TensorOptions -> TensorOptions
D.withDevice (forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
Device
optionsRuntimeDevice @shape @dtype @device)
          forall b c a. (b -> c) -> (a -> b) -> a -> c
. DType -> TensorOptions -> TensorOptions
D.withDType (forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
DType
optionsRuntimeDType @shape @dtype @device)
          forall a b. (a -> b) -> a -> b
$ TensorOptions
D.defaultOpts
      )

-- | linspace
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: [Float]) $ linspace @7 @'( 'D.CPU, 0) 0 3
-- (Float,([7],[0.0,0.5,1.0,1.5,2.0,2.5,3.0]))
-- >>> dtype &&& shape &&& (\t' -> D.asValue (toDynamic t') :: [Float]) $ linspace @3 @'( 'D.CPU, 0) 0 2
-- (Float,([3],[0.0,1.0,2.0]))
linspace ::
  forall steps device start end.
  ( D.Scalar start,
    D.Scalar end,
    KnownNat steps,
    TensorOptions '[steps] 'D.Float device
  ) =>
  -- | start
  start ->
  -- | end
  end ->
  -- | output
  Tensor device 'D.Float '[steps]
linspace :: forall (steps :: Nat) (device :: (DeviceType, Nat)) start end.
(Scalar start, Scalar end, KnownNat steps,
 TensorOptions '[steps] 'Float device) =>
start -> end -> Tensor device 'Float '[steps]
linspace start
start end
end =
  forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$
    forall a b.
(Scalar a, Scalar b) =>
a -> b -> Int -> TensorOptions -> Tensor
D.linspace
      start
start
      end
end
      (forall (n :: Nat). KnownNat n => Int
natValI @steps)
      ( Device -> TensorOptions -> TensorOptions
D.withDevice (forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
Device
optionsRuntimeDevice @'[steps] @D.Float @device)
          forall b c a. (b -> c) -> (a -> b) -> a -> c
. DType -> TensorOptions -> TensorOptions
D.withDType (forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
DType
optionsRuntimeDType @'[steps] @D.Float @device)
          forall a b. (a -> b) -> a -> b
$ TensorOptions
D.defaultOpts
      )

eyeSquare ::
  forall n dtype device.
  ( KnownNat n,
    TensorOptions '[n, n] dtype device
  ) =>
  -- | output
  Tensor device dtype '[n, n]
eyeSquare :: forall (n :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat n, TensorOptions '[n, n] dtype device) =>
Tensor device dtype '[n, n]
eyeSquare =
  forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor forall a b. (a -> b) -> a -> b
$
    Int -> TensorOptions -> Tensor
D.eyeSquare
      (forall (n :: Nat). KnownNat n => Int
natValI @n)
      ( Device -> TensorOptions -> TensorOptions
D.withDevice (forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
Device
optionsRuntimeDevice @'[n, n] @dtype @device)
          forall b c a. (b -> c) -> (a -> b) -> a -> c
. DType -> TensorOptions -> TensorOptions
D.withDType (forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
DType
optionsRuntimeDType @'[n, n] @dtype @device)
          forall a b. (a -> b) -> a -> b
$ TensorOptions
D.defaultOpts
      )