{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module Torch.Typed.Device where
import Data.Kind (Type)
import Data.Proxy (Proxy (..))
import GHC.Generics
import GHC.TypeLits
import GHC.TypeLits.Extra
import System.IO.Unsafe
import qualified Torch.DType as D
import qualified Torch.Device as D
import Torch.HList
import qualified Torch.Internal.Cast as ATen
import qualified Torch.Internal.Class as ATen
import qualified Torch.Internal.Managed.Autograd as LibTorch
import qualified Torch.Tensor as D
import Torch.Typed.Auxiliary
import Torch.Typed.Functional
import Torch.Typed.Parameter
import Torch.Typed.Tensor
class
HasToDevice
(device' :: (D.DeviceType, Nat))
(device :: (D.DeviceType, Nat))
(f :: Type)
(g :: Type)
| device' device f -> g,
device' device g -> f
where
toDevice :: f -> g
type family ReplaceDevice (f :: k) (device' :: (D.DeviceType, Nat)) (device :: (D.DeviceType, Nat)) :: k where
ReplaceDevice (t device) device' device = t device'
ReplaceDevice (t a) device' device = (ReplaceDevice t device' device) (ReplaceDevice a device' device)
ReplaceDevice t _ _ = t
type family ReplaceDevice' (f :: k) (device' :: (D.DeviceType, Nat)) :: k where
ReplaceDevice' (t (device :: (D.DeviceType, Nat))) device' = t device'
ReplaceDevice' (t a) device' = (ReplaceDevice' t device') (ReplaceDevice' a device')
ReplaceDevice' t _ = t
instance
( g ~ ReplaceDevice f device' device,
f ~ ReplaceDevice g device device',
Generic f,
Generic g,
GHasToDevice device' device (Rep f) (Rep g)
) =>
HasToDevice device' device f g
where
toDevice :: f -> g
toDevice = forall a x. Generic a => Rep a x -> a
to forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (device' :: (DeviceType, Nat)) (device :: (DeviceType, Nat))
(f :: * -> *) (g :: * -> *) a.
GHasToDevice device' device f g =>
f a -> g a
gToDevice @device' @device forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a x. Generic a => a -> Rep a x
from
class
GHasToDevice
(device' :: (D.DeviceType, Nat))
(device :: (D.DeviceType, Nat))
(f :: Type -> Type)
(g :: Type -> Type)
where
gToDevice :: forall a. f a -> g a
instance
( GHasToDevice device' device l l',
GHasToDevice device' device r r'
) =>
GHasToDevice device' device (l :*: r) (l' :*: r')
where
gToDevice :: forall a. (:*:) l r a -> (:*:) l' r' a
gToDevice (l a
l :*: r a
r) =
let l' :: l' a
l' = forall (device' :: (DeviceType, Nat)) (device :: (DeviceType, Nat))
(f :: * -> *) (g :: * -> *) a.
GHasToDevice device' device f g =>
f a -> g a
gToDevice @device' @device l a
l
r' :: r' a
r' = forall (device' :: (DeviceType, Nat)) (device :: (DeviceType, Nat))
(f :: * -> *) (g :: * -> *) a.
GHasToDevice device' device f g =>
f a -> g a
gToDevice @device' @device r a
r
in l' a
l' forall k (f :: k -> *) (g :: k -> *) (p :: k).
f p -> g p -> (:*:) f g p
:*: r' a
r'
instance {-# OVERLAPS #-} HasToDevice device' device Double Double where
toDevice :: Double -> Double
toDevice = forall a. a -> a
id
instance {-# OVERLAPS #-} (KnownDevice device') => HasToDevice device' device (Tensor device dtype shape) (Tensor device' dtype shape) where
toDevice :: Tensor device dtype shape -> Tensor device' dtype shape
toDevice = forall (device' :: (DeviceType, Nat)) (device :: (DeviceType, Nat))
(dtype :: DType) (shape :: [Nat]) t t'.
(KnownDevice device', IsUnnamed t device dtype shape, Unnamed t',
t' ~ ReplaceDevice'' t device') =>
t -> t'
Torch.Typed.Tensor.toDevice
instance {-# OVERLAPS #-} (KnownDevice device') => HasToDevice device' device (Parameter device dtype shape) (Parameter device' dtype shape) where
toDevice :: Parameter device dtype shape -> Parameter device' dtype shape
toDevice = forall (device' :: (DeviceType, Nat)) (device :: (DeviceType, Nat))
(dtype :: DType) (shape :: [Nat]).
KnownDevice device' =>
Parameter device dtype shape -> Parameter device' dtype shape
Torch.Typed.Parameter.parameterToDevice
instance {-# OVERLAPS #-} HasToDevice device' device (HList ('[] :: [Type])) (HList ('[] :: [Type])) where
toDevice :: HList '[] -> HList '[]
toDevice = forall a. a -> a
id
instance {-# OVERLAPS #-} (HasToDevice device' device x x', HasToDevice device' device (HList xs) (HList xs')) => HasToDevice device' device (HList (x ': xs)) (HList (x' ': xs')) where
toDevice :: HList (x : xs) -> HList (x' : xs')
toDevice (x
x :. HList xs
xs) = forall (device' :: (DeviceType, Nat)) (device :: (DeviceType, Nat))
f g.
HasToDevice device' device f g =>
f -> g
Torch.Typed.Device.toDevice @device' @device x
x forall x (xs :: [*]). x -> HList xs -> HList (x : xs)
:. forall (device' :: (DeviceType, Nat)) (device :: (DeviceType, Nat))
f g.
HasToDevice device' device f g =>
f -> g
Torch.Typed.Device.toDevice @device' @device HList xs
xs
instance {-# OVERLAPPABLE #-} (HasToDevice device' device f g) => GHasToDevice device' device (K1 i f) (K1 i g) where
gToDevice :: forall a. K1 i f a -> K1 i g a
gToDevice = forall k i c (p :: k). c -> K1 i c p
K1 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (device' :: (DeviceType, Nat)) (device :: (DeviceType, Nat))
f g.
HasToDevice device' device f g =>
f -> g
Torch.Typed.Device.toDevice @device' @device forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k i c (p :: k). K1 i c p -> c
unK1
instance (GHasToDevice device' device f g) => GHasToDevice device' device (M1 i t f) (M1 i t g) where
gToDevice :: forall a. M1 i t f a -> M1 i t g a
gToDevice = forall k i (c :: Meta) (f :: k -> *) (p :: k). f p -> M1 i c f p
M1 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (device' :: (DeviceType, Nat)) (device :: (DeviceType, Nat))
(f :: * -> *) (g :: * -> *) a.
GHasToDevice device' device f g =>
f a -> g a
gToDevice @device' @device forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k i (c :: Meta) (f :: k -> *) (p :: k). M1 i c f p -> f p
unM1
instance GHasToDevice device' device U1 U1 where
gToDevice :: forall a. U1 a -> U1 a
gToDevice = forall a. a -> a
id
class HasReplicate (devices' :: [(D.DeviceType, Nat)]) (device :: (D.DeviceType, Nat)) (f :: Type) (gs :: [Type]) | devices' device f -> gs where
replicate :: f -> HList gs
instance HasReplicate '[] device f '[] where
replicate :: f -> HList '[]
replicate f
_ = forall k. HList '[]
HNil
instance
( HasReplicate devices' device f gs,
HasToDevice device' device f g
) =>
HasReplicate (device' ': devices') device f (g ': gs)
where
replicate :: f -> HList (g : gs)
replicate f
f = forall (device' :: (DeviceType, Nat)) (device :: (DeviceType, Nat))
f g.
HasToDevice device' device f g =>
f -> g
Torch.Typed.Device.toDevice @device' @device f
f forall x (xs :: [*]). x -> HList xs -> HList (x : xs)
:. forall (devices' :: [(DeviceType, Nat)])
(device :: (DeviceType, Nat)) f (gs :: [*]).
HasReplicate devices' device f gs =>
f -> HList gs
Torch.Typed.Device.replicate @devices' @device @f @gs f
f
class
HasToDevices
(devices' :: [(D.DeviceType, Nat)])
(devices :: [(D.DeviceType, Nat)])
(fs :: [Type])
(gs :: [Type])
| devices' devices fs -> gs,
devices' devices gs -> fs
where
toDevices :: HList fs -> HList gs
instance HasToDevices '[] '[] '[] '[] where
toDevices :: HList '[] -> HList '[]
toDevices HList '[]
R:HListk[] (*)
HNil = forall k. HList '[]
HNil
instance
( HasToDevices devices' devices fs gs,
HasToDevice device' device f g
) =>
HasToDevices (device' ': devices') (device ': devices) (f ': fs) (g ': gs)
where
toDevices :: HList (f : fs) -> HList (g : gs)
toDevices (f
f :. HList fs
fs) = forall (device' :: (DeviceType, Nat)) (device :: (DeviceType, Nat))
f g.
HasToDevice device' device f g =>
f -> g
Torch.Typed.Device.toDevice @device' @device f
f forall x (xs :: [*]). x -> HList xs -> HList (x : xs)
:. forall (devices' :: [(DeviceType, Nat)])
(devices :: [(DeviceType, Nat)]) (fs :: [*]) (gs :: [*]).
HasToDevices devices' devices fs gs =>
HList fs -> HList gs
toDevices @devices' @devices @fs @gs HList fs
fs
type family GetDevice (f :: k) :: Maybe (D.DeviceType, Nat) where
GetDevice (t (device :: (D.DeviceType, Nat))) = Just device
GetDevice (t a) = GetDevice t
GetDevice t = Nothing
type family GetDevices (fs :: [k]) :: [(D.DeviceType, Nat)] where
GetDevices '[] = '[]
GetDevices (f ': fs) = MaybePrepend (GetDevice f) (GetDevices fs)
class HasScatter devices' device f gs | devices' device f -> gs where
scatter :: f -> HList gs
instance
( chunks ~ ListLength devices',
tensorChunks ~ Chunk chunks 0 shape dtype device,
ATen.Castable (HList tensorChunks) [D.ATenTensor],
devices ~ HReplicateR chunks device,
HasToDevices devices' devices tensorChunks gs,
KnownNat chunks
) =>
HasScatter devices' device (Tensor device dtype shape) gs
where
scatter :: Tensor device dtype shape -> HList gs
scatter = forall (devices' :: [(DeviceType, Nat)])
(devices :: [(DeviceType, Nat)]) (fs :: [*]) (gs :: [*]).
HasToDevices devices' devices fs gs =>
HList fs -> HList gs
toDevices @devices' @devices forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (chunks :: Nat) (dim :: Nat) (shape :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat))
(tensorChunks :: [k]).
(KnownNat chunks, KnownNat dim,
tensorChunks ~ Chunk chunks dim shape dtype device,
Castable (HList tensorChunks) [ATenTensor]) =>
Tensor device dtype shape -> HList tensorChunks
Torch.Typed.Functional.chunk @chunks @0
class HasGather device' devices fs g | device' devices fs -> g where
gather :: HList fs -> g
instance
( chunks ~ ListLength fs,
devices ~ GetDevices fs,
devices' ~ HReplicateR chunks device',
HasToDevices devices' devices fs tensorChunks,
'(shape, dtype, device') ~ Cat 0 tensorChunks,
ATen.Castable (HList tensorChunks) [D.ATenTensor]
) =>
HasGather device' devices fs (Tensor device' dtype shape)
where
gather :: HList fs -> Tensor device' dtype shape
gather = forall {k} (dim :: Nat) (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)) (tensors :: [k]).
(KnownNat dim, '(shape, dtype, device) ~ Cat dim tensors,
Castable (HList tensors) [ATenTensor]) =>
HList tensors -> Tensor device dtype shape
Torch.Typed.Functional.cat @0 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (devices' :: [(DeviceType, Nat)])
(devices :: [(DeviceType, Nat)]) (fs :: [*]) (gs :: [*]).
HasToDevices devices' devices fs gs =>
HList fs -> HList gs
toDevices @devices' @devices