Safe Haskell | Safe-Inferred |
---|---|
Language | Haskell2010 |
Documentation
data ForwardConcurrentlyF Source #
Instances
HasForward model input output => Apply' ForwardConcurrentlyF (model, input) (Concurrently output) Source # | |
Defined in Torch.Typed.NN.DataParallel apply' :: ForwardConcurrentlyF -> (model, input) -> Concurrently output Source # |
forwardConcurrently' :: forall devices' device' device model input output models inputs outputs. ('Just device ~ GetDevice model, 'Just device ~ GetDevice input, HasScatter devices' device input inputs, HasReplicate devices' device model models, HZipWithM Concurrently ForwardConcurrentlyF models inputs outputs, HasGather device' devices' outputs output) => model -> input -> IO output Source #
forwardConcurrentlyStoch' :: forall devices' device' device model input output models inputs outputs. ('Just device ~ GetDevice model, 'Just device ~ GetDevice input, HasScatter devices' device input inputs, HasReplicate devices' device model models, HZipWithM Concurrently ForwardConcurrentlyF models inputs outputs, HasGather device' devices' outputs output) => model -> input -> IO output Source #
forwardConcurrently :: forall models inputs outputs. HZipWithM Concurrently ForwardConcurrentlyF models inputs outputs => HList models -> HList inputs -> Concurrently (HList outputs) Source #
forwardConcurrentlyStoch :: forall models inputs outputs. HZipWithM Concurrently ForwardConcurrentlyF models inputs outputs => HList models -> HList inputs -> Concurrently (HList outputs) Source #
class HasGradConcurrently device' devices parameters losses gradients | device' devices parameters losses -> gradients where Source #
gradConcurrently :: HList parameters -> HList losses -> Concurrently (HList gradients) Source #
Instances
(HZipWithM Concurrently GradConcurrentlyF parameters losses gradients', ReduceGradients device' devices gradients' gradients) => HasGradConcurrently (device' :: (DeviceType, Nat)) (devices :: [(DeviceType, Nat)]) (parameters :: [k1]) (losses :: [k1]) (gradients :: [k2]) Source # | |
Defined in Torch.Typed.NN.DataParallel gradConcurrently :: HList parameters -> HList losses -> Concurrently (HList gradients) Source # |
data GradConcurrentlyF Source #
Instances
(HasGrad (HList parameters) (HList gradients), Castable (HList gradients) [ATenTensor]) => Apply' GradConcurrentlyF (HList parameters, Loss device dtype) (Concurrently (HList gradients)) Source # | |
Defined in Torch.Typed.NN.DataParallel apply' :: GradConcurrentlyF -> (HList parameters, Loss device dtype) -> Concurrently (HList gradients) Source # |
class ReduceGradients (device' :: (DeviceType, Nat)) (devices :: [(DeviceType, Nat)]) xxs ys | device' devices xxs -> ys where Source #
reduceGradients :: HList xxs -> HList ys Source #
Instances
HasToDevice device' device (HList xs) (HList ys) => ReduceGradients device' '[device] ('[HList xs] :: [Type]) (ys :: [k]) Source # | |
Defined in Torch.Typed.NN.DataParallel | |
(HasToDevice device' device (HList xs) (HList ys), ReduceGradients device' devices xxs ys, HZipWith SumF ys ys ys, 1 <= ListLength xxs) => ReduceGradients device' (device ': devices) (HList xs ': xxs :: [Type]) (ys :: [k]) Source # | |
Defined in Torch.Typed.NN.DataParallel |