{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module Torch.Typed.NN.DataParallel where
import Control.Concurrent.Async
import Data.Kind
import GHC.TypeLits
import qualified Torch.Device as D
import Torch.HList
import qualified Torch.Internal.Cast as ATen
import qualified Torch.Internal.Class as ATen
import Torch.NN (HasForward (..))
import qualified Torch.Tensor as D
import Torch.Typed.Autograd
import Torch.Typed.Device
import Torch.Typed.Optim
data ForwardConcurrentlyF = ForwardConcurrentlyF | ForwardConcurrentlyStochF
instance
( HasForward model input output
) =>
Apply' ForwardConcurrentlyF (model, input) (Concurrently output)
where
apply' :: ForwardConcurrentlyF -> (model, input) -> Concurrently output
apply' ForwardConcurrentlyF
ForwardConcurrentlyF (model
model, input
input) = forall a. IO a -> Concurrently a
Concurrently forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall f a b. HasForward f a b => f -> a -> b
forward model
model forall a b. (a -> b) -> a -> b
$ input
input
apply' ForwardConcurrentlyF
ForwardConcurrentlyStochF (model
model, input
input) = forall a. IO a -> Concurrently a
Concurrently forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall f a b. HasForward f a b => f -> a -> IO b
forwardStoch model
model forall a b. (a -> b) -> a -> b
$ input
input
forwardConcurrently',
forwardConcurrentlyStoch' ::
forall devices' device' device model input output models inputs outputs.
( 'Just device ~ GetDevice model,
'Just device ~ GetDevice input,
HasScatter devices' device input inputs,
HasReplicate devices' device model models,
HZipWithM Concurrently ForwardConcurrentlyF models inputs outputs,
HasGather device' devices' outputs output
) =>
model ->
input ->
IO output
forwardConcurrently' :: forall {k} (devices' :: [(DeviceType, Nat)]) (device' :: k)
(device :: (DeviceType, Nat)) model input output (models :: [*])
(inputs :: [*]) (outputs :: [*]).
('Just device ~ GetDevice model, 'Just device ~ GetDevice input,
HasScatter devices' device input inputs,
HasReplicate devices' device model models,
HZipWithM Concurrently ForwardConcurrentlyF models inputs outputs,
HasGather device' devices' outputs output) =>
model -> input -> IO output
forwardConcurrently' model
model input
input = do
let models :: HList models
models = forall (devices' :: [(DeviceType, Nat)])
(device :: (DeviceType, Nat)) f (gs :: [*]).
HasReplicate devices' device f gs =>
f -> HList gs
Torch.Typed.Device.replicate @devices' @device @model @models model
model
inputs :: HList inputs
inputs = forall {k} {k} {k} (devices' :: k) (device :: k) f (gs :: [k]).
HasScatter devices' device f gs =>
f -> HList gs
scatter @devices' @device @input @inputs input
input
HList outputs
outputs <- forall a. Concurrently a -> IO a
runConcurrently forall a b. (a -> b) -> a -> b
$ forall {k} (models :: [k]) (inputs :: [k]) (outputs :: [k]).
HZipWithM
Concurrently ForwardConcurrentlyF models inputs outputs =>
HList models -> HList inputs -> Concurrently (HList outputs)
forwardConcurrently HList models
models HList inputs
inputs
let output :: output
output = forall {k} {k} {k} (device' :: k) (devices :: k) (fs :: [k]) g.
HasGather device' devices fs g =>
HList fs -> g
gather @device' @devices' @outputs @output HList outputs
outputs
forall (m :: * -> *) a. Monad m => a -> m a
return output
output
forwardConcurrentlyStoch' :: forall {k} (devices' :: [(DeviceType, Nat)]) (device' :: k)
(device :: (DeviceType, Nat)) model input output (models :: [*])
(inputs :: [*]) (outputs :: [*]).
('Just device ~ GetDevice model, 'Just device ~ GetDevice input,
HasScatter devices' device input inputs,
HasReplicate devices' device model models,
HZipWithM Concurrently ForwardConcurrentlyF models inputs outputs,
HasGather device' devices' outputs output) =>
model -> input -> IO output
forwardConcurrentlyStoch' model
model input
input = do
let models :: HList models
models = forall (devices' :: [(DeviceType, Nat)])
(device :: (DeviceType, Nat)) f (gs :: [*]).
HasReplicate devices' device f gs =>
f -> HList gs
Torch.Typed.Device.replicate @devices' @device @model @models model
model
inputs :: HList inputs
inputs = forall {k} {k} {k} (devices' :: k) (device :: k) f (gs :: [k]).
HasScatter devices' device f gs =>
f -> HList gs
scatter @devices' @device @input @inputs input
input
HList outputs
outputs <- forall a. Concurrently a -> IO a
runConcurrently forall a b. (a -> b) -> a -> b
$ forall {k} (models :: [k]) (inputs :: [k]) (outputs :: [k]).
HZipWithM
Concurrently ForwardConcurrentlyF models inputs outputs =>
HList models -> HList inputs -> Concurrently (HList outputs)
forwardConcurrentlyStoch HList models
models HList inputs
inputs
let output :: output
output = forall {k} {k} {k} (device' :: k) (devices :: k) (fs :: [k]) g.
HasGather device' devices fs g =>
HList fs -> g
gather @device' @devices' @outputs @output HList outputs
outputs
forall (m :: * -> *) a. Monad m => a -> m a
return output
output
forwardConcurrently,
forwardConcurrentlyStoch ::
forall models inputs outputs.
HZipWithM Concurrently ForwardConcurrentlyF models inputs outputs =>
HList models ->
HList inputs ->
Concurrently (HList outputs)
forwardConcurrently :: forall {k} (models :: [k]) (inputs :: [k]) (outputs :: [k]).
HZipWithM
Concurrently ForwardConcurrentlyF models inputs outputs =>
HList models -> HList inputs -> Concurrently (HList outputs)
forwardConcurrently = forall k (m :: * -> *) f (xs :: [k]) (ys :: [k]) (zs :: [k]).
HZipWithM m f xs ys zs =>
f -> HList xs -> HList ys -> m (HList zs)
hzipWithM ForwardConcurrentlyF
ForwardConcurrentlyF
forwardConcurrentlyStoch :: forall {k} (models :: [k]) (inputs :: [k]) (outputs :: [k]).
HZipWithM
Concurrently ForwardConcurrentlyF models inputs outputs =>
HList models -> HList inputs -> Concurrently (HList outputs)
forwardConcurrentlyStoch = forall k (m :: * -> *) f (xs :: [k]) (ys :: [k]) (zs :: [k]).
HZipWithM m f xs ys zs =>
f -> HList xs -> HList ys -> m (HList zs)
hzipWithM ForwardConcurrentlyF
ForwardConcurrentlyStochF
class HasGradConcurrently device' devices parameters losses gradients | device' devices parameters losses -> gradients where
gradConcurrently :: HList parameters -> HList losses -> Concurrently (HList gradients)
data GradConcurrentlyF = GradConcurrentlyF
instance
( HasGrad (HList parameters) (HList gradients),
ATen.Castable (HList gradients) [D.ATenTensor]
) =>
Apply' GradConcurrentlyF (HList parameters, Loss device dtype) (Concurrently (HList gradients))
where
apply' :: GradConcurrentlyF
-> (HList parameters, Loss device dtype)
-> Concurrently (HList gradients)
apply' GradConcurrentlyF
GradConcurrentlyF (HList parameters
parameters, Loss device dtype
loss) = forall a. IO a -> Concurrently a
Concurrently forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b (dtype :: DType) (device :: (DeviceType, Nat)).
HasGrad a b =>
Tensor device dtype '[] -> a -> b
grad Loss device dtype
loss forall a b. (a -> b) -> a -> b
$ HList parameters
parameters
instance
( HZipWithM Concurrently GradConcurrentlyF parameters losses gradients',
ReduceGradients device' devices gradients' gradients
) =>
HasGradConcurrently device' devices parameters losses gradients
where
gradConcurrently :: HList parameters -> HList losses -> Concurrently (HList gradients)
gradConcurrently HList parameters
parameters HList losses
losses =
let gradients :: Concurrently (HList gradients')
gradients = forall k (m :: * -> *) f (xs :: [k]) (ys :: [k]) (zs :: [k]).
HZipWithM m f xs ys zs =>
f -> HList xs -> HList ys -> m (HList zs)
hzipWithM GradConcurrentlyF
GradConcurrentlyF HList parameters
parameters HList losses
losses
in forall {k} {k} (device' :: (DeviceType, Nat))
(devices :: [(DeviceType, Nat)]) (xxs :: [k]) (ys :: [k]).
ReduceGradients device' devices xxs ys =>
HList xxs -> HList ys
reduceGradients @device' @devices forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Concurrently (HList gradients')
gradients
class ReduceGradients (device' :: (D.DeviceType, Nat)) (devices :: [(D.DeviceType, Nat)]) xxs ys | device' devices xxs -> ys where
reduceGradients :: HList xxs -> HList ys
instance
{-# OVERLAPS #-}
( HasToDevice device' device (HList xs) (HList ys)
) =>
ReduceGradients device' (device ': '[]) ((HList (xs :: [Type])) ': '[]) ys
where
reduceGradients :: HList '[HList xs] -> HList ys
reduceGradients (HList xs
xs :. HList '[]
R:HListk[] (*)
HNil) = forall (device' :: (DeviceType, Nat)) (device :: (DeviceType, Nat))
f g.
HasToDevice device' device f g =>
f -> g
Torch.Typed.Device.toDevice @device' @device HList xs
xs
data SumF = SumF
instance Num y => Apply' SumF (y, y) y where
apply' :: SumF -> (y, y) -> y
apply' SumF
_ = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum
instance
{-# OVERLAPPABLE #-}
( HasToDevice device' device (HList xs) (HList ys),
ReduceGradients device' devices xxs ys,
HZipWith SumF ys ys ys,
1 <= ListLength xxs
) =>
ReduceGradients device' (device ': devices) ((HList (xs :: [Type])) ': xxs) ys
where
reduceGradients :: HList (HList xs : xxs) -> HList ys
reduceGradients (HList xs
xs :. HList xxs
xxs) = forall k f (xs :: [k]) (ys :: [k]) (zs :: [k]).
HZipWith f xs ys zs =>
f -> HList xs -> HList ys -> HList zs
hzipWith SumF
SumF (forall (device' :: (DeviceType, Nat)) (device :: (DeviceType, Nat))
f g.
HasToDevice device' device f g =>
f -> g
Torch.Typed.Device.toDevice @device' @device HList xs
xs) (forall {k} {k} (device' :: (DeviceType, Nat))
(devices :: [(DeviceType, Nat)]) (xxs :: [k]) (ys :: [k]).
ReduceGradients device' devices xxs ys =>
HList xxs -> HList ys
reduceGradients @device' @devices @xxs HList xxs
xxs)