{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

module Torch.Typed.NN.DataParallel where

import Control.Concurrent.Async
import Data.Kind
import GHC.TypeLits
import qualified Torch.Device as D
import Torch.HList
import qualified Torch.Internal.Cast as ATen
import qualified Torch.Internal.Class as ATen
import Torch.NN (HasForward (..))
import qualified Torch.Tensor as D
import Torch.Typed.Autograd
import Torch.Typed.Device
import Torch.Typed.Optim

data ForwardConcurrentlyF = ForwardConcurrentlyF | ForwardConcurrentlyStochF

instance
  ( HasForward model input output
  ) =>
  Apply' ForwardConcurrentlyF (model, input) (Concurrently output)
  where
  apply' :: ForwardConcurrentlyF -> (model, input) -> Concurrently output
apply' ForwardConcurrentlyF
ForwardConcurrentlyF (model
model, input
input) = forall a. IO a -> Concurrently a
Concurrently forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall f a b. HasForward f a b => f -> a -> b
forward model
model forall a b. (a -> b) -> a -> b
$ input
input
  apply' ForwardConcurrentlyF
ForwardConcurrentlyStochF (model
model, input
input) = forall a. IO a -> Concurrently a
Concurrently forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall f a b. HasForward f a b => f -> a -> IO b
forwardStoch model
model forall a b. (a -> b) -> a -> b
$ input
input

-- Run a `model` concurrently on an `input`.
--
-- The `model` is replicated over the supplied `devices'`, and the `input` is scattered
-- over them as well. Then the `forward` function of the replicated `models` is run
-- concurrently on the scattered `inputs`. Finally, the `outputs` are gathered on the
-- target `device'`
--
-- >>> model <- A.sample (LinearSpec @1 @1 @'D.Float @'( 'D.CPU, 0))
-- >>> t = ones @'[2, 1] @'D.Float @'( 'D.CPU, 0)
--
-- >>> :t forward model t
-- forward model t :: IO (Tensor '( 'D.CPU, 0) 'D.Float '[2, 1])
-- >>> forward model t
-- Tensor Float [2,1] [[ 0.2478   ],
--                     [ 0.2478   ]]
--
-- >>> :t forwardConcurrently' @'[ '( 'D.CPU, 0), '( 'D.CUDA, 0)] @'( 'D.CPU, 0) model t
-- forwardConcurrently' @'[ '( 'D.CPU, 0), '( 'D.CUDA, 0)] @'( 'D.CPU, 0) model t
--   :: IO (Tensor '( 'D.CPU, 0) 'D.Float '[2, 1])
-- >>> forwardConcurrently' @'[ '( 'D.CPU, 0), '( 'D.CUDA, 0)] @'( 'D.CPU, 0) model t
-- Tensor Float [2,1] [[ 0.2478   ],
--                     [ 0.2478   ]]
forwardConcurrently',
  forwardConcurrentlyStoch' ::
    forall devices' device' device model input output models inputs outputs.
    ( 'Just device ~ GetDevice model,
      'Just device ~ GetDevice input,
      HasScatter devices' device input inputs,
      HasReplicate devices' device model models,
      HZipWithM Concurrently ForwardConcurrentlyF models inputs outputs,
      HasGather device' devices' outputs output
    ) =>
    model ->
    input ->
    IO output
forwardConcurrently' :: forall {k} (devices' :: [(DeviceType, Nat)]) (device' :: k)
       (device :: (DeviceType, Nat)) model input output (models :: [*])
       (inputs :: [*]) (outputs :: [*]).
('Just device ~ GetDevice model, 'Just device ~ GetDevice input,
 HasScatter devices' device input inputs,
 HasReplicate devices' device model models,
 HZipWithM Concurrently ForwardConcurrentlyF models inputs outputs,
 HasGather device' devices' outputs output) =>
model -> input -> IO output
forwardConcurrently' model
model input
input = do
  let models :: HList models
models = forall (devices' :: [(DeviceType, Nat)])
       (device :: (DeviceType, Nat)) f (gs :: [*]).
HasReplicate devices' device f gs =>
f -> HList gs
Torch.Typed.Device.replicate @devices' @device @model @models model
model
      inputs :: HList inputs
inputs = forall {k} {k} {k} (devices' :: k) (device :: k) f (gs :: [k]).
HasScatter devices' device f gs =>
f -> HList gs
scatter @devices' @device @input @inputs input
input
  HList outputs
outputs <- forall a. Concurrently a -> IO a
runConcurrently forall a b. (a -> b) -> a -> b
$ forall {k} (models :: [k]) (inputs :: [k]) (outputs :: [k]).
HZipWithM
  Concurrently ForwardConcurrentlyF models inputs outputs =>
HList models -> HList inputs -> Concurrently (HList outputs)
forwardConcurrently HList models
models HList inputs
inputs
  let output :: output
output = forall {k} {k} {k} (device' :: k) (devices :: k) (fs :: [k]) g.
HasGather device' devices fs g =>
HList fs -> g
gather @device' @devices' @outputs @output HList outputs
outputs
  forall (m :: * -> *) a. Monad m => a -> m a
return output
output
forwardConcurrentlyStoch' :: forall {k} (devices' :: [(DeviceType, Nat)]) (device' :: k)
       (device :: (DeviceType, Nat)) model input output (models :: [*])
       (inputs :: [*]) (outputs :: [*]).
('Just device ~ GetDevice model, 'Just device ~ GetDevice input,
 HasScatter devices' device input inputs,
 HasReplicate devices' device model models,
 HZipWithM Concurrently ForwardConcurrentlyF models inputs outputs,
 HasGather device' devices' outputs output) =>
model -> input -> IO output
forwardConcurrentlyStoch' model
model input
input = do
  let models :: HList models
models = forall (devices' :: [(DeviceType, Nat)])
       (device :: (DeviceType, Nat)) f (gs :: [*]).
HasReplicate devices' device f gs =>
f -> HList gs
Torch.Typed.Device.replicate @devices' @device @model @models model
model
      inputs :: HList inputs
inputs = forall {k} {k} {k} (devices' :: k) (device :: k) f (gs :: [k]).
HasScatter devices' device f gs =>
f -> HList gs
scatter @devices' @device @input @inputs input
input
  HList outputs
outputs <- forall a. Concurrently a -> IO a
runConcurrently forall a b. (a -> b) -> a -> b
$ forall {k} (models :: [k]) (inputs :: [k]) (outputs :: [k]).
HZipWithM
  Concurrently ForwardConcurrentlyF models inputs outputs =>
HList models -> HList inputs -> Concurrently (HList outputs)
forwardConcurrentlyStoch HList models
models HList inputs
inputs
  let output :: output
output = forall {k} {k} {k} (device' :: k) (devices :: k) (fs :: [k]) g.
HasGather device' devices fs g =>
HList fs -> g
gather @device' @devices' @outputs @output HList outputs
outputs
  forall (m :: * -> *) a. Monad m => a -> m a
return output
output

forwardConcurrently,
  forwardConcurrentlyStoch ::
    forall models inputs outputs.
    HZipWithM Concurrently ForwardConcurrentlyF models inputs outputs =>
    HList models ->
    HList inputs ->
    Concurrently (HList outputs)
forwardConcurrently :: forall {k} (models :: [k]) (inputs :: [k]) (outputs :: [k]).
HZipWithM
  Concurrently ForwardConcurrentlyF models inputs outputs =>
HList models -> HList inputs -> Concurrently (HList outputs)
forwardConcurrently = forall k (m :: * -> *) f (xs :: [k]) (ys :: [k]) (zs :: [k]).
HZipWithM m f xs ys zs =>
f -> HList xs -> HList ys -> m (HList zs)
hzipWithM ForwardConcurrentlyF
ForwardConcurrentlyF
forwardConcurrentlyStoch :: forall {k} (models :: [k]) (inputs :: [k]) (outputs :: [k]).
HZipWithM
  Concurrently ForwardConcurrentlyF models inputs outputs =>
HList models -> HList inputs -> Concurrently (HList outputs)
forwardConcurrentlyStoch = forall k (m :: * -> *) f (xs :: [k]) (ys :: [k]) (zs :: [k]).
HZipWithM m f xs ys zs =>
f -> HList xs -> HList ys -> m (HList zs)
hzipWithM ForwardConcurrentlyF
ForwardConcurrentlyStochF

class HasGradConcurrently device' devices parameters losses gradients | device' devices parameters losses -> gradients where
  gradConcurrently :: HList parameters -> HList losses -> Concurrently (HList gradients)

data GradConcurrentlyF = GradConcurrentlyF

instance
  ( HasGrad (HList parameters) (HList gradients),
    ATen.Castable (HList gradients) [D.ATenTensor]
  ) =>
  Apply' GradConcurrentlyF (HList parameters, Loss device dtype) (Concurrently (HList gradients))
  where
  apply' :: GradConcurrentlyF
-> (HList parameters, Loss device dtype)
-> Concurrently (HList gradients)
apply' GradConcurrentlyF
GradConcurrentlyF (HList parameters
parameters, Loss device dtype
loss) = forall a. IO a -> Concurrently a
Concurrently forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b (dtype :: DType) (device :: (DeviceType, Nat)).
HasGrad a b =>
Tensor device dtype '[] -> a -> b
grad Loss device dtype
loss forall a b. (a -> b) -> a -> b
$ HList parameters
parameters

instance
  ( HZipWithM Concurrently GradConcurrentlyF parameters losses gradients',
    ReduceGradients device' devices gradients' gradients
  ) =>
  HasGradConcurrently device' devices parameters losses gradients
  where
  gradConcurrently :: HList parameters -> HList losses -> Concurrently (HList gradients)
gradConcurrently HList parameters
parameters HList losses
losses =
    let gradients :: Concurrently (HList gradients')
gradients = forall k (m :: * -> *) f (xs :: [k]) (ys :: [k]) (zs :: [k]).
HZipWithM m f xs ys zs =>
f -> HList xs -> HList ys -> m (HList zs)
hzipWithM GradConcurrentlyF
GradConcurrentlyF HList parameters
parameters HList losses
losses
     in forall {k} {k} (device' :: (DeviceType, Nat))
       (devices :: [(DeviceType, Nat)]) (xxs :: [k]) (ys :: [k]).
ReduceGradients device' devices xxs ys =>
HList xxs -> HList ys
reduceGradients @device' @devices forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Concurrently (HList gradients')
gradients

class ReduceGradients (device' :: (D.DeviceType, Nat)) (devices :: [(D.DeviceType, Nat)]) xxs ys | device' devices xxs -> ys where
  reduceGradients :: HList xxs -> HList ys

instance
  {-# OVERLAPS #-}
  ( HasToDevice device' device (HList xs) (HList ys)
  ) =>
  ReduceGradients device' (device ': '[]) ((HList (xs :: [Type])) ': '[]) ys
  where
  reduceGradients :: HList '[HList xs] -> HList ys
reduceGradients (HList xs
xs :. HList '[]
R:HListk[] (*)
HNil) = forall (device' :: (DeviceType, Nat)) (device :: (DeviceType, Nat))
       f g.
HasToDevice device' device f g =>
f -> g
Torch.Typed.Device.toDevice @device' @device HList xs
xs

data SumF = SumF

instance Num y => Apply' SumF (y, y) y where
  apply' :: SumF -> (y, y) -> y
apply' SumF
_ = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum

instance
  {-# OVERLAPPABLE #-}
  ( HasToDevice device' device (HList xs) (HList ys),
    ReduceGradients device' devices xxs ys,
    HZipWith SumF ys ys ys,
    1 <= ListLength xxs
  ) =>
  ReduceGradients device' (device ': devices) ((HList (xs :: [Type])) ': xxs) ys
  where
  reduceGradients :: HList (HList xs : xxs) -> HList ys
reduceGradients (HList xs
xs :. HList xxs
xxs) = forall k f (xs :: [k]) (ys :: [k]) (zs :: [k]).
HZipWith f xs ys zs =>
f -> HList xs -> HList ys -> HList zs
hzipWith SumF
SumF (forall (device' :: (DeviceType, Nat)) (device :: (DeviceType, Nat))
       f g.
HasToDevice device' device f g =>
f -> g
Torch.Typed.Device.toDevice @device' @device HList xs
xs) (forall {k} {k} (device' :: (DeviceType, Nat))
       (devices :: [(DeviceType, Nat)]) (xxs :: [k]) (ys :: [k]).
ReduceGradients device' devices xxs ys =>
HList xxs -> HList ys
reduceGradients @device' @devices @xxs HList xxs
xxs)