{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}

module Torch.Typed.VLTensor where

import Data.Proxy
import GHC.TypeLits
import qualified Torch.DType as D
import qualified Torch.Device as D
import qualified Torch.Functional as Untyped
import qualified Torch.Functional.Internal as Internal
import qualified Torch.Tensor as Untyped
import Torch.Typed.Auxiliary
import Torch.Typed.Tensor
import Unsafe.Coerce (unsafeCoerce)

-- | A variable length tensor. The length cannot be determined in advance.
data VLTensor (device :: (D.DeviceType, Nat)) (dtype :: D.DType) (shape :: [Nat]) = forall n. KnownNat n => VLTensor (Tensor device dtype (n : shape))

instance Show (VLTensor device dtype shape) where
  show :: VLTensor device dtype shape -> String
show VLTensor device dtype shape
input =
    case VLTensor device dtype shape
input of
      VLTensor Tensor device dtype (n : shape)
v -> forall a. Show a => a -> String
show Tensor device dtype (n : shape)
v

fromVLTensor ::
  forall n device dtype shape.
  ( KnownNat n,
    TensorOptions shape dtype device,
    KnownShape shape
  ) =>
  VLTensor device dtype shape ->
  Maybe (Tensor device dtype (n : shape))
fromVLTensor :: forall (n :: Nat) (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
(KnownNat n, TensorOptions shape dtype device, KnownShape shape) =>
VLTensor device dtype shape
-> Maybe (Tensor device dtype (n : shape))
fromVLTensor (VLTensor Tensor device dtype (n : shape)
input) =
  if forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]) t.
(TensorOptions shape dtype device,
 IsUnnamed t device dtype shape) =>
t -> [Int]
shape Tensor device dtype (n : shape)
input forall a. Eq a => a -> a -> Bool
== forall (shape :: [Nat]). KnownShape shape => [Int]
shapeVal @(n : shape)
    then forall a. a -> Maybe a
Just (forall a b. a -> b
unsafeCoerce Tensor device dtype (n : shape)
input)
    else forall a. Maybe a
Nothing

selectIndexes :: forall n device dtype shape. Tensor device dtype (n : shape) -> Tensor device 'D.Bool '[n] -> VLTensor device dtype shape
selectIndexes :: forall (n :: Nat) (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor device dtype (n : shape)
-> Tensor device 'Bool '[n] -> VLTensor device dtype shape
selectIndexes Tensor device dtype (n : shape)
input Tensor device 'Bool '[n]
boolTensor =
  let output :: Tensor
output = forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype (n : shape)
input forall a. TensorIndex a => Tensor -> a -> Tensor
Untyped.! forall t. Unnamed t => t -> Tensor
toDynamic Tensor device 'Bool '[n]
boolTensor
   in forall r.
Int -> (forall (n :: Nat). KnownNat n => Proxy n -> r) -> r
withNat (forall a. [a] -> a
head forall a b. (a -> b) -> a -> b
$ Tensor -> [Int]
Untyped.shape Tensor
output) forall a b. (a -> b) -> a -> b
$ \(Proxy n
Proxy :: Proxy b) ->
        forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]) (n :: Nat).
KnownNat n =>
Tensor device dtype (n : shape) -> VLTensor device dtype shape
VLTensor forall a b. (a -> b) -> a -> b
$ forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor @device @dtype @(b : shape) Tensor
output

pack :: forall device dtype shape. [Tensor device dtype shape] -> VLTensor device dtype shape
pack :: forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
[Tensor device dtype shape] -> VLTensor device dtype shape
pack [Tensor device dtype shape]
input =
  let output :: Tensor
output = Dim -> [Tensor] -> Tensor
Untyped.stack (Int -> Dim
Untyped.Dim Int
0) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall t. Unnamed t => t -> Tensor
toDynamic [Tensor device dtype shape]
input
   in forall r.
Int -> (forall (n :: Nat). KnownNat n => Proxy n -> r) -> r
withNat (forall a. [a] -> a
head forall a b. (a -> b) -> a -> b
$ Tensor -> [Int]
Untyped.shape Tensor
output) forall a b. (a -> b) -> a -> b
$ \(Proxy n
Proxy :: Proxy n) ->
        forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]) (n :: Nat).
KnownNat n =>
Tensor device dtype (n : shape) -> VLTensor device dtype shape
VLTensor forall a b. (a -> b) -> a -> b
$ forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor @device @dtype @(n : shape) Tensor
output

unpack :: forall device dtype shape. VLTensor device dtype shape -> [Tensor device dtype shape]
unpack :: forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
VLTensor device dtype shape -> [Tensor device dtype shape]
unpack VLTensor device dtype shape
input =
  case VLTensor device dtype shape
input of
    (VLTensor (Tensor device dtype (n : shape)
input' :: Tensor device dtype (n : shape))) ->
      let output :: [Tensor]
output = Tensor -> Int -> [Tensor]
Internal.unbind (forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype (n : shape)
input') Int
0
       in forall a b. (a -> b) -> [a] -> [b]
map (forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor @device @dtype @shape) [Tensor]
output