{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeInType #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# OPTIONS_GHC -Wno-partial-type-signatures #-}
module Torch.Typed.NN.Convolution where
import Data.Proxy
import GHC.Generics
import GHC.TypeLits
import qualified Torch.DType as D
import qualified Torch.Device as D
import Torch.NN (HasForward (..), Randomizable (..))
import Torch.Typed.Auxiliary
import Torch.Typed.Factories
import Torch.Typed.Functional
import Torch.Typed.Parameter
import Torch.Typed.Tensor
data
Conv1dSpec
(inputChannelSize :: Nat)
(outputChannelSize :: Nat)
(kernelSize :: Nat)
(dtype :: D.DType)
(device :: (D.DeviceType, Nat))
= Conv1dSpec
deriving (Int
-> Conv1dSpec
inputChannelSize outputChannelSize kernelSize dtype device
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Int
-> Conv1dSpec
inputChannelSize outputChannelSize kernelSize dtype device
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
[Conv1dSpec
inputChannelSize outputChannelSize kernelSize dtype device]
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Conv1dSpec
inputChannelSize outputChannelSize kernelSize dtype device
-> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Conv1dSpec
inputChannelSize outputChannelSize kernelSize dtype device]
-> ShowS
$cshowList :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
[Conv1dSpec
inputChannelSize outputChannelSize kernelSize dtype device]
-> ShowS
show :: Conv1dSpec
inputChannelSize outputChannelSize kernelSize dtype device
-> String
$cshow :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Conv1dSpec
inputChannelSize outputChannelSize kernelSize dtype device
-> String
showsPrec :: Int
-> Conv1dSpec
inputChannelSize outputChannelSize kernelSize dtype device
-> ShowS
$cshowsPrec :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Int
-> Conv1dSpec
inputChannelSize outputChannelSize kernelSize dtype device
-> ShowS
Show, Conv1dSpec
inputChannelSize outputChannelSize kernelSize dtype device
-> Conv1dSpec
inputChannelSize outputChannelSize kernelSize dtype device
-> Bool
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Conv1dSpec
inputChannelSize outputChannelSize kernelSize dtype device
-> Conv1dSpec
inputChannelSize outputChannelSize kernelSize dtype device
-> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Conv1dSpec
inputChannelSize outputChannelSize kernelSize dtype device
-> Conv1dSpec
inputChannelSize outputChannelSize kernelSize dtype device
-> Bool
$c/= :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Conv1dSpec
inputChannelSize outputChannelSize kernelSize dtype device
-> Conv1dSpec
inputChannelSize outputChannelSize kernelSize dtype device
-> Bool
== :: Conv1dSpec
inputChannelSize outputChannelSize kernelSize dtype device
-> Conv1dSpec
inputChannelSize outputChannelSize kernelSize dtype device
-> Bool
$c== :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Conv1dSpec
inputChannelSize outputChannelSize kernelSize dtype device
-> Conv1dSpec
inputChannelSize outputChannelSize kernelSize dtype device
-> Bool
Eq)
data
Conv1d
(inputChannelSize :: Nat)
(outputChannelSize :: Nat)
(kernelSize :: Nat)
(dtype :: D.DType)
(device :: (D.DeviceType, Nat))
where
Conv1d ::
forall inputChannelSize outputChannelSize kernelSize dtype device.
{ forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Conv1d inputChannelSize outputChannelSize kernelSize dtype device
-> Parameter
device dtype '[outputChannelSize, inputChannelSize, kernelSize]
weight :: Parameter device dtype '[outputChannelSize, inputChannelSize, kernelSize],
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Conv1d inputChannelSize outputChannelSize kernelSize dtype device
-> Parameter device dtype '[outputChannelSize]
bias :: Parameter device dtype '[outputChannelSize]
} ->
Conv1d
inputChannelSize
outputChannelSize
kernelSize
dtype
device
deriving (Int
-> Conv1d
inputChannelSize outputChannelSize kernelSize dtype device
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Int
-> Conv1d
inputChannelSize outputChannelSize kernelSize dtype device
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
[Conv1d inputChannelSize outputChannelSize kernelSize dtype device]
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Conv1d inputChannelSize outputChannelSize kernelSize dtype device
-> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Conv1d inputChannelSize outputChannelSize kernelSize dtype device]
-> ShowS
$cshowList :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
[Conv1d inputChannelSize outputChannelSize kernelSize dtype device]
-> ShowS
show :: Conv1d inputChannelSize outputChannelSize kernelSize dtype device
-> String
$cshow :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Conv1d inputChannelSize outputChannelSize kernelSize dtype device
-> String
showsPrec :: Int
-> Conv1d
inputChannelSize outputChannelSize kernelSize dtype device
-> ShowS
$cshowsPrec :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Int
-> Conv1d
inputChannelSize outputChannelSize kernelSize dtype device
-> ShowS
Show, forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat))
x.
Rep
(Conv1d inputChannelSize outputChannelSize kernelSize dtype device)
x
-> Conv1d
inputChannelSize outputChannelSize kernelSize dtype device
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat))
x.
Conv1d inputChannelSize outputChannelSize kernelSize dtype device
-> Rep
(Conv1d inputChannelSize outputChannelSize kernelSize dtype device)
x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat))
x.
Rep
(Conv1d inputChannelSize outputChannelSize kernelSize dtype device)
x
-> Conv1d
inputChannelSize outputChannelSize kernelSize dtype device
$cfrom :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat))
x.
Conv1d inputChannelSize outputChannelSize kernelSize dtype device
-> Rep
(Conv1d inputChannelSize outputChannelSize kernelSize dtype device)
x
Generic, forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Conv1d inputChannelSize outputChannelSize kernelSize dtype device
-> HList
(Parameters
(Conv1d
inputChannelSize outputChannelSize kernelSize dtype device))
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Conv1d inputChannelSize outputChannelSize kernelSize dtype device
-> HList
(Parameters
(Conv1d
inputChannelSize outputChannelSize kernelSize dtype device))
-> Conv1d
inputChannelSize outputChannelSize kernelSize dtype device
forall f.
(f -> HList (Parameters f))
-> (f -> HList (Parameters f) -> f) -> Parameterized f
replaceParameters :: Conv1d inputChannelSize outputChannelSize kernelSize dtype device
-> HList
(Parameters
(Conv1d
inputChannelSize outputChannelSize kernelSize dtype device))
-> Conv1d
inputChannelSize outputChannelSize kernelSize dtype device
$creplaceParameters :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Conv1d inputChannelSize outputChannelSize kernelSize dtype device
-> HList
(Parameters
(Conv1d
inputChannelSize outputChannelSize kernelSize dtype device))
-> Conv1d
inputChannelSize outputChannelSize kernelSize dtype device
flattenParameters :: Conv1d inputChannelSize outputChannelSize kernelSize dtype device
-> HList
(Parameters
(Conv1d
inputChannelSize outputChannelSize kernelSize dtype device))
$cflattenParameters :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Conv1d inputChannelSize outputChannelSize kernelSize dtype device
-> HList
(Parameters
(Conv1d
inputChannelSize outputChannelSize kernelSize dtype device))
Parameterized)
conv1dForward ::
forall stride padding.
_ =>
Conv1d _ _ _ _ _ ->
Tensor _ _ _ ->
Tensor _ _ _
conv1dForward :: Conv1d inputChannelSize outputChannelSize kernelSize w w
-> Tensor w w '[batchSize, inputChannelSize, inputSize]
-> Tensor
w
w
'[batchSize, outputChannelSize,
Div ((inputSize + (2 * padding)) - kernelSize) stride + 1]
conv1dForward Conv1d {Parameter w w '[outputChannelSize, inputChannelSize, kernelSize]
Parameter w w '[outputChannelSize]
bias :: Parameter w w '[outputChannelSize]
weight :: Parameter w w '[outputChannelSize, inputChannelSize, kernelSize]
$sel:bias:Conv1d :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Conv1d inputChannelSize outputChannelSize kernelSize dtype device
-> Parameter device dtype '[outputChannelSize]
$sel:weight:Conv1d :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Conv1d inputChannelSize outputChannelSize kernelSize dtype device
-> Parameter
device dtype '[outputChannelSize, inputChannelSize, kernelSize]
..} Tensor w w '[batchSize, inputChannelSize, inputSize]
input =
forall (stride :: Nat) (padding :: Nat) (inputChannelSize :: Nat)
(outputChannelSize :: Nat) (kernelSize :: Nat) (inputSize :: Nat)
(batchSize :: Nat) (outputSize :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
(All
KnownNat
'[stride, padding, inputChannelSize, outputChannelSize, kernelSize,
inputSize, batchSize, outputSize],
ConvSideCheck inputSize kernelSize stride padding outputSize) =>
Tensor
device dtype '[outputChannelSize, inputChannelSize, kernelSize]
-> Tensor device dtype '[outputChannelSize]
-> Tensor device dtype '[batchSize, inputChannelSize, inputSize]
-> Tensor device dtype '[batchSize, outputChannelSize, outputSize]
conv1d @stride @padding
(forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Parameter device dtype shape -> Tensor device dtype shape
toDependent Parameter w w '[outputChannelSize, inputChannelSize, kernelSize]
weight)
(forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Parameter device dtype shape -> Tensor device dtype shape
toDependent Parameter w w '[outputChannelSize]
bias)
Tensor w w '[batchSize, inputChannelSize, inputSize]
input
instance
( All
KnownNat
'[ stride,
padding,
inputChannelSize,
outputChannelSize,
kernelSize,
inputSize,
batchSize,
outputSize
],
ConvSideCheck inputSize kernelSize stride padding outputSize
) =>
HasForward (Conv1d inputChannelSize outputChannelSize kernelSize dtype device) (Tensor device dtype '[batchSize, inputChannelSize, inputSize], Proxy stride, Proxy padding) (Tensor device dtype '[batchSize, outputChannelSize, outputSize])
where
forward :: Conv1d inputChannelSize outputChannelSize kernelSize dtype device
-> (Tensor device dtype '[batchSize, inputChannelSize, inputSize],
Proxy stride, Proxy padding)
-> Tensor device dtype '[batchSize, outputChannelSize, outputSize]
forward Conv1d inputChannelSize outputChannelSize kernelSize dtype device
model (Tensor device dtype '[batchSize, inputChannelSize, inputSize]
input, Proxy stride
Proxy, Proxy padding
Proxy) = forall (stride :: Nat) (padding :: Nat) {kernelSize :: Nat}
{inputSize :: Nat} {inputChannelSize :: Nat}
{outputChannelSize :: Nat} {batchSize :: Nat} {w :: DType}
{w :: (DeviceType, Nat)}.
(OrdCond
(CmpNat kernelSize (inputSize + (2 * padding))) 'True 'True 'False
~ 'True,
OrdCond
(CmpNat (kernelSize - 1) (inputSize + (2 * padding)))
'True
'True
'False
~ 'True,
OrdCond (CmpNat 1 stride) 'True 'True 'False ~ 'True,
OrdCond (CmpNat 1 kernelSize) 'True 'True 'False ~ 'True,
KnownNat padding, KnownNat inputSize, KnownNat kernelSize,
KnownNat stride, KnownNat inputChannelSize,
KnownNat outputChannelSize, KnownNat batchSize) =>
Conv1d inputChannelSize outputChannelSize kernelSize w w
-> Tensor w w '[batchSize, inputChannelSize, inputSize]
-> Tensor
w
w
'[batchSize, outputChannelSize,
Div ((inputSize + (2 * padding)) - kernelSize) stride + 1]
conv1dForward @stride @padding Conv1d inputChannelSize outputChannelSize kernelSize dtype device
model Tensor device dtype '[batchSize, inputChannelSize, inputSize]
input
forwardStoch :: Conv1d inputChannelSize outputChannelSize kernelSize dtype device
-> (Tensor device dtype '[batchSize, inputChannelSize, inputSize],
Proxy stride, Proxy padding)
-> IO
(Tensor device dtype '[batchSize, outputChannelSize, outputSize])
forwardStoch = (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
.) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall f a b. HasForward f a b => f -> a -> b
forward
instance
( KnownNat inputChannelSize,
KnownNat outputChannelSize,
KnownNat kernelSize,
KnownDType dtype,
KnownDevice device,
RandDTypeIsValid device dtype
) =>
Randomizable
(Conv1dSpec inputChannelSize outputChannelSize kernelSize dtype device)
(Conv1d inputChannelSize outputChannelSize kernelSize dtype device)
where
sample :: Conv1dSpec
inputChannelSize outputChannelSize kernelSize dtype device
-> IO
(Conv1d inputChannelSize outputChannelSize kernelSize dtype device)
sample Conv1dSpec
inputChannelSize outputChannelSize kernelSize dtype device
Conv1dSpec =
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Parameter
device dtype '[outputChannelSize, inputChannelSize, kernelSize]
-> Parameter device dtype '[outputChannelSize]
-> Conv1d
inputChannelSize outputChannelSize kernelSize dtype device
Conv1d forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(TensorOptions shape dtype device,
RandDTypeIsValid device dtype) =>
IO (Tensor device dtype shape)
randn) forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(TensorOptions shape dtype device,
RandDTypeIsValid device dtype) =>
IO (Tensor device dtype shape)
randn)
data
Conv2dSpec
(inputChannelSize :: Nat)
(outputChannelSize :: Nat)
(kernelSize0 :: Nat)
(kernelSize1 :: Nat)
(dtype :: D.DType)
(device :: (D.DeviceType, Nat))
= Conv2dSpec
deriving (Int
-> Conv2dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
Int
-> Conv2dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
[Conv2dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device]
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
Conv2dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Conv2dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device]
-> ShowS
$cshowList :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
[Conv2dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device]
-> ShowS
show :: Conv2dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> String
$cshow :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
Conv2dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> String
showsPrec :: Int
-> Conv2dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> ShowS
$cshowsPrec :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
Int
-> Conv2dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> ShowS
Show, Conv2dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> Conv2dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> Bool
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
Conv2dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> Conv2dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Conv2dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> Conv2dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> Bool
$c/= :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
Conv2dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> Conv2dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> Bool
== :: Conv2dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> Conv2dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> Bool
$c== :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
Conv2dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> Conv2dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> Bool
Eq)
data
Conv2d
(inputChannelSize :: Nat)
(outputChannelSize :: Nat)
(kernelSize0 :: Nat)
(kernelSize1 :: Nat)
(dtype :: D.DType)
(device :: (D.DeviceType, Nat))
where
Conv2d ::
forall inputChannelSize outputChannelSize kernelSize0 kernelSize1 dtype device.
{ forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
Conv2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> Parameter
device
dtype
'[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1]
weight :: Parameter device dtype '[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1],
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
Conv2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> Parameter device dtype '[outputChannelSize]
bias :: Parameter device dtype '[outputChannelSize]
} ->
Conv2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
deriving (Int
-> Conv2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
Int
-> Conv2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
[Conv2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device]
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
Conv2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Conv2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device]
-> ShowS
$cshowList :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
[Conv2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device]
-> ShowS
show :: Conv2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> String
$cshow :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
Conv2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> String
showsPrec :: Int
-> Conv2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> ShowS
$cshowsPrec :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
Int
-> Conv2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> ShowS
Show, forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)) x.
Rep
(Conv2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device)
x
-> Conv2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)) x.
Conv2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> Rep
(Conv2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device)
x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)) x.
Rep
(Conv2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device)
x
-> Conv2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
$cfrom :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)) x.
Conv2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> Rep
(Conv2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device)
x
Generic, forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
Conv2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> HList
(Parameters
(Conv2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device))
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
Conv2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> HList
(Parameters
(Conv2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device))
-> Conv2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
forall f.
(f -> HList (Parameters f))
-> (f -> HList (Parameters f) -> f) -> Parameterized f
replaceParameters :: Conv2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> HList
(Parameters
(Conv2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device))
-> Conv2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
$creplaceParameters :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
Conv2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> HList
(Parameters
(Conv2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device))
-> Conv2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
flattenParameters :: Conv2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> HList
(Parameters
(Conv2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device))
$cflattenParameters :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
Conv2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> HList
(Parameters
(Conv2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device))
Parameterized)
conv2dForward ::
forall stride padding.
_ =>
Conv2d _ _ _ _ _ _ ->
Tensor _ _ _ ->
Tensor _ _ _
conv2dForward :: Conv2d
inputChannelSize outputChannelSize kernelSize0 kernelSize1 w w
-> Tensor
w w '[batchSize, inputChannelSize, inputSize0, inputSize1]
-> Tensor
w
w
'[batchSize, outputChannelSize,
Div ((inputSize0 + (2 * Fst padding)) - kernelSize0) (Fst stride)
+ 1,
Div ((inputSize1 + (2 * Snd padding)) - kernelSize1) (Snd stride)
+ 1]
conv2dForward Conv2d {Parameter
w
w
'[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1]
Parameter w w '[outputChannelSize]
bias :: Parameter w w '[outputChannelSize]
weight :: Parameter
w
w
'[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1]
$sel:bias:Conv2d :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
Conv2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> Parameter device dtype '[outputChannelSize]
$sel:weight:Conv2d :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
Conv2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> Parameter
device
dtype
'[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1]
..} Tensor w w '[batchSize, inputChannelSize, inputSize0, inputSize1]
input =
forall (stride :: (Nat, Nat)) (padding :: (Nat, Nat))
(inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (inputSize0 :: Nat)
(inputSize1 :: Nat) (batchSize :: Nat) (outputSize0 :: Nat)
(outputSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
(All
KnownNat
'[Fst stride, Snd stride, Fst padding, Snd padding,
inputChannelSize, outputChannelSize, kernelSize0, kernelSize1,
inputSize0, inputSize1, batchSize, outputSize0, outputSize1],
ConvSideCheck
inputSize0 kernelSize0 (Fst stride) (Fst padding) outputSize0,
ConvSideCheck
inputSize1 kernelSize1 (Snd stride) (Snd padding) outputSize1) =>
Tensor
device
dtype
'[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1]
-> Tensor device dtype '[outputChannelSize]
-> Tensor
device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1]
-> Tensor
device
dtype
'[batchSize, outputChannelSize, outputSize0, outputSize1]
conv2d @stride @padding
(forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Parameter device dtype shape -> Tensor device dtype shape
toDependent Parameter
w
w
'[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1]
weight)
(forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Parameter device dtype shape -> Tensor device dtype shape
toDependent Parameter w w '[outputChannelSize]
bias)
Tensor w w '[batchSize, inputChannelSize, inputSize0, inputSize1]
input
instance
( All
KnownNat
'[ Torch.Typed.Auxiliary.Fst stride,
Torch.Typed.Auxiliary.Snd stride,
Torch.Typed.Auxiliary.Fst padding,
Torch.Typed.Auxiliary.Snd padding,
inputChannelSize,
outputChannelSize,
kernelSize0,
kernelSize1,
inputSize0,
inputSize1,
batchSize,
outputSize0,
outputSize1
],
ConvSideCheck inputSize0 kernelSize0 (Torch.Typed.Auxiliary.Fst stride) (Torch.Typed.Auxiliary.Fst padding) outputSize0,
ConvSideCheck inputSize1 kernelSize1 (Torch.Typed.Auxiliary.Snd stride) (Torch.Typed.Auxiliary.Snd padding) outputSize1
) =>
HasForward (Conv2d inputChannelSize outputChannelSize kernelSize0 kernelSize1 dtype device) (Tensor device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1], Proxy stride, Proxy padding) (Tensor device dtype '[batchSize, outputChannelSize, outputSize0, outputSize1])
where
forward :: Conv2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> (Tensor
device
dtype
'[batchSize, inputChannelSize, inputSize0, inputSize1],
Proxy stride, Proxy padding)
-> Tensor
device
dtype
'[batchSize, outputChannelSize, outputSize0, outputSize1]
forward Conv2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
model (Tensor
device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1]
input, Proxy stride
Proxy, Proxy padding
Proxy) = forall (stride :: (Nat, Nat)) (padding :: (Nat, Nat))
{kernelSize1 :: Nat} {inputSize1 :: Nat} {kernelSize0 :: Nat}
{inputSize0 :: Nat} {inputChannelSize :: Nat}
{outputChannelSize :: Nat} {batchSize :: Nat} {w :: DType}
{w :: (DeviceType, Nat)}.
(OrdCond
(CmpNat kernelSize1 (inputSize1 + (2 * Snd padding)))
'True
'True
'False
~ 'True,
OrdCond
(CmpNat kernelSize0 (inputSize0 + (2 * Fst padding)))
'True
'True
'False
~ 'True,
OrdCond
(CmpNat (kernelSize1 - 1) (inputSize1 + (2 * Snd padding)))
'True
'True
'False
~ 'True,
OrdCond
(CmpNat (kernelSize0 - 1) (inputSize0 + (2 * Fst padding)))
'True
'True
'False
~ 'True,
OrdCond (CmpNat 1 kernelSize1) 'True 'True 'False ~ 'True,
OrdCond (CmpNat 1 kernelSize0) 'True 'True 'False ~ 'True,
OrdCond (CmpNat 1 (Snd stride)) 'True 'True 'False ~ 'True,
OrdCond (CmpNat 1 (Fst stride)) 'True 'True 'False ~ 'True,
KnownNat inputSize1, KnownNat inputSize0, KnownNat kernelSize1,
KnownNat kernelSize0, KnownNat inputChannelSize,
KnownNat outputChannelSize, KnownNat batchSize,
KnownNat (Fst padding), KnownNat (Fst stride),
KnownNat (Snd padding), KnownNat (Snd stride)) =>
Conv2d
inputChannelSize outputChannelSize kernelSize0 kernelSize1 w w
-> Tensor
w w '[batchSize, inputChannelSize, inputSize0, inputSize1]
-> Tensor
w
w
'[batchSize, outputChannelSize,
Div ((inputSize0 + (2 * Fst padding)) - kernelSize0) (Fst stride)
+ 1,
Div ((inputSize1 + (2 * Snd padding)) - kernelSize1) (Snd stride)
+ 1]
conv2dForward @stride @padding Conv2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
model Tensor
device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1]
input
forwardStoch :: Conv2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> (Tensor
device
dtype
'[batchSize, inputChannelSize, inputSize0, inputSize1],
Proxy stride, Proxy padding)
-> IO
(Tensor
device
dtype
'[batchSize, outputChannelSize, outputSize0, outputSize1])
forwardStoch = (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
.) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall f a b. HasForward f a b => f -> a -> b
forward
instance
( KnownNat inputChannelSize,
KnownNat outputChannelSize,
KnownNat kernelSize0,
KnownNat kernelSize1,
KnownDType dtype,
KnownDevice device,
RandDTypeIsValid device dtype
) =>
Randomizable
(Conv2dSpec inputChannelSize outputChannelSize kernelSize0 kernelSize1 dtype device)
(Conv2d inputChannelSize outputChannelSize kernelSize0 kernelSize1 dtype device)
where
sample :: Conv2dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> IO
(Conv2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device)
sample Conv2dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
Conv2dSpec =
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
Parameter
device
dtype
'[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1]
-> Parameter device dtype '[outputChannelSize]
-> Conv2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
Conv2d forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(TensorOptions shape dtype device,
RandDTypeIsValid device dtype) =>
IO (Tensor device dtype shape)
randn) forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(TensorOptions shape dtype device,
RandDTypeIsValid device dtype) =>
IO (Tensor device dtype shape)
randn)
data
Conv3dSpec
(inputChannelSize :: Nat)
(outputChannelSize :: Nat)
(kernelSize0 :: Nat)
(kernelSize1 :: Nat)
(kernelSize2 :: Nat)
(dtype :: D.DType)
(device :: (D.DeviceType, Nat))
= Conv3dSpec
deriving (Int
-> Conv3dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
Int
-> Conv3dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
[Conv3dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device]
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
Conv3dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Conv3dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device]
-> ShowS
$cshowList :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
[Conv3dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device]
-> ShowS
show :: Conv3dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> String
$cshow :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
Conv3dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> String
showsPrec :: Int
-> Conv3dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> ShowS
$cshowsPrec :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
Int
-> Conv3dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> ShowS
Show, Conv3dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> Conv3dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> Bool
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
Conv3dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> Conv3dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Conv3dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> Conv3dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> Bool
$c/= :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
Conv3dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> Conv3dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> Bool
== :: Conv3dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> Conv3dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> Bool
$c== :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
Conv3dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> Conv3dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> Bool
Eq)
data
Conv3d
(inputChannelSize :: Nat)
(outputChannelSize :: Nat)
(kernelSize0 :: Nat)
(kernelSize1 :: Nat)
(kernelSize2 :: Nat)
(dtype :: D.DType)
(device :: (D.DeviceType, Nat))
where
Conv3d ::
forall inputChannelSize outputChannelSize kernelSize0 kernelSize1 kernelSize2 dtype device.
{ forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
Conv3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> Parameter
device
dtype
'[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1,
kernelSize2]
weight :: Parameter device dtype '[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1, kernelSize2],
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
Conv3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> Parameter device dtype '[outputChannelSize]
bias :: Parameter device dtype '[outputChannelSize]
} ->
Conv3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
deriving (Int
-> Conv3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
Int
-> Conv3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
[Conv3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device]
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
Conv3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Conv3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device]
-> ShowS
$cshowList :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
[Conv3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device]
-> ShowS
show :: Conv3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> String
$cshow :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
Conv3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> String
showsPrec :: Int
-> Conv3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> ShowS
$cshowsPrec :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
Int
-> Conv3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> ShowS
Show, forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)) x.
Rep
(Conv3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device)
x
-> Conv3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)) x.
Conv3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> Rep
(Conv3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device)
x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)) x.
Rep
(Conv3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device)
x
-> Conv3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
$cfrom :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)) x.
Conv3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> Rep
(Conv3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device)
x
Generic, forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
Conv3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> HList
(Parameters
(Conv3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device))
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
Conv3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> HList
(Parameters
(Conv3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device))
-> Conv3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
forall f.
(f -> HList (Parameters f))
-> (f -> HList (Parameters f) -> f) -> Parameterized f
replaceParameters :: Conv3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> HList
(Parameters
(Conv3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device))
-> Conv3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
$creplaceParameters :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
Conv3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> HList
(Parameters
(Conv3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device))
-> Conv3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
flattenParameters :: Conv3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> HList
(Parameters
(Conv3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device))
$cflattenParameters :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
Conv3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> HList
(Parameters
(Conv3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device))
Parameterized)
conv3dForward ::
forall stride padding.
_ =>
Conv3d _ _ _ _ _ _ _ ->
Tensor _ _ _ ->
Tensor _ _ _
conv3dForward :: Conv3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
w
w
-> Tensor
w
w
'[batchSize, inputChannelSize, inputSize0, inputSize1, inputSize2]
-> Tensor
w
w
'[batchSize, outputChannelSize,
Div ((inputSize0 + (2 * Fst3 padding)) - kernelSize0) (Fst3 stride)
+ 1,
Div ((inputSize1 + (2 * Snd3 padding)) - kernelSize1) (Snd3 stride)
+ 1,
Div ((inputSize2 + (2 * Trd3 padding)) - kernelSize2) (Trd3 stride)
+ 1]
conv3dForward Conv3d {Parameter
w
w
'[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1,
kernelSize2]
Parameter w w '[outputChannelSize]
bias :: Parameter w w '[outputChannelSize]
weight :: Parameter
w
w
'[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1,
kernelSize2]
$sel:bias:Conv3d :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
Conv3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> Parameter device dtype '[outputChannelSize]
$sel:weight:Conv3d :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
Conv3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> Parameter
device
dtype
'[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1,
kernelSize2]
..} Tensor
w
w
'[batchSize, inputChannelSize, inputSize0, inputSize1, inputSize2]
input =
forall (stride :: (Nat, Nat, Nat)) (padding :: (Nat, Nat, Nat))
(inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
(inputSize0 :: Nat) (inputSize1 :: Nat) (inputSize2 :: Nat)
(batchSize :: Nat) (outputSize0 :: Nat) (outputSize1 :: Nat)
(outputSize2 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
(All
KnownNat
'[Fst3 stride, Snd3 stride, Trd3 stride, Fst3 padding,
Snd3 padding, Trd3 padding, inputChannelSize, outputChannelSize,
kernelSize0, kernelSize1, kernelSize2, inputSize0, inputSize1,
inputSize2, batchSize],
ConvSideCheck
inputSize0 kernelSize0 (Fst3 stride) (Fst3 padding) outputSize0,
ConvSideCheck
inputSize1 kernelSize1 (Snd3 stride) (Snd3 padding) outputSize1,
ConvSideCheck
inputSize2 kernelSize2 (Trd3 stride) (Trd3 padding) outputSize2) =>
Tensor
device
dtype
'[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1,
kernelSize2]
-> Tensor device dtype '[outputChannelSize]
-> Tensor
device
dtype
'[batchSize, inputChannelSize, inputSize0, inputSize1, inputSize2]
-> Tensor
device
dtype
'[batchSize, outputChannelSize, outputSize0, outputSize1,
outputSize2]
conv3d @stride @padding
(forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Parameter device dtype shape -> Tensor device dtype shape
toDependent Parameter
w
w
'[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1,
kernelSize2]
weight)
(forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Parameter device dtype shape -> Tensor device dtype shape
toDependent Parameter w w '[outputChannelSize]
bias)
Tensor
w
w
'[batchSize, inputChannelSize, inputSize0, inputSize1, inputSize2]
input
instance
( All
KnownNat
'[ Fst3 stride,
Snd3 stride,
Trd3 stride,
Fst3 padding,
Snd3 padding,
Trd3 padding,
inputChannelSize,
outputChannelSize,
kernelSize0,
kernelSize1,
kernelSize2,
inputSize0,
inputSize1,
inputSize2,
batchSize
],
ConvSideCheck inputSize0 kernelSize0 (Fst3 stride) (Fst3 padding) outputSize0,
ConvSideCheck inputSize1 kernelSize1 (Snd3 stride) (Snd3 padding) outputSize1,
ConvSideCheck inputSize2 kernelSize2 (Trd3 stride) (Trd3 padding) outputSize2
) =>
HasForward (Conv3d inputChannelSize outputChannelSize kernelSize0 kernelSize1 kernelSize2 dtype device) (Tensor device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1, inputSize2], Proxy stride, Proxy padding) (Tensor device dtype '[batchSize, outputChannelSize, outputSize0, outputSize1, outputSize2])
where
forward :: Conv3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> (Tensor
device
dtype
'[batchSize, inputChannelSize, inputSize0, inputSize1, inputSize2],
Proxy stride, Proxy padding)
-> Tensor
device
dtype
'[batchSize, outputChannelSize, outputSize0, outputSize1,
outputSize2]
forward Conv3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
model (Tensor
device
dtype
'[batchSize, inputChannelSize, inputSize0, inputSize1, inputSize2]
input, Proxy stride
Proxy, Proxy padding
Proxy) = forall (stride :: (Nat, Nat, Nat)) (padding :: (Nat, Nat, Nat))
{kernelSize2 :: Nat} {inputSize2 :: Nat} {kernelSize1 :: Nat}
{inputSize1 :: Nat} {kernelSize0 :: Nat} {inputSize0 :: Nat}
{inputChannelSize :: Nat} {outputChannelSize :: Nat}
{batchSize :: Nat} {w :: DType} {w :: (DeviceType, Nat)}.
(OrdCond
(CmpNat kernelSize2 (inputSize2 + (2 * Trd3 padding)))
'True
'True
'False
~ 'True,
OrdCond
(CmpNat kernelSize1 (inputSize1 + (2 * Snd3 padding)))
'True
'True
'False
~ 'True,
OrdCond
(CmpNat kernelSize0 (inputSize0 + (2 * Fst3 padding)))
'True
'True
'False
~ 'True,
OrdCond
(CmpNat (kernelSize1 - 1) (inputSize1 + (2 * Snd3 padding)))
'True
'True
'False
~ 'True,
OrdCond
(CmpNat (kernelSize0 - 1) (inputSize0 + (2 * Fst3 padding)))
'True
'True
'False
~ 'True,
OrdCond
(CmpNat (kernelSize2 - 1) (inputSize2 + (2 * Trd3 padding)))
'True
'True
'False
~ 'True,
OrdCond (CmpNat 1 kernelSize2) 'True 'True 'False ~ 'True,
OrdCond (CmpNat 1 kernelSize0) 'True 'True 'False ~ 'True,
OrdCond (CmpNat 1 kernelSize1) 'True 'True 'False ~ 'True,
OrdCond (CmpNat 1 (Trd3 stride)) 'True 'True 'False ~ 'True,
OrdCond (CmpNat 1 (Snd3 stride)) 'True 'True 'False ~ 'True,
OrdCond (CmpNat 1 (Fst3 stride)) 'True 'True 'False ~ 'True,
KnownNat inputSize2, KnownNat inputSize1, KnownNat inputSize0,
KnownNat kernelSize2, KnownNat kernelSize1, KnownNat kernelSize0,
KnownNat inputChannelSize, KnownNat outputChannelSize,
KnownNat batchSize, KnownNat (Fst3 padding),
KnownNat (Fst3 stride), KnownNat (Snd3 padding),
KnownNat (Snd3 stride), KnownNat (Trd3 padding),
KnownNat (Trd3 stride)) =>
Conv3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
w
w
-> Tensor
w
w
'[batchSize, inputChannelSize, inputSize0, inputSize1, inputSize2]
-> Tensor
w
w
'[batchSize, outputChannelSize,
Div ((inputSize0 + (2 * Fst3 padding)) - kernelSize0) (Fst3 stride)
+ 1,
Div ((inputSize1 + (2 * Snd3 padding)) - kernelSize1) (Snd3 stride)
+ 1,
Div ((inputSize2 + (2 * Trd3 padding)) - kernelSize2) (Trd3 stride)
+ 1]
conv3dForward @stride @padding Conv3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
model Tensor
device
dtype
'[batchSize, inputChannelSize, inputSize0, inputSize1, inputSize2]
input
forwardStoch :: Conv3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> (Tensor
device
dtype
'[batchSize, inputChannelSize, inputSize0, inputSize1, inputSize2],
Proxy stride, Proxy padding)
-> IO
(Tensor
device
dtype
'[batchSize, outputChannelSize, outputSize0, outputSize1,
outputSize2])
forwardStoch = (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
.) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall f a b. HasForward f a b => f -> a -> b
forward
instance
( KnownNat inputChannelSize,
KnownNat outputChannelSize,
KnownNat kernelSize0,
KnownNat kernelSize1,
KnownNat kernelSize2,
KnownDType dtype,
KnownDevice device,
RandDTypeIsValid device dtype
) =>
Randomizable
(Conv3dSpec inputChannelSize outputChannelSize kernelSize0 kernelSize1 kernelSize2 dtype device)
(Conv3d inputChannelSize outputChannelSize kernelSize0 kernelSize1 kernelSize2 dtype device)
where
sample :: Conv3dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> IO
(Conv3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device)
sample Conv3dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
Conv3dSpec =
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
Parameter
device
dtype
'[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1,
kernelSize2]
-> Parameter device dtype '[outputChannelSize]
-> Conv3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
Conv3d forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(TensorOptions shape dtype device,
RandDTypeIsValid device dtype) =>
IO (Tensor device dtype shape)
randn) forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(TensorOptions shape dtype device,
RandDTypeIsValid device dtype) =>
IO (Tensor device dtype shape)
randn)
data
ConvTranspose1dSpec
(inputChannelSize :: Nat)
(outputChannelSize :: Nat)
(kernelSize :: Nat)
(dtype :: D.DType)
(device :: (D.DeviceType, Nat))
= ConvTranspose1dSpec
deriving (Int
-> ConvTranspose1dSpec
inputChannelSize outputChannelSize kernelSize dtype device
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Int
-> ConvTranspose1dSpec
inputChannelSize outputChannelSize kernelSize dtype device
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
[ConvTranspose1dSpec
inputChannelSize outputChannelSize kernelSize dtype device]
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
ConvTranspose1dSpec
inputChannelSize outputChannelSize kernelSize dtype device
-> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ConvTranspose1dSpec
inputChannelSize outputChannelSize kernelSize dtype device]
-> ShowS
$cshowList :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
[ConvTranspose1dSpec
inputChannelSize outputChannelSize kernelSize dtype device]
-> ShowS
show :: ConvTranspose1dSpec
inputChannelSize outputChannelSize kernelSize dtype device
-> String
$cshow :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
ConvTranspose1dSpec
inputChannelSize outputChannelSize kernelSize dtype device
-> String
showsPrec :: Int
-> ConvTranspose1dSpec
inputChannelSize outputChannelSize kernelSize dtype device
-> ShowS
$cshowsPrec :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Int
-> ConvTranspose1dSpec
inputChannelSize outputChannelSize kernelSize dtype device
-> ShowS
Show, ConvTranspose1dSpec
inputChannelSize outputChannelSize kernelSize dtype device
-> ConvTranspose1dSpec
inputChannelSize outputChannelSize kernelSize dtype device
-> Bool
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
ConvTranspose1dSpec
inputChannelSize outputChannelSize kernelSize dtype device
-> ConvTranspose1dSpec
inputChannelSize outputChannelSize kernelSize dtype device
-> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ConvTranspose1dSpec
inputChannelSize outputChannelSize kernelSize dtype device
-> ConvTranspose1dSpec
inputChannelSize outputChannelSize kernelSize dtype device
-> Bool
$c/= :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
ConvTranspose1dSpec
inputChannelSize outputChannelSize kernelSize dtype device
-> ConvTranspose1dSpec
inputChannelSize outputChannelSize kernelSize dtype device
-> Bool
== :: ConvTranspose1dSpec
inputChannelSize outputChannelSize kernelSize dtype device
-> ConvTranspose1dSpec
inputChannelSize outputChannelSize kernelSize dtype device
-> Bool
$c== :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
ConvTranspose1dSpec
inputChannelSize outputChannelSize kernelSize dtype device
-> ConvTranspose1dSpec
inputChannelSize outputChannelSize kernelSize dtype device
-> Bool
Eq)
data
ConvTranspose1d
(inputChannelSize :: Nat)
(outputChannelSize :: Nat)
(kernelSize :: Nat)
(dtype :: D.DType)
(device :: (D.DeviceType, Nat))
where
ConvTranspose1d ::
forall inputChannelSize outputChannelSize kernelSize dtype device.
{ forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
ConvTranspose1d
inputChannelSize outputChannelSize kernelSize dtype device
-> Parameter
device dtype '[inputChannelSize, outputChannelSize, kernelSize]
weight :: Parameter device dtype '[inputChannelSize, outputChannelSize, kernelSize],
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
ConvTranspose1d
inputChannelSize outputChannelSize kernelSize dtype device
-> Parameter device dtype '[outputChannelSize]
bias :: Parameter device dtype '[outputChannelSize]
} ->
ConvTranspose1d
inputChannelSize
outputChannelSize
kernelSize
dtype
device
deriving (Int
-> ConvTranspose1d
inputChannelSize outputChannelSize kernelSize dtype device
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Int
-> ConvTranspose1d
inputChannelSize outputChannelSize kernelSize dtype device
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
[ConvTranspose1d
inputChannelSize outputChannelSize kernelSize dtype device]
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
ConvTranspose1d
inputChannelSize outputChannelSize kernelSize dtype device
-> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ConvTranspose1d
inputChannelSize outputChannelSize kernelSize dtype device]
-> ShowS
$cshowList :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
[ConvTranspose1d
inputChannelSize outputChannelSize kernelSize dtype device]
-> ShowS
show :: ConvTranspose1d
inputChannelSize outputChannelSize kernelSize dtype device
-> String
$cshow :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
ConvTranspose1d
inputChannelSize outputChannelSize kernelSize dtype device
-> String
showsPrec :: Int
-> ConvTranspose1d
inputChannelSize outputChannelSize kernelSize dtype device
-> ShowS
$cshowsPrec :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Int
-> ConvTranspose1d
inputChannelSize outputChannelSize kernelSize dtype device
-> ShowS
Show, forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat))
x.
Rep
(ConvTranspose1d
inputChannelSize outputChannelSize kernelSize dtype device)
x
-> ConvTranspose1d
inputChannelSize outputChannelSize kernelSize dtype device
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat))
x.
ConvTranspose1d
inputChannelSize outputChannelSize kernelSize dtype device
-> Rep
(ConvTranspose1d
inputChannelSize outputChannelSize kernelSize dtype device)
x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat))
x.
Rep
(ConvTranspose1d
inputChannelSize outputChannelSize kernelSize dtype device)
x
-> ConvTranspose1d
inputChannelSize outputChannelSize kernelSize dtype device
$cfrom :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat))
x.
ConvTranspose1d
inputChannelSize outputChannelSize kernelSize dtype device
-> Rep
(ConvTranspose1d
inputChannelSize outputChannelSize kernelSize dtype device)
x
Generic, forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
ConvTranspose1d
inputChannelSize outputChannelSize kernelSize dtype device
-> HList
(Parameters
(ConvTranspose1d
inputChannelSize outputChannelSize kernelSize dtype device))
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
ConvTranspose1d
inputChannelSize outputChannelSize kernelSize dtype device
-> HList
(Parameters
(ConvTranspose1d
inputChannelSize outputChannelSize kernelSize dtype device))
-> ConvTranspose1d
inputChannelSize outputChannelSize kernelSize dtype device
forall f.
(f -> HList (Parameters f))
-> (f -> HList (Parameters f) -> f) -> Parameterized f
replaceParameters :: ConvTranspose1d
inputChannelSize outputChannelSize kernelSize dtype device
-> HList
(Parameters
(ConvTranspose1d
inputChannelSize outputChannelSize kernelSize dtype device))
-> ConvTranspose1d
inputChannelSize outputChannelSize kernelSize dtype device
$creplaceParameters :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
ConvTranspose1d
inputChannelSize outputChannelSize kernelSize dtype device
-> HList
(Parameters
(ConvTranspose1d
inputChannelSize outputChannelSize kernelSize dtype device))
-> ConvTranspose1d
inputChannelSize outputChannelSize kernelSize dtype device
flattenParameters :: ConvTranspose1d
inputChannelSize outputChannelSize kernelSize dtype device
-> HList
(Parameters
(ConvTranspose1d
inputChannelSize outputChannelSize kernelSize dtype device))
$cflattenParameters :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
ConvTranspose1d
inputChannelSize outputChannelSize kernelSize dtype device
-> HList
(Parameters
(ConvTranspose1d
inputChannelSize outputChannelSize kernelSize dtype device))
Parameterized)
convTranspose1dForward ::
forall stride padding.
_ =>
ConvTranspose1d _ _ _ _ _ ->
Tensor _ _ _ ->
Tensor _ _ _
convTranspose1dForward :: ConvTranspose1d inputChannelSize outputChannelSize kernelSize w w
-> Tensor w w '[batchSize, inputChannelSize, inputSize]
-> Tensor
w
w
'[batchSize, outputChannelSize,
Div ((inputSize + (2 * padding)) - kernelSize) stride + 1]
convTranspose1dForward ConvTranspose1d {Parameter w w '[inputChannelSize, outputChannelSize, kernelSize]
Parameter w w '[outputChannelSize]
bias :: Parameter w w '[outputChannelSize]
weight :: Parameter w w '[inputChannelSize, outputChannelSize, kernelSize]
$sel:bias:ConvTranspose1d :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
ConvTranspose1d
inputChannelSize outputChannelSize kernelSize dtype device
-> Parameter device dtype '[outputChannelSize]
$sel:weight:ConvTranspose1d :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
ConvTranspose1d
inputChannelSize outputChannelSize kernelSize dtype device
-> Parameter
device dtype '[inputChannelSize, outputChannelSize, kernelSize]
..} Tensor w w '[batchSize, inputChannelSize, inputSize]
input =
forall (stride :: Nat) (padding :: Nat) (inputChannelSize :: Nat)
(outputChannelSize :: Nat) (kernelSize :: Nat) (inputSize :: Nat)
(batchSize :: Nat) (outputSize :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
(All
KnownNat
'[stride, padding, inputChannelSize, outputChannelSize, kernelSize,
inputSize, batchSize, outputSize],
ConvSideCheck inputSize kernelSize stride padding outputSize) =>
Tensor
device dtype '[inputChannelSize, outputChannelSize, kernelSize]
-> Tensor device dtype '[outputChannelSize]
-> Tensor device dtype '[batchSize, inputChannelSize, inputSize]
-> Tensor device dtype '[batchSize, outputChannelSize, outputSize]
convTranspose1d @stride @padding
(forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Parameter device dtype shape -> Tensor device dtype shape
toDependent Parameter w w '[inputChannelSize, outputChannelSize, kernelSize]
weight)
(forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Parameter device dtype shape -> Tensor device dtype shape
toDependent Parameter w w '[outputChannelSize]
bias)
Tensor w w '[batchSize, inputChannelSize, inputSize]
input
instance
( All
KnownNat
'[ stride,
padding,
inputChannelSize,
outputChannelSize,
kernelSize,
inputSize,
batchSize,
outputSize
],
ConvSideCheck inputSize kernelSize stride padding outputSize
) =>
HasForward (ConvTranspose1d inputChannelSize outputChannelSize kernelSize dtype device) (Tensor device dtype '[batchSize, inputChannelSize, inputSize], Proxy stride, Proxy padding) (Tensor device dtype '[batchSize, outputChannelSize, outputSize])
where
forward :: ConvTranspose1d
inputChannelSize outputChannelSize kernelSize dtype device
-> (Tensor device dtype '[batchSize, inputChannelSize, inputSize],
Proxy stride, Proxy padding)
-> Tensor device dtype '[batchSize, outputChannelSize, outputSize]
forward ConvTranspose1d
inputChannelSize outputChannelSize kernelSize dtype device
model (Tensor device dtype '[batchSize, inputChannelSize, inputSize]
input, Proxy stride
Proxy, Proxy padding
Proxy) = forall (stride :: Nat) (padding :: Nat) {kernelSize :: Nat}
{inputSize :: Nat} {inputChannelSize :: Nat}
{outputChannelSize :: Nat} {batchSize :: Nat} {w :: DType}
{w :: (DeviceType, Nat)}.
(OrdCond
(CmpNat kernelSize (inputSize + (2 * padding))) 'True 'True 'False
~ 'True,
OrdCond
(CmpNat (kernelSize - 1) (inputSize + (2 * padding)))
'True
'True
'False
~ 'True,
OrdCond (CmpNat 1 stride) 'True 'True 'False ~ 'True,
OrdCond (CmpNat 1 kernelSize) 'True 'True 'False ~ 'True,
KnownNat padding, KnownNat inputSize, KnownNat kernelSize,
KnownNat stride, KnownNat inputChannelSize,
KnownNat outputChannelSize, KnownNat batchSize) =>
ConvTranspose1d inputChannelSize outputChannelSize kernelSize w w
-> Tensor w w '[batchSize, inputChannelSize, inputSize]
-> Tensor
w
w
'[batchSize, outputChannelSize,
Div ((inputSize + (2 * padding)) - kernelSize) stride + 1]
convTranspose1dForward @stride @padding ConvTranspose1d
inputChannelSize outputChannelSize kernelSize dtype device
model Tensor device dtype '[batchSize, inputChannelSize, inputSize]
input
forwardStoch :: ConvTranspose1d
inputChannelSize outputChannelSize kernelSize dtype device
-> (Tensor device dtype '[batchSize, inputChannelSize, inputSize],
Proxy stride, Proxy padding)
-> IO
(Tensor device dtype '[batchSize, outputChannelSize, outputSize])
forwardStoch = (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
.) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall f a b. HasForward f a b => f -> a -> b
forward
instance
( KnownNat inputChannelSize,
KnownNat outputChannelSize,
KnownNat kernelSize,
KnownDType dtype,
KnownDevice device,
RandDTypeIsValid device dtype
) =>
Randomizable
(ConvTranspose1dSpec inputChannelSize outputChannelSize kernelSize dtype device)
(ConvTranspose1d inputChannelSize outputChannelSize kernelSize dtype device)
where
sample :: ConvTranspose1dSpec
inputChannelSize outputChannelSize kernelSize dtype device
-> IO
(ConvTranspose1d
inputChannelSize outputChannelSize kernelSize dtype device)
sample ConvTranspose1dSpec
inputChannelSize outputChannelSize kernelSize dtype device
ConvTranspose1dSpec =
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Parameter
device dtype '[inputChannelSize, outputChannelSize, kernelSize]
-> Parameter device dtype '[outputChannelSize]
-> ConvTranspose1d
inputChannelSize outputChannelSize kernelSize dtype device
ConvTranspose1d forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(TensorOptions shape dtype device,
RandDTypeIsValid device dtype) =>
IO (Tensor device dtype shape)
randn) forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(TensorOptions shape dtype device,
RandDTypeIsValid device dtype) =>
IO (Tensor device dtype shape)
randn)
data
ConvTranspose2dSpec
(inputChannelSize :: Nat)
(outputChannelSize :: Nat)
(kernelSize0 :: Nat)
(kernelSize1 :: Nat)
(dtype :: D.DType)
(device :: (D.DeviceType, Nat))
= ConvTranspose2dSpec
deriving (Int
-> ConvTranspose2dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
Int
-> ConvTranspose2dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
[ConvTranspose2dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device]
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
ConvTranspose2dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ConvTranspose2dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device]
-> ShowS
$cshowList :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
[ConvTranspose2dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device]
-> ShowS
show :: ConvTranspose2dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> String
$cshow :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
ConvTranspose2dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> String
showsPrec :: Int
-> ConvTranspose2dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> ShowS
$cshowsPrec :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
Int
-> ConvTranspose2dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> ShowS
Show, ConvTranspose2dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> ConvTranspose2dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> Bool
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
ConvTranspose2dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> ConvTranspose2dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ConvTranspose2dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> ConvTranspose2dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> Bool
$c/= :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
ConvTranspose2dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> ConvTranspose2dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> Bool
== :: ConvTranspose2dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> ConvTranspose2dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> Bool
$c== :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
ConvTranspose2dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> ConvTranspose2dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> Bool
Eq)
data
ConvTranspose2d
(inputChannelSize :: Nat)
(outputChannelSize :: Nat)
(kernelSize0 :: Nat)
(kernelSize1 :: Nat)
(dtype :: D.DType)
(device :: (D.DeviceType, Nat))
where
ConvTranspose2d ::
forall inputChannelSize outputChannelSize kernelSize0 kernelSize1 dtype device.
{ forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
ConvTranspose2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> Parameter
device
dtype
'[inputChannelSize, outputChannelSize, kernelSize0, kernelSize1]
weight :: Parameter device dtype '[inputChannelSize, outputChannelSize, kernelSize0, kernelSize1],
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
ConvTranspose2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> Parameter device dtype '[outputChannelSize]
bias :: Parameter device dtype '[outputChannelSize]
} ->
ConvTranspose2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
deriving (Int
-> ConvTranspose2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
Int
-> ConvTranspose2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
[ConvTranspose2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device]
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
ConvTranspose2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ConvTranspose2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device]
-> ShowS
$cshowList :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
[ConvTranspose2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device]
-> ShowS
show :: ConvTranspose2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> String
$cshow :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
ConvTranspose2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> String
showsPrec :: Int
-> ConvTranspose2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> ShowS
$cshowsPrec :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
Int
-> ConvTranspose2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> ShowS
Show, forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)) x.
Rep
(ConvTranspose2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device)
x
-> ConvTranspose2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)) x.
ConvTranspose2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> Rep
(ConvTranspose2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device)
x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)) x.
Rep
(ConvTranspose2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device)
x
-> ConvTranspose2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
$cfrom :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)) x.
ConvTranspose2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> Rep
(ConvTranspose2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device)
x
Generic, forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
ConvTranspose2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> HList
(Parameters
(ConvTranspose2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device))
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
ConvTranspose2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> HList
(Parameters
(ConvTranspose2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device))
-> ConvTranspose2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
forall f.
(f -> HList (Parameters f))
-> (f -> HList (Parameters f) -> f) -> Parameterized f
replaceParameters :: ConvTranspose2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> HList
(Parameters
(ConvTranspose2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device))
-> ConvTranspose2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
$creplaceParameters :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
ConvTranspose2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> HList
(Parameters
(ConvTranspose2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device))
-> ConvTranspose2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
flattenParameters :: ConvTranspose2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> HList
(Parameters
(ConvTranspose2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device))
$cflattenParameters :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
ConvTranspose2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> HList
(Parameters
(ConvTranspose2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device))
Parameterized)
convTranspose2dForward ::
forall stride padding.
_ =>
ConvTranspose2d _ _ _ _ _ _ ->
Tensor _ _ _ ->
Tensor _ _ _
convTranspose2dForward :: ConvTranspose2d
inputChannelSize outputChannelSize kernelSize0 kernelSize1 w w
-> Tensor
w w '[batchSize, inputChannelSize, inputSize0, inputSize1]
-> Tensor
w
w
'[batchSize, outputChannelSize,
Div ((inputSize0 + (2 * Fst padding)) - kernelSize0) (Fst stride)
+ 1,
Div ((inputSize1 + (2 * Snd padding)) - kernelSize1) (Snd stride)
+ 1]
convTranspose2dForward ConvTranspose2d {Parameter
w
w
'[inputChannelSize, outputChannelSize, kernelSize0, kernelSize1]
Parameter w w '[outputChannelSize]
bias :: Parameter w w '[outputChannelSize]
weight :: Parameter
w
w
'[inputChannelSize, outputChannelSize, kernelSize0, kernelSize1]
$sel:bias:ConvTranspose2d :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
ConvTranspose2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> Parameter device dtype '[outputChannelSize]
$sel:weight:ConvTranspose2d :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
ConvTranspose2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> Parameter
device
dtype
'[inputChannelSize, outputChannelSize, kernelSize0, kernelSize1]
..} Tensor w w '[batchSize, inputChannelSize, inputSize0, inputSize1]
input =
forall (stride :: (Nat, Nat)) (padding :: (Nat, Nat))
(inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (inputSize0 :: Nat)
(inputSize1 :: Nat) (batchSize :: Nat) (outputSize0 :: Nat)
(outputSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
(All
KnownNat
'[Fst stride, Snd stride, Fst padding, Snd padding,
inputChannelSize, outputChannelSize, kernelSize0, kernelSize1,
inputSize0, inputSize1, batchSize, outputSize0, outputSize1],
ConvSideCheck
inputSize0 kernelSize0 (Fst stride) (Fst padding) outputSize0,
ConvSideCheck
inputSize1 kernelSize1 (Snd stride) (Snd padding) outputSize1) =>
Tensor
device
dtype
'[inputChannelSize, outputChannelSize, kernelSize0, kernelSize1]
-> Tensor device dtype '[outputChannelSize]
-> Tensor
device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1]
-> Tensor
device
dtype
'[batchSize, outputChannelSize, outputSize0, outputSize1]
convTranspose2d @stride @padding
(forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Parameter device dtype shape -> Tensor device dtype shape
toDependent Parameter
w
w
'[inputChannelSize, outputChannelSize, kernelSize0, kernelSize1]
weight)
(forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Parameter device dtype shape -> Tensor device dtype shape
toDependent Parameter w w '[outputChannelSize]
bias)
Tensor w w '[batchSize, inputChannelSize, inputSize0, inputSize1]
input
instance
( All
KnownNat
'[ Torch.Typed.Auxiliary.Fst stride,
Torch.Typed.Auxiliary.Snd stride,
Torch.Typed.Auxiliary.Fst padding,
Torch.Typed.Auxiliary.Snd padding,
inputChannelSize,
outputChannelSize,
kernelSize0,
kernelSize1,
inputSize0,
inputSize1,
batchSize,
outputSize0,
outputSize1
],
ConvSideCheck inputSize0 kernelSize0 (Torch.Typed.Auxiliary.Fst stride) (Torch.Typed.Auxiliary.Fst padding) outputSize0,
ConvSideCheck inputSize1 kernelSize1 (Torch.Typed.Auxiliary.Snd stride) (Torch.Typed.Auxiliary.Snd padding) outputSize1
) =>
HasForward (ConvTranspose2d inputChannelSize outputChannelSize kernelSize0 kernelSize1 dtype device) (Tensor device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1], Proxy stride, Proxy padding) (Tensor device dtype '[batchSize, outputChannelSize, outputSize0, outputSize1])
where
forward :: ConvTranspose2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> (Tensor
device
dtype
'[batchSize, inputChannelSize, inputSize0, inputSize1],
Proxy stride, Proxy padding)
-> Tensor
device
dtype
'[batchSize, outputChannelSize, outputSize0, outputSize1]
forward ConvTranspose2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
model (Tensor
device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1]
input, Proxy stride
Proxy, Proxy padding
Proxy) = forall (stride :: (Nat, Nat)) (padding :: (Nat, Nat))
{kernelSize1 :: Nat} {inputSize1 :: Nat} {kernelSize0 :: Nat}
{inputSize0 :: Nat} {inputChannelSize :: Nat}
{outputChannelSize :: Nat} {batchSize :: Nat} {w :: DType}
{w :: (DeviceType, Nat)}.
(OrdCond
(CmpNat kernelSize1 (inputSize1 + (2 * Snd padding)))
'True
'True
'False
~ 'True,
OrdCond
(CmpNat kernelSize0 (inputSize0 + (2 * Fst padding)))
'True
'True
'False
~ 'True,
OrdCond
(CmpNat (kernelSize1 - 1) (inputSize1 + (2 * Snd padding)))
'True
'True
'False
~ 'True,
OrdCond
(CmpNat (kernelSize0 - 1) (inputSize0 + (2 * Fst padding)))
'True
'True
'False
~ 'True,
OrdCond (CmpNat 1 kernelSize1) 'True 'True 'False ~ 'True,
OrdCond (CmpNat 1 kernelSize0) 'True 'True 'False ~ 'True,
OrdCond (CmpNat 1 (Snd stride)) 'True 'True 'False ~ 'True,
OrdCond (CmpNat 1 (Fst stride)) 'True 'True 'False ~ 'True,
KnownNat inputSize1, KnownNat inputSize0, KnownNat kernelSize1,
KnownNat kernelSize0, KnownNat inputChannelSize,
KnownNat outputChannelSize, KnownNat batchSize,
KnownNat (Fst padding), KnownNat (Fst stride),
KnownNat (Snd padding), KnownNat (Snd stride)) =>
ConvTranspose2d
inputChannelSize outputChannelSize kernelSize0 kernelSize1 w w
-> Tensor
w w '[batchSize, inputChannelSize, inputSize0, inputSize1]
-> Tensor
w
w
'[batchSize, outputChannelSize,
Div ((inputSize0 + (2 * Fst padding)) - kernelSize0) (Fst stride)
+ 1,
Div ((inputSize1 + (2 * Snd padding)) - kernelSize1) (Snd stride)
+ 1]
convTranspose2dForward @stride @padding ConvTranspose2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
model Tensor
device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1]
input
forwardStoch :: ConvTranspose2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> (Tensor
device
dtype
'[batchSize, inputChannelSize, inputSize0, inputSize1],
Proxy stride, Proxy padding)
-> IO
(Tensor
device
dtype
'[batchSize, outputChannelSize, outputSize0, outputSize1])
forwardStoch = (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
.) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall f a b. HasForward f a b => f -> a -> b
forward
instance
( KnownNat inputChannelSize,
KnownNat outputChannelSize,
KnownNat kernelSize0,
KnownNat kernelSize1,
KnownDType dtype,
KnownDevice device,
RandDTypeIsValid device dtype
) =>
Randomizable
(ConvTranspose2dSpec inputChannelSize outputChannelSize kernelSize0 kernelSize1 dtype device)
(ConvTranspose2d inputChannelSize outputChannelSize kernelSize0 kernelSize1 dtype device)
where
sample :: ConvTranspose2dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> IO
(ConvTranspose2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device)
sample ConvTranspose2dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
ConvTranspose2dSpec =
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
Parameter
device
dtype
'[inputChannelSize, outputChannelSize, kernelSize0, kernelSize1]
-> Parameter device dtype '[outputChannelSize]
-> ConvTranspose2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
ConvTranspose2d forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(TensorOptions shape dtype device,
RandDTypeIsValid device dtype) =>
IO (Tensor device dtype shape)
randn) forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(TensorOptions shape dtype device,
RandDTypeIsValid device dtype) =>
IO (Tensor device dtype shape)
randn)
data
ConvTranspose3dSpec
(inputChannelSize :: Nat)
(outputChannelSize :: Nat)
(kernelSize0 :: Nat)
(kernelSize1 :: Nat)
(kernelSize2 :: Nat)
(dtype :: D.DType)
(device :: (D.DeviceType, Nat))
= ConvTranspose3dSpec
deriving (Int
-> ConvTranspose3dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
Int
-> ConvTranspose3dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
[ConvTranspose3dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device]
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
ConvTranspose3dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ConvTranspose3dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device]
-> ShowS
$cshowList :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
[ConvTranspose3dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device]
-> ShowS
show :: ConvTranspose3dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> String
$cshow :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
ConvTranspose3dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> String
showsPrec :: Int
-> ConvTranspose3dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> ShowS
$cshowsPrec :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
Int
-> ConvTranspose3dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> ShowS
Show, ConvTranspose3dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> ConvTranspose3dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> Bool
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
ConvTranspose3dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> ConvTranspose3dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ConvTranspose3dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> ConvTranspose3dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> Bool
$c/= :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
ConvTranspose3dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> ConvTranspose3dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> Bool
== :: ConvTranspose3dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> ConvTranspose3dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> Bool
$c== :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
ConvTranspose3dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> ConvTranspose3dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> Bool
Eq)
data
ConvTranspose3d
(inputChannelSize :: Nat)
(outputChannelSize :: Nat)
(kernelSize0 :: Nat)
(kernelSize1 :: Nat)
(kernelSize2 :: Nat)
(dtype :: D.DType)
(device :: (D.DeviceType, Nat))
where
ConvTranspose3d ::
forall inputChannelSize outputChannelSize kernelSize0 kernelSize1 kernelSize2 dtype device.
{ forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
ConvTranspose3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> Parameter
device
dtype
'[inputChannelSize, outputChannelSize, kernelSize0, kernelSize1,
kernelSize2]
weight :: Parameter device dtype '[inputChannelSize, outputChannelSize, kernelSize0, kernelSize1, kernelSize2],
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
ConvTranspose3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> Parameter device dtype '[outputChannelSize]
bias :: Parameter device dtype '[outputChannelSize]
} ->
ConvTranspose3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
deriving (Int
-> ConvTranspose3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
Int
-> ConvTranspose3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
[ConvTranspose3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device]
-> ShowS
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
ConvTranspose3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ConvTranspose3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device]
-> ShowS
$cshowList :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
[ConvTranspose3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device]
-> ShowS
show :: ConvTranspose3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> String
$cshow :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
ConvTranspose3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> String
showsPrec :: Int
-> ConvTranspose3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> ShowS
$cshowsPrec :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
Int
-> ConvTranspose3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> ShowS
Show, forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)) x.
Rep
(ConvTranspose3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device)
x
-> ConvTranspose3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)) x.
ConvTranspose3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> Rep
(ConvTranspose3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device)
x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)) x.
Rep
(ConvTranspose3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device)
x
-> ConvTranspose3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
$cfrom :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)) x.
ConvTranspose3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> Rep
(ConvTranspose3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device)
x
Generic, forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
ConvTranspose3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> HList
(Parameters
(ConvTranspose3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device))
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
ConvTranspose3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> HList
(Parameters
(ConvTranspose3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device))
-> ConvTranspose3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
forall f.
(f -> HList (Parameters f))
-> (f -> HList (Parameters f) -> f) -> Parameterized f
replaceParameters :: ConvTranspose3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> HList
(Parameters
(ConvTranspose3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device))
-> ConvTranspose3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
$creplaceParameters :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
ConvTranspose3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> HList
(Parameters
(ConvTranspose3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device))
-> ConvTranspose3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
flattenParameters :: ConvTranspose3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> HList
(Parameters
(ConvTranspose3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device))
$cflattenParameters :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
ConvTranspose3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> HList
(Parameters
(ConvTranspose3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device))
Parameterized)
convTranspose3dForward ::
forall stride padding.
_ =>
ConvTranspose3d _ _ _ _ _ _ _ ->
Tensor _ _ _ ->
Tensor _ _ _
convTranspose3dForward :: ConvTranspose3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
w
w
-> Tensor
w
w
'[batchSize, inputChannelSize, inputSize0, inputSize1, inputSize2]
-> Tensor
w
w
'[batchSize, outputChannelSize,
Div ((inputSize0 + (2 * Fst3 padding)) - kernelSize0) (Fst3 stride)
+ 1,
Div ((inputSize1 + (2 * Snd3 padding)) - kernelSize1) (Snd3 stride)
+ 1,
Div ((inputSize2 + (2 * Trd3 padding)) - kernelSize2) (Trd3 stride)
+ 1]
convTranspose3dForward ConvTranspose3d {Parameter
w
w
'[inputChannelSize, outputChannelSize, kernelSize0, kernelSize1,
kernelSize2]
Parameter w w '[outputChannelSize]
bias :: Parameter w w '[outputChannelSize]
weight :: Parameter
w
w
'[inputChannelSize, outputChannelSize, kernelSize0, kernelSize1,
kernelSize2]
$sel:bias:ConvTranspose3d :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
ConvTranspose3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> Parameter device dtype '[outputChannelSize]
$sel:weight:ConvTranspose3d :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
ConvTranspose3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> Parameter
device
dtype
'[inputChannelSize, outputChannelSize, kernelSize0, kernelSize1,
kernelSize2]
..} Tensor
w
w
'[batchSize, inputChannelSize, inputSize0, inputSize1, inputSize2]
input =
forall (stride :: (Nat, Nat, Nat)) (padding :: (Nat, Nat, Nat))
(inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
(inputSize0 :: Nat) (inputSize1 :: Nat) (inputSize2 :: Nat)
(batchSize :: Nat) (outputSize0 :: Nat) (outputSize1 :: Nat)
(outputSize2 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
(All
KnownNat
'[Fst3 stride, Snd3 stride, Trd3 stride, Fst3 padding,
Snd3 padding, Trd3 padding, inputChannelSize, outputChannelSize,
kernelSize0, kernelSize1, kernelSize2, inputSize0, inputSize1,
inputSize2, batchSize],
ConvSideCheck
inputSize0 kernelSize0 (Fst3 stride) (Fst3 padding) outputSize0,
ConvSideCheck
inputSize1 kernelSize1 (Snd3 stride) (Snd3 padding) outputSize1,
ConvSideCheck
inputSize2 kernelSize2 (Trd3 stride) (Trd3 padding) outputSize2) =>
Tensor
device
dtype
'[inputChannelSize, outputChannelSize, kernelSize0, kernelSize1,
kernelSize2]
-> Tensor device dtype '[outputChannelSize]
-> Tensor
device
dtype
'[batchSize, inputChannelSize, inputSize0, inputSize1, inputSize2]
-> Tensor
device
dtype
'[batchSize, outputChannelSize, outputSize0, outputSize1,
outputSize2]
convTranspose3d @stride @padding
(forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Parameter device dtype shape -> Tensor device dtype shape
toDependent Parameter
w
w
'[inputChannelSize, outputChannelSize, kernelSize0, kernelSize1,
kernelSize2]
weight)
(forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Parameter device dtype shape -> Tensor device dtype shape
toDependent Parameter w w '[outputChannelSize]
bias)
Tensor
w
w
'[batchSize, inputChannelSize, inputSize0, inputSize1, inputSize2]
input
instance
( All
KnownNat
'[ Fst3 stride,
Snd3 stride,
Trd3 stride,
Fst3 padding,
Snd3 padding,
Trd3 padding,
inputChannelSize,
outputChannelSize,
kernelSize0,
kernelSize1,
kernelSize2,
inputSize0,
inputSize1,
inputSize2,
batchSize
],
ConvSideCheck inputSize0 kernelSize0 (Fst3 stride) (Fst3 padding) outputSize0,
ConvSideCheck inputSize1 kernelSize1 (Snd3 stride) (Snd3 padding) outputSize1,
ConvSideCheck inputSize2 kernelSize2 (Trd3 stride) (Trd3 padding) outputSize2
) =>
HasForward (ConvTranspose3d inputChannelSize outputChannelSize kernelSize0 kernelSize1 kernelSize2 dtype device) (Tensor device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1, inputSize2], Proxy stride, Proxy padding) (Tensor device dtype '[batchSize, outputChannelSize, outputSize0, outputSize1, outputSize2])
where
forward :: ConvTranspose3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> (Tensor
device
dtype
'[batchSize, inputChannelSize, inputSize0, inputSize1, inputSize2],
Proxy stride, Proxy padding)
-> Tensor
device
dtype
'[batchSize, outputChannelSize, outputSize0, outputSize1,
outputSize2]
forward ConvTranspose3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
model (Tensor
device
dtype
'[batchSize, inputChannelSize, inputSize0, inputSize1, inputSize2]
input, Proxy stride
Proxy, Proxy padding
Proxy) = forall (stride :: (Nat, Nat, Nat)) (padding :: (Nat, Nat, Nat))
{kernelSize2 :: Nat} {inputSize2 :: Nat} {kernelSize1 :: Nat}
{inputSize1 :: Nat} {kernelSize0 :: Nat} {inputSize0 :: Nat}
{inputChannelSize :: Nat} {outputChannelSize :: Nat}
{batchSize :: Nat} {w :: DType} {w :: (DeviceType, Nat)}.
(OrdCond
(CmpNat kernelSize2 (inputSize2 + (2 * Trd3 padding)))
'True
'True
'False
~ 'True,
OrdCond
(CmpNat kernelSize1 (inputSize1 + (2 * Snd3 padding)))
'True
'True
'False
~ 'True,
OrdCond
(CmpNat kernelSize0 (inputSize0 + (2 * Fst3 padding)))
'True
'True
'False
~ 'True,
OrdCond
(CmpNat (kernelSize1 - 1) (inputSize1 + (2 * Snd3 padding)))
'True
'True
'False
~ 'True,
OrdCond
(CmpNat (kernelSize0 - 1) (inputSize0 + (2 * Fst3 padding)))
'True
'True
'False
~ 'True,
OrdCond
(CmpNat (kernelSize2 - 1) (inputSize2 + (2 * Trd3 padding)))
'True
'True
'False
~ 'True,
OrdCond (CmpNat 1 kernelSize2) 'True 'True 'False ~ 'True,
OrdCond (CmpNat 1 kernelSize0) 'True 'True 'False ~ 'True,
OrdCond (CmpNat 1 kernelSize1) 'True 'True 'False ~ 'True,
OrdCond (CmpNat 1 (Trd3 stride)) 'True 'True 'False ~ 'True,
OrdCond (CmpNat 1 (Snd3 stride)) 'True 'True 'False ~ 'True,
OrdCond (CmpNat 1 (Fst3 stride)) 'True 'True 'False ~ 'True,
KnownNat inputSize2, KnownNat inputSize1, KnownNat inputSize0,
KnownNat kernelSize2, KnownNat kernelSize1, KnownNat kernelSize0,
KnownNat inputChannelSize, KnownNat outputChannelSize,
KnownNat batchSize, KnownNat (Fst3 padding),
KnownNat (Fst3 stride), KnownNat (Snd3 padding),
KnownNat (Snd3 stride), KnownNat (Trd3 padding),
KnownNat (Trd3 stride)) =>
ConvTranspose3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
w
w
-> Tensor
w
w
'[batchSize, inputChannelSize, inputSize0, inputSize1, inputSize2]
-> Tensor
w
w
'[batchSize, outputChannelSize,
Div ((inputSize0 + (2 * Fst3 padding)) - kernelSize0) (Fst3 stride)
+ 1,
Div ((inputSize1 + (2 * Snd3 padding)) - kernelSize1) (Snd3 stride)
+ 1,
Div ((inputSize2 + (2 * Trd3 padding)) - kernelSize2) (Trd3 stride)
+ 1]
convTranspose3dForward @stride @padding ConvTranspose3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
model Tensor
device
dtype
'[batchSize, inputChannelSize, inputSize0, inputSize1, inputSize2]
input
forwardStoch :: ConvTranspose3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> (Tensor
device
dtype
'[batchSize, inputChannelSize, inputSize0, inputSize1, inputSize2],
Proxy stride, Proxy padding)
-> IO
(Tensor
device
dtype
'[batchSize, outputChannelSize, outputSize0, outputSize1,
outputSize2])
forwardStoch = (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
.) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall f a b. HasForward f a b => f -> a -> b
forward
instance
( KnownNat inputChannelSize,
KnownNat outputChannelSize,
KnownNat kernelSize0,
KnownNat kernelSize1,
KnownNat kernelSize2,
KnownDType dtype,
KnownDevice device,
RandDTypeIsValid device dtype
) =>
Randomizable
(ConvTranspose3dSpec inputChannelSize outputChannelSize kernelSize0 kernelSize1 kernelSize2 dtype device)
(ConvTranspose3d inputChannelSize outputChannelSize kernelSize0 kernelSize1 kernelSize2 dtype device)
where
sample :: ConvTranspose3dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
-> IO
(ConvTranspose3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device)
sample ConvTranspose3dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
ConvTranspose3dSpec =
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (kernelSize2 :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
Parameter
device
dtype
'[inputChannelSize, outputChannelSize, kernelSize0, kernelSize1,
kernelSize2]
-> Parameter device dtype '[outputChannelSize]
-> ConvTranspose3d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
kernelSize2
dtype
device
ConvTranspose3d forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(TensorOptions shape dtype device,
RandDTypeIsValid device dtype) =>
IO (Tensor device dtype shape)
randn) forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(TensorOptions shape dtype device,
RandDTypeIsValid device dtype) =>
IO (Tensor device dtype shape)
randn)