{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeInType #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# OPTIONS_GHC -Wno-partial-type-signatures #-}

module Torch.Typed.NN.Convolution where

import Data.Proxy
import GHC.Generics
import GHC.TypeLits
import qualified Torch.DType as D
import qualified Torch.Device as D
import Torch.NN (HasForward (..), Randomizable (..))
import Torch.Typed.Auxiliary
import Torch.Typed.Factories
import Torch.Typed.Functional
import Torch.Typed.Parameter
import Torch.Typed.Tensor

data
  Conv1dSpec
    (inputChannelSize :: Nat)
    (outputChannelSize :: Nat)
    (kernelSize :: Nat)
    (dtype :: D.DType)
    (device :: (D.DeviceType, Nat))
  = Conv1dSpec
  deriving (Int
-> Conv1dSpec
     inputChannelSize outputChannelSize kernelSize dtype device
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Int
-> Conv1dSpec
     inputChannelSize outputChannelSize kernelSize dtype device
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
[Conv1dSpec
   inputChannelSize outputChannelSize kernelSize dtype device]
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Conv1dSpec
  inputChannelSize outputChannelSize kernelSize dtype device
-> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Conv1dSpec
   inputChannelSize outputChannelSize kernelSize dtype device]
-> ShowS
$cshowList :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
[Conv1dSpec
   inputChannelSize outputChannelSize kernelSize dtype device]
-> ShowS
show :: Conv1dSpec
  inputChannelSize outputChannelSize kernelSize dtype device
-> String
$cshow :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Conv1dSpec
  inputChannelSize outputChannelSize kernelSize dtype device
-> String
showsPrec :: Int
-> Conv1dSpec
     inputChannelSize outputChannelSize kernelSize dtype device
-> ShowS
$cshowsPrec :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Int
-> Conv1dSpec
     inputChannelSize outputChannelSize kernelSize dtype device
-> ShowS
Show, Conv1dSpec
  inputChannelSize outputChannelSize kernelSize dtype device
-> Conv1dSpec
     inputChannelSize outputChannelSize kernelSize dtype device
-> Bool
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Conv1dSpec
  inputChannelSize outputChannelSize kernelSize dtype device
-> Conv1dSpec
     inputChannelSize outputChannelSize kernelSize dtype device
-> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Conv1dSpec
  inputChannelSize outputChannelSize kernelSize dtype device
-> Conv1dSpec
     inputChannelSize outputChannelSize kernelSize dtype device
-> Bool
$c/= :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Conv1dSpec
  inputChannelSize outputChannelSize kernelSize dtype device
-> Conv1dSpec
     inputChannelSize outputChannelSize kernelSize dtype device
-> Bool
== :: Conv1dSpec
  inputChannelSize outputChannelSize kernelSize dtype device
-> Conv1dSpec
     inputChannelSize outputChannelSize kernelSize dtype device
-> Bool
$c== :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Conv1dSpec
  inputChannelSize outputChannelSize kernelSize dtype device
-> Conv1dSpec
     inputChannelSize outputChannelSize kernelSize dtype device
-> Bool
Eq)

data
  Conv1d
    (inputChannelSize :: Nat)
    (outputChannelSize :: Nat)
    (kernelSize :: Nat)
    (dtype :: D.DType)
    (device :: (D.DeviceType, Nat))
  where
  Conv1d ::
    forall inputChannelSize outputChannelSize kernelSize dtype device.
    { forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Conv1d inputChannelSize outputChannelSize kernelSize dtype device
-> Parameter
     device dtype '[outputChannelSize, inputChannelSize, kernelSize]
weight :: Parameter device dtype '[outputChannelSize, inputChannelSize, kernelSize],
      forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Conv1d inputChannelSize outputChannelSize kernelSize dtype device
-> Parameter device dtype '[outputChannelSize]
bias :: Parameter device dtype '[outputChannelSize]
    } ->
    Conv1d
      inputChannelSize
      outputChannelSize
      kernelSize
      dtype
      device
  deriving (Int
-> Conv1d
     inputChannelSize outputChannelSize kernelSize dtype device
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Int
-> Conv1d
     inputChannelSize outputChannelSize kernelSize dtype device
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
[Conv1d inputChannelSize outputChannelSize kernelSize dtype device]
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Conv1d inputChannelSize outputChannelSize kernelSize dtype device
-> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Conv1d inputChannelSize outputChannelSize kernelSize dtype device]
-> ShowS
$cshowList :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
[Conv1d inputChannelSize outputChannelSize kernelSize dtype device]
-> ShowS
show :: Conv1d inputChannelSize outputChannelSize kernelSize dtype device
-> String
$cshow :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Conv1d inputChannelSize outputChannelSize kernelSize dtype device
-> String
showsPrec :: Int
-> Conv1d
     inputChannelSize outputChannelSize kernelSize dtype device
-> ShowS
$cshowsPrec :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Int
-> Conv1d
     inputChannelSize outputChannelSize kernelSize dtype device
-> ShowS
Show, forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat))
       x.
Rep
  (Conv1d inputChannelSize outputChannelSize kernelSize dtype device)
  x
-> Conv1d
     inputChannelSize outputChannelSize kernelSize dtype device
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat))
       x.
Conv1d inputChannelSize outputChannelSize kernelSize dtype device
-> Rep
     (Conv1d inputChannelSize outputChannelSize kernelSize dtype device)
     x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat))
       x.
Rep
  (Conv1d inputChannelSize outputChannelSize kernelSize dtype device)
  x
-> Conv1d
     inputChannelSize outputChannelSize kernelSize dtype device
$cfrom :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat))
       x.
Conv1d inputChannelSize outputChannelSize kernelSize dtype device
-> Rep
     (Conv1d inputChannelSize outputChannelSize kernelSize dtype device)
     x
Generic, forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Conv1d inputChannelSize outputChannelSize kernelSize dtype device
-> HList
     (Parameters
        (Conv1d
           inputChannelSize outputChannelSize kernelSize dtype device))
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Conv1d inputChannelSize outputChannelSize kernelSize dtype device
-> HList
     (Parameters
        (Conv1d
           inputChannelSize outputChannelSize kernelSize dtype device))
-> Conv1d
     inputChannelSize outputChannelSize kernelSize dtype device
forall f.
(f -> HList (Parameters f))
-> (f -> HList (Parameters f) -> f) -> Parameterized f
replaceParameters :: Conv1d inputChannelSize outputChannelSize kernelSize dtype device
-> HList
     (Parameters
        (Conv1d
           inputChannelSize outputChannelSize kernelSize dtype device))
-> Conv1d
     inputChannelSize outputChannelSize kernelSize dtype device
$creplaceParameters :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Conv1d inputChannelSize outputChannelSize kernelSize dtype device
-> HList
     (Parameters
        (Conv1d
           inputChannelSize outputChannelSize kernelSize dtype device))
-> Conv1d
     inputChannelSize outputChannelSize kernelSize dtype device
flattenParameters :: Conv1d inputChannelSize outputChannelSize kernelSize dtype device
-> HList
     (Parameters
        (Conv1d
           inputChannelSize outputChannelSize kernelSize dtype device))
$cflattenParameters :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Conv1d inputChannelSize outputChannelSize kernelSize dtype device
-> HList
     (Parameters
        (Conv1d
           inputChannelSize outputChannelSize kernelSize dtype device))
Parameterized)

-- | conv1d
-- The constraints on this one are _very_ involved, so the partial signatures
-- make the code significantly cleaner.
conv1dForward ::
  forall stride padding.
  _ =>
  Conv1d _ _ _ _ _ ->
  Tensor _ _ _ ->
  Tensor _ _ _
conv1dForward :: Conv1d inputChannelSize outputChannelSize kernelSize w w
-> Tensor w w '[batchSize, inputChannelSize, inputSize]
-> Tensor
     w
     w
     '[batchSize, outputChannelSize,
       Div ((inputSize + (2 * padding)) - kernelSize) stride + 1]
conv1dForward Conv1d {Parameter w w '[outputChannelSize, inputChannelSize, kernelSize]
Parameter w w '[outputChannelSize]
bias :: Parameter w w '[outputChannelSize]
weight :: Parameter w w '[outputChannelSize, inputChannelSize, kernelSize]
$sel:bias:Conv1d :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Conv1d inputChannelSize outputChannelSize kernelSize dtype device
-> Parameter device dtype '[outputChannelSize]
$sel:weight:Conv1d :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Conv1d inputChannelSize outputChannelSize kernelSize dtype device
-> Parameter
     device dtype '[outputChannelSize, inputChannelSize, kernelSize]
..} Tensor w w '[batchSize, inputChannelSize, inputSize]
input =
  forall (stride :: Nat) (padding :: Nat) (inputChannelSize :: Nat)
       (outputChannelSize :: Nat) (kernelSize :: Nat) (inputSize :: Nat)
       (batchSize :: Nat) (outputSize :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(All
   KnownNat
   '[stride, padding, inputChannelSize, outputChannelSize, kernelSize,
     inputSize, batchSize, outputSize],
 ConvSideCheck inputSize kernelSize stride padding outputSize) =>
Tensor
  device dtype '[outputChannelSize, inputChannelSize, kernelSize]
-> Tensor device dtype '[outputChannelSize]
-> Tensor device dtype '[batchSize, inputChannelSize, inputSize]
-> Tensor device dtype '[batchSize, outputChannelSize, outputSize]
conv1d @stride @padding
    (forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Parameter device dtype shape -> Tensor device dtype shape
toDependent Parameter w w '[outputChannelSize, inputChannelSize, kernelSize]
weight)
    (forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Parameter device dtype shape -> Tensor device dtype shape
toDependent Parameter w w '[outputChannelSize]
bias)
    Tensor w w '[batchSize, inputChannelSize, inputSize]
input

instance
  ( All
      KnownNat
      '[ stride,
         padding,
         inputChannelSize,
         outputChannelSize,
         kernelSize,
         inputSize,
         batchSize,
         outputSize
       ],
    ConvSideCheck inputSize kernelSize stride padding outputSize
  ) =>
  HasForward (Conv1d inputChannelSize outputChannelSize kernelSize dtype device) (Tensor device dtype '[batchSize, inputChannelSize, inputSize], Proxy stride, Proxy padding) (Tensor device dtype '[batchSize, outputChannelSize, outputSize])
  where
  forward :: Conv1d inputChannelSize outputChannelSize kernelSize dtype device
-> (Tensor device dtype '[batchSize, inputChannelSize, inputSize],
    Proxy stride, Proxy padding)
-> Tensor device dtype '[batchSize, outputChannelSize, outputSize]
forward Conv1d inputChannelSize outputChannelSize kernelSize dtype device
model (Tensor device dtype '[batchSize, inputChannelSize, inputSize]
input, Proxy stride
Proxy, Proxy padding
Proxy) = forall (stride :: Nat) (padding :: Nat) {kernelSize :: Nat}
       {inputSize :: Nat} {inputChannelSize :: Nat}
       {outputChannelSize :: Nat} {batchSize :: Nat} {w :: DType}
       {w :: (DeviceType, Nat)}.
(OrdCond
   (CmpNat kernelSize (inputSize + (2 * padding))) 'True 'True 'False
 ~ 'True,
 OrdCond
   (CmpNat (kernelSize - 1) (inputSize + (2 * padding)))
   'True
   'True
   'False
 ~ 'True,
 OrdCond (CmpNat 1 stride) 'True 'True 'False ~ 'True,
 OrdCond (CmpNat 1 kernelSize) 'True 'True 'False ~ 'True,
 KnownNat padding, KnownNat inputSize, KnownNat kernelSize,
 KnownNat stride, KnownNat inputChannelSize,
 KnownNat outputChannelSize, KnownNat batchSize) =>
Conv1d inputChannelSize outputChannelSize kernelSize w w
-> Tensor w w '[batchSize, inputChannelSize, inputSize]
-> Tensor
     w
     w
     '[batchSize, outputChannelSize,
       Div ((inputSize + (2 * padding)) - kernelSize) stride + 1]
conv1dForward @stride @padding Conv1d inputChannelSize outputChannelSize kernelSize dtype device
model Tensor device dtype '[batchSize, inputChannelSize, inputSize]
input
  forwardStoch :: Conv1d inputChannelSize outputChannelSize kernelSize dtype device
-> (Tensor device dtype '[batchSize, inputChannelSize, inputSize],
    Proxy stride, Proxy padding)
-> IO
     (Tensor device dtype '[batchSize, outputChannelSize, outputSize])
forwardStoch = (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
.) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall f a b. HasForward f a b => f -> a -> b
forward

instance
  ( KnownNat inputChannelSize,
    KnownNat outputChannelSize,
    KnownNat kernelSize,
    KnownDType dtype,
    KnownDevice device,
    RandDTypeIsValid device dtype
  ) =>
  Randomizable
    (Conv1dSpec inputChannelSize outputChannelSize kernelSize dtype device)
    (Conv1d inputChannelSize outputChannelSize kernelSize dtype device)
  where
  sample :: Conv1dSpec
  inputChannelSize outputChannelSize kernelSize dtype device
-> IO
     (Conv1d inputChannelSize outputChannelSize kernelSize dtype device)
sample Conv1dSpec
  inputChannelSize outputChannelSize kernelSize dtype device
Conv1dSpec =
    forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Parameter
  device dtype '[outputChannelSize, inputChannelSize, kernelSize]
-> Parameter device dtype '[outputChannelSize]
-> Conv1d
     inputChannelSize outputChannelSize kernelSize dtype device
Conv1d forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(TensorOptions shape dtype device,
 RandDTypeIsValid device dtype) =>
IO (Tensor device dtype shape)
randn) forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(TensorOptions shape dtype device,
 RandDTypeIsValid device dtype) =>
IO (Tensor device dtype shape)
randn)

data
  Conv2dSpec
    (inputChannelSize :: Nat)
    (outputChannelSize :: Nat)
    (kernelSize0 :: Nat)
    (kernelSize1 :: Nat)
    (dtype :: D.DType)
    (device :: (D.DeviceType, Nat))
  = Conv2dSpec
  deriving (Int
-> Conv2dSpec
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     dtype
     device
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Int
-> Conv2dSpec
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     dtype
     device
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
[Conv2dSpec
   inputChannelSize
   outputChannelSize
   kernelSize0
   kernelSize1
   dtype
   device]
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Conv2dSpec
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Conv2dSpec
   inputChannelSize
   outputChannelSize
   kernelSize0
   kernelSize1
   dtype
   device]
-> ShowS
$cshowList :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
[Conv2dSpec
   inputChannelSize
   outputChannelSize
   kernelSize0
   kernelSize1
   dtype
   device]
-> ShowS
show :: Conv2dSpec
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> String
$cshow :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Conv2dSpec
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> String
showsPrec :: Int
-> Conv2dSpec
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     dtype
     device
-> ShowS
$cshowsPrec :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Int
-> Conv2dSpec
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     dtype
     device
-> ShowS
Show, Conv2dSpec
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> Conv2dSpec
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     dtype
     device
-> Bool
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Conv2dSpec
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> Conv2dSpec
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     dtype
     device
-> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Conv2dSpec
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> Conv2dSpec
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     dtype
     device
-> Bool
$c/= :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Conv2dSpec
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> Conv2dSpec
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     dtype
     device
-> Bool
== :: Conv2dSpec
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> Conv2dSpec
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     dtype
     device
-> Bool
$c== :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Conv2dSpec
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> Conv2dSpec
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     dtype
     device
-> Bool
Eq)

data
  Conv2d
    (inputChannelSize :: Nat)
    (outputChannelSize :: Nat)
    (kernelSize0 :: Nat)
    (kernelSize1 :: Nat)
    (dtype :: D.DType)
    (device :: (D.DeviceType, Nat))
  where
  Conv2d ::
    forall inputChannelSize outputChannelSize kernelSize0 kernelSize1 dtype device.
    { forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Conv2d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> Parameter
     device
     dtype
     '[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1]
weight :: Parameter device dtype '[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1],
      forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Conv2d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> Parameter device dtype '[outputChannelSize]
bias :: Parameter device dtype '[outputChannelSize]
    } ->
    Conv2d
      inputChannelSize
      outputChannelSize
      kernelSize0
      kernelSize1
      dtype
      device
  deriving (Int
-> Conv2d
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     dtype
     device
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Int
-> Conv2d
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     dtype
     device
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
[Conv2d
   inputChannelSize
   outputChannelSize
   kernelSize0
   kernelSize1
   dtype
   device]
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Conv2d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Conv2d
   inputChannelSize
   outputChannelSize
   kernelSize0
   kernelSize1
   dtype
   device]
-> ShowS
$cshowList :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
[Conv2d
   inputChannelSize
   outputChannelSize
   kernelSize0
   kernelSize1
   dtype
   device]
-> ShowS
show :: Conv2d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> String
$cshow :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Conv2d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> String
showsPrec :: Int
-> Conv2d
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     dtype
     device
-> ShowS
$cshowsPrec :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Int
-> Conv2d
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     dtype
     device
-> ShowS
Show, forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)) x.
Rep
  (Conv2d
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     dtype
     device)
  x
-> Conv2d
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     dtype
     device
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)) x.
Conv2d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> Rep
     (Conv2d
        inputChannelSize
        outputChannelSize
        kernelSize0
        kernelSize1
        dtype
        device)
     x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)) x.
Rep
  (Conv2d
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     dtype
     device)
  x
-> Conv2d
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     dtype
     device
$cfrom :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)) x.
Conv2d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> Rep
     (Conv2d
        inputChannelSize
        outputChannelSize
        kernelSize0
        kernelSize1
        dtype
        device)
     x
Generic, forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Conv2d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> HList
     (Parameters
        (Conv2d
           inputChannelSize
           outputChannelSize
           kernelSize0
           kernelSize1
           dtype
           device))
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Conv2d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> HList
     (Parameters
        (Conv2d
           inputChannelSize
           outputChannelSize
           kernelSize0
           kernelSize1
           dtype
           device))
-> Conv2d
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     dtype
     device
forall f.
(f -> HList (Parameters f))
-> (f -> HList (Parameters f) -> f) -> Parameterized f
replaceParameters :: Conv2d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> HList
     (Parameters
        (Conv2d
           inputChannelSize
           outputChannelSize
           kernelSize0
           kernelSize1
           dtype
           device))
-> Conv2d
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     dtype
     device
$creplaceParameters :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Conv2d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> HList
     (Parameters
        (Conv2d
           inputChannelSize
           outputChannelSize
           kernelSize0
           kernelSize1
           dtype
           device))
-> Conv2d
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     dtype
     device
flattenParameters :: Conv2d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> HList
     (Parameters
        (Conv2d
           inputChannelSize
           outputChannelSize
           kernelSize0
           kernelSize1
           dtype
           device))
$cflattenParameters :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Conv2d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> HList
     (Parameters
        (Conv2d
           inputChannelSize
           outputChannelSize
           kernelSize0
           kernelSize1
           dtype
           device))
Parameterized)

-- | conv2d
-- The constraints on this one are _very_ involved, so the partial signatures
-- make the code significantly cleaner.
conv2dForward ::
  forall stride padding.
  _ =>
  Conv2d _ _ _ _ _ _ ->
  Tensor _ _ _ ->
  Tensor _ _ _
conv2dForward :: Conv2d
  inputChannelSize outputChannelSize kernelSize0 kernelSize1 w w
-> Tensor
     w w '[batchSize, inputChannelSize, inputSize0, inputSize1]
-> Tensor
     w
     w
     '[batchSize, outputChannelSize,
       Div ((inputSize0 + (2 * Fst padding)) - kernelSize0) (Fst stride)
       + 1,
       Div ((inputSize1 + (2 * Snd padding)) - kernelSize1) (Snd stride)
       + 1]
conv2dForward Conv2d {Parameter
  w
  w
  '[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1]
Parameter w w '[outputChannelSize]
bias :: Parameter w w '[outputChannelSize]
weight :: Parameter
  w
  w
  '[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1]
$sel:bias:Conv2d :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Conv2d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> Parameter device dtype '[outputChannelSize]
$sel:weight:Conv2d :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Conv2d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> Parameter
     device
     dtype
     '[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1]
..} Tensor w w '[batchSize, inputChannelSize, inputSize0, inputSize1]
input =
  forall (stride :: (Nat, Nat)) (padding :: (Nat, Nat))
       (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (inputSize0 :: Nat)
       (inputSize1 :: Nat) (batchSize :: Nat) (outputSize0 :: Nat)
       (outputSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(All
   KnownNat
   '[Fst stride, Snd stride, Fst padding, Snd padding,
     inputChannelSize, outputChannelSize, kernelSize0, kernelSize1,
     inputSize0, inputSize1, batchSize, outputSize0, outputSize1],
 ConvSideCheck
   inputSize0 kernelSize0 (Fst stride) (Fst padding) outputSize0,
 ConvSideCheck
   inputSize1 kernelSize1 (Snd stride) (Snd padding) outputSize1) =>
Tensor
  device
  dtype
  '[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1]
-> Tensor device dtype '[outputChannelSize]
-> Tensor
     device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1]
-> Tensor
     device
     dtype
     '[batchSize, outputChannelSize, outputSize0, outputSize1]
conv2d @stride @padding
    (forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Parameter device dtype shape -> Tensor device dtype shape
toDependent Parameter
  w
  w
  '[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1]
weight)
    (forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Parameter device dtype shape -> Tensor device dtype shape
toDependent Parameter w w '[outputChannelSize]
bias)
    Tensor w w '[batchSize, inputChannelSize, inputSize0, inputSize1]
input

instance
  ( All
      KnownNat
      '[ Torch.Typed.Auxiliary.Fst stride,
         Torch.Typed.Auxiliary.Snd stride,
         Torch.Typed.Auxiliary.Fst padding,
         Torch.Typed.Auxiliary.Snd padding,
         inputChannelSize,
         outputChannelSize,
         kernelSize0,
         kernelSize1,
         inputSize0,
         inputSize1,
         batchSize,
         outputSize0,
         outputSize1
       ],
    ConvSideCheck inputSize0 kernelSize0 (Torch.Typed.Auxiliary.Fst stride) (Torch.Typed.Auxiliary.Fst padding) outputSize0,
    ConvSideCheck inputSize1 kernelSize1 (Torch.Typed.Auxiliary.Snd stride) (Torch.Typed.Auxiliary.Snd padding) outputSize1
  ) =>
  HasForward (Conv2d inputChannelSize outputChannelSize kernelSize0 kernelSize1 dtype device) (Tensor device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1], Proxy stride, Proxy padding) (Tensor device dtype '[batchSize, outputChannelSize, outputSize0, outputSize1])
  where
  forward :: Conv2d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> (Tensor
      device
      dtype
      '[batchSize, inputChannelSize, inputSize0, inputSize1],
    Proxy stride, Proxy padding)
-> Tensor
     device
     dtype
     '[batchSize, outputChannelSize, outputSize0, outputSize1]
forward Conv2d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
model (Tensor
  device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1]
input, Proxy stride
Proxy, Proxy padding
Proxy) = forall (stride :: (Nat, Nat)) (padding :: (Nat, Nat))
       {kernelSize1 :: Nat} {inputSize1 :: Nat} {kernelSize0 :: Nat}
       {inputSize0 :: Nat} {inputChannelSize :: Nat}
       {outputChannelSize :: Nat} {batchSize :: Nat} {w :: DType}
       {w :: (DeviceType, Nat)}.
(OrdCond
   (CmpNat kernelSize1 (inputSize1 + (2 * Snd padding)))
   'True
   'True
   'False
 ~ 'True,
 OrdCond
   (CmpNat kernelSize0 (inputSize0 + (2 * Fst padding)))
   'True
   'True
   'False
 ~ 'True,
 OrdCond
   (CmpNat (kernelSize1 - 1) (inputSize1 + (2 * Snd padding)))
   'True
   'True
   'False
 ~ 'True,
 OrdCond
   (CmpNat (kernelSize0 - 1) (inputSize0 + (2 * Fst padding)))
   'True
   'True
   'False
 ~ 'True,
 OrdCond (CmpNat 1 kernelSize1) 'True 'True 'False ~ 'True,
 OrdCond (CmpNat 1 kernelSize0) 'True 'True 'False ~ 'True,
 OrdCond (CmpNat 1 (Snd stride)) 'True 'True 'False ~ 'True,
 OrdCond (CmpNat 1 (Fst stride)) 'True 'True 'False ~ 'True,
 KnownNat inputSize1, KnownNat inputSize0, KnownNat kernelSize1,
 KnownNat kernelSize0, KnownNat inputChannelSize,
 KnownNat outputChannelSize, KnownNat batchSize,
 KnownNat (Fst padding), KnownNat (Fst stride),
 KnownNat (Snd padding), KnownNat (Snd stride)) =>
Conv2d
  inputChannelSize outputChannelSize kernelSize0 kernelSize1 w w
-> Tensor
     w w '[batchSize, inputChannelSize, inputSize0, inputSize1]
-> Tensor
     w
     w
     '[batchSize, outputChannelSize,
       Div ((inputSize0 + (2 * Fst padding)) - kernelSize0) (Fst stride)
       + 1,
       Div ((inputSize1 + (2 * Snd padding)) - kernelSize1) (Snd stride)
       + 1]
conv2dForward @stride @padding Conv2d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
model Tensor
  device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1]
input
  forwardStoch :: Conv2d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> (Tensor
      device
      dtype
      '[batchSize, inputChannelSize, inputSize0, inputSize1],
    Proxy stride, Proxy padding)
-> IO
     (Tensor
        device
        dtype
        '[batchSize, outputChannelSize, outputSize0, outputSize1])
forwardStoch = (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
.) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall f a b. HasForward f a b => f -> a -> b
forward

instance
  ( KnownNat inputChannelSize,
    KnownNat outputChannelSize,
    KnownNat kernelSize0,
    KnownNat kernelSize1,
    KnownDType dtype,
    KnownDevice device,
    RandDTypeIsValid device dtype
  ) =>
  Randomizable
    (Conv2dSpec inputChannelSize outputChannelSize kernelSize0 kernelSize1 dtype device)
    (Conv2d inputChannelSize outputChannelSize kernelSize0 kernelSize1 dtype device)
  where
  sample :: Conv2dSpec
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> IO
     (Conv2d
        inputChannelSize
        outputChannelSize
        kernelSize0
        kernelSize1
        dtype
        device)
sample Conv2dSpec
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
Conv2dSpec =
    forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Parameter
  device
  dtype
  '[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1]
-> Parameter device dtype '[outputChannelSize]
-> Conv2d
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     dtype
     device
Conv2d forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(TensorOptions shape dtype device,
 RandDTypeIsValid device dtype) =>
IO (Tensor device dtype shape)
randn) forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(TensorOptions shape dtype device,
 RandDTypeIsValid device dtype) =>
IO (Tensor device dtype shape)
randn)

data
  Conv3dSpec
    (inputChannelSize :: Nat)
    (outputChannelSize :: Nat)
    (kernelSize0 :: Nat)
    (kernelSize1 :: Nat)
    (kernelSize2 :: Nat)
    (dtype :: D.DType)
    (device :: (D.DeviceType, Nat))
  = Conv3dSpec
  deriving (Int
-> Conv3dSpec
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     kernelSize2
     dtype
     device
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Int
-> Conv3dSpec
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     kernelSize2
     dtype
     device
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
[Conv3dSpec
   inputChannelSize
   outputChannelSize
   kernelSize0
   kernelSize1
   kernelSize2
   dtype
   device]
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Conv3dSpec
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  kernelSize2
  dtype
  device
-> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Conv3dSpec
   inputChannelSize
   outputChannelSize
   kernelSize0
   kernelSize1
   kernelSize2
   dtype
   device]
-> ShowS
$cshowList :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
[Conv3dSpec
   inputChannelSize
   outputChannelSize
   kernelSize0
   kernelSize1
   kernelSize2
   dtype
   device]
-> ShowS
show :: Conv3dSpec
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  kernelSize2
  dtype
  device
-> String
$cshow :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Conv3dSpec
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  kernelSize2
  dtype
  device
-> String
showsPrec :: Int
-> Conv3dSpec
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     kernelSize2
     dtype
     device
-> ShowS
$cshowsPrec :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Int
-> Conv3dSpec
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     kernelSize2
     dtype
     device
-> ShowS
Show, Conv3dSpec
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  kernelSize2
  dtype
  device
-> Conv3dSpec
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     kernelSize2
     dtype
     device
-> Bool
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Conv3dSpec
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  kernelSize2
  dtype
  device
-> Conv3dSpec
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     kernelSize2
     dtype
     device
-> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Conv3dSpec
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  kernelSize2
  dtype
  device
-> Conv3dSpec
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     kernelSize2
     dtype
     device
-> Bool
$c/= :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Conv3dSpec
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  kernelSize2
  dtype
  device
-> Conv3dSpec
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     kernelSize2
     dtype
     device
-> Bool
== :: Conv3dSpec
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  kernelSize2
  dtype
  device
-> Conv3dSpec
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     kernelSize2
     dtype
     device
-> Bool
$c== :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Conv3dSpec
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  kernelSize2
  dtype
  device
-> Conv3dSpec
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     kernelSize2
     dtype
     device
-> Bool
Eq)

data
  Conv3d
    (inputChannelSize :: Nat)
    (outputChannelSize :: Nat)
    (kernelSize0 :: Nat)
    (kernelSize1 :: Nat)
    (kernelSize2 :: Nat)
    (dtype :: D.DType)
    (device :: (D.DeviceType, Nat))
  where
  Conv3d ::
    forall inputChannelSize outputChannelSize kernelSize0 kernelSize1 kernelSize2 dtype device.
    { forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Conv3d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  kernelSize2
  dtype
  device
-> Parameter
     device
     dtype
     '[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1,
       kernelSize2]
weight :: Parameter device dtype '[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1, kernelSize2],
      forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Conv3d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  kernelSize2
  dtype
  device
-> Parameter device dtype '[outputChannelSize]
bias :: Parameter device dtype '[outputChannelSize]
    } ->
    Conv3d
      inputChannelSize
      outputChannelSize
      kernelSize0
      kernelSize1
      kernelSize2
      dtype
      device
  deriving (Int
-> Conv3d
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     kernelSize2
     dtype
     device
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Int
-> Conv3d
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     kernelSize2
     dtype
     device
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
[Conv3d
   inputChannelSize
   outputChannelSize
   kernelSize0
   kernelSize1
   kernelSize2
   dtype
   device]
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Conv3d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  kernelSize2
  dtype
  device
-> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Conv3d
   inputChannelSize
   outputChannelSize
   kernelSize0
   kernelSize1
   kernelSize2
   dtype
   device]
-> ShowS
$cshowList :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
[Conv3d
   inputChannelSize
   outputChannelSize
   kernelSize0
   kernelSize1
   kernelSize2
   dtype
   device]
-> ShowS
show :: Conv3d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  kernelSize2
  dtype
  device
-> String
$cshow :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Conv3d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  kernelSize2
  dtype
  device
-> String
showsPrec :: Int
-> Conv3d
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     kernelSize2
     dtype
     device
-> ShowS
$cshowsPrec :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Int
-> Conv3d
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     kernelSize2
     dtype
     device
-> ShowS
Show, forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)) x.
Rep
  (Conv3d
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     kernelSize2
     dtype
     device)
  x
-> Conv3d
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     kernelSize2
     dtype
     device
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)) x.
Conv3d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  kernelSize2
  dtype
  device
-> Rep
     (Conv3d
        inputChannelSize
        outputChannelSize
        kernelSize0
        kernelSize1
        kernelSize2
        dtype
        device)
     x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)) x.
Rep
  (Conv3d
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     kernelSize2
     dtype
     device)
  x
-> Conv3d
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     kernelSize2
     dtype
     device
$cfrom :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)) x.
Conv3d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  kernelSize2
  dtype
  device
-> Rep
     (Conv3d
        inputChannelSize
        outputChannelSize
        kernelSize0
        kernelSize1
        kernelSize2
        dtype
        device)
     x
Generic, forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Conv3d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  kernelSize2
  dtype
  device
-> HList
     (Parameters
        (Conv3d
           inputChannelSize
           outputChannelSize
           kernelSize0
           kernelSize1
           kernelSize2
           dtype
           device))
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Conv3d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  kernelSize2
  dtype
  device
-> HList
     (Parameters
        (Conv3d
           inputChannelSize
           outputChannelSize
           kernelSize0
           kernelSize1
           kernelSize2
           dtype
           device))
-> Conv3d
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     kernelSize2
     dtype
     device
forall f.
(f -> HList (Parameters f))
-> (f -> HList (Parameters f) -> f) -> Parameterized f
replaceParameters :: Conv3d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  kernelSize2
  dtype
  device
-> HList
     (Parameters
        (Conv3d
           inputChannelSize
           outputChannelSize
           kernelSize0
           kernelSize1
           kernelSize2
           dtype
           device))
-> Conv3d
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     kernelSize2
     dtype
     device
$creplaceParameters :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Conv3d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  kernelSize2
  dtype
  device
-> HList
     (Parameters
        (Conv3d
           inputChannelSize
           outputChannelSize
           kernelSize0
           kernelSize1
           kernelSize2
           dtype
           device))
-> Conv3d
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     kernelSize2
     dtype
     device
flattenParameters :: Conv3d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  kernelSize2
  dtype
  device
-> HList
     (Parameters
        (Conv3d
           inputChannelSize
           outputChannelSize
           kernelSize0
           kernelSize1
           kernelSize2
           dtype
           device))
$cflattenParameters :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Conv3d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  kernelSize2
  dtype
  device
-> HList
     (Parameters
        (Conv3d
           inputChannelSize
           outputChannelSize
           kernelSize0
           kernelSize1
           kernelSize2
           dtype
           device))
Parameterized)

-- | conv3d
-- The constraints on this one are _very_ involved, so the partial signatures
-- make the code significantly cleaner.
conv3dForward ::
  forall stride padding.
  _ =>
  Conv3d _ _ _ _ _ _ _ ->
  Tensor _ _ _ ->
  Tensor _ _ _
conv3dForward :: Conv3d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  kernelSize2
  w
  w
-> Tensor
     w
     w
     '[batchSize, inputChannelSize, inputSize0, inputSize1, inputSize2]
-> Tensor
     w
     w
     '[batchSize, outputChannelSize,
       Div ((inputSize0 + (2 * Fst3 padding)) - kernelSize0) (Fst3 stride)
       + 1,
       Div ((inputSize1 + (2 * Snd3 padding)) - kernelSize1) (Snd3 stride)
       + 1,
       Div ((inputSize2 + (2 * Trd3 padding)) - kernelSize2) (Trd3 stride)
       + 1]
conv3dForward Conv3d {Parameter
  w
  w
  '[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1,
    kernelSize2]
Parameter w w '[outputChannelSize]
bias :: Parameter w w '[outputChannelSize]
weight :: Parameter
  w
  w
  '[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1,
    kernelSize2]
$sel:bias:Conv3d :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Conv3d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  kernelSize2
  dtype
  device
-> Parameter device dtype '[outputChannelSize]
$sel:weight:Conv3d :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Conv3d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  kernelSize2
  dtype
  device
-> Parameter
     device
     dtype
     '[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1,
       kernelSize2]
..} Tensor
  w
  w
  '[batchSize, inputChannelSize, inputSize0, inputSize1, inputSize2]
input =
  forall (stride :: (Nat, Nat, Nat)) (padding :: (Nat, Nat, Nat))
       (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
       (inputSize0 :: Nat) (inputSize1 :: Nat) (inputSize2 :: Nat)
       (batchSize :: Nat) (outputSize0 :: Nat) (outputSize1 :: Nat)
       (outputSize2 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(All
   KnownNat
   '[Fst3 stride, Snd3 stride, Trd3 stride, Fst3 padding,
     Snd3 padding, Trd3 padding, inputChannelSize, outputChannelSize,
     kernelSize0, kernelSize1, kernelSize2, inputSize0, inputSize1,
     inputSize2, batchSize],
 ConvSideCheck
   inputSize0 kernelSize0 (Fst3 stride) (Fst3 padding) outputSize0,
 ConvSideCheck
   inputSize1 kernelSize1 (Snd3 stride) (Snd3 padding) outputSize1,
 ConvSideCheck
   inputSize2 kernelSize2 (Trd3 stride) (Trd3 padding) outputSize2) =>
Tensor
  device
  dtype
  '[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1,
    kernelSize2]
-> Tensor device dtype '[outputChannelSize]
-> Tensor
     device
     dtype
     '[batchSize, inputChannelSize, inputSize0, inputSize1, inputSize2]
-> Tensor
     device
     dtype
     '[batchSize, outputChannelSize, outputSize0, outputSize1,
       outputSize2]
conv3d @stride @padding
    (forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Parameter device dtype shape -> Tensor device dtype shape
toDependent Parameter
  w
  w
  '[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1,
    kernelSize2]
weight)
    (forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Parameter device dtype shape -> Tensor device dtype shape
toDependent Parameter w w '[outputChannelSize]
bias)
    Tensor
  w
  w
  '[batchSize, inputChannelSize, inputSize0, inputSize1, inputSize2]
input

instance
  ( All
      KnownNat
      '[ Fst3 stride,
         Snd3 stride,
         Trd3 stride,
         Fst3 padding,
         Snd3 padding,
         Trd3 padding,
         inputChannelSize,
         outputChannelSize,
         kernelSize0,
         kernelSize1,
         kernelSize2,
         inputSize0,
         inputSize1,
         inputSize2,
         batchSize
       ],
    ConvSideCheck inputSize0 kernelSize0 (Fst3 stride) (Fst3 padding) outputSize0,
    ConvSideCheck inputSize1 kernelSize1 (Snd3 stride) (Snd3 padding) outputSize1,
    ConvSideCheck inputSize2 kernelSize2 (Trd3 stride) (Trd3 padding) outputSize2
  ) =>
  HasForward (Conv3d inputChannelSize outputChannelSize kernelSize0 kernelSize1 kernelSize2 dtype device) (Tensor device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1, inputSize2], Proxy stride, Proxy padding) (Tensor device dtype '[batchSize, outputChannelSize, outputSize0, outputSize1, outputSize2])
  where
  forward :: Conv3d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  kernelSize2
  dtype
  device
-> (Tensor
      device
      dtype
      '[batchSize, inputChannelSize, inputSize0, inputSize1, inputSize2],
    Proxy stride, Proxy padding)
-> Tensor
     device
     dtype
     '[batchSize, outputChannelSize, outputSize0, outputSize1,
       outputSize2]
forward Conv3d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  kernelSize2
  dtype
  device
model (Tensor
  device
  dtype
  '[batchSize, inputChannelSize, inputSize0, inputSize1, inputSize2]
input, Proxy stride
Proxy, Proxy padding
Proxy) = forall (stride :: (Nat, Nat, Nat)) (padding :: (Nat, Nat, Nat))
       {kernelSize2 :: Nat} {inputSize2 :: Nat} {kernelSize1 :: Nat}
       {inputSize1 :: Nat} {kernelSize0 :: Nat} {inputSize0 :: Nat}
       {inputChannelSize :: Nat} {outputChannelSize :: Nat}
       {batchSize :: Nat} {w :: DType} {w :: (DeviceType, Nat)}.
(OrdCond
   (CmpNat kernelSize2 (inputSize2 + (2 * Trd3 padding)))
   'True
   'True
   'False
 ~ 'True,
 OrdCond
   (CmpNat kernelSize1 (inputSize1 + (2 * Snd3 padding)))
   'True
   'True
   'False
 ~ 'True,
 OrdCond
   (CmpNat kernelSize0 (inputSize0 + (2 * Fst3 padding)))
   'True
   'True
   'False
 ~ 'True,
 OrdCond
   (CmpNat (kernelSize1 - 1) (inputSize1 + (2 * Snd3 padding)))
   'True
   'True
   'False
 ~ 'True,
 OrdCond
   (CmpNat (kernelSize0 - 1) (inputSize0 + (2 * Fst3 padding)))
   'True
   'True
   'False
 ~ 'True,
 OrdCond
   (CmpNat (kernelSize2 - 1) (inputSize2 + (2 * Trd3 padding)))
   'True
   'True
   'False
 ~ 'True,
 OrdCond (CmpNat 1 kernelSize2) 'True 'True 'False ~ 'True,
 OrdCond (CmpNat 1 kernelSize0) 'True 'True 'False ~ 'True,
 OrdCond (CmpNat 1 kernelSize1) 'True 'True 'False ~ 'True,
 OrdCond (CmpNat 1 (Trd3 stride)) 'True 'True 'False ~ 'True,
 OrdCond (CmpNat 1 (Snd3 stride)) 'True 'True 'False ~ 'True,
 OrdCond (CmpNat 1 (Fst3 stride)) 'True 'True 'False ~ 'True,
 KnownNat inputSize2, KnownNat inputSize1, KnownNat inputSize0,
 KnownNat kernelSize2, KnownNat kernelSize1, KnownNat kernelSize0,
 KnownNat inputChannelSize, KnownNat outputChannelSize,
 KnownNat batchSize, KnownNat (Fst3 padding),
 KnownNat (Fst3 stride), KnownNat (Snd3 padding),
 KnownNat (Snd3 stride), KnownNat (Trd3 padding),
 KnownNat (Trd3 stride)) =>
Conv3d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  kernelSize2
  w
  w
-> Tensor
     w
     w
     '[batchSize, inputChannelSize, inputSize0, inputSize1, inputSize2]
-> Tensor
     w
     w
     '[batchSize, outputChannelSize,
       Div ((inputSize0 + (2 * Fst3 padding)) - kernelSize0) (Fst3 stride)
       + 1,
       Div ((inputSize1 + (2 * Snd3 padding)) - kernelSize1) (Snd3 stride)
       + 1,
       Div ((inputSize2 + (2 * Trd3 padding)) - kernelSize2) (Trd3 stride)
       + 1]
conv3dForward @stride @padding Conv3d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  kernelSize2
  dtype
  device
model Tensor
  device
  dtype
  '[batchSize, inputChannelSize, inputSize0, inputSize1, inputSize2]
input
  forwardStoch :: Conv3d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  kernelSize2
  dtype
  device
-> (Tensor
      device
      dtype
      '[batchSize, inputChannelSize, inputSize0, inputSize1, inputSize2],
    Proxy stride, Proxy padding)
-> IO
     (Tensor
        device
        dtype
        '[batchSize, outputChannelSize, outputSize0, outputSize1,
          outputSize2])
forwardStoch = (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
.) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall f a b. HasForward f a b => f -> a -> b
forward

instance
  ( KnownNat inputChannelSize,
    KnownNat outputChannelSize,
    KnownNat kernelSize0,
    KnownNat kernelSize1,
    KnownNat kernelSize2,
    KnownDType dtype,
    KnownDevice device,
    RandDTypeIsValid device dtype
  ) =>
  Randomizable
    (Conv3dSpec inputChannelSize outputChannelSize kernelSize0 kernelSize1 kernelSize2 dtype device)
    (Conv3d inputChannelSize outputChannelSize kernelSize0 kernelSize1 kernelSize2 dtype device)
  where
  sample :: Conv3dSpec
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  kernelSize2
  dtype
  device
-> IO
     (Conv3d
        inputChannelSize
        outputChannelSize
        kernelSize0
        kernelSize1
        kernelSize2
        dtype
        device)
sample Conv3dSpec
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  kernelSize2
  dtype
  device
Conv3dSpec =
    forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Parameter
  device
  dtype
  '[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1,
    kernelSize2]
-> Parameter device dtype '[outputChannelSize]
-> Conv3d
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     kernelSize2
     dtype
     device
Conv3d forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(TensorOptions shape dtype device,
 RandDTypeIsValid device dtype) =>
IO (Tensor device dtype shape)
randn) forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(TensorOptions shape dtype device,
 RandDTypeIsValid device dtype) =>
IO (Tensor device dtype shape)
randn)

data
  ConvTranspose1dSpec
    (inputChannelSize :: Nat)
    (outputChannelSize :: Nat)
    (kernelSize :: Nat)
    (dtype :: D.DType)
    (device :: (D.DeviceType, Nat))
  = ConvTranspose1dSpec
  deriving (Int
-> ConvTranspose1dSpec
     inputChannelSize outputChannelSize kernelSize dtype device
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Int
-> ConvTranspose1dSpec
     inputChannelSize outputChannelSize kernelSize dtype device
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
[ConvTranspose1dSpec
   inputChannelSize outputChannelSize kernelSize dtype device]
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
ConvTranspose1dSpec
  inputChannelSize outputChannelSize kernelSize dtype device
-> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ConvTranspose1dSpec
   inputChannelSize outputChannelSize kernelSize dtype device]
-> ShowS
$cshowList :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
[ConvTranspose1dSpec
   inputChannelSize outputChannelSize kernelSize dtype device]
-> ShowS
show :: ConvTranspose1dSpec
  inputChannelSize outputChannelSize kernelSize dtype device
-> String
$cshow :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
ConvTranspose1dSpec
  inputChannelSize outputChannelSize kernelSize dtype device
-> String
showsPrec :: Int
-> ConvTranspose1dSpec
     inputChannelSize outputChannelSize kernelSize dtype device
-> ShowS
$cshowsPrec :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Int
-> ConvTranspose1dSpec
     inputChannelSize outputChannelSize kernelSize dtype device
-> ShowS
Show, ConvTranspose1dSpec
  inputChannelSize outputChannelSize kernelSize dtype device
-> ConvTranspose1dSpec
     inputChannelSize outputChannelSize kernelSize dtype device
-> Bool
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
ConvTranspose1dSpec
  inputChannelSize outputChannelSize kernelSize dtype device
-> ConvTranspose1dSpec
     inputChannelSize outputChannelSize kernelSize dtype device
-> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ConvTranspose1dSpec
  inputChannelSize outputChannelSize kernelSize dtype device
-> ConvTranspose1dSpec
     inputChannelSize outputChannelSize kernelSize dtype device
-> Bool
$c/= :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
ConvTranspose1dSpec
  inputChannelSize outputChannelSize kernelSize dtype device
-> ConvTranspose1dSpec
     inputChannelSize outputChannelSize kernelSize dtype device
-> Bool
== :: ConvTranspose1dSpec
  inputChannelSize outputChannelSize kernelSize dtype device
-> ConvTranspose1dSpec
     inputChannelSize outputChannelSize kernelSize dtype device
-> Bool
$c== :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
ConvTranspose1dSpec
  inputChannelSize outputChannelSize kernelSize dtype device
-> ConvTranspose1dSpec
     inputChannelSize outputChannelSize kernelSize dtype device
-> Bool
Eq)

data
  ConvTranspose1d
    (inputChannelSize :: Nat)
    (outputChannelSize :: Nat)
    (kernelSize :: Nat)
    (dtype :: D.DType)
    (device :: (D.DeviceType, Nat))
  where
  ConvTranspose1d ::
    forall inputChannelSize outputChannelSize kernelSize dtype device.
    { forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
ConvTranspose1d
  inputChannelSize outputChannelSize kernelSize dtype device
-> Parameter
     device dtype '[inputChannelSize, outputChannelSize, kernelSize]
weight :: Parameter device dtype '[inputChannelSize, outputChannelSize, kernelSize],
      forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
ConvTranspose1d
  inputChannelSize outputChannelSize kernelSize dtype device
-> Parameter device dtype '[outputChannelSize]
bias :: Parameter device dtype '[outputChannelSize]
    } ->
    ConvTranspose1d
      inputChannelSize
      outputChannelSize
      kernelSize
      dtype
      device
  deriving (Int
-> ConvTranspose1d
     inputChannelSize outputChannelSize kernelSize dtype device
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Int
-> ConvTranspose1d
     inputChannelSize outputChannelSize kernelSize dtype device
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
[ConvTranspose1d
   inputChannelSize outputChannelSize kernelSize dtype device]
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
ConvTranspose1d
  inputChannelSize outputChannelSize kernelSize dtype device
-> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ConvTranspose1d
   inputChannelSize outputChannelSize kernelSize dtype device]
-> ShowS
$cshowList :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
[ConvTranspose1d
   inputChannelSize outputChannelSize kernelSize dtype device]
-> ShowS
show :: ConvTranspose1d
  inputChannelSize outputChannelSize kernelSize dtype device
-> String
$cshow :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
ConvTranspose1d
  inputChannelSize outputChannelSize kernelSize dtype device
-> String
showsPrec :: Int
-> ConvTranspose1d
     inputChannelSize outputChannelSize kernelSize dtype device
-> ShowS
$cshowsPrec :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Int
-> ConvTranspose1d
     inputChannelSize outputChannelSize kernelSize dtype device
-> ShowS
Show, forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat))
       x.
Rep
  (ConvTranspose1d
     inputChannelSize outputChannelSize kernelSize dtype device)
  x
-> ConvTranspose1d
     inputChannelSize outputChannelSize kernelSize dtype device
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat))
       x.
ConvTranspose1d
  inputChannelSize outputChannelSize kernelSize dtype device
-> Rep
     (ConvTranspose1d
        inputChannelSize outputChannelSize kernelSize dtype device)
     x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat))
       x.
Rep
  (ConvTranspose1d
     inputChannelSize outputChannelSize kernelSize dtype device)
  x
-> ConvTranspose1d
     inputChannelSize outputChannelSize kernelSize dtype device
$cfrom :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat))
       x.
ConvTranspose1d
  inputChannelSize outputChannelSize kernelSize dtype device
-> Rep
     (ConvTranspose1d
        inputChannelSize outputChannelSize kernelSize dtype device)
     x
Generic, forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
ConvTranspose1d
  inputChannelSize outputChannelSize kernelSize dtype device
-> HList
     (Parameters
        (ConvTranspose1d
           inputChannelSize outputChannelSize kernelSize dtype device))
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
ConvTranspose1d
  inputChannelSize outputChannelSize kernelSize dtype device
-> HList
     (Parameters
        (ConvTranspose1d
           inputChannelSize outputChannelSize kernelSize dtype device))
-> ConvTranspose1d
     inputChannelSize outputChannelSize kernelSize dtype device
forall f.
(f -> HList (Parameters f))
-> (f -> HList (Parameters f) -> f) -> Parameterized f
replaceParameters :: ConvTranspose1d
  inputChannelSize outputChannelSize kernelSize dtype device
-> HList
     (Parameters
        (ConvTranspose1d
           inputChannelSize outputChannelSize kernelSize dtype device))
-> ConvTranspose1d
     inputChannelSize outputChannelSize kernelSize dtype device
$creplaceParameters :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
ConvTranspose1d
  inputChannelSize outputChannelSize kernelSize dtype device
-> HList
     (Parameters
        (ConvTranspose1d
           inputChannelSize outputChannelSize kernelSize dtype device))
-> ConvTranspose1d
     inputChannelSize outputChannelSize kernelSize dtype device
flattenParameters :: ConvTranspose1d
  inputChannelSize outputChannelSize kernelSize dtype device
-> HList
     (Parameters
        (ConvTranspose1d
           inputChannelSize outputChannelSize kernelSize dtype device))
$cflattenParameters :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
ConvTranspose1d
  inputChannelSize outputChannelSize kernelSize dtype device
-> HList
     (Parameters
        (ConvTranspose1d
           inputChannelSize outputChannelSize kernelSize dtype device))
Parameterized)

-- | convTranspose1d
-- The constraints on this one are _very_ involved, so the partial signatures
-- make the code significantly cleaner.
convTranspose1dForward ::
  forall stride padding.
  _ =>
  ConvTranspose1d _ _ _ _ _ ->
  Tensor _ _ _ ->
  Tensor _ _ _
convTranspose1dForward :: ConvTranspose1d inputChannelSize outputChannelSize kernelSize w w
-> Tensor w w '[batchSize, inputChannelSize, inputSize]
-> Tensor
     w
     w
     '[batchSize, outputChannelSize,
       Div ((inputSize + (2 * padding)) - kernelSize) stride + 1]
convTranspose1dForward ConvTranspose1d {Parameter w w '[inputChannelSize, outputChannelSize, kernelSize]
Parameter w w '[outputChannelSize]
bias :: Parameter w w '[outputChannelSize]
weight :: Parameter w w '[inputChannelSize, outputChannelSize, kernelSize]
$sel:bias:ConvTranspose1d :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
ConvTranspose1d
  inputChannelSize outputChannelSize kernelSize dtype device
-> Parameter device dtype '[outputChannelSize]
$sel:weight:ConvTranspose1d :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
ConvTranspose1d
  inputChannelSize outputChannelSize kernelSize dtype device
-> Parameter
     device dtype '[inputChannelSize, outputChannelSize, kernelSize]
..} Tensor w w '[batchSize, inputChannelSize, inputSize]
input =
  forall (stride :: Nat) (padding :: Nat) (inputChannelSize :: Nat)
       (outputChannelSize :: Nat) (kernelSize :: Nat) (inputSize :: Nat)
       (batchSize :: Nat) (outputSize :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(All
   KnownNat
   '[stride, padding, inputChannelSize, outputChannelSize, kernelSize,
     inputSize, batchSize, outputSize],
 ConvSideCheck inputSize kernelSize stride padding outputSize) =>
Tensor
  device dtype '[inputChannelSize, outputChannelSize, kernelSize]
-> Tensor device dtype '[outputChannelSize]
-> Tensor device dtype '[batchSize, inputChannelSize, inputSize]
-> Tensor device dtype '[batchSize, outputChannelSize, outputSize]
convTranspose1d @stride @padding
    (forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Parameter device dtype shape -> Tensor device dtype shape
toDependent Parameter w w '[inputChannelSize, outputChannelSize, kernelSize]
weight)
    (forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Parameter device dtype shape -> Tensor device dtype shape
toDependent Parameter w w '[outputChannelSize]
bias)
    Tensor w w '[batchSize, inputChannelSize, inputSize]
input

instance
  ( All
      KnownNat
      '[ stride,
         padding,
         inputChannelSize,
         outputChannelSize,
         kernelSize,
         inputSize,
         batchSize,
         outputSize
       ],
    ConvSideCheck inputSize kernelSize stride padding outputSize
  ) =>
  HasForward (ConvTranspose1d inputChannelSize outputChannelSize kernelSize dtype device) (Tensor device dtype '[batchSize, inputChannelSize, inputSize], Proxy stride, Proxy padding) (Tensor device dtype '[batchSize, outputChannelSize, outputSize])
  where
  forward :: ConvTranspose1d
  inputChannelSize outputChannelSize kernelSize dtype device
-> (Tensor device dtype '[batchSize, inputChannelSize, inputSize],
    Proxy stride, Proxy padding)
-> Tensor device dtype '[batchSize, outputChannelSize, outputSize]
forward ConvTranspose1d
  inputChannelSize outputChannelSize kernelSize dtype device
model (Tensor device dtype '[batchSize, inputChannelSize, inputSize]
input, Proxy stride
Proxy, Proxy padding
Proxy) = forall (stride :: Nat) (padding :: Nat) {kernelSize :: Nat}
       {inputSize :: Nat} {inputChannelSize :: Nat}
       {outputChannelSize :: Nat} {batchSize :: Nat} {w :: DType}
       {w :: (DeviceType, Nat)}.
(OrdCond
   (CmpNat kernelSize (inputSize + (2 * padding))) 'True 'True 'False
 ~ 'True,
 OrdCond
   (CmpNat (kernelSize - 1) (inputSize + (2 * padding)))
   'True
   'True
   'False
 ~ 'True,
 OrdCond (CmpNat 1 stride) 'True 'True 'False ~ 'True,
 OrdCond (CmpNat 1 kernelSize) 'True 'True 'False ~ 'True,
 KnownNat padding, KnownNat inputSize, KnownNat kernelSize,
 KnownNat stride, KnownNat inputChannelSize,
 KnownNat outputChannelSize, KnownNat batchSize) =>
ConvTranspose1d inputChannelSize outputChannelSize kernelSize w w
-> Tensor w w '[batchSize, inputChannelSize, inputSize]
-> Tensor
     w
     w
     '[batchSize, outputChannelSize,
       Div ((inputSize + (2 * padding)) - kernelSize) stride + 1]
convTranspose1dForward @stride @padding ConvTranspose1d
  inputChannelSize outputChannelSize kernelSize dtype device
model Tensor device dtype '[batchSize, inputChannelSize, inputSize]
input
  forwardStoch :: ConvTranspose1d
  inputChannelSize outputChannelSize kernelSize dtype device
-> (Tensor device dtype '[batchSize, inputChannelSize, inputSize],
    Proxy stride, Proxy padding)
-> IO
     (Tensor device dtype '[batchSize, outputChannelSize, outputSize])
forwardStoch = (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
.) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall f a b. HasForward f a b => f -> a -> b
forward

instance
  ( KnownNat inputChannelSize,
    KnownNat outputChannelSize,
    KnownNat kernelSize,
    KnownDType dtype,
    KnownDevice device,
    RandDTypeIsValid device dtype
  ) =>
  Randomizable
    (ConvTranspose1dSpec inputChannelSize outputChannelSize kernelSize dtype device)
    (ConvTranspose1d inputChannelSize outputChannelSize kernelSize dtype device)
  where
  sample :: ConvTranspose1dSpec
  inputChannelSize outputChannelSize kernelSize dtype device
-> IO
     (ConvTranspose1d
        inputChannelSize outputChannelSize kernelSize dtype device)
sample ConvTranspose1dSpec
  inputChannelSize outputChannelSize kernelSize dtype device
ConvTranspose1dSpec =
    forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Parameter
  device dtype '[inputChannelSize, outputChannelSize, kernelSize]
-> Parameter device dtype '[outputChannelSize]
-> ConvTranspose1d
     inputChannelSize outputChannelSize kernelSize dtype device
ConvTranspose1d forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(TensorOptions shape dtype device,
 RandDTypeIsValid device dtype) =>
IO (Tensor device dtype shape)
randn) forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(TensorOptions shape dtype device,
 RandDTypeIsValid device dtype) =>
IO (Tensor device dtype shape)
randn)

data
  ConvTranspose2dSpec
    (inputChannelSize :: Nat)
    (outputChannelSize :: Nat)
    (kernelSize0 :: Nat)
    (kernelSize1 :: Nat)
    (dtype :: D.DType)
    (device :: (D.DeviceType, Nat))
  = ConvTranspose2dSpec
  deriving (Int
-> ConvTranspose2dSpec
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     dtype
     device
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Int
-> ConvTranspose2dSpec
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     dtype
     device
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
[ConvTranspose2dSpec
   inputChannelSize
   outputChannelSize
   kernelSize0
   kernelSize1
   dtype
   device]
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
ConvTranspose2dSpec
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ConvTranspose2dSpec
   inputChannelSize
   outputChannelSize
   kernelSize0
   kernelSize1
   dtype
   device]
-> ShowS
$cshowList :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
[ConvTranspose2dSpec
   inputChannelSize
   outputChannelSize
   kernelSize0
   kernelSize1
   dtype
   device]
-> ShowS
show :: ConvTranspose2dSpec
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> String
$cshow :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
ConvTranspose2dSpec
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> String
showsPrec :: Int
-> ConvTranspose2dSpec
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     dtype
     device
-> ShowS
$cshowsPrec :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Int
-> ConvTranspose2dSpec
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     dtype
     device
-> ShowS
Show, ConvTranspose2dSpec
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> ConvTranspose2dSpec
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     dtype
     device
-> Bool
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
ConvTranspose2dSpec
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> ConvTranspose2dSpec
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     dtype
     device
-> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ConvTranspose2dSpec
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> ConvTranspose2dSpec
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     dtype
     device
-> Bool
$c/= :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
ConvTranspose2dSpec
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> ConvTranspose2dSpec
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     dtype
     device
-> Bool
== :: ConvTranspose2dSpec
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> ConvTranspose2dSpec
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     dtype
     device
-> Bool
$c== :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
ConvTranspose2dSpec
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> ConvTranspose2dSpec
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     dtype
     device
-> Bool
Eq)

data
  ConvTranspose2d
    (inputChannelSize :: Nat)
    (outputChannelSize :: Nat)
    (kernelSize0 :: Nat)
    (kernelSize1 :: Nat)
    (dtype :: D.DType)
    (device :: (D.DeviceType, Nat))
  where
  ConvTranspose2d ::
    forall inputChannelSize outputChannelSize kernelSize0 kernelSize1 dtype device.
    { forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
ConvTranspose2d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> Parameter
     device
     dtype
     '[inputChannelSize, outputChannelSize, kernelSize0, kernelSize1]
weight :: Parameter device dtype '[inputChannelSize, outputChannelSize, kernelSize0, kernelSize1],
      forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
ConvTranspose2d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> Parameter device dtype '[outputChannelSize]
bias :: Parameter device dtype '[outputChannelSize]
    } ->
    ConvTranspose2d
      inputChannelSize
      outputChannelSize
      kernelSize0
      kernelSize1
      dtype
      device
  deriving (Int
-> ConvTranspose2d
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     dtype
     device
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Int
-> ConvTranspose2d
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     dtype
     device
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
[ConvTranspose2d
   inputChannelSize
   outputChannelSize
   kernelSize0
   kernelSize1
   dtype
   device]
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
ConvTranspose2d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ConvTranspose2d
   inputChannelSize
   outputChannelSize
   kernelSize0
   kernelSize1
   dtype
   device]
-> ShowS
$cshowList :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
[ConvTranspose2d
   inputChannelSize
   outputChannelSize
   kernelSize0
   kernelSize1
   dtype
   device]
-> ShowS
show :: ConvTranspose2d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> String
$cshow :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
ConvTranspose2d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> String
showsPrec :: Int
-> ConvTranspose2d
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     dtype
     device
-> ShowS
$cshowsPrec :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Int
-> ConvTranspose2d
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     dtype
     device
-> ShowS
Show, forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)) x.
Rep
  (ConvTranspose2d
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     dtype
     device)
  x
-> ConvTranspose2d
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     dtype
     device
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)) x.
ConvTranspose2d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> Rep
     (ConvTranspose2d
        inputChannelSize
        outputChannelSize
        kernelSize0
        kernelSize1
        dtype
        device)
     x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)) x.
Rep
  (ConvTranspose2d
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     dtype
     device)
  x
-> ConvTranspose2d
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     dtype
     device
$cfrom :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)) x.
ConvTranspose2d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> Rep
     (ConvTranspose2d
        inputChannelSize
        outputChannelSize
        kernelSize0
        kernelSize1
        dtype
        device)
     x
Generic, forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
ConvTranspose2d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> HList
     (Parameters
        (ConvTranspose2d
           inputChannelSize
           outputChannelSize
           kernelSize0
           kernelSize1
           dtype
           device))
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
ConvTranspose2d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> HList
     (Parameters
        (ConvTranspose2d
           inputChannelSize
           outputChannelSize
           kernelSize0
           kernelSize1
           dtype
           device))
-> ConvTranspose2d
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     dtype
     device
forall f.
(f -> HList (Parameters f))
-> (f -> HList (Parameters f) -> f) -> Parameterized f
replaceParameters :: ConvTranspose2d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> HList
     (Parameters
        (ConvTranspose2d
           inputChannelSize
           outputChannelSize
           kernelSize0
           kernelSize1
           dtype
           device))
-> ConvTranspose2d
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     dtype
     device
$creplaceParameters :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
ConvTranspose2d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> HList
     (Parameters
        (ConvTranspose2d
           inputChannelSize
           outputChannelSize
           kernelSize0
           kernelSize1
           dtype
           device))
-> ConvTranspose2d
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     dtype
     device
flattenParameters :: ConvTranspose2d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> HList
     (Parameters
        (ConvTranspose2d
           inputChannelSize
           outputChannelSize
           kernelSize0
           kernelSize1
           dtype
           device))
$cflattenParameters :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
ConvTranspose2d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> HList
     (Parameters
        (ConvTranspose2d
           inputChannelSize
           outputChannelSize
           kernelSize0
           kernelSize1
           dtype
           device))
Parameterized)

-- | convTranspose2d
-- The constraints on this one are _very_ involved, so the partial signatures
-- make the code significantly cleaner.
convTranspose2dForward ::
  forall stride padding.
  _ =>
  ConvTranspose2d _ _ _ _ _ _ ->
  Tensor _ _ _ ->
  Tensor _ _ _
convTranspose2dForward :: ConvTranspose2d
  inputChannelSize outputChannelSize kernelSize0 kernelSize1 w w
-> Tensor
     w w '[batchSize, inputChannelSize, inputSize0, inputSize1]
-> Tensor
     w
     w
     '[batchSize, outputChannelSize,
       Div ((inputSize0 + (2 * Fst padding)) - kernelSize0) (Fst stride)
       + 1,
       Div ((inputSize1 + (2 * Snd padding)) - kernelSize1) (Snd stride)
       + 1]
convTranspose2dForward ConvTranspose2d {Parameter
  w
  w
  '[inputChannelSize, outputChannelSize, kernelSize0, kernelSize1]
Parameter w w '[outputChannelSize]
bias :: Parameter w w '[outputChannelSize]
weight :: Parameter
  w
  w
  '[inputChannelSize, outputChannelSize, kernelSize0, kernelSize1]
$sel:bias:ConvTranspose2d :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
ConvTranspose2d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> Parameter device dtype '[outputChannelSize]
$sel:weight:ConvTranspose2d :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
ConvTranspose2d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> Parameter
     device
     dtype
     '[inputChannelSize, outputChannelSize, kernelSize0, kernelSize1]
..} Tensor w w '[batchSize, inputChannelSize, inputSize0, inputSize1]
input =
  forall (stride :: (Nat, Nat)) (padding :: (Nat, Nat))
       (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (inputSize0 :: Nat)
       (inputSize1 :: Nat) (batchSize :: Nat) (outputSize0 :: Nat)
       (outputSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(All
   KnownNat
   '[Fst stride, Snd stride, Fst padding, Snd padding,
     inputChannelSize, outputChannelSize, kernelSize0, kernelSize1,
     inputSize0, inputSize1, batchSize, outputSize0, outputSize1],
 ConvSideCheck
   inputSize0 kernelSize0 (Fst stride) (Fst padding) outputSize0,
 ConvSideCheck
   inputSize1 kernelSize1 (Snd stride) (Snd padding) outputSize1) =>
Tensor
  device
  dtype
  '[inputChannelSize, outputChannelSize, kernelSize0, kernelSize1]
-> Tensor device dtype '[outputChannelSize]
-> Tensor
     device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1]
-> Tensor
     device
     dtype
     '[batchSize, outputChannelSize, outputSize0, outputSize1]
convTranspose2d @stride @padding
    (forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Parameter device dtype shape -> Tensor device dtype shape
toDependent Parameter
  w
  w
  '[inputChannelSize, outputChannelSize, kernelSize0, kernelSize1]
weight)
    (forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Parameter device dtype shape -> Tensor device dtype shape
toDependent Parameter w w '[outputChannelSize]
bias)
    Tensor w w '[batchSize, inputChannelSize, inputSize0, inputSize1]
input

instance
  ( All
      KnownNat
      '[ Torch.Typed.Auxiliary.Fst stride,
         Torch.Typed.Auxiliary.Snd stride,
         Torch.Typed.Auxiliary.Fst padding,
         Torch.Typed.Auxiliary.Snd padding,
         inputChannelSize,
         outputChannelSize,
         kernelSize0,
         kernelSize1,
         inputSize0,
         inputSize1,
         batchSize,
         outputSize0,
         outputSize1
       ],
    ConvSideCheck inputSize0 kernelSize0 (Torch.Typed.Auxiliary.Fst stride) (Torch.Typed.Auxiliary.Fst padding) outputSize0,
    ConvSideCheck inputSize1 kernelSize1 (Torch.Typed.Auxiliary.Snd stride) (Torch.Typed.Auxiliary.Snd padding) outputSize1
  ) =>
  HasForward (ConvTranspose2d inputChannelSize outputChannelSize kernelSize0 kernelSize1 dtype device) (Tensor device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1], Proxy stride, Proxy padding) (Tensor device dtype '[batchSize, outputChannelSize, outputSize0, outputSize1])
  where
  forward :: ConvTranspose2d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> (Tensor
      device
      dtype
      '[batchSize, inputChannelSize, inputSize0, inputSize1],
    Proxy stride, Proxy padding)
-> Tensor
     device
     dtype
     '[batchSize, outputChannelSize, outputSize0, outputSize1]
forward ConvTranspose2d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
model (Tensor
  device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1]
input, Proxy stride
Proxy, Proxy padding
Proxy) = forall (stride :: (Nat, Nat)) (padding :: (Nat, Nat))
       {kernelSize1 :: Nat} {inputSize1 :: Nat} {kernelSize0 :: Nat}
       {inputSize0 :: Nat} {inputChannelSize :: Nat}
       {outputChannelSize :: Nat} {batchSize :: Nat} {w :: DType}
       {w :: (DeviceType, Nat)}.
(OrdCond
   (CmpNat kernelSize1 (inputSize1 + (2 * Snd padding)))
   'True
   'True
   'False
 ~ 'True,
 OrdCond
   (CmpNat kernelSize0 (inputSize0 + (2 * Fst padding)))
   'True
   'True
   'False
 ~ 'True,
 OrdCond
   (CmpNat (kernelSize1 - 1) (inputSize1 + (2 * Snd padding)))
   'True
   'True
   'False
 ~ 'True,
 OrdCond
   (CmpNat (kernelSize0 - 1) (inputSize0 + (2 * Fst padding)))
   'True
   'True
   'False
 ~ 'True,
 OrdCond (CmpNat 1 kernelSize1) 'True 'True 'False ~ 'True,
 OrdCond (CmpNat 1 kernelSize0) 'True 'True 'False ~ 'True,
 OrdCond (CmpNat 1 (Snd stride)) 'True 'True 'False ~ 'True,
 OrdCond (CmpNat 1 (Fst stride)) 'True 'True 'False ~ 'True,
 KnownNat inputSize1, KnownNat inputSize0, KnownNat kernelSize1,
 KnownNat kernelSize0, KnownNat inputChannelSize,
 KnownNat outputChannelSize, KnownNat batchSize,
 KnownNat (Fst padding), KnownNat (Fst stride),
 KnownNat (Snd padding), KnownNat (Snd stride)) =>
ConvTranspose2d
  inputChannelSize outputChannelSize kernelSize0 kernelSize1 w w
-> Tensor
     w w '[batchSize, inputChannelSize, inputSize0, inputSize1]
-> Tensor
     w
     w
     '[batchSize, outputChannelSize,
       Div ((inputSize0 + (2 * Fst padding)) - kernelSize0) (Fst stride)
       + 1,
       Div ((inputSize1 + (2 * Snd padding)) - kernelSize1) (Snd stride)
       + 1]
convTranspose2dForward @stride @padding ConvTranspose2d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
model Tensor
  device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1]
input
  forwardStoch :: ConvTranspose2d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> (Tensor
      device
      dtype
      '[batchSize, inputChannelSize, inputSize0, inputSize1],
    Proxy stride, Proxy padding)
-> IO
     (Tensor
        device
        dtype
        '[batchSize, outputChannelSize, outputSize0, outputSize1])
forwardStoch = (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
.) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall f a b. HasForward f a b => f -> a -> b
forward

instance
  ( KnownNat inputChannelSize,
    KnownNat outputChannelSize,
    KnownNat kernelSize0,
    KnownNat kernelSize1,
    KnownDType dtype,
    KnownDevice device,
    RandDTypeIsValid device dtype
  ) =>
  Randomizable
    (ConvTranspose2dSpec inputChannelSize outputChannelSize kernelSize0 kernelSize1 dtype device)
    (ConvTranspose2d inputChannelSize outputChannelSize kernelSize0 kernelSize1 dtype device)
  where
  sample :: ConvTranspose2dSpec
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> IO
     (ConvTranspose2d
        inputChannelSize
        outputChannelSize
        kernelSize0
        kernelSize1
        dtype
        device)
sample ConvTranspose2dSpec
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
ConvTranspose2dSpec =
    forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Parameter
  device
  dtype
  '[inputChannelSize, outputChannelSize, kernelSize0, kernelSize1]
-> Parameter device dtype '[outputChannelSize]
-> ConvTranspose2d
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     dtype
     device
ConvTranspose2d forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(TensorOptions shape dtype device,
 RandDTypeIsValid device dtype) =>
IO (Tensor device dtype shape)
randn) forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(TensorOptions shape dtype device,
 RandDTypeIsValid device dtype) =>
IO (Tensor device dtype shape)
randn)

data
  ConvTranspose3dSpec
    (inputChannelSize :: Nat)
    (outputChannelSize :: Nat)
    (kernelSize0 :: Nat)
    (kernelSize1 :: Nat)
    (kernelSize2 :: Nat)
    (dtype :: D.DType)
    (device :: (D.DeviceType, Nat))
  = ConvTranspose3dSpec
  deriving (Int
-> ConvTranspose3dSpec
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     kernelSize2
     dtype
     device
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Int
-> ConvTranspose3dSpec
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     kernelSize2
     dtype
     device
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
[ConvTranspose3dSpec
   inputChannelSize
   outputChannelSize
   kernelSize0
   kernelSize1
   kernelSize2
   dtype
   device]
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
ConvTranspose3dSpec
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  kernelSize2
  dtype
  device
-> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ConvTranspose3dSpec
   inputChannelSize
   outputChannelSize
   kernelSize0
   kernelSize1
   kernelSize2
   dtype
   device]
-> ShowS
$cshowList :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
[ConvTranspose3dSpec
   inputChannelSize
   outputChannelSize
   kernelSize0
   kernelSize1
   kernelSize2
   dtype
   device]
-> ShowS
show :: ConvTranspose3dSpec
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  kernelSize2
  dtype
  device
-> String
$cshow :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
ConvTranspose3dSpec
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  kernelSize2
  dtype
  device
-> String
showsPrec :: Int
-> ConvTranspose3dSpec
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     kernelSize2
     dtype
     device
-> ShowS
$cshowsPrec :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Int
-> ConvTranspose3dSpec
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     kernelSize2
     dtype
     device
-> ShowS
Show, ConvTranspose3dSpec
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  kernelSize2
  dtype
  device
-> ConvTranspose3dSpec
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     kernelSize2
     dtype
     device
-> Bool
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
ConvTranspose3dSpec
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  kernelSize2
  dtype
  device
-> ConvTranspose3dSpec
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     kernelSize2
     dtype
     device
-> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ConvTranspose3dSpec
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  kernelSize2
  dtype
  device
-> ConvTranspose3dSpec
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     kernelSize2
     dtype
     device
-> Bool
$c/= :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
ConvTranspose3dSpec
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  kernelSize2
  dtype
  device
-> ConvTranspose3dSpec
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     kernelSize2
     dtype
     device
-> Bool
== :: ConvTranspose3dSpec
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  kernelSize2
  dtype
  device
-> ConvTranspose3dSpec
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     kernelSize2
     dtype
     device
-> Bool
$c== :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
ConvTranspose3dSpec
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  kernelSize2
  dtype
  device
-> ConvTranspose3dSpec
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     kernelSize2
     dtype
     device
-> Bool
Eq)

data
  ConvTranspose3d
    (inputChannelSize :: Nat)
    (outputChannelSize :: Nat)
    (kernelSize0 :: Nat)
    (kernelSize1 :: Nat)
    (kernelSize2 :: Nat)
    (dtype :: D.DType)
    (device :: (D.DeviceType, Nat))
  where
  ConvTranspose3d ::
    forall inputChannelSize outputChannelSize kernelSize0 kernelSize1 kernelSize2 dtype device.
    { forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
ConvTranspose3d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  kernelSize2
  dtype
  device
-> Parameter
     device
     dtype
     '[inputChannelSize, outputChannelSize, kernelSize0, kernelSize1,
       kernelSize2]
weight :: Parameter device dtype '[inputChannelSize, outputChannelSize, kernelSize0, kernelSize1, kernelSize2],
      forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
ConvTranspose3d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  kernelSize2
  dtype
  device
-> Parameter device dtype '[outputChannelSize]
bias :: Parameter device dtype '[outputChannelSize]
    } ->
    ConvTranspose3d
      inputChannelSize
      outputChannelSize
      kernelSize0
      kernelSize1
      kernelSize2
      dtype
      device
  deriving (Int
-> ConvTranspose3d
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     kernelSize2
     dtype
     device
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Int
-> ConvTranspose3d
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     kernelSize2
     dtype
     device
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
[ConvTranspose3d
   inputChannelSize
   outputChannelSize
   kernelSize0
   kernelSize1
   kernelSize2
   dtype
   device]
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
ConvTranspose3d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  kernelSize2
  dtype
  device
-> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ConvTranspose3d
   inputChannelSize
   outputChannelSize
   kernelSize0
   kernelSize1
   kernelSize2
   dtype
   device]
-> ShowS
$cshowList :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
[ConvTranspose3d
   inputChannelSize
   outputChannelSize
   kernelSize0
   kernelSize1
   kernelSize2
   dtype
   device]
-> ShowS
show :: ConvTranspose3d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  kernelSize2
  dtype
  device
-> String
$cshow :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
ConvTranspose3d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  kernelSize2
  dtype
  device
-> String
showsPrec :: Int
-> ConvTranspose3d
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     kernelSize2
     dtype
     device
-> ShowS
$cshowsPrec :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Int
-> ConvTranspose3d
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     kernelSize2
     dtype
     device
-> ShowS
Show, forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)) x.
Rep
  (ConvTranspose3d
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     kernelSize2
     dtype
     device)
  x
-> ConvTranspose3d
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     kernelSize2
     dtype
     device
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)) x.
ConvTranspose3d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  kernelSize2
  dtype
  device
-> Rep
     (ConvTranspose3d
        inputChannelSize
        outputChannelSize
        kernelSize0
        kernelSize1
        kernelSize2
        dtype
        device)
     x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)) x.
Rep
  (ConvTranspose3d
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     kernelSize2
     dtype
     device)
  x
-> ConvTranspose3d
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     kernelSize2
     dtype
     device
$cfrom :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)) x.
ConvTranspose3d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  kernelSize2
  dtype
  device
-> Rep
     (ConvTranspose3d
        inputChannelSize
        outputChannelSize
        kernelSize0
        kernelSize1
        kernelSize2
        dtype
        device)
     x
Generic, forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
ConvTranspose3d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  kernelSize2
  dtype
  device
-> HList
     (Parameters
        (ConvTranspose3d
           inputChannelSize
           outputChannelSize
           kernelSize0
           kernelSize1
           kernelSize2
           dtype
           device))
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
ConvTranspose3d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  kernelSize2
  dtype
  device
-> HList
     (Parameters
        (ConvTranspose3d
           inputChannelSize
           outputChannelSize
           kernelSize0
           kernelSize1
           kernelSize2
           dtype
           device))
-> ConvTranspose3d
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     kernelSize2
     dtype
     device
forall f.
(f -> HList (Parameters f))
-> (f -> HList (Parameters f) -> f) -> Parameterized f
replaceParameters :: ConvTranspose3d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  kernelSize2
  dtype
  device
-> HList
     (Parameters
        (ConvTranspose3d
           inputChannelSize
           outputChannelSize
           kernelSize0
           kernelSize1
           kernelSize2
           dtype
           device))
-> ConvTranspose3d
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     kernelSize2
     dtype
     device
$creplaceParameters :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
ConvTranspose3d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  kernelSize2
  dtype
  device
-> HList
     (Parameters
        (ConvTranspose3d
           inputChannelSize
           outputChannelSize
           kernelSize0
           kernelSize1
           kernelSize2
           dtype
           device))
-> ConvTranspose3d
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     kernelSize2
     dtype
     device
flattenParameters :: ConvTranspose3d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  kernelSize2
  dtype
  device
-> HList
     (Parameters
        (ConvTranspose3d
           inputChannelSize
           outputChannelSize
           kernelSize0
           kernelSize1
           kernelSize2
           dtype
           device))
$cflattenParameters :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
ConvTranspose3d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  kernelSize2
  dtype
  device
-> HList
     (Parameters
        (ConvTranspose3d
           inputChannelSize
           outputChannelSize
           kernelSize0
           kernelSize1
           kernelSize2
           dtype
           device))
Parameterized)

-- | convTranspose3d
-- The constraints on this one are _very_ involved, so the partial signatures
-- make the code significantly cleaner.
convTranspose3dForward ::
  forall stride padding.
  _ =>
  ConvTranspose3d _ _ _ _ _ _ _ ->
  Tensor _ _ _ ->
  Tensor _ _ _
convTranspose3dForward :: ConvTranspose3d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  kernelSize2
  w
  w
-> Tensor
     w
     w
     '[batchSize, inputChannelSize, inputSize0, inputSize1, inputSize2]
-> Tensor
     w
     w
     '[batchSize, outputChannelSize,
       Div ((inputSize0 + (2 * Fst3 padding)) - kernelSize0) (Fst3 stride)
       + 1,
       Div ((inputSize1 + (2 * Snd3 padding)) - kernelSize1) (Snd3 stride)
       + 1,
       Div ((inputSize2 + (2 * Trd3 padding)) - kernelSize2) (Trd3 stride)
       + 1]
convTranspose3dForward ConvTranspose3d {Parameter
  w
  w
  '[inputChannelSize, outputChannelSize, kernelSize0, kernelSize1,
    kernelSize2]
Parameter w w '[outputChannelSize]
bias :: Parameter w w '[outputChannelSize]
weight :: Parameter
  w
  w
  '[inputChannelSize, outputChannelSize, kernelSize0, kernelSize1,
    kernelSize2]
$sel:bias:ConvTranspose3d :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
ConvTranspose3d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  kernelSize2
  dtype
  device
-> Parameter device dtype '[outputChannelSize]
$sel:weight:ConvTranspose3d :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
ConvTranspose3d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  kernelSize2
  dtype
  device
-> Parameter
     device
     dtype
     '[inputChannelSize, outputChannelSize, kernelSize0, kernelSize1,
       kernelSize2]
..} Tensor
  w
  w
  '[batchSize, inputChannelSize, inputSize0, inputSize1, inputSize2]
input =
  forall (stride :: (Nat, Nat, Nat)) (padding :: (Nat, Nat, Nat))
       (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
       (inputSize0 :: Nat) (inputSize1 :: Nat) (inputSize2 :: Nat)
       (batchSize :: Nat) (outputSize0 :: Nat) (outputSize1 :: Nat)
       (outputSize2 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(All
   KnownNat
   '[Fst3 stride, Snd3 stride, Trd3 stride, Fst3 padding,
     Snd3 padding, Trd3 padding, inputChannelSize, outputChannelSize,
     kernelSize0, kernelSize1, kernelSize2, inputSize0, inputSize1,
     inputSize2, batchSize],
 ConvSideCheck
   inputSize0 kernelSize0 (Fst3 stride) (Fst3 padding) outputSize0,
 ConvSideCheck
   inputSize1 kernelSize1 (Snd3 stride) (Snd3 padding) outputSize1,
 ConvSideCheck
   inputSize2 kernelSize2 (Trd3 stride) (Trd3 padding) outputSize2) =>
Tensor
  device
  dtype
  '[inputChannelSize, outputChannelSize, kernelSize0, kernelSize1,
    kernelSize2]
-> Tensor device dtype '[outputChannelSize]
-> Tensor
     device
     dtype
     '[batchSize, inputChannelSize, inputSize0, inputSize1, inputSize2]
-> Tensor
     device
     dtype
     '[batchSize, outputChannelSize, outputSize0, outputSize1,
       outputSize2]
convTranspose3d @stride @padding
    (forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Parameter device dtype shape -> Tensor device dtype shape
toDependent Parameter
  w
  w
  '[inputChannelSize, outputChannelSize, kernelSize0, kernelSize1,
    kernelSize2]
weight)
    (forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Parameter device dtype shape -> Tensor device dtype shape
toDependent Parameter w w '[outputChannelSize]
bias)
    Tensor
  w
  w
  '[batchSize, inputChannelSize, inputSize0, inputSize1, inputSize2]
input

instance
  ( All
      KnownNat
      '[ Fst3 stride,
         Snd3 stride,
         Trd3 stride,
         Fst3 padding,
         Snd3 padding,
         Trd3 padding,
         inputChannelSize,
         outputChannelSize,
         kernelSize0,
         kernelSize1,
         kernelSize2,
         inputSize0,
         inputSize1,
         inputSize2,
         batchSize
       ],
    ConvSideCheck inputSize0 kernelSize0 (Fst3 stride) (Fst3 padding) outputSize0,
    ConvSideCheck inputSize1 kernelSize1 (Snd3 stride) (Snd3 padding) outputSize1,
    ConvSideCheck inputSize2 kernelSize2 (Trd3 stride) (Trd3 padding) outputSize2
  ) =>
  HasForward (ConvTranspose3d inputChannelSize outputChannelSize kernelSize0 kernelSize1 kernelSize2 dtype device) (Tensor device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1, inputSize2], Proxy stride, Proxy padding) (Tensor device dtype '[batchSize, outputChannelSize, outputSize0, outputSize1, outputSize2])
  where
  forward :: ConvTranspose3d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  kernelSize2
  dtype
  device
-> (Tensor
      device
      dtype
      '[batchSize, inputChannelSize, inputSize0, inputSize1, inputSize2],
    Proxy stride, Proxy padding)
-> Tensor
     device
     dtype
     '[batchSize, outputChannelSize, outputSize0, outputSize1,
       outputSize2]
forward ConvTranspose3d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  kernelSize2
  dtype
  device
model (Tensor
  device
  dtype
  '[batchSize, inputChannelSize, inputSize0, inputSize1, inputSize2]
input, Proxy stride
Proxy, Proxy padding
Proxy) = forall (stride :: (Nat, Nat, Nat)) (padding :: (Nat, Nat, Nat))
       {kernelSize2 :: Nat} {inputSize2 :: Nat} {kernelSize1 :: Nat}
       {inputSize1 :: Nat} {kernelSize0 :: Nat} {inputSize0 :: Nat}
       {inputChannelSize :: Nat} {outputChannelSize :: Nat}
       {batchSize :: Nat} {w :: DType} {w :: (DeviceType, Nat)}.
(OrdCond
   (CmpNat kernelSize2 (inputSize2 + (2 * Trd3 padding)))
   'True
   'True
   'False
 ~ 'True,
 OrdCond
   (CmpNat kernelSize1 (inputSize1 + (2 * Snd3 padding)))
   'True
   'True
   'False
 ~ 'True,
 OrdCond
   (CmpNat kernelSize0 (inputSize0 + (2 * Fst3 padding)))
   'True
   'True
   'False
 ~ 'True,
 OrdCond
   (CmpNat (kernelSize1 - 1) (inputSize1 + (2 * Snd3 padding)))
   'True
   'True
   'False
 ~ 'True,
 OrdCond
   (CmpNat (kernelSize0 - 1) (inputSize0 + (2 * Fst3 padding)))
   'True
   'True
   'False
 ~ 'True,
 OrdCond
   (CmpNat (kernelSize2 - 1) (inputSize2 + (2 * Trd3 padding)))
   'True
   'True
   'False
 ~ 'True,
 OrdCond (CmpNat 1 kernelSize2) 'True 'True 'False ~ 'True,
 OrdCond (CmpNat 1 kernelSize0) 'True 'True 'False ~ 'True,
 OrdCond (CmpNat 1 kernelSize1) 'True 'True 'False ~ 'True,
 OrdCond (CmpNat 1 (Trd3 stride)) 'True 'True 'False ~ 'True,
 OrdCond (CmpNat 1 (Snd3 stride)) 'True 'True 'False ~ 'True,
 OrdCond (CmpNat 1 (Fst3 stride)) 'True 'True 'False ~ 'True,
 KnownNat inputSize2, KnownNat inputSize1, KnownNat inputSize0,
 KnownNat kernelSize2, KnownNat kernelSize1, KnownNat kernelSize0,
 KnownNat inputChannelSize, KnownNat outputChannelSize,
 KnownNat batchSize, KnownNat (Fst3 padding),
 KnownNat (Fst3 stride), KnownNat (Snd3 padding),
 KnownNat (Snd3 stride), KnownNat (Trd3 padding),
 KnownNat (Trd3 stride)) =>
ConvTranspose3d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  kernelSize2
  w
  w
-> Tensor
     w
     w
     '[batchSize, inputChannelSize, inputSize0, inputSize1, inputSize2]
-> Tensor
     w
     w
     '[batchSize, outputChannelSize,
       Div ((inputSize0 + (2 * Fst3 padding)) - kernelSize0) (Fst3 stride)
       + 1,
       Div ((inputSize1 + (2 * Snd3 padding)) - kernelSize1) (Snd3 stride)
       + 1,
       Div ((inputSize2 + (2 * Trd3 padding)) - kernelSize2) (Trd3 stride)
       + 1]
convTranspose3dForward @stride @padding ConvTranspose3d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  kernelSize2
  dtype
  device
model Tensor
  device
  dtype
  '[batchSize, inputChannelSize, inputSize0, inputSize1, inputSize2]
input
  forwardStoch :: ConvTranspose3d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  kernelSize2
  dtype
  device
-> (Tensor
      device
      dtype
      '[batchSize, inputChannelSize, inputSize0, inputSize1, inputSize2],
    Proxy stride, Proxy padding)
-> IO
     (Tensor
        device
        dtype
        '[batchSize, outputChannelSize, outputSize0, outputSize1,
          outputSize2])
forwardStoch = (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
.) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall f a b. HasForward f a b => f -> a -> b
forward

instance
  ( KnownNat inputChannelSize,
    KnownNat outputChannelSize,
    KnownNat kernelSize0,
    KnownNat kernelSize1,
    KnownNat kernelSize2,
    KnownDType dtype,
    KnownDevice device,
    RandDTypeIsValid device dtype
  ) =>
  Randomizable
    (ConvTranspose3dSpec inputChannelSize outputChannelSize kernelSize0 kernelSize1 kernelSize2 dtype device)
    (ConvTranspose3d inputChannelSize outputChannelSize kernelSize0 kernelSize1 kernelSize2 dtype device)
  where
  sample :: ConvTranspose3dSpec
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  kernelSize2
  dtype
  device
-> IO
     (ConvTranspose3d
        inputChannelSize
        outputChannelSize
        kernelSize0
        kernelSize1
        kernelSize2
        dtype
        device)
sample ConvTranspose3dSpec
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  kernelSize2
  dtype
  device
ConvTranspose3dSpec =
    forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Parameter
  device
  dtype
  '[inputChannelSize, outputChannelSize, kernelSize0, kernelSize1,
    kernelSize2]
-> Parameter device dtype '[outputChannelSize]
-> ConvTranspose3d
     inputChannelSize
     outputChannelSize
     kernelSize0
     kernelSize1
     kernelSize2
     dtype
     device
ConvTranspose3d forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(TensorOptions shape dtype device,
 RandDTypeIsValid device dtype) =>
IO (Tensor device dtype shape)
randn) forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(TensorOptions shape dtype device,
 RandDTypeIsValid device dtype) =>
IO (Tensor device dtype shape)
randn)