{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module Torch.Typed.DType where
import Data.Kind (Type)
import GHC.Generics
import GHC.TypeLits
import GHC.TypeLits.Extra
import System.IO.Unsafe
import qualified Torch.DType as D
import qualified Torch.Device as D
import Torch.HList
import qualified Torch.Internal.Cast as ATen
import qualified Torch.Internal.Class as ATen
import qualified Torch.Internal.Managed.Autograd as LibTorch
import qualified Torch.Tensor as D
import Torch.Typed.Parameter
import Torch.Typed.Tensor
class HasToDType dtype' dtype f g | dtype' dtype f -> g, dtype' dtype g -> f where
toDType :: f -> g
type family ReplaceDType (f :: k) (dtype' :: D.DType) (dtype :: D.DType) :: k where
ReplaceDType (t dtype) dtype' dtype = t dtype'
ReplaceDType (t a) dtype' dtype = (ReplaceDType t dtype' dtype) (ReplaceDType a dtype' dtype)
ReplaceDType t _ _ = t
type family ReplaceDType' (f :: k) (dtype' :: D.DType) :: k where
ReplaceDType' (t (dtype :: D.DType)) dtype' = t dtype'
ReplaceDType' (t a) dtype' = (ReplaceDType' t dtype') (ReplaceDType' a dtype')
ReplaceDType' t _ = t
instance
( g ~ ReplaceDType f dtype' dtype,
f ~ ReplaceDType g dtype dtype',
Generic f,
Generic g,
GHasToDType dtype' dtype (Rep f) (Rep g)
) =>
HasToDType dtype' dtype f g
where
toDType :: f -> g
toDType = forall a x. Generic a => Rep a x -> a
to forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (dtype' :: DType) (dtype :: DType) (f :: * -> *)
(g :: * -> *) a.
GHasToDType dtype' dtype f g =>
f a -> g a
gToDType @dtype' @dtype forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a x. Generic a => a -> Rep a x
from
class
GHasToDType
(dtype' :: D.DType)
(dtype :: D.DType)
(f :: Type -> Type)
(g :: Type -> Type)
where
gToDType :: forall a. f a -> g a
instance
( GHasToDType dtype' dtype l l',
GHasToDType dtype' dtype r r'
) =>
GHasToDType dtype' dtype (l :*: r) (l' :*: r')
where
gToDType :: forall a. (:*:) l r a -> (:*:) l' r' a
gToDType (l a
l :*: r a
r) =
let l' :: l' a
l' = forall (dtype' :: DType) (dtype :: DType) (f :: * -> *)
(g :: * -> *) a.
GHasToDType dtype' dtype f g =>
f a -> g a
gToDType @dtype' @dtype l a
l
r' :: r' a
r' = forall (dtype' :: DType) (dtype :: DType) (f :: * -> *)
(g :: * -> *) a.
GHasToDType dtype' dtype f g =>
f a -> g a
gToDType @dtype' @dtype r a
r
in l' a
l' forall k (f :: k -> *) (g :: k -> *) (p :: k).
f p -> g p -> (:*:) f g p
:*: r' a
r'
instance {-# OVERLAPS #-} (KnownDType dtype') => HasToDType dtype' dtype (Tensor device dtype shape) (Tensor device dtype' shape) where
toDType :: Tensor device dtype shape -> Tensor device dtype' shape
toDType = forall (dtype' :: DType) (dtype :: DType)
(device :: (DeviceType, Nat)) (shape :: [Nat]) t t'.
(KnownDType dtype', IsUnnamed t device dtype shape, Unnamed t',
t' ~ ReplaceDType'' t dtype') =>
t -> t'
Torch.Typed.Tensor.toDType
instance {-# OVERLAPS #-} (KnownDType dtype') => HasToDType dtype' dtype (Parameter device dtype shape) (Parameter device dtype' shape) where
toDType :: Parameter device dtype shape -> Parameter device dtype' shape
toDType = forall (dtype' :: DType) (dtype :: DType)
(device :: (DeviceType, Nat)) (shape :: [Nat]).
KnownDType dtype' =>
Parameter device dtype shape -> Parameter device dtype' shape
Torch.Typed.Parameter.parameterToDType
instance {-# OVERLAPPABLE #-} (HasToDType dtype' dtype f g) => GHasToDType dtype' dtype (K1 i f) (K1 i g) where
gToDType :: forall a. K1 i f a -> K1 i g a
gToDType = forall k i c (p :: k). c -> K1 i c p
K1 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} {k} (dtype' :: k) (dtype :: k) f g.
HasToDType dtype' dtype f g =>
f -> g
Torch.Typed.DType.toDType @dtype' @dtype forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k i c (p :: k). K1 i c p -> c
unK1
instance (GHasToDType dtype' dtype f g) => GHasToDType dtype' dtype (M1 i t f) (M1 i t g) where
gToDType :: forall a. M1 i t f a -> M1 i t g a
gToDType = forall k i (c :: Meta) (f :: k -> *) (p :: k). f p -> M1 i c f p
M1 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (dtype' :: DType) (dtype :: DType) (f :: * -> *)
(g :: * -> *) a.
GHasToDType dtype' dtype f g =>
f a -> g a
gToDType @dtype' @dtype forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k i (c :: Meta) (f :: k -> *) (p :: k). M1 i c f p -> f p
unM1
instance GHasToDType dtype' dtype U1 U1 where
gToDType :: forall a. U1 a -> U1 a
gToDType = forall a. a -> a
id
type family GetDType (f :: k) :: Maybe D.DType where
GetDType (t (dtype :: D.DType)) = Just dtype
GetDType (t a) = GetDType t
GetDType t = Nothing