{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE DuplicateRecordFields #-}

module Torch.NN where

import Control.Applicative (Applicative (liftA2))
import Control.Monad.State.Strict
import Data.Foldable (toList)
import Data.Kind
import GHC.Generics
import System.IO.Unsafe (unsafePerformIO)
import Torch.Autograd
import Torch.Device
import Torch.Functional
import Torch.Initializers
import Torch.Internal.Cast (cast3)
import qualified Torch.Internal.Managed.Native as ATen
import qualified Torch.Internal.Managed.Type.Tensor as ATen
import Torch.Scalar
import Torch.Tensor
import Torch.TensorFactories (ones', randIO', randnIO', zeros')

type Parameter = IndependentTensor

type ParamStream a = State [Parameter] a

nextParameter :: ParamStream Parameter
nextParameter :: ParamStream Parameter
nextParameter = do
  [Parameter]
params <- forall s (m :: * -> *). MonadState s m => m s
get
  case [Parameter]
params of
    [] -> forall a. HasCallStack => [Char] -> a
error [Char]
"Not enough parameters supplied to replaceParameters"
    (Parameter
p : [Parameter]
t) -> do forall s (m :: * -> *). MonadState s m => s -> m ()
put [Parameter]
t; forall (m :: * -> *) a. Monad m => a -> m a
return Parameter
p

class HasForward f a b | f a -> b where
  forward :: f -> a -> b
  default forward ::
    ( Generic f,
      Generic a,
      Generic b,
      GHasForward (Rep f) (Rep a) (Rep b)
    ) =>
    f ->
    a ->
    b
  forward f
f a
a = forall a x. Generic a => Rep a x -> a
to forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) (a :: * -> *) (b :: * -> *) c c' c''.
GHasForward f a b =>
f c -> a c' -> b c''
gForward (forall a x. Generic a => a -> Rep a x
from f
f) (forall a x. Generic a => a -> Rep a x
from a
a)
  forwardStoch :: f -> a -> IO b
  default forwardStoch ::
    ( Generic f,
      Generic a,
      Generic b,
      GHasForward (Rep f) (Rep a) (Rep b)
    ) =>
    f ->
    a ->
    IO b
  forwardStoch f
f a
a = forall a x. Generic a => Rep a x -> a
to forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (f :: * -> *) (a :: * -> *) (b :: * -> *) c c' c''.
GHasForward f a b =>
f c -> a c' -> IO (b c)
gForwardStoch (forall a x. Generic a => a -> Rep a x
from f
f) (forall a x. Generic a => a -> Rep a x
from a
a)

class GHasForward (f :: Type -> Type) (a :: Type -> Type) (b :: Type -> Type) | f a -> b where
  gForward :: forall c c' c''. f c -> a c' -> b c''
  gForwardStoch :: forall c c' c''. f c -> a c' -> IO (b c)

instance GHasForward U1 U1 U1 where
  gForward :: forall c c' c''. U1 c -> U1 c' -> U1 c''
gForward U1 c
U1 U1 c'
U1 = forall k (p :: k). U1 p
U1
  gForwardStoch :: forall c c' c''. U1 c -> U1 c' -> IO (U1 c)
gForwardStoch U1 c
U1 U1 c'
U1 = forall (m :: * -> *) a. Monad m => a -> m a
return forall k (p :: k). U1 p
U1

instance
  ( GHasForward f a b,
    GHasForward g a' b',
    b'' ~ (b :+: b')
  ) =>
  GHasForward (f :+: g) (a :+: a') b''
  where
  gForward :: forall c c' c''. (:+:) f g c -> (:+:) a a' c' -> b'' c''
gForward (L1 f c
f) (L1 a c'
a) = forall k (f :: k -> *) (g :: k -> *) (p :: k). f p -> (:+:) f g p
L1 forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) (a :: * -> *) (b :: * -> *) c c' c''.
GHasForward f a b =>
f c -> a c' -> b c''
gForward f c
f a c'
a
  gForward (R1 g c
g) (R1 a' c'
a') = forall k (f :: k -> *) (g :: k -> *) (p :: k). g p -> (:+:) f g p
R1 forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) (a :: * -> *) (b :: * -> *) c c' c''.
GHasForward f a b =>
f c -> a c' -> b c''
gForward g c
g a' c'
a'
  gForwardStoch :: forall c c' c''. (:+:) f g c -> (:+:) a a' c' -> IO (b'' c)
gForwardStoch (L1 f c
f) (L1 a c'
a) = forall k (f :: k -> *) (g :: k -> *) (p :: k). f p -> (:+:) f g p
L1 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (f :: * -> *) (a :: * -> *) (b :: * -> *) c c' c''.
GHasForward f a b =>
f c -> a c' -> IO (b c)
gForwardStoch f c
f a c'
a
  gForwardStoch (R1 g c
g) (R1 a' c'
a') = forall k (f :: k -> *) (g :: k -> *) (p :: k). g p -> (:+:) f g p
R1 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (f :: * -> *) (a :: * -> *) (b :: * -> *) c c' c''.
GHasForward f a b =>
f c -> a c' -> IO (b c)
gForwardStoch g c
g a' c'
a'

instance
  ( GHasForward f a b,
    GHasForward g a' b',
    b'' ~ (b :*: b')
  ) =>
  GHasForward (f :*: g) (a :*: a') b''
  where
  gForward :: forall c c' c''. (:*:) f g c -> (:*:) a a' c' -> b'' c''
gForward (f c
f :*: g c
g) (a c'
a :*: a' c'
a') = forall (f :: * -> *) (a :: * -> *) (b :: * -> *) c c' c''.
GHasForward f a b =>
f c -> a c' -> b c''
gForward f c
f a c'
a forall k (f :: k -> *) (g :: k -> *) (p :: k).
f p -> g p -> (:*:) f g p
:*: forall (f :: * -> *) (a :: * -> *) (b :: * -> *) c c' c''.
GHasForward f a b =>
f c -> a c' -> b c''
gForward g c
g a' c'
a'
  gForwardStoch :: forall c c' c''. (:*:) f g c -> (:*:) a a' c' -> IO (b'' c)
gForwardStoch (f c
f :*: g c
g) (a c'
a :*: a' c'
a') = forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 forall k (f :: k -> *) (g :: k -> *) (p :: k).
f p -> g p -> (:*:) f g p
(:*:) (forall (f :: * -> *) (a :: * -> *) (b :: * -> *) c c' c''.
GHasForward f a b =>
f c -> a c' -> IO (b c)
gForwardStoch f c
f a c'
a) (forall (f :: * -> *) (a :: * -> *) (b :: * -> *) c c' c''.
GHasForward f a b =>
f c -> a c' -> IO (b c)
gForwardStoch g c
g a' c'
a')

instance
  (HasForward f a b) =>
  GHasForward (K1 i f) (K1 i a) (K1 i b)
  where
  gForward :: forall c c' c''. K1 i f c -> K1 i a c' -> K1 i b c''
gForward (K1 f
f) (K1 a
a) = forall k i c (p :: k). c -> K1 i c p
K1 forall a b. (a -> b) -> a -> b
$ forall f a b. HasForward f a b => f -> a -> b
forward f
f a
a
  gForwardStoch :: forall c c' c''. K1 i f c -> K1 i a c' -> IO (K1 i b c)
gForwardStoch (K1 f
f) (K1 a
a) = forall k i c (p :: k). c -> K1 i c p
K1 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall f a b. HasForward f a b => f -> a -> IO b
forwardStoch f
f a
a

instance
  (GHasForward f a b) =>
  GHasForward (M1 i t f) (M1 i t' a) (M1 i t' b)
  where
  gForward :: forall c c' c''. M1 i t f c -> M1 i t' a c' -> M1 i t' b c''
gForward (M1 f c
f) (M1 a c'
a) = forall k i (c :: Meta) (f :: k -> *) (p :: k). f p -> M1 i c f p
M1 forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) (a :: * -> *) (b :: * -> *) c c' c''.
GHasForward f a b =>
f c -> a c' -> b c''
gForward f c
f a c'
a
  gForwardStoch :: forall c c' c''. M1 i t f c -> M1 i t' a c' -> IO (M1 i t' b c)
gForwardStoch (M1 f c
f) (M1 a c'
a) = forall k i (c :: Meta) (f :: k -> *) (p :: k). f p -> M1 i c f p
M1 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (f :: * -> *) (a :: * -> *) (b :: * -> *) c c' c''.
GHasForward f a b =>
f c -> a c' -> IO (b c)
gForwardStoch f c
f a c'
a

class Parameterized f where
  flattenParameters :: f -> [Parameter]
  default flattenParameters :: (Generic f, GParameterized (Rep f)) => f -> [Parameter]
  flattenParameters f
f = forall (f :: * -> *) a. GParameterized f => f a -> [Parameter]
gFlattenParameters (forall a x. Generic a => a -> Rep a x
from f
f)

  _replaceParameters :: f -> ParamStream f
  default _replaceParameters :: (Generic f, GParameterized (Rep f)) => f -> ParamStream f
  _replaceParameters f
f = forall a x. Generic a => Rep a x -> a
to forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (f :: * -> *) a.
GParameterized f =>
f a -> ParamStream (f a)
_gReplaceParameters (forall a x. Generic a => a -> Rep a x
from f
f)

replaceParameters :: Parameterized f => f -> [Parameter] -> f
replaceParameters :: forall f. Parameterized f => f -> [Parameter] -> f
replaceParameters f
f [Parameter]
params =
  let (f
f', [Parameter]
remaining) = forall s a. State s a -> s -> (a, s)
runState (forall f. Parameterized f => f -> ParamStream f
_replaceParameters f
f) [Parameter]
params
   in if forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Parameter]
remaining
        then f
f'
        else forall a. HasCallStack => [Char] -> a
error [Char]
"Some parameters in a call to replaceParameters haven't been consumed!"

instance Parameterized Tensor where
  flattenParameters :: Tensor -> [Parameter]
flattenParameters Tensor
_ = []
  _replaceParameters :: Tensor -> ParamStream Tensor
_replaceParameters = forall (m :: * -> *) a. Monad m => a -> m a
return

instance Parameterized Parameter where
  flattenParameters :: Parameter -> [Parameter]
flattenParameters = forall (f :: * -> *) a. Applicative f => a -> f a
pure
  _replaceParameters :: Parameter -> ParamStream Parameter
_replaceParameters Parameter
_ = ParamStream Parameter
nextParameter

instance {-# OVERLAPS #-} (Scalar a) => Parameterized a where
  flattenParameters :: a -> [Parameter]
flattenParameters a
_ = []
  _replaceParameters :: a -> ParamStream a
_replaceParameters = forall (m :: * -> *) a. Monad m => a -> m a
return

instance {-# OVERLAPS #-} (Parameterized a, Parameterized b) => Parameterized (a, b) where
  flattenParameters :: (a, b) -> [Parameter]
flattenParameters (a
a, b
b) = forall f. Parameterized f => f -> [Parameter]
flattenParameters a
a forall a. [a] -> [a] -> [a]
++ forall f. Parameterized f => f -> [Parameter]
flattenParameters b
b
  _replaceParameters :: (a, b) -> ParamStream (a, b)
_replaceParameters (a
a, b
b) = do
    a
a' <- forall f. Parameterized f => f -> ParamStream f
_replaceParameters a
a
    b
b' <- forall f. Parameterized f => f -> ParamStream f
_replaceParameters b
b
    forall (m :: * -> *) a. Monad m => a -> m a
return (a
a', b
b')

instance {-# OVERLAPS #-} (Parameterized a, Parameterized b, Parameterized c) => Parameterized (a, b, c) where
  flattenParameters :: (a, b, c) -> [Parameter]
flattenParameters (a
a, b
b, c
c) = forall f. Parameterized f => f -> [Parameter]
flattenParameters a
a forall a. [a] -> [a] -> [a]
++ forall f. Parameterized f => f -> [Parameter]
flattenParameters b
b forall a. [a] -> [a] -> [a]
++ forall f. Parameterized f => f -> [Parameter]
flattenParameters c
c
  _replaceParameters :: (a, b, c) -> ParamStream (a, b, c)
_replaceParameters (a
a, b
b, c
c) = do
    a
a' <- forall f. Parameterized f => f -> ParamStream f
_replaceParameters a
a
    b
b' <- forall f. Parameterized f => f -> ParamStream f
_replaceParameters b
b
    c
c' <- forall f. Parameterized f => f -> ParamStream f
_replaceParameters c
c
    forall (m :: * -> *) a. Monad m => a -> m a
return (a
a', b
b', c
c')

instance {-# OVERLAPS #-} (Foldable t, Traversable t, Parameterized a) => Parameterized (t a) where
  flattenParameters :: t a -> [Parameter]
flattenParameters = forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
(=<<) forall f. Parameterized f => f -> [Parameter]
flattenParameters forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a. Foldable t => t a -> [a]
toList
  _replaceParameters :: t a -> ParamStream (t a)
_replaceParameters = forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall f. Parameterized f => f -> ParamStream f
_replaceParameters

instance Parameterized (a -> a) where
  flattenParameters :: (a -> a) -> [Parameter]
flattenParameters a -> a
_ = []
  _replaceParameters :: (a -> a) -> ParamStream (a -> a)
_replaceParameters = forall (m :: * -> *) a. Monad m => a -> m a
return

class GParameterized f where
  gFlattenParameters :: forall a. f a -> [Parameter]
  _gReplaceParameters :: forall a. f a -> ParamStream (f a)

instance GParameterized U1 where
  gFlattenParameters :: forall a. U1 a -> [Parameter]
gFlattenParameters U1 a
U1 = []
  _gReplaceParameters :: forall a. U1 a -> ParamStream (U1 a)
_gReplaceParameters U1 a
U1 = forall (m :: * -> *) a. Monad m => a -> m a
return forall k (p :: k). U1 p
U1

instance (GParameterized f, GParameterized g) => GParameterized (f :+: g) where
  gFlattenParameters :: forall a. (:+:) f g a -> [Parameter]
gFlattenParameters (L1 f a
x) = forall (f :: * -> *) a. GParameterized f => f a -> [Parameter]
gFlattenParameters f a
x
  gFlattenParameters (R1 g a
x) = forall (f :: * -> *) a. GParameterized f => f a -> [Parameter]
gFlattenParameters g a
x
  _gReplaceParameters :: forall a. (:+:) f g a -> ParamStream ((:+:) f g a)
_gReplaceParameters (L1 f a
x) = do
    f a
x' <- forall (f :: * -> *) a.
GParameterized f =>
f a -> ParamStream (f a)
_gReplaceParameters f a
x
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall k (f :: k -> *) (g :: k -> *) (p :: k). f p -> (:+:) f g p
L1 f a
x'
  _gReplaceParameters (R1 g a
x) = do
    g a
x' <- forall (f :: * -> *) a.
GParameterized f =>
f a -> ParamStream (f a)
_gReplaceParameters g a
x
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall k (f :: k -> *) (g :: k -> *) (p :: k). g p -> (:+:) f g p
R1 g a
x'

instance (GParameterized f, GParameterized g) => GParameterized (f :*: g) where
  gFlattenParameters :: forall a. (:*:) f g a -> [Parameter]
gFlattenParameters (f a
x :*: g a
y) = forall (f :: * -> *) a. GParameterized f => f a -> [Parameter]
gFlattenParameters f a
x forall a. [a] -> [a] -> [a]
++ forall (f :: * -> *) a. GParameterized f => f a -> [Parameter]
gFlattenParameters g a
y
  _gReplaceParameters :: forall a. (:*:) f g a -> ParamStream ((:*:) f g a)
_gReplaceParameters (f a
x :*: g a
y) = do
    f a
x' <- forall (f :: * -> *) a.
GParameterized f =>
f a -> ParamStream (f a)
_gReplaceParameters f a
x
    g a
y' <- forall (f :: * -> *) a.
GParameterized f =>
f a -> ParamStream (f a)
_gReplaceParameters g a
y
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ f a
x' forall k (f :: k -> *) (g :: k -> *) (p :: k).
f p -> g p -> (:*:) f g p
:*: g a
y'

instance (Parameterized c) => GParameterized (K1 i c) where
  gFlattenParameters :: forall a. K1 i c a -> [Parameter]
gFlattenParameters (K1 c
x) = forall f. Parameterized f => f -> [Parameter]
flattenParameters c
x
  _gReplaceParameters :: forall a. K1 i c a -> ParamStream (K1 i c a)
_gReplaceParameters (K1 c
x) = do
    c
x' <- forall f. Parameterized f => f -> ParamStream f
_replaceParameters c
x
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall k i c (p :: k). c -> K1 i c p
K1 c
x'

instance (GParameterized f) => GParameterized (M1 i t f) where
  gFlattenParameters :: forall a. M1 i t f a -> [Parameter]
gFlattenParameters (M1 f a
x) = forall (f :: * -> *) a. GParameterized f => f a -> [Parameter]
gFlattenParameters f a
x
  _gReplaceParameters :: forall a. M1 i t f a -> ParamStream (M1 i t f a)
_gReplaceParameters (M1 f a
x) = do
    f a
x' <- forall (f :: * -> *) a.
GParameterized f =>
f a -> ParamStream (f a)
_gReplaceParameters f a
x
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall k i (c :: Meta) (f :: k -> *) (p :: k). f p -> M1 i c f p
M1 f a
x'

class Randomizable spec f | spec -> f where
  sample :: spec -> IO f

--
-- Linear FC Layer
--

data LinearSpec = LinearSpec
  { LinearSpec -> Int
in_features :: Int,
    LinearSpec -> Int
out_features :: Int
  }
  deriving (Int -> LinearSpec -> ShowS
[LinearSpec] -> ShowS
LinearSpec -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [LinearSpec] -> ShowS
$cshowList :: [LinearSpec] -> ShowS
show :: LinearSpec -> [Char]
$cshow :: LinearSpec -> [Char]
showsPrec :: Int -> LinearSpec -> ShowS
$cshowsPrec :: Int -> LinearSpec -> ShowS
Show, LinearSpec -> LinearSpec -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: LinearSpec -> LinearSpec -> Bool
$c/= :: LinearSpec -> LinearSpec -> Bool
== :: LinearSpec -> LinearSpec -> Bool
$c== :: LinearSpec -> LinearSpec -> Bool
Eq)

data Linear = Linear
  { Linear -> Parameter
weight :: Parameter,
    Linear -> Parameter
bias :: Parameter
  }
  deriving (Int -> Linear -> ShowS
[Linear] -> ShowS
Linear -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [Linear] -> ShowS
$cshowList :: [Linear] -> ShowS
show :: Linear -> [Char]
$cshow :: Linear -> [Char]
showsPrec :: Int -> Linear -> ShowS
$cshowsPrec :: Int -> Linear -> ShowS
Show, forall x. Rep Linear x -> Linear
forall x. Linear -> Rep Linear x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep Linear x -> Linear
$cfrom :: forall x. Linear -> Rep Linear x
Generic, Linear -> [Parameter]
Linear -> ParamStream Linear
forall f.
(f -> [Parameter]) -> (f -> ParamStream f) -> Parameterized f
_replaceParameters :: Linear -> ParamStream Linear
$c_replaceParameters :: Linear -> ParamStream Linear
flattenParameters :: Linear -> [Parameter]
$cflattenParameters :: Linear -> [Parameter]
Parameterized)

linear :: Linear -> Tensor -> Tensor
linear :: Linear -> Tensor -> Tensor
linear Linear
layer Tensor
input = forall {a} {x1} {x2} {a}.
(Castable a (ForeignPtr Tensor), Castable x1 (ForeignPtr Tensor),
 Castable x2 (ForeignPtr Tensor), Castable a (ForeignPtr Tensor)) =>
a -> x1 -> x2 -> a
linear' Tensor
input Tensor
w Tensor
b
  where
    linear' :: a -> x1 -> x2 -> a
linear' a
input x1
weight x2
bias = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Tensor
-> ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.linear_ttt a
input x1
weight x2
bias
    w :: Tensor
w = Parameter -> Tensor
toDependent (Linear
layer.weight)
    b :: Tensor
b = Parameter -> Tensor
toDependent (Linear
layer.bias)

linearForward :: Linear -> Tensor -> Tensor
linearForward :: Linear -> Tensor -> Tensor
linearForward = Linear -> Tensor -> Tensor
linear -- temporary alias until dependencies are updated

instance HasForward Linear Tensor Tensor where
  forward :: Linear -> Tensor -> Tensor
forward = Linear -> Tensor -> Tensor
linearForward
  forwardStoch :: Linear -> Tensor -> IO Tensor
forwardStoch Linear
m Tensor
x = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Linear -> Tensor -> Tensor
linearForward Linear
m Tensor
x

instance Randomizable LinearSpec Linear where
  sample :: LinearSpec -> IO Linear
sample LinearSpec {Int
out_features :: Int
in_features :: Int
$sel:out_features:LinearSpec :: LinearSpec -> Int
$sel:in_features:LinearSpec :: LinearSpec -> Int
..} = do
    Parameter
w <-
      Tensor -> IO Parameter
makeIndependent
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< FanMode -> NonLinearity -> [Int] -> IO Tensor
kaimingUniform
          FanMode
FanIn
          (Float -> NonLinearity
LeakyRelu forall a b. (a -> b) -> a -> b
$ forall a. Floating a => a -> a
Prelude.sqrt (Float
5.0 :: Float))
          [Int
out_features, Int
in_features]
    Tensor
init <- [Int] -> IO Tensor
randIO' [Int
out_features]
    let bound :: Float
bound =
          (Float
1 :: Float)
            forall a. Fractional a => a -> a -> a
/ forall a. Floating a => a -> a
Prelude.sqrt
              ( forall a b. (Integral a, Num b) => a -> b
fromIntegral
                  ( FanMode -> (Int, Int) -> Int
getter FanMode
FanIn forall a b. (a -> b) -> a -> b
$
                      [Int] -> (Int, Int)
calculateFan
                        [ Int
out_features,
                          Int
in_features
                        ]
                  ) ::
                  Float
              )
    Parameter
b <-
      Tensor -> IO Parameter
makeIndependent
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (f :: * -> *) a. Applicative f => a -> f a
pure
          ( forall a. Scalar a => a -> Tensor -> Tensor
subScalar Float
bound forall a b. (a -> b) -> a -> b
$ forall a. Scalar a => a -> Tensor -> Tensor
mulScalar (Float
bound forall a. Num a => a -> a -> a
* Float
2.0) Tensor
init
          )
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Parameter -> Parameter -> Linear
Linear Parameter
w Parameter
b

--
-- Conv1d
--
data Conv1dSpec = Conv1dSpec
  { Conv1dSpec -> Int
inputChannelSize1d :: Int,
    Conv1dSpec -> Int
outputChannelSize1d :: Int,
    Conv1dSpec -> Int
kernelSize :: Int
  }
  deriving (Int -> Conv1dSpec -> ShowS
[Conv1dSpec] -> ShowS
Conv1dSpec -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [Conv1dSpec] -> ShowS
$cshowList :: [Conv1dSpec] -> ShowS
show :: Conv1dSpec -> [Char]
$cshow :: Conv1dSpec -> [Char]
showsPrec :: Int -> Conv1dSpec -> ShowS
$cshowsPrec :: Int -> Conv1dSpec -> ShowS
Show, Conv1dSpec -> Conv1dSpec -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Conv1dSpec -> Conv1dSpec -> Bool
$c/= :: Conv1dSpec -> Conv1dSpec -> Bool
== :: Conv1dSpec -> Conv1dSpec -> Bool
$c== :: Conv1dSpec -> Conv1dSpec -> Bool
Eq)

data Conv1d = Conv1d
  { Conv1d -> Parameter
weight :: Parameter,
    Conv1d -> Parameter
bias :: Parameter
  }
  deriving (Int -> Conv1d -> ShowS
[Conv1d] -> ShowS
Conv1d -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [Conv1d] -> ShowS
$cshowList :: [Conv1d] -> ShowS
show :: Conv1d -> [Char]
$cshow :: Conv1d -> [Char]
showsPrec :: Int -> Conv1d -> ShowS
$cshowsPrec :: Int -> Conv1d -> ShowS
Show, forall x. Rep Conv1d x -> Conv1d
forall x. Conv1d -> Rep Conv1d x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep Conv1d x -> Conv1d
$cfrom :: forall x. Conv1d -> Rep Conv1d x
Generic, Conv1d -> [Parameter]
Conv1d -> ParamStream Conv1d
forall f.
(f -> [Parameter]) -> (f -> ParamStream f) -> Parameterized f
_replaceParameters :: Conv1d -> ParamStream Conv1d
$c_replaceParameters :: Conv1d -> ParamStream Conv1d
flattenParameters :: Conv1d -> [Parameter]
$cflattenParameters :: Conv1d -> [Parameter]
Parameterized)

conv1dForward ::
  -- | layer
  Conv1d ->
  -- | stride
  Int ->
  -- | padding
  Int ->
  -- | input
  Tensor ->
  -- | output
  Tensor
conv1dForward :: Conv1d -> Int -> Int -> Tensor -> Tensor
conv1dForward Conv1d
layer = Tensor -> Tensor -> Int -> Int -> Tensor -> Tensor
Torch.Functional.conv1d' Tensor
w Tensor
b
  where
    w :: Tensor
w = Parameter -> Tensor
toDependent (Conv1d
layer.weight)
    b :: Tensor
b = Parameter -> Tensor
toDependent (Conv1d
layer.bias)

instance Randomizable Conv1dSpec Conv1d where
  sample :: Conv1dSpec -> IO Conv1d
sample Conv1dSpec {Int
kernelSize :: Int
outputChannelSize1d :: Int
inputChannelSize1d :: Int
$sel:kernelSize:Conv1dSpec :: Conv1dSpec -> Int
$sel:outputChannelSize1d:Conv1dSpec :: Conv1dSpec -> Int
$sel:inputChannelSize1d:Conv1dSpec :: Conv1dSpec -> Int
..} = do
    Parameter
w <-
      Tensor -> IO Parameter
makeIndependent
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< FanMode -> NonLinearity -> [Int] -> IO Tensor
kaimingUniform
          FanMode
FanIn
          (Float -> NonLinearity
LeakyRelu forall a b. (a -> b) -> a -> b
$ forall a. Floating a => a -> a
Prelude.sqrt (Float
5.0 :: Float))
          [ Int
outputChannelSize1d,
            Int
inputChannelSize1d,
            Int
kernelSize
          ]
    Tensor
init <- [Int] -> IO Tensor
randIO' [Int
outputChannelSize1d]
    let bound :: Float
bound =
          (Float
1 :: Float)
            forall a. Fractional a => a -> a -> a
/ forall a. Floating a => a -> a
Prelude.sqrt
              ( forall a b. (Integral a, Num b) => a -> b
fromIntegral
                  ( FanMode -> (Int, Int) -> Int
getter FanMode
FanIn forall a b. (a -> b) -> a -> b
$
                      [Int] -> (Int, Int)
calculateFan
                        [ Int
outputChannelSize1d,
                          Int
inputChannelSize1d,
                          Int
kernelSize
                        ]
                  ) ::
                  Float
              )
    Parameter
b <-
      Tensor -> IO Parameter
makeIndependent
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (f :: * -> *) a. Applicative f => a -> f a
pure
          ( forall a. Scalar a => a -> Tensor -> Tensor
subScalar Float
bound forall a b. (a -> b) -> a -> b
$ forall a. Scalar a => a -> Tensor -> Tensor
mulScalar (Float
bound forall a. Num a => a -> a -> a
* Float
2.0) Tensor
init
          )
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Parameter -> Parameter -> Conv1d
Conv1d Parameter
w Parameter
b

--
-- Conv2d
--

data Conv2dSpec = Conv2dSpec
  { Conv2dSpec -> Int
inputChannelSize2d :: Int,
    Conv2dSpec -> Int
outputChannelSize2d :: Int,
    Conv2dSpec -> Int
kernelHeight2d :: Int,
    Conv2dSpec -> Int
kernelWidth2d :: Int
  }
  deriving (Int -> Conv2dSpec -> ShowS
[Conv2dSpec] -> ShowS
Conv2dSpec -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [Conv2dSpec] -> ShowS
$cshowList :: [Conv2dSpec] -> ShowS
show :: Conv2dSpec -> [Char]
$cshow :: Conv2dSpec -> [Char]
showsPrec :: Int -> Conv2dSpec -> ShowS
$cshowsPrec :: Int -> Conv2dSpec -> ShowS
Show, Conv2dSpec -> Conv2dSpec -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Conv2dSpec -> Conv2dSpec -> Bool
$c/= :: Conv2dSpec -> Conv2dSpec -> Bool
== :: Conv2dSpec -> Conv2dSpec -> Bool
$c== :: Conv2dSpec -> Conv2dSpec -> Bool
Eq)

data Conv2d = Conv2d
  { Conv2d -> Parameter
weight :: Parameter,
    Conv2d -> Parameter
bias :: Parameter
  }
  deriving (Int -> Conv2d -> ShowS
[Conv2d] -> ShowS
Conv2d -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [Conv2d] -> ShowS
$cshowList :: [Conv2d] -> ShowS
show :: Conv2d -> [Char]
$cshow :: Conv2d -> [Char]
showsPrec :: Int -> Conv2d -> ShowS
$cshowsPrec :: Int -> Conv2d -> ShowS
Show, forall x. Rep Conv2d x -> Conv2d
forall x. Conv2d -> Rep Conv2d x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep Conv2d x -> Conv2d
$cfrom :: forall x. Conv2d -> Rep Conv2d x
Generic, Conv2d -> [Parameter]
Conv2d -> ParamStream Conv2d
forall f.
(f -> [Parameter]) -> (f -> ParamStream f) -> Parameterized f
_replaceParameters :: Conv2d -> ParamStream Conv2d
$c_replaceParameters :: Conv2d -> ParamStream Conv2d
flattenParameters :: Conv2d -> [Parameter]
$cflattenParameters :: Conv2d -> [Parameter]
Parameterized)

conv2dForward ::
  -- | layer
  Conv2d ->
  -- | stride
  (Int, Int) ->
  -- | padding
  (Int, Int) ->
  -- | input
  Tensor ->
  -- | output
  Tensor
conv2dForward :: Conv2d -> (Int, Int) -> (Int, Int) -> Tensor -> Tensor
conv2dForward Conv2d
layer = Tensor -> Tensor -> (Int, Int) -> (Int, Int) -> Tensor -> Tensor
Torch.Functional.conv2d' Tensor
w Tensor
b
  where
    w :: Tensor
w = Parameter -> Tensor
toDependent (Conv2d
layer.weight)
    b :: Tensor
b = Parameter -> Tensor
toDependent (Conv2d
layer.bias)

instance Randomizable Conv2dSpec Conv2d where
  sample :: Conv2dSpec -> IO Conv2d
sample Conv2dSpec {Int
kernelWidth2d :: Int
kernelHeight2d :: Int
outputChannelSize2d :: Int
inputChannelSize2d :: Int
$sel:kernelWidth2d:Conv2dSpec :: Conv2dSpec -> Int
$sel:kernelHeight2d:Conv2dSpec :: Conv2dSpec -> Int
$sel:outputChannelSize2d:Conv2dSpec :: Conv2dSpec -> Int
$sel:inputChannelSize2d:Conv2dSpec :: Conv2dSpec -> Int
..} = do
    Parameter
w <-
      Tensor -> IO Parameter
makeIndependent
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< FanMode -> NonLinearity -> [Int] -> IO Tensor
kaimingUniform
          FanMode
FanIn
          (Float -> NonLinearity
LeakyRelu forall a b. (a -> b) -> a -> b
$ forall a. Floating a => a -> a
Prelude.sqrt (Float
5.0 :: Float))
          [ Int
outputChannelSize2d,
            Int
inputChannelSize2d,
            Int
kernelHeight2d,
            Int
kernelWidth2d
          ]
    Tensor
init <- [Int] -> IO Tensor
randIO' [Int
outputChannelSize2d]
    let bound :: Float
bound =
          (Float
1 :: Float)
            forall a. Fractional a => a -> a -> a
/ forall a. Floating a => a -> a
Prelude.sqrt
              ( forall a b. (Integral a, Num b) => a -> b
fromIntegral
                  ( FanMode -> (Int, Int) -> Int
getter FanMode
FanIn forall a b. (a -> b) -> a -> b
$
                      [Int] -> (Int, Int)
calculateFan
                        [ Int
outputChannelSize2d,
                          Int
inputChannelSize2d,
                          Int
kernelHeight2d,
                          Int
kernelWidth2d
                        ]
                  ) ::
                  Float
              )
    Parameter
b <-
      Tensor -> IO Parameter
makeIndependent
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (f :: * -> *) a. Applicative f => a -> f a
pure
          ( forall a. Scalar a => a -> Tensor -> Tensor
subScalar Float
bound forall a b. (a -> b) -> a -> b
$ forall a. Scalar a => a -> Tensor -> Tensor
mulScalar (Float
bound forall a. Num a => a -> a -> a
* Float
2.0) Tensor
init
          )
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Parameter -> Parameter -> Conv2d
Conv2d Parameter
w Parameter
b

--
-- Conv3d
--

data Conv3dSpec = Conv3dSpec
  { Conv3dSpec -> Int
inputChannelSize3d :: Int,
    Conv3dSpec -> Int
outputChannelSize3d :: Int,
    Conv3dSpec -> Int
kernelHeight3d :: Int,
    Conv3dSpec -> Int
kernelWidth3d :: Int,
    Conv3dSpec -> Int
kernelDepth3d :: Int
  }
  deriving (Int -> Conv3dSpec -> ShowS
[Conv3dSpec] -> ShowS
Conv3dSpec -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [Conv3dSpec] -> ShowS
$cshowList :: [Conv3dSpec] -> ShowS
show :: Conv3dSpec -> [Char]
$cshow :: Conv3dSpec -> [Char]
showsPrec :: Int -> Conv3dSpec -> ShowS
$cshowsPrec :: Int -> Conv3dSpec -> ShowS
Show, Conv3dSpec -> Conv3dSpec -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Conv3dSpec -> Conv3dSpec -> Bool
$c/= :: Conv3dSpec -> Conv3dSpec -> Bool
== :: Conv3dSpec -> Conv3dSpec -> Bool
$c== :: Conv3dSpec -> Conv3dSpec -> Bool
Eq)

data Conv3d = Conv3d
  { Conv3d -> Parameter
weight :: Parameter,
    Conv3d -> Parameter
bias :: Parameter
  }
  deriving (Int -> Conv3d -> ShowS
[Conv3d] -> ShowS
Conv3d -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [Conv3d] -> ShowS
$cshowList :: [Conv3d] -> ShowS
show :: Conv3d -> [Char]
$cshow :: Conv3d -> [Char]
showsPrec :: Int -> Conv3d -> ShowS
$cshowsPrec :: Int -> Conv3d -> ShowS
Show, forall x. Rep Conv3d x -> Conv3d
forall x. Conv3d -> Rep Conv3d x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep Conv3d x -> Conv3d
$cfrom :: forall x. Conv3d -> Rep Conv3d x
Generic, Conv3d -> [Parameter]
Conv3d -> ParamStream Conv3d
forall f.
(f -> [Parameter]) -> (f -> ParamStream f) -> Parameterized f
_replaceParameters :: Conv3d -> ParamStream Conv3d
$c_replaceParameters :: Conv3d -> ParamStream Conv3d
flattenParameters :: Conv3d -> [Parameter]
$cflattenParameters :: Conv3d -> [Parameter]
Parameterized)

conv3dForward ::
  -- | layer
  Conv3d ->
  -- | stride
  (Int, Int, Int) ->
  -- | padding
  (Int, Int, Int) ->
  -- | input
  Tensor ->
  -- | output
  Tensor
conv3dForward :: Conv3d -> (Int, Int, Int) -> (Int, Int, Int) -> Tensor -> Tensor
conv3dForward Conv3d
layer = Tensor
-> Tensor -> (Int, Int, Int) -> (Int, Int, Int) -> Tensor -> Tensor
Torch.Functional.conv3d' Tensor
w Tensor
b
  where
    w :: Tensor
w = Parameter -> Tensor
toDependent (Conv3d
layer.weight)
    b :: Tensor
b = Parameter -> Tensor
toDependent (Conv3d
layer.bias)

instance Randomizable Conv3dSpec Conv3d where
  sample :: Conv3dSpec -> IO Conv3d
sample Conv3dSpec {Int
kernelDepth3d :: Int
kernelWidth3d :: Int
kernelHeight3d :: Int
outputChannelSize3d :: Int
inputChannelSize3d :: Int
$sel:kernelDepth3d:Conv3dSpec :: Conv3dSpec -> Int
$sel:kernelWidth3d:Conv3dSpec :: Conv3dSpec -> Int
$sel:kernelHeight3d:Conv3dSpec :: Conv3dSpec -> Int
$sel:outputChannelSize3d:Conv3dSpec :: Conv3dSpec -> Int
$sel:inputChannelSize3d:Conv3dSpec :: Conv3dSpec -> Int
..} = do
    Parameter
w <-
      Tensor -> IO Parameter
makeIndependent
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< FanMode -> NonLinearity -> [Int] -> IO Tensor
kaimingUniform
          FanMode
FanIn
          (Float -> NonLinearity
LeakyRelu forall a b. (a -> b) -> a -> b
$ forall a. Floating a => a -> a
Prelude.sqrt (Float
5.0 :: Float))
          [ Int
outputChannelSize3d,
            Int
inputChannelSize3d,
            Int
kernelHeight3d,
            Int
kernelWidth3d,
            Int
kernelDepth3d
          ]
    Tensor
init <- [Int] -> IO Tensor
randIO' [Int
outputChannelSize3d]
    let bound :: Float
bound =
          (Float
1 :: Float)
            forall a. Fractional a => a -> a -> a
/ forall a. Floating a => a -> a
Prelude.sqrt
              ( forall a b. (Integral a, Num b) => a -> b
fromIntegral
                  ( FanMode -> (Int, Int) -> Int
getter FanMode
FanIn forall a b. (a -> b) -> a -> b
$
                      [Int] -> (Int, Int)
calculateFan
                        [ Int
outputChannelSize3d,
                          Int
inputChannelSize3d,
                          Int
kernelHeight3d,
                          Int
kernelWidth3d,
                          Int
kernelDepth3d
                        ]
                  ) ::
                  Float
              )
    Parameter
b <-
      Tensor -> IO Parameter
makeIndependent
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (f :: * -> *) a. Applicative f => a -> f a
pure
          ( forall a. Scalar a => a -> Tensor -> Tensor
subScalar Float
bound forall a b. (a -> b) -> a -> b
$ forall a. Scalar a => a -> Tensor -> Tensor
mulScalar (Float
bound forall a. Num a => a -> a -> a
* Float
2.0) Tensor
init
          )
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Parameter -> Parameter -> Conv3d
Conv3d Parameter
w Parameter
b

--
-- ConvTranspose1d
--

data ConvTranspose1dSpec = ConvTranspose1dSpec
  { ConvTranspose1dSpec -> Int
trInputChannelSize1d :: Int,
    ConvTranspose1dSpec -> Int
trOutputChannelSize1d :: Int,
    ConvTranspose1dSpec -> Int
trKernelSize :: Int
  }
  deriving (Int -> ConvTranspose1dSpec -> ShowS
[ConvTranspose1dSpec] -> ShowS
ConvTranspose1dSpec -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [ConvTranspose1dSpec] -> ShowS
$cshowList :: [ConvTranspose1dSpec] -> ShowS
show :: ConvTranspose1dSpec -> [Char]
$cshow :: ConvTranspose1dSpec -> [Char]
showsPrec :: Int -> ConvTranspose1dSpec -> ShowS
$cshowsPrec :: Int -> ConvTranspose1dSpec -> ShowS
Show, ConvTranspose1dSpec -> ConvTranspose1dSpec -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ConvTranspose1dSpec -> ConvTranspose1dSpec -> Bool
$c/= :: ConvTranspose1dSpec -> ConvTranspose1dSpec -> Bool
== :: ConvTranspose1dSpec -> ConvTranspose1dSpec -> Bool
$c== :: ConvTranspose1dSpec -> ConvTranspose1dSpec -> Bool
Eq)

data ConvTranspose1d = ConvTranspose1d
  { ConvTranspose1d -> Parameter
weight :: Parameter,
    ConvTranspose1d -> Parameter
bias :: Parameter
  }
  deriving (Int -> ConvTranspose1d -> ShowS
[ConvTranspose1d] -> ShowS
ConvTranspose1d -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [ConvTranspose1d] -> ShowS
$cshowList :: [ConvTranspose1d] -> ShowS
show :: ConvTranspose1d -> [Char]
$cshow :: ConvTranspose1d -> [Char]
showsPrec :: Int -> ConvTranspose1d -> ShowS
$cshowsPrec :: Int -> ConvTranspose1d -> ShowS
Show, forall x. Rep ConvTranspose1d x -> ConvTranspose1d
forall x. ConvTranspose1d -> Rep ConvTranspose1d x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep ConvTranspose1d x -> ConvTranspose1d
$cfrom :: forall x. ConvTranspose1d -> Rep ConvTranspose1d x
Generic, ConvTranspose1d -> [Parameter]
ConvTranspose1d -> ParamStream ConvTranspose1d
forall f.
(f -> [Parameter]) -> (f -> ParamStream f) -> Parameterized f
_replaceParameters :: ConvTranspose1d -> ParamStream ConvTranspose1d
$c_replaceParameters :: ConvTranspose1d -> ParamStream ConvTranspose1d
flattenParameters :: ConvTranspose1d -> [Parameter]
$cflattenParameters :: ConvTranspose1d -> [Parameter]
Parameterized)

convTranspose1dForward ::
  -- | layer
  ConvTranspose1d ->
  -- | stride
  Int ->
  -- | padding
  Int ->
  -- | input
  Tensor ->
  -- | output
  Tensor
convTranspose1dForward :: ConvTranspose1d -> Int -> Int -> Tensor -> Tensor
convTranspose1dForward ConvTranspose1d
layer = Tensor -> Tensor -> Int -> Int -> Tensor -> Tensor
convTranspose1d' Tensor
w Tensor
b
  where
    w :: Tensor
w = Parameter -> Tensor
toDependent (ConvTranspose1d
layer.weight)
    b :: Tensor
b = Parameter -> Tensor
toDependent (ConvTranspose1d
layer.bias)

instance Randomizable ConvTranspose1dSpec ConvTranspose1d where
  sample :: ConvTranspose1dSpec -> IO ConvTranspose1d
sample ConvTranspose1dSpec {Int
trKernelSize :: Int
trOutputChannelSize1d :: Int
trInputChannelSize1d :: Int
$sel:trKernelSize:ConvTranspose1dSpec :: ConvTranspose1dSpec -> Int
$sel:trOutputChannelSize1d:ConvTranspose1dSpec :: ConvTranspose1dSpec -> Int
$sel:trInputChannelSize1d:ConvTranspose1dSpec :: ConvTranspose1dSpec -> Int
..} = do
    Parameter
w <-
      Tensor -> IO Parameter
makeIndependent
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< FanMode -> NonLinearity -> [Int] -> IO Tensor
kaimingUniform
          FanMode
FanIn
          (Float -> NonLinearity
LeakyRelu forall a b. (a -> b) -> a -> b
$ forall a. Floating a => a -> a
Prelude.sqrt (Float
5.0 :: Float))
          [ Int
trInputChannelSize1d,
            Int
trOutputChannelSize1d,
            Int
trKernelSize
          ]
    Tensor
init <- [Int] -> IO Tensor
randIO' [Int
trOutputChannelSize1d]
    let bound :: Float
bound =
          (Float
1 :: Float)
            forall a. Fractional a => a -> a -> a
/ forall a. Floating a => a -> a
Prelude.sqrt
              ( forall a b. (Integral a, Num b) => a -> b
fromIntegral
                  ( FanMode -> (Int, Int) -> Int
getter FanMode
FanIn forall a b. (a -> b) -> a -> b
$
                      [Int] -> (Int, Int)
calculateFan
                        [ Int
trInputChannelSize1d,
                          Int
trOutputChannelSize1d,
                          Int
trKernelSize
                        ]
                  ) ::
                  Float
              )
    Parameter
b <-
      Tensor -> IO Parameter
makeIndependent
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (f :: * -> *) a. Applicative f => a -> f a
pure
          ( forall a. Scalar a => a -> Tensor -> Tensor
subScalar Float
bound forall a b. (a -> b) -> a -> b
$ forall a. Scalar a => a -> Tensor -> Tensor
mulScalar (Float
bound forall a. Num a => a -> a -> a
* Float
2.0) Tensor
init
          )
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Parameter -> Parameter -> ConvTranspose1d
ConvTranspose1d Parameter
w Parameter
b

--
-- ConvTranspose2d
--

data ConvTranspose2dSpec = ConvTranspose2dSpec
  { ConvTranspose2dSpec -> Int
trInputChannelSize2d :: Int,
    ConvTranspose2dSpec -> Int
trOutputChannelSize2d :: Int,
    ConvTranspose2dSpec -> Int
trKernelHeight2d :: Int,
    ConvTranspose2dSpec -> Int
trKernelWidth2d :: Int
  }
  deriving (Int -> ConvTranspose2dSpec -> ShowS
[ConvTranspose2dSpec] -> ShowS
ConvTranspose2dSpec -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [ConvTranspose2dSpec] -> ShowS
$cshowList :: [ConvTranspose2dSpec] -> ShowS
show :: ConvTranspose2dSpec -> [Char]
$cshow :: ConvTranspose2dSpec -> [Char]
showsPrec :: Int -> ConvTranspose2dSpec -> ShowS
$cshowsPrec :: Int -> ConvTranspose2dSpec -> ShowS
Show, ConvTranspose2dSpec -> ConvTranspose2dSpec -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ConvTranspose2dSpec -> ConvTranspose2dSpec -> Bool
$c/= :: ConvTranspose2dSpec -> ConvTranspose2dSpec -> Bool
== :: ConvTranspose2dSpec -> ConvTranspose2dSpec -> Bool
$c== :: ConvTranspose2dSpec -> ConvTranspose2dSpec -> Bool
Eq)

data ConvTranspose2d = ConvTranspose2d
  { ConvTranspose2d -> Parameter
weight :: Parameter,
    ConvTranspose2d -> Parameter
bias :: Parameter
  }
  deriving (Int -> ConvTranspose2d -> ShowS
[ConvTranspose2d] -> ShowS
ConvTranspose2d -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [ConvTranspose2d] -> ShowS
$cshowList :: [ConvTranspose2d] -> ShowS
show :: ConvTranspose2d -> [Char]
$cshow :: ConvTranspose2d -> [Char]
showsPrec :: Int -> ConvTranspose2d -> ShowS
$cshowsPrec :: Int -> ConvTranspose2d -> ShowS
Show, forall x. Rep ConvTranspose2d x -> ConvTranspose2d
forall x. ConvTranspose2d -> Rep ConvTranspose2d x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep ConvTranspose2d x -> ConvTranspose2d
$cfrom :: forall x. ConvTranspose2d -> Rep ConvTranspose2d x
Generic, ConvTranspose2d -> [Parameter]
ConvTranspose2d -> ParamStream ConvTranspose2d
forall f.
(f -> [Parameter]) -> (f -> ParamStream f) -> Parameterized f
_replaceParameters :: ConvTranspose2d -> ParamStream ConvTranspose2d
$c_replaceParameters :: ConvTranspose2d -> ParamStream ConvTranspose2d
flattenParameters :: ConvTranspose2d -> [Parameter]
$cflattenParameters :: ConvTranspose2d -> [Parameter]
Parameterized)

convTranspose2dForward ::
  -- | layer
  ConvTranspose2d ->
  -- | stride
  (Int, Int) ->
  -- | padding
  (Int, Int) ->
  -- | input
  Tensor ->
  -- | output
  Tensor
convTranspose2dForward :: ConvTranspose2d -> (Int, Int) -> (Int, Int) -> Tensor -> Tensor
convTranspose2dForward ConvTranspose2d
layer = Tensor -> Tensor -> (Int, Int) -> (Int, Int) -> Tensor -> Tensor
convTranspose2d' Tensor
w Tensor
b
  where
    w :: Tensor
w = Parameter -> Tensor
toDependent (ConvTranspose2d
layer.weight)
    b :: Tensor
b = Parameter -> Tensor
toDependent (ConvTranspose2d
layer.bias)

instance Randomizable ConvTranspose2dSpec ConvTranspose2d where
  sample :: ConvTranspose2dSpec -> IO ConvTranspose2d
sample ConvTranspose2dSpec {Int
trKernelWidth2d :: Int
trKernelHeight2d :: Int
trOutputChannelSize2d :: Int
trInputChannelSize2d :: Int
$sel:trKernelWidth2d:ConvTranspose2dSpec :: ConvTranspose2dSpec -> Int
$sel:trKernelHeight2d:ConvTranspose2dSpec :: ConvTranspose2dSpec -> Int
$sel:trOutputChannelSize2d:ConvTranspose2dSpec :: ConvTranspose2dSpec -> Int
$sel:trInputChannelSize2d:ConvTranspose2dSpec :: ConvTranspose2dSpec -> Int
..} = do
    Parameter
w <-
      Tensor -> IO Parameter
makeIndependent
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< FanMode -> NonLinearity -> [Int] -> IO Tensor
kaimingUniform
          FanMode
FanIn
          (Float -> NonLinearity
LeakyRelu forall a b. (a -> b) -> a -> b
$ forall a. Floating a => a -> a
Prelude.sqrt (Float
5.0 :: Float))
          [ Int
trInputChannelSize2d,
            Int
trOutputChannelSize2d,
            Int
trKernelHeight2d,
            Int
trKernelWidth2d
          ]
    Tensor
init <- [Int] -> IO Tensor
randIO' [Int
trOutputChannelSize2d]
    let bound :: Float
bound =
          (Float
1 :: Float)
            forall a. Fractional a => a -> a -> a
/ forall a. Floating a => a -> a
Prelude.sqrt
              ( forall a b. (Integral a, Num b) => a -> b
fromIntegral
                  ( FanMode -> (Int, Int) -> Int
getter FanMode
FanIn forall a b. (a -> b) -> a -> b
$
                      [Int] -> (Int, Int)
calculateFan
                        [ Int
trInputChannelSize2d,
                          Int
trOutputChannelSize2d,
                          Int
trKernelHeight2d,
                          Int
trKernelWidth2d
                        ]
                  ) ::
                  Float
              )
    Parameter
b <-
      Tensor -> IO Parameter
makeIndependent
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (f :: * -> *) a. Applicative f => a -> f a
pure
          ( forall a. Scalar a => a -> Tensor -> Tensor
subScalar Float
bound forall a b. (a -> b) -> a -> b
$ forall a. Scalar a => a -> Tensor -> Tensor
mulScalar (Float
bound forall a. Num a => a -> a -> a
* Float
2.0) Tensor
init
          )
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Parameter -> Parameter -> ConvTranspose2d
ConvTranspose2d Parameter
w Parameter
b

--
-- ConvTranspose2d
--

data ConvTranspose3dSpec = ConvTranspose3dSpec
  { ConvTranspose3dSpec -> Int
trInputChannelSize3d :: Int,
    ConvTranspose3dSpec -> Int
trOutputChannelSize3d :: Int,
    ConvTranspose3dSpec -> Int
trKernelHeight3d :: Int,
    ConvTranspose3dSpec -> Int
trKernelWidth3d :: Int,
    ConvTranspose3dSpec -> Int
trKernelDepth3d :: Int
  }
  deriving (Int -> ConvTranspose3dSpec -> ShowS
[ConvTranspose3dSpec] -> ShowS
ConvTranspose3dSpec -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [ConvTranspose3dSpec] -> ShowS
$cshowList :: [ConvTranspose3dSpec] -> ShowS
show :: ConvTranspose3dSpec -> [Char]
$cshow :: ConvTranspose3dSpec -> [Char]
showsPrec :: Int -> ConvTranspose3dSpec -> ShowS
$cshowsPrec :: Int -> ConvTranspose3dSpec -> ShowS
Show, ConvTranspose3dSpec -> ConvTranspose3dSpec -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ConvTranspose3dSpec -> ConvTranspose3dSpec -> Bool
$c/= :: ConvTranspose3dSpec -> ConvTranspose3dSpec -> Bool
== :: ConvTranspose3dSpec -> ConvTranspose3dSpec -> Bool
$c== :: ConvTranspose3dSpec -> ConvTranspose3dSpec -> Bool
Eq)

data ConvTranspose3d = ConvTranspose3d
  { ConvTranspose3d -> Parameter
weight :: Parameter,
    ConvTranspose3d -> Parameter
bias :: Parameter
  }
  deriving (Int -> ConvTranspose3d -> ShowS
[ConvTranspose3d] -> ShowS
ConvTranspose3d -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [ConvTranspose3d] -> ShowS
$cshowList :: [ConvTranspose3d] -> ShowS
show :: ConvTranspose3d -> [Char]
$cshow :: ConvTranspose3d -> [Char]
showsPrec :: Int -> ConvTranspose3d -> ShowS
$cshowsPrec :: Int -> ConvTranspose3d -> ShowS
Show, forall x. Rep ConvTranspose3d x -> ConvTranspose3d
forall x. ConvTranspose3d -> Rep ConvTranspose3d x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep ConvTranspose3d x -> ConvTranspose3d
$cfrom :: forall x. ConvTranspose3d -> Rep ConvTranspose3d x
Generic, ConvTranspose3d -> [Parameter]
ConvTranspose3d -> ParamStream ConvTranspose3d
forall f.
(f -> [Parameter]) -> (f -> ParamStream f) -> Parameterized f
_replaceParameters :: ConvTranspose3d -> ParamStream ConvTranspose3d
$c_replaceParameters :: ConvTranspose3d -> ParamStream ConvTranspose3d
flattenParameters :: ConvTranspose3d -> [Parameter]
$cflattenParameters :: ConvTranspose3d -> [Parameter]
Parameterized)

convTranspose3dForward ::
  -- | layer
  ConvTranspose3d ->
  -- | stride
  (Int, Int, Int) ->
  -- | padding
  (Int, Int, Int) ->
  -- | input
  Tensor ->
  -- | output
  Tensor
convTranspose3dForward :: ConvTranspose3d
-> (Int, Int, Int) -> (Int, Int, Int) -> Tensor -> Tensor
convTranspose3dForward ConvTranspose3d
layer = Tensor
-> Tensor -> (Int, Int, Int) -> (Int, Int, Int) -> Tensor -> Tensor
convTranspose3d' Tensor
w Tensor
b
  where
    w :: Tensor
w = Parameter -> Tensor
toDependent (ConvTranspose3d
layer.weight)
    b :: Tensor
b = Parameter -> Tensor
toDependent (ConvTranspose3d
layer.bias)

instance Randomizable ConvTranspose3dSpec ConvTranspose3d where
  sample :: ConvTranspose3dSpec -> IO ConvTranspose3d
sample ConvTranspose3dSpec {Int
trKernelDepth3d :: Int
trKernelWidth3d :: Int
trKernelHeight3d :: Int
trOutputChannelSize3d :: Int
trInputChannelSize3d :: Int
$sel:trKernelDepth3d:ConvTranspose3dSpec :: ConvTranspose3dSpec -> Int
$sel:trKernelWidth3d:ConvTranspose3dSpec :: ConvTranspose3dSpec -> Int
$sel:trKernelHeight3d:ConvTranspose3dSpec :: ConvTranspose3dSpec -> Int
$sel:trOutputChannelSize3d:ConvTranspose3dSpec :: ConvTranspose3dSpec -> Int
$sel:trInputChannelSize3d:ConvTranspose3dSpec :: ConvTranspose3dSpec -> Int
..} = do
    Parameter
w <-
      Tensor -> IO Parameter
makeIndependent
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< FanMode -> NonLinearity -> [Int] -> IO Tensor
kaimingUniform
          FanMode
FanIn
          (Float -> NonLinearity
LeakyRelu forall a b. (a -> b) -> a -> b
$ forall a. Floating a => a -> a
Prelude.sqrt (Float
5.0 :: Float))
          [ Int
trInputChannelSize3d,
            Int
trOutputChannelSize3d,
            Int
trKernelHeight3d,
            Int
trKernelWidth3d,
            Int
trKernelDepth3d
          ]
    Tensor
init <- [Int] -> IO Tensor
randIO' [Int
trOutputChannelSize3d]
    let bound :: Float
bound =
          (Float
1 :: Float)
            forall a. Fractional a => a -> a -> a
/ forall a. Floating a => a -> a
Prelude.sqrt
              ( forall a b. (Integral a, Num b) => a -> b
fromIntegral
                  ( FanMode -> (Int, Int) -> Int
getter FanMode
FanIn forall a b. (a -> b) -> a -> b
$
                      [Int] -> (Int, Int)
calculateFan
                        [ Int
trInputChannelSize3d,
                          Int
trOutputChannelSize3d,
                          Int
trKernelHeight3d,
                          Int
trKernelWidth3d,
                          Int
trKernelDepth3d
                        ]
                  ) ::
                  Float
              )
    Parameter
b <-
      Tensor -> IO Parameter
makeIndependent
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (f :: * -> *) a. Applicative f => a -> f a
pure
          ( forall a. Scalar a => a -> Tensor -> Tensor
subScalar Float
bound forall a b. (a -> b) -> a -> b
$ forall a. Scalar a => a -> Tensor -> Tensor
mulScalar (Float
bound forall a. Num a => a -> a -> a
* Float
2.0) Tensor
init
          )
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Parameter -> Parameter -> ConvTranspose3d
ConvTranspose3d Parameter
w Parameter
b

data BatchNormSpec = BatchNormSpec
  { BatchNormSpec -> Int
numFeatures :: Int
  }
  deriving (Int -> BatchNormSpec -> ShowS
[BatchNormSpec] -> ShowS
BatchNormSpec -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [BatchNormSpec] -> ShowS
$cshowList :: [BatchNormSpec] -> ShowS
show :: BatchNormSpec -> [Char]
$cshow :: BatchNormSpec -> [Char]
showsPrec :: Int -> BatchNormSpec -> ShowS
$cshowsPrec :: Int -> BatchNormSpec -> ShowS
Show, BatchNormSpec -> BatchNormSpec -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: BatchNormSpec -> BatchNormSpec -> Bool
$c/= :: BatchNormSpec -> BatchNormSpec -> Bool
== :: BatchNormSpec -> BatchNormSpec -> Bool
$c== :: BatchNormSpec -> BatchNormSpec -> Bool
Eq)

data BatchNorm = BatchNorm
  { BatchNorm -> Parameter
weight :: Parameter,
    BatchNorm -> Parameter
bias :: Parameter,
    BatchNorm -> MutableTensor
runningMean :: MutableTensor,
    BatchNorm -> MutableTensor
runningVar :: MutableTensor
  }
  deriving (Int -> BatchNorm -> ShowS
[BatchNorm] -> ShowS
BatchNorm -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [BatchNorm] -> ShowS
$cshowList :: [BatchNorm] -> ShowS
show :: BatchNorm -> [Char]
$cshow :: BatchNorm -> [Char]
showsPrec :: Int -> BatchNorm -> ShowS
$cshowsPrec :: Int -> BatchNorm -> ShowS
Show, forall x. Rep BatchNorm x -> BatchNorm
forall x. BatchNorm -> Rep BatchNorm x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep BatchNorm x -> BatchNorm
$cfrom :: forall x. BatchNorm -> Rep BatchNorm x
Generic)

batchNormForwardIO :: BatchNorm -> Bool -> Double -> Double -> Tensor -> IO Tensor
batchNormForwardIO :: BatchNorm -> Bool -> Double -> Double -> Tensor -> IO Tensor
batchNormForwardIO BatchNorm
params Bool
train Double
momentum Double
eps Tensor
input =
  Tensor
-> Tensor
-> MutableTensor
-> MutableTensor
-> Bool
-> Double
-> Double
-> Tensor
-> IO Tensor
Torch.Functional.batchNormIO
    (Parameter -> Tensor
toDependent BatchNorm
params.weight)
    (Parameter -> Tensor
toDependent BatchNorm
params.bias)
    BatchNorm
params.runningMean
    BatchNorm
params.runningVar
    Bool
train
    Double
momentum
    Double
eps
    Tensor
input

instance Randomizable BatchNormSpec BatchNorm where
  sample :: BatchNormSpec -> IO BatchNorm
sample BatchNormSpec {Int
numFeatures :: Int
$sel:numFeatures:BatchNormSpec :: BatchNormSpec -> Int
..} = do
    Parameter
w <- Tensor -> IO Parameter
makeIndependent ([Int] -> Tensor
ones' [Int
numFeatures])
    Parameter
b <- Tensor -> IO Parameter
makeIndependent ([Int] -> Tensor
zeros' [Int
numFeatures])
    MutableTensor
mean <- Tensor -> MutableTensor
MutableTensor forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Parameter -> Tensor
toDependent forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Tensor -> Bool -> IO Parameter
makeIndependentWithRequiresGrad ([Int] -> Tensor
zeros' [Int
numFeatures]) Bool
False
    MutableTensor
var <- Tensor -> MutableTensor
MutableTensor forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Parameter -> Tensor
toDependent forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Tensor -> Bool -> IO Parameter
makeIndependentWithRequiresGrad ([Int] -> Tensor
ones' [Int
numFeatures]) Bool
False
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Parameter
-> Parameter -> MutableTensor -> MutableTensor -> BatchNorm
BatchNorm Parameter
w Parameter
b MutableTensor
mean MutableTensor
var

data InstanceNormSpec = InstanceNormSpec
  { InstanceNormSpec -> Int
numFeatures :: Int
  }
  deriving (Int -> InstanceNormSpec -> ShowS
[InstanceNormSpec] -> ShowS
InstanceNormSpec -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [InstanceNormSpec] -> ShowS
$cshowList :: [InstanceNormSpec] -> ShowS
show :: InstanceNormSpec -> [Char]
$cshow :: InstanceNormSpec -> [Char]
showsPrec :: Int -> InstanceNormSpec -> ShowS
$cshowsPrec :: Int -> InstanceNormSpec -> ShowS
Show, InstanceNormSpec -> InstanceNormSpec -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: InstanceNormSpec -> InstanceNormSpec -> Bool
$c/= :: InstanceNormSpec -> InstanceNormSpec -> Bool
== :: InstanceNormSpec -> InstanceNormSpec -> Bool
$c== :: InstanceNormSpec -> InstanceNormSpec -> Bool
Eq)

data InstanceNorm = InstanceNorm
  { InstanceNorm -> Parameter
weight :: Parameter,
    InstanceNorm -> Parameter
bias :: Parameter,
    InstanceNorm -> MutableTensor
runningMean :: MutableTensor,
    InstanceNorm -> MutableTensor
runningVar :: MutableTensor
  }
  deriving (Int -> InstanceNorm -> ShowS
[InstanceNorm] -> ShowS
InstanceNorm -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [InstanceNorm] -> ShowS
$cshowList :: [InstanceNorm] -> ShowS
show :: InstanceNorm -> [Char]
$cshow :: InstanceNorm -> [Char]
showsPrec :: Int -> InstanceNorm -> ShowS
$cshowsPrec :: Int -> InstanceNorm -> ShowS
Show, forall x. Rep InstanceNorm x -> InstanceNorm
forall x. InstanceNorm -> Rep InstanceNorm x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep InstanceNorm x -> InstanceNorm
$cfrom :: forall x. InstanceNorm -> Rep InstanceNorm x
Generic)

instanceNormForwardIO :: InstanceNorm -> Bool -> Double -> Double -> Tensor -> IO Tensor
instanceNormForwardIO :: InstanceNorm -> Bool -> Double -> Double -> Tensor -> IO Tensor
instanceNormForwardIO InstanceNorm
params Bool
train Double
momentum Double
eps Tensor
input =
  Tensor
-> Tensor
-> MutableTensor
-> MutableTensor
-> Bool
-> Double
-> Double
-> Tensor
-> IO Tensor
Torch.Functional.instanceNormIO
    (Parameter -> Tensor
toDependent InstanceNorm
params.weight)
    (Parameter -> Tensor
toDependent InstanceNorm
params.bias)
    InstanceNorm
params.runningMean
    InstanceNorm
params.runningVar
    Bool
train
    Double
momentum
    Double
eps
    Tensor
input

instance Randomizable InstanceNormSpec InstanceNorm where
  sample :: InstanceNormSpec -> IO InstanceNorm
sample InstanceNormSpec {Int
numFeatures :: Int
$sel:numFeatures:InstanceNormSpec :: InstanceNormSpec -> Int
..} = do
    Parameter
w <- Tensor -> IO Parameter
makeIndependent ([Int] -> Tensor
ones' [Int
numFeatures])
    Parameter
b <- Tensor -> IO Parameter
makeIndependent ([Int] -> Tensor
zeros' [Int
numFeatures])
    MutableTensor
mean <- Tensor -> MutableTensor
MutableTensor forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Parameter -> Tensor
toDependent forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Tensor -> Bool -> IO Parameter
makeIndependentWithRequiresGrad ([Int] -> Tensor
zeros' [Int
numFeatures]) Bool
False
    MutableTensor
var <- Tensor -> MutableTensor
MutableTensor forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Parameter -> Tensor
toDependent forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Tensor -> Bool -> IO Parameter
makeIndependentWithRequiresGrad ([Int] -> Tensor
ones' [Int
numFeatures]) Bool
False
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Parameter
-> Parameter -> MutableTensor -> MutableTensor -> InstanceNorm
InstanceNorm Parameter
w Parameter
b MutableTensor
mean MutableTensor
var

data UpSampleSpec = UpSampleSpec
  { UpSampleSpec -> Int
upsampleInputFilters :: Int,
    UpSampleSpec -> Int
upsampleStride :: Int
  }
  deriving (Int -> UpSampleSpec -> ShowS
[UpSampleSpec] -> ShowS
UpSampleSpec -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [UpSampleSpec] -> ShowS
$cshowList :: [UpSampleSpec] -> ShowS
show :: UpSampleSpec -> [Char]
$cshow :: UpSampleSpec -> [Char]
showsPrec :: Int -> UpSampleSpec -> ShowS
$cshowsPrec :: Int -> UpSampleSpec -> ShowS
Show, UpSampleSpec -> UpSampleSpec -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: UpSampleSpec -> UpSampleSpec -> Bool
$c/= :: UpSampleSpec -> UpSampleSpec -> Bool
== :: UpSampleSpec -> UpSampleSpec -> Bool
$c== :: UpSampleSpec -> UpSampleSpec -> Bool
Eq)

instance Parameterized UpSampleSpec where
  flattenParameters :: UpSampleSpec -> [Parameter]
flattenParameters UpSampleSpec
_ = []
  _replaceParameters :: UpSampleSpec -> ParamStream UpSampleSpec
_replaceParameters = forall (m :: * -> *) a. Monad m => a -> m a
return

data UpSample = UpSample
  { UpSample -> UpSampleSpec
upsampleSpec :: UpSampleSpec
  }
  deriving (Int -> UpSample -> ShowS
[UpSample] -> ShowS
UpSample -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [UpSample] -> ShowS
$cshowList :: [UpSample] -> ShowS
show :: UpSample -> [Char]
$cshow :: UpSample -> [Char]
showsPrec :: Int -> UpSample -> ShowS
$cshowsPrec :: Int -> UpSample -> ShowS
Show, forall x. Rep UpSample x -> UpSample
forall x. UpSample -> Rep UpSample x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep UpSample x -> UpSample
$cfrom :: forall x. UpSample -> Rep UpSample x
Generic, UpSample -> [Parameter]
UpSample -> ParamStream UpSample
forall f.
(f -> [Parameter]) -> (f -> ParamStream f) -> Parameterized f
_replaceParameters :: UpSample -> ParamStream UpSample
$c_replaceParameters :: UpSample -> ParamStream UpSample
flattenParameters :: UpSample -> [Parameter]
$cflattenParameters :: UpSample -> [Parameter]
Parameterized)

instance Randomizable UpSampleSpec UpSample where
  sample :: UpSampleSpec -> IO UpSample
sample UpSampleSpec
s = do
    UpSampleSpec -> UpSample
UpSample
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (f :: * -> *) a. Applicative f => a -> f a
pure UpSampleSpec
s

instance HasForward UpSample Tensor Tensor where
  forward :: UpSample -> Tensor -> Tensor
forward (UpSample (UpSampleSpec {Int
upsampleStride :: Int
upsampleInputFilters :: Int
$sel:upsampleStride:UpSampleSpec :: UpSampleSpec -> Int
$sel:upsampleInputFilters:UpSampleSpec :: UpSampleSpec -> Int
..})) Tensor
input =
    (Int, Int) -> Double -> Double -> Tensor -> Tensor
upsampleNearest2d (Int
outputWidth forall a. Num a => a -> a -> a
* Int
upsampleStride, Int
outputHeight forall a. Num a => a -> a -> a
* Int
upsampleStride) (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
upsampleStride) (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
upsampleStride) Tensor
input
    where
      Int
outputWidth : Int
outputHeight : [Int]
_ = forall a. [a] -> [a]
reverse forall a b. (a -> b) -> a -> b
$ Tensor -> [Int]
shape Tensor
input
  forwardStoch :: UpSample -> Tensor -> IO Tensor
forwardStoch UpSample
m Tensor
x = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall f a b. HasForward f a b => f -> a -> b
forward UpSample
m Tensor
x