{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}

module Torch.Vision where

import qualified Codec.Picture as I
import Control.Exception.Safe
  ( SomeException (..),
    throwIO,
    try,
  )
import Control.Monad
  ( MonadPlus,
    forM_,
    when,
  )
import qualified Data.ByteString as BS
import qualified Data.ByteString.Internal as BSI
import Data.Int
import qualified Data.Vector.Storable as V
import Data.Word
import qualified Foreign.ForeignPtr as F
import qualified Foreign.Ptr as F
import GHC.Exts (IsList (fromList))
import qualified Language.C.Inline as C
import Pipes
import System.IO.Unsafe
import System.Random (mkStdGen, randoms)
import qualified Torch.DType as D
import Torch.Data.Pipeline
import Torch.Data.StreamedPipeline
import Torch.Functional hiding (take)
import qualified Torch.Functional as D
import Torch.Internal.Cast
import qualified Torch.Internal.Managed.TensorFactories as LibTorch
import Torch.NN
import Torch.Tensor
import qualified Torch.Tensor as D
import qualified Torch.TensorOptions as D
import qualified Torch.Typed.Vision as I
import Prelude hiding (max, min)
import qualified Prelude as P

C.include "<stdint.h>"

data MNIST (m :: * -> *) = MNIST
  { forall (m :: * -> *). MNIST m -> Int
batchSize :: Int,
    forall (m :: * -> *). MNIST m -> MnistData
mnistData :: I.MnistData
  }

instance Monad m => Datastream m Int (MNIST m) (Tensor, Tensor) where
  streamSamples :: MNIST m -> Int -> ListT m (Tensor, Tensor)
streamSamples MNIST {Int
MnistData
mnistData :: MnistData
batchSize :: Int
mnistData :: forall (m :: * -> *). MNIST m -> MnistData
batchSize :: forall (m :: * -> *). MNIST m -> Int
..} Int
seed = forall (m :: * -> *) a. Producer a m () -> ListT m a
Select forall a b. (a -> b) -> a -> b
$
    forall (m :: * -> *) x' x b' b a' c' c.
Functor m =>
Proxy x' x b' b m a'
-> (b -> Proxy x' x c' c m b') -> Proxy x' x c' c m a'
for (forall (m :: * -> *) (f :: * -> *) a x' x.
(Functor m, Foldable f) =>
f a -> Proxy x' x () a m ()
each [Int
1 .. Int
numIters]) forall a b. (a -> b) -> a -> b
$
      \Int
iter -> do
        let from :: Int
from = (Int
iter forall a. Num a => a -> a -> a
-Int
1) forall a. Num a => a -> a -> a
* Int
batchSize
            to :: Int
to = (Int
iter forall a. Num a => a -> a -> a
* Int
batchSize) forall a. Num a => a -> a -> a
- Int
1
            indexes :: [Int]
indexes = [Int
from .. Int
to]
            target :: Tensor
target = Int -> MnistData -> [Int] -> Tensor
getLabels' Int
batchSize MnistData
mnistData [Int]
indexes
        let input :: Tensor
input = Int -> Int -> MnistData -> [Int] -> Tensor
getImages' Int
batchSize Int
784 MnistData
mnistData [Int]
indexes
        forall (m :: * -> *) a x' x. Functor m => a -> Proxy x' x () a m ()
yield (Tensor
input, Tensor
target)
    where
      numIters :: Int
numIters = MnistData -> Int
I.length MnistData
mnistData forall a. Integral a => a -> a -> a
`Prelude.div` Int
batchSize

instance Applicative m => Dataset m (MNIST m) Int (Tensor, Tensor) where
  getItem :: MNIST m -> Int -> m (Tensor, Tensor)
getItem MNIST {Int
MnistData
mnistData :: MnistData
batchSize :: Int
mnistData :: forall (m :: * -> *). MNIST m -> MnistData
batchSize :: forall (m :: * -> *). MNIST m -> Int
..} Int
ix =
    let indexes :: [Int]
indexes = [Int
ix forall a. Num a => a -> a -> a
* Int
batchSize .. (Int
ix forall a. Num a => a -> a -> a
+ Int
1) forall a. Num a => a -> a -> a
* Int
batchSize forall a. Num a => a -> a -> a
- Int
1]
        imgs :: Tensor
imgs = Int -> Int -> MnistData -> [Int] -> Tensor
getImages' Int
batchSize Int
784 MnistData
mnistData [Int]
indexes
        labels :: Tensor
labels = Int -> MnistData -> [Int] -> Tensor
getLabels' Int
batchSize MnistData
mnistData [Int]
indexes
     in forall (f :: * -> *) a. Applicative f => a -> f a
pure (Tensor
imgs, Tensor
labels)

  keys :: MNIST m -> Set Int
keys MNIST {Int
MnistData
mnistData :: MnistData
batchSize :: Int
mnistData :: forall (m :: * -> *). MNIST m -> MnistData
batchSize :: forall (m :: * -> *). MNIST m -> Int
..} = forall l. IsList l => [Item l] -> l
fromList [Int
0 .. MnistData -> Int
I.length MnistData
mnistData forall a. Integral a => a -> a -> a
`Prelude.div` Int
batchSize forall a. Num a => a -> a -> a
- Int
1]

getLabels' :: Int -> I.MnistData -> [Int] -> Tensor
getLabels' :: Int -> MnistData -> [Int] -> Tensor
getLabels' Int
n MnistData
mnist [Int]
imageIdxs =
  forall a. TensorLike a => a -> Tensor
asTensor forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (MnistData -> Int -> Int
I.getLabel MnistData
mnist) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Int -> [a] -> [a]
take Int
n forall a b. (a -> b) -> a -> b
$ [Int]
imageIdxs

getImages' ::
  Int -> -- number of observations in minibatch
  Int -> -- dimensionality of the data
  I.MnistData -> -- mnist data representation
  [Int] -> -- indices of the dataset
  Tensor
getImages' :: Int -> Int -> MnistData -> [Int] -> Tensor
getImages' Int
n Int
dataDim MnistData
mnist [Int]
imageIdxs = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
  let (BSI.PS ForeignPtr Pixel8
fptr Int
off Int
len) = MnistData -> ByteString
I.images MnistData
mnist
  Tensor
t <-
    (forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr IntArray
-> ForeignPtr TensorOptions -> IO (ForeignPtr Tensor)
LibTorch.empty_lo :: [Int] -> D.TensorOptions -> IO D.Tensor)
      [Int
n, Int
dataDim]
      (DType -> TensorOptions -> TensorOptions
D.withDType DType
D.UInt8 TensorOptions
D.defaultOpts)
  forall a. Tensor -> (Ptr () -> IO a) -> IO a
D.withTensor Tensor
t forall a b. (a -> b) -> a -> b
$ \Ptr ()
ptr1 -> do
    forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
F.withForeignPtr ForeignPtr Pixel8
fptr forall a b. (a -> b) -> a -> b
$ \Ptr Pixel8
ptr2 -> do
      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0 .. (Int
n forall a. Num a => a -> a -> a
-Int
1)] [Int]
imageIdxs) forall a b. (a -> b) -> a -> b
$ \(Int
i, Int
idx) -> do
        Ptr Pixel8 -> Ptr Pixel8 -> Int -> IO ()
BSI.memcpy
          (forall a b. Ptr a -> Int -> Ptr b
F.plusPtr Ptr ()
ptr1 (Int
dataDim forall a. Num a => a -> a -> a
* Int
i))
          (forall a b. Ptr a -> Int -> Ptr b
F.plusPtr Ptr Pixel8
ptr2 (Int
off forall a. Num a => a -> a -> a
+ Int
16 forall a. Num a => a -> a -> a
+ Int
dataDim forall a. Num a => a -> a -> a
* Int
idx))
          Int
dataDim
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a. HasTypes a Tensor => DType -> a -> a
D.toType DType
D.Float Tensor
t

-- http://paulbourke.net/dataformats/asciiart/
grayScale10 :: String
grayScale10 = String
" .:-=+*#%@"

grayScale70 :: String
grayScale70 = forall a. [a] -> [a]
reverse String
"$@B%8&WM#*oahkbdpqwmZO0QLCJUYXzcvunxrjft/\\|()1{}[]?-_+~<>i!lI;:,\"^`'. "

-- Display an MNIST image tensor as ascii text
dispImage :: Tensor -> IO ()
dispImage :: Tensor -> IO ()
dispImage Tensor
img = do
  forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM
    ( \Int
row ->
        forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM
          ( \Int
col ->
              Char -> IO ()
putChar forall a b. (a -> b) -> a -> b
$ String
grayScale forall a. [a] -> Int -> a
!! (forall a b. (RealFrac a, Integral b) => a -> b
P.floor forall a b. (a -> b) -> a -> b
$ [[Float]]
scaled forall a. [a] -> Int -> a
!! Int
row forall a. [a] -> Int -> a
!! Int
col)
          )
          [Int
0, Int
downSamp .. Int
27]
          forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> String -> IO ()
putStrLn String
""
    )
    [Int
0, Int
downSamp .. Int
27]
  forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  where
    downSamp :: Int
downSamp = Int
2
    grayScale :: String
grayScale = String
grayScale10
    paletteMax :: Tensor
paletteMax = (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Int
length String
grayScale) forall a. Num a => a -> a -> a
- Tensor
1.0
    img' :: Tensor
img' = [Int] -> Tensor -> Tensor
reshape [Int
28, Int
28] Tensor
img
    [[Float]]
scaled :: [[Float]] =
      let (Tensor
mn, Tensor
mx) = (Tensor -> Tensor
min Tensor
img', Tensor -> Tensor
max Tensor
img')
       in forall a. TensorLike a => Tensor -> a
asValue forall a b. (a -> b) -> a -> b
$ (Tensor
img' forall a. Num a => a -> a -> a
- Tensor
mn) forall a. Fractional a => a -> a -> a
/ (Tensor
mx forall a. Num a => a -> a -> a
- Tensor
mn) forall a. Num a => a -> a -> a
* Tensor
paletteMax

data PixelFormat
  = Y8
  | YF
  | YA8
  | RGB8
  | RGBF
  | RGBA8
  | YCbCr8
  | CMYK8
  | CMYK16
  | RGBA16
  | RGB16
  | Y16
  | YA16
  | Y32
  deriving (Int -> PixelFormat -> ShowS
[PixelFormat] -> ShowS
PixelFormat -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [PixelFormat] -> ShowS
$cshowList :: [PixelFormat] -> ShowS
show :: PixelFormat -> String
$cshow :: PixelFormat -> String
showsPrec :: Int -> PixelFormat -> ShowS
$cshowsPrec :: Int -> PixelFormat -> ShowS
Show, PixelFormat -> PixelFormat -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: PixelFormat -> PixelFormat -> Bool
$c/= :: PixelFormat -> PixelFormat -> Bool
== :: PixelFormat -> PixelFormat -> Bool
$c== :: PixelFormat -> PixelFormat -> Bool
Eq)

readImage :: FilePath -> IO (Either String (D.Tensor, PixelFormat))
readImage :: String -> IO (Either String (Tensor, PixelFormat))
readImage String
file =
  String -> IO (Either String DynamicImage)
I.readImage String
file forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Left String
err -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a b. a -> Either a b
Left String
err
    Right DynamicImage
img' -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a b. b -> Either a b
Right forall a b. (a -> b) -> a -> b
$ (DynamicImage -> Tensor
fromDynImage DynamicImage
img', DynamicImage -> PixelFormat
pixelFormat DynamicImage
img')

readImageAsRGB8 :: FilePath -> IO (Either String D.Tensor)
readImageAsRGB8 :: String -> IO (Either String Tensor)
readImageAsRGB8 String
file =
  String -> IO (Either String DynamicImage)
I.readImage String
file forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Left String
err -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a b. a -> Either a b
Left String
err
    Right DynamicImage
img' -> forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. b -> Either a b
Right forall b c a. (b -> c) -> (a -> b) -> a -> c
. DynamicImage -> Tensor
fromDynImage forall b c a. (b -> c) -> (a -> b) -> a -> c
. Image PixelRGB8 -> DynamicImage
I.ImageRGB8 forall b c a. (b -> c) -> (a -> b) -> a -> c
. DynamicImage -> Image PixelRGB8
I.convertRGB8 forall a b. (a -> b) -> a -> b
$ DynamicImage
img'

readImageAsRGB8WithScaling :: FilePath -> Int -> Int -> Bool -> IO (Either String (I.Image I.PixelRGB8, D.Tensor))
readImageAsRGB8WithScaling :: String
-> Int
-> Int
-> Bool
-> IO (Either String (Image PixelRGB8, Tensor))
readImageAsRGB8WithScaling String
file Int
width Int
height Bool
keepAspectRatio =
  String -> IO (Either String DynamicImage)
I.readImage String
file forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Left String
err -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a b. a -> Either a b
Left String
err
    Right DynamicImage
img' -> do
      let img :: Image PixelRGB8
img = (Int -> Int -> Bool -> Image PixelRGB8 -> Image PixelRGB8
resizeRGB8 Int
width Int
height Bool
keepAspectRatio) forall b c a. (b -> c) -> (a -> b) -> a -> c
. DynamicImage -> Image PixelRGB8
I.convertRGB8 forall a b. (a -> b) -> a -> b
$ DynamicImage
img'
      forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a b. b -> Either a b
Right (Image PixelRGB8
img, DynamicImage -> Tensor
fromDynImage forall b c a. (b -> c) -> (a -> b) -> a -> c
. Image PixelRGB8 -> DynamicImage
I.ImageRGB8 forall a b. (a -> b) -> a -> b
$ Image PixelRGB8
img)

centerCrop :: Int -> Int -> I.Image I.PixelRGB8 -> I.Image I.PixelRGB8
centerCrop :: Int -> Int -> Image PixelRGB8 -> Image PixelRGB8
centerCrop Int
width Int
height Image PixelRGB8
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
  let channel :: Int
channel = Int
3 :: Int
      (I.Image Int
org_w Int
org_h Vector (PixelBaseComponent PixelRGB8)
org_vec) = Image PixelRGB8
input
      img :: Image PixelRGB8
img@(I.Image Int
w Int
h Vector (PixelBaseComponent PixelRGB8)
vec) = forall px. Pixel px => (Int -> Int -> px) -> Int -> Int -> Image px
I.generateImage (\Int
_ Int
_ -> (Pixel8 -> Pixel8 -> Pixel8 -> PixelRGB8
I.PixelRGB8 Pixel8
0 Pixel8
0 Pixel8
0)) Int
width Int
height :: I.Image I.PixelRGB8
      (ForeignPtr Pixel8
org_fptr, Int
org_len) = forall a. Storable a => Vector a -> (ForeignPtr a, Int)
V.unsafeToForeignPtr0 Vector Pixel8
org_vec
      org_whc :: Integer
org_whc = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ Int
org_w forall a. Num a => a -> a -> a
* Int
org_h forall a. Num a => a -> a -> a
* Int
channel
      (ForeignPtr Pixel8
fptr, Int
len) = forall a. Storable a => Vector a -> (ForeignPtr a, Int)
V.unsafeToForeignPtr0 Vector Pixel8
vec
      whc :: Integer
whc = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ Int
w forall a. Num a => a -> a -> a
* Int
h forall a. Num a => a -> a -> a
* Int
channel
  forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
F.withForeignPtr ForeignPtr Pixel8
org_fptr forall a b. (a -> b) -> a -> b
$ \Ptr Pixel8
ptr1 -> forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
F.withForeignPtr ForeignPtr Pixel8
fptr forall a b. (a -> b) -> a -> b
$ \Ptr Pixel8
ptr2 -> do
    let src :: Ptr b
src = forall a b. Ptr a -> Ptr b
F.castPtr Ptr Pixel8
ptr1
        dst :: Ptr b
dst = forall a b. Ptr a -> Ptr b
F.castPtr Ptr Pixel8
ptr2
        iw :: CInt
iw = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
w
        ih :: CInt
ih = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
h
        iorg_w :: CInt
iorg_w = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
org_w
        iorg_h :: CInt
iorg_h = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
org_h
        ichannel :: CInt
ichannel = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
channel
    [C.block| void {
        uint8_t* src = $(uint8_t* src);
        uint8_t* dst = $(uint8_t* dst);
        int w = $(int iw);
        int h = $(int ih);
        int channel = $(int ichannel);
        int ow = $(int iorg_w);
        int oh = $(int iorg_h);
        int offsetx = (ow - w)/2;
        int offsety = (oh - h)/2;
        for(int y=0;y<h;y++){
          for(int x=0;x<w;x++){
            for(int c=0;c<channel;c++){
              int sy = y + offsety;
              int sx = x + offsetx;
              if(sx >= 0 && sx < ow &&
                 sy >= 0 && sy < oh){
                 dst[(y*w+x)*channel+c] = src[(sy*ow+sx)*channel+c];
              }
            }
          }
        }
    } |]
    forall (m :: * -> *) a. Monad m => a -> m a
return Image PixelRGB8
img

drawLine :: Int -> Int -> Int -> Int -> (Int, Int, Int) -> I.Image I.PixelRGB8 -> IO ()
drawLine :: Int
-> Int -> Int -> Int -> (Int, Int, Int) -> Image PixelRGB8 -> IO ()
drawLine Int
x0 Int
y0 Int
x1 Int
y1 (Int
r, Int
g, Int
b) Image PixelRGB8
input = do
  let img :: Image PixelRGB8
img@(I.Image Int
w Int
h Vector (PixelBaseComponent PixelRGB8)
vec) = Image PixelRGB8
input
      (ForeignPtr Pixel8
fptr, Int
len) = forall a. Storable a => Vector a -> (ForeignPtr a, Int)
V.unsafeToForeignPtr0 Vector Pixel8
vec
  forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
F.withForeignPtr ForeignPtr Pixel8
fptr forall a b. (a -> b) -> a -> b
$ \Ptr Pixel8
ptr2 -> do
    let iw :: CInt
iw = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
w
        ih :: CInt
ih = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
h
        ix0 :: CInt
ix0 = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
x0
        iy0 :: CInt
iy0 = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
y0
        ix1 :: CInt
ix1 = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
x1
        iy1 :: CInt
iy1 = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
y1
        ir :: CInt
ir = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
r
        ig :: CInt
ig = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
g
        ib :: CInt
ib = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
b
        dst :: Ptr b
dst = forall a b. Ptr a -> Ptr b
F.castPtr Ptr Pixel8
ptr2
    [C.block| void {
        uint8_t* dst = $(uint8_t* dst);
        int w = $(int iw);
        int h = $(int ih);
        int x0 = $(int ix0);
        int y0 = $(int iy0);
        int x1 = $(int ix1);
        int y1 = $(int iy1);
        int r = $(int ir);
        int g = $(int ig);
        int b = $(int ib);
        int channel = 3;
        int sign_x =  x1 - x0 >= 0 ? 1 : -1;
        int sign_y =  y1 - y0 >= 0 ? 1 : -1;
        int abs_x =  x1 - x0 >= 0 ? x1 - x0 : x0 - x1;
        int abs_y =  y1 - y0 >= 0 ? y1 - y0 : y0 - y1;
        if(abs_x>=abs_y){
          for(int x=x0;x!=x1;x+=sign_x){
            int y = (x-x0) * (y1-y0) / (x1-x0) + y0;
            if(y >=0 && y < h &&
               x >=0 && x < w) {
              dst[(y*w+x)*channel+0] = r;
              dst[(y*w+x)*channel+1] = g;
              dst[(y*w+x)*channel+2] = b;
            }
          }
        } else {
          for(int y=y0;y!=y1;y+=sign_y){
            int x = (y-y0) * (x1-x0) / (y1-y0) + x0;
            if(y >=0 && y < h &&
               x >=0 && x < w) {
              dst[(y*w+x)*channel+0] = r;
              dst[(y*w+x)*channel+1] = g;
              dst[(y*w+x)*channel+2] = b;
            }
          }
        }
    } |]

drawRect :: Int -> Int -> Int -> Int -> (Int, Int, Int) -> I.Image I.PixelRGB8 -> IO ()
drawRect :: Int
-> Int -> Int -> Int -> (Int, Int, Int) -> Image PixelRGB8 -> IO ()
drawRect Int
x0 Int
y0 Int
x1 Int
y1 (Int
r, Int
g, Int
b) Image PixelRGB8
input = do
  Int
-> Int -> Int -> Int -> (Int, Int, Int) -> Image PixelRGB8 -> IO ()
drawLine Int
x0 Int
y0 (Int
x1 forall a. Num a => a -> a -> a
+ Int
1) Int
y0 (Int
r, Int
g, Int
b) Image PixelRGB8
input
  Int
-> Int -> Int -> Int -> (Int, Int, Int) -> Image PixelRGB8 -> IO ()
drawLine Int
x0 Int
y0 Int
x0 (Int
y1 forall a. Num a => a -> a -> a
+ Int
1) (Int
r, Int
g, Int
b) Image PixelRGB8
input
  Int
-> Int -> Int -> Int -> (Int, Int, Int) -> Image PixelRGB8 -> IO ()
drawLine Int
x0 Int
y1 (Int
x1 forall a. Num a => a -> a -> a
+ Int
1) Int
y1 (Int
r, Int
g, Int
b) Image PixelRGB8
input
  Int
-> Int -> Int -> Int -> (Int, Int, Int) -> Image PixelRGB8 -> IO ()
drawLine Int
x1 Int
y0 Int
x1 (Int
y1 forall a. Num a => a -> a -> a
+ Int
1) (Int
r, Int
g, Int
b) Image PixelRGB8
input

drawString :: String -> Int -> Int -> (Int, Int, Int) -> (Int, Int, Int) -> I.Image I.PixelRGB8 -> IO ()
drawString :: String
-> Int
-> Int
-> (Int, Int, Int)
-> (Int, Int, Int)
-> Image PixelRGB8
-> IO ()
drawString String
text Int
x0 Int
y0 (Int
r, Int
g, Int
b) (Int
br, Int
bg, Int
bb) Image PixelRGB8
input = do
  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0 ..] String
text) forall a b. (a -> b) -> a -> b
$ \(Int
i, Char
ch) -> do
    Int
-> Int
-> Int
-> (Int, Int, Int)
-> (Int, Int, Int)
-> Image PixelRGB8
-> IO ()
drawChar (forall a. Enum a => a -> Int
fromEnum Char
ch) (Int
x0 forall a. Num a => a -> a -> a
+ Int
i forall a. Num a => a -> a -> a
* Int
8) Int
y0 (Int
r, Int
g, Int
b) (Int
br, Int
bg, Int
bb) Image PixelRGB8
input

drawChar :: Int -> Int -> Int -> (Int, Int, Int) -> (Int, Int, Int) -> I.Image I.PixelRGB8 -> IO ()
drawChar :: Int
-> Int
-> Int
-> (Int, Int, Int)
-> (Int, Int, Int)
-> Image PixelRGB8
-> IO ()
drawChar Int
ascii_code Int
x0 Int
y0 (Int
r, Int
g, Int
b) (Int
br, Int
bg, Int
bb) Image PixelRGB8
input = do
  let img :: Image PixelRGB8
img@(I.Image Int
w Int
h Vector (PixelBaseComponent PixelRGB8)
vec) = Image PixelRGB8
input
      (ForeignPtr Pixel8
fptr, Int
len) = forall a. Storable a => Vector a -> (ForeignPtr a, Int)
V.unsafeToForeignPtr0 Vector Pixel8
vec
  forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
F.withForeignPtr ForeignPtr Pixel8
fptr forall a b. (a -> b) -> a -> b
$ \Ptr Pixel8
ptr2 -> do
    let iw :: CInt
iw = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
w
        ih :: CInt
ih = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
h
        ix0 :: CInt
ix0 = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
x0
        iy0 :: CInt
iy0 = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
y0
        ir :: CInt
ir = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
r
        ig :: CInt
ig = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
g
        ib :: CInt
ib = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
b
        ibr :: CInt
ibr = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
br
        ibg :: CInt
ibg = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
bg
        ibb :: CInt
ibb = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
bb
        dst :: Ptr b
dst = forall a b. Ptr a -> Ptr b
F.castPtr Ptr Pixel8
ptr2
        iascii_code :: CInt
iascii_code = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
ascii_code
    [C.block| void {
        uint8_t* dst = $(uint8_t* dst);
        int w = $(int iw);
        int h = $(int ih);
        int x0 = $(int ix0);
        int y0 = $(int iy0);
        int r = $(int ir);
        int g = $(int ig);
        int b = $(int ib);
        int br = $(int ibr);
        int bg = $(int ibg);
        int bb = $(int ibb);
        int ascii_code = $(int iascii_code);
        int channel = 3;
        int char_width = 8;
        int char_height = 8;
        char fonts[95][8] = { // 0x20 to 0x7e
            { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
            { 0x18, 0x3C, 0x3C, 0x18, 0x18, 0x00, 0x18, 0x00},
            { 0x36, 0x36, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
            { 0x36, 0x36, 0x7F, 0x36, 0x7F, 0x36, 0x36, 0x00},
            { 0x0C, 0x3E, 0x03, 0x1E, 0x30, 0x1F, 0x0C, 0x00},
            { 0x00, 0x63, 0x33, 0x18, 0x0C, 0x66, 0x63, 0x00},
            { 0x1C, 0x36, 0x1C, 0x6E, 0x3B, 0x33, 0x6E, 0x00},
            { 0x06, 0x06, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00},
            { 0x18, 0x0C, 0x06, 0x06, 0x06, 0x0C, 0x18, 0x00},
            { 0x06, 0x0C, 0x18, 0x18, 0x18, 0x0C, 0x06, 0x00},
            { 0x00, 0x66, 0x3C, 0xFF, 0x3C, 0x66, 0x00, 0x00},
            { 0x00, 0x0C, 0x0C, 0x3F, 0x0C, 0x0C, 0x00, 0x00},
            { 0x00, 0x00, 0x00, 0x00, 0x00, 0x0C, 0x0C, 0x06},
            { 0x00, 0x00, 0x00, 0x3F, 0x00, 0x00, 0x00, 0x00},
            { 0x00, 0x00, 0x00, 0x00, 0x00, 0x0C, 0x0C, 0x00},
            { 0x60, 0x30, 0x18, 0x0C, 0x06, 0x03, 0x01, 0x00},
            { 0x3E, 0x63, 0x73, 0x7B, 0x6F, 0x67, 0x3E, 0x00},
            { 0x0C, 0x0E, 0x0C, 0x0C, 0x0C, 0x0C, 0x3F, 0x00},
            { 0x1E, 0x33, 0x30, 0x1C, 0x06, 0x33, 0x3F, 0x00},
            { 0x1E, 0x33, 0x30, 0x1C, 0x30, 0x33, 0x1E, 0x00},
            { 0x38, 0x3C, 0x36, 0x33, 0x7F, 0x30, 0x78, 0x00},
            { 0x3F, 0x03, 0x1F, 0x30, 0x30, 0x33, 0x1E, 0x00},
            { 0x1C, 0x06, 0x03, 0x1F, 0x33, 0x33, 0x1E, 0x00},
            { 0x3F, 0x33, 0x30, 0x18, 0x0C, 0x0C, 0x0C, 0x00},
            { 0x1E, 0x33, 0x33, 0x1E, 0x33, 0x33, 0x1E, 0x00},
            { 0x1E, 0x33, 0x33, 0x3E, 0x30, 0x18, 0x0E, 0x00},
            { 0x00, 0x0C, 0x0C, 0x00, 0x00, 0x0C, 0x0C, 0x00},
            { 0x00, 0x0C, 0x0C, 0x00, 0x00, 0x0C, 0x0C, 0x06},
            { 0x18, 0x0C, 0x06, 0x03, 0x06, 0x0C, 0x18, 0x00},
            { 0x00, 0x00, 0x3F, 0x00, 0x00, 0x3F, 0x00, 0x00},
            { 0x06, 0x0C, 0x18, 0x30, 0x18, 0x0C, 0x06, 0x00},
            { 0x1E, 0x33, 0x30, 0x18, 0x0C, 0x00, 0x0C, 0x00},
            { 0x3E, 0x63, 0x7B, 0x7B, 0x7B, 0x03, 0x1E, 0x00},
            { 0x0C, 0x1E, 0x33, 0x33, 0x3F, 0x33, 0x33, 0x00},
            { 0x3F, 0x66, 0x66, 0x3E, 0x66, 0x66, 0x3F, 0x00},
            { 0x3C, 0x66, 0x03, 0x03, 0x03, 0x66, 0x3C, 0x00},
            { 0x1F, 0x36, 0x66, 0x66, 0x66, 0x36, 0x1F, 0x00},
            { 0x7F, 0x46, 0x16, 0x1E, 0x16, 0x46, 0x7F, 0x00},
            { 0x7F, 0x46, 0x16, 0x1E, 0x16, 0x06, 0x0F, 0x00},
            { 0x3C, 0x66, 0x03, 0x03, 0x73, 0x66, 0x7C, 0x00},
            { 0x33, 0x33, 0x33, 0x3F, 0x33, 0x33, 0x33, 0x00},
            { 0x1E, 0x0C, 0x0C, 0x0C, 0x0C, 0x0C, 0x1E, 0x00},
            { 0x78, 0x30, 0x30, 0x30, 0x33, 0x33, 0x1E, 0x00},
            { 0x67, 0x66, 0x36, 0x1E, 0x36, 0x66, 0x67, 0x00},
            { 0x0F, 0x06, 0x06, 0x06, 0x46, 0x66, 0x7F, 0x00},
            { 0x63, 0x77, 0x7F, 0x7F, 0x6B, 0x63, 0x63, 0x00},
            { 0x63, 0x67, 0x6F, 0x7B, 0x73, 0x63, 0x63, 0x00},
            { 0x1C, 0x36, 0x63, 0x63, 0x63, 0x36, 0x1C, 0x00},
            { 0x3F, 0x66, 0x66, 0x3E, 0x06, 0x06, 0x0F, 0x00},
            { 0x1E, 0x33, 0x33, 0x33, 0x3B, 0x1E, 0x38, 0x00},
            { 0x3F, 0x66, 0x66, 0x3E, 0x36, 0x66, 0x67, 0x00},
            { 0x1E, 0x33, 0x07, 0x0E, 0x38, 0x33, 0x1E, 0x00},
            { 0x3F, 0x2D, 0x0C, 0x0C, 0x0C, 0x0C, 0x1E, 0x00},
            { 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, 0x3F, 0x00},
            { 0x33, 0x33, 0x33, 0x33, 0x33, 0x1E, 0x0C, 0x00},
            { 0x63, 0x63, 0x63, 0x6B, 0x7F, 0x77, 0x63, 0x00},
            { 0x63, 0x63, 0x36, 0x1C, 0x1C, 0x36, 0x63, 0x00},
            { 0x33, 0x33, 0x33, 0x1E, 0x0C, 0x0C, 0x1E, 0x00},
            { 0x7F, 0x63, 0x31, 0x18, 0x4C, 0x66, 0x7F, 0x00},
            { 0x1E, 0x06, 0x06, 0x06, 0x06, 0x06, 0x1E, 0x00},
            { 0x03, 0x06, 0x0C, 0x18, 0x30, 0x60, 0x40, 0x00},
            { 0x1E, 0x18, 0x18, 0x18, 0x18, 0x18, 0x1E, 0x00},
            { 0x08, 0x1C, 0x36, 0x63, 0x00, 0x00, 0x00, 0x00},
            { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF},
            { 0x0C, 0x0C, 0x18, 0x00, 0x00, 0x00, 0x00, 0x00},
            { 0x00, 0x00, 0x1E, 0x30, 0x3E, 0x33, 0x6E, 0x00},
            { 0x07, 0x06, 0x06, 0x3E, 0x66, 0x66, 0x3B, 0x00},
            { 0x00, 0x00, 0x1E, 0x33, 0x03, 0x33, 0x1E, 0x00},
            { 0x38, 0x30, 0x30, 0x3e, 0x33, 0x33, 0x6E, 0x00},
            { 0x00, 0x00, 0x1E, 0x33, 0x3f, 0x03, 0x1E, 0x00},
            { 0x1C, 0x36, 0x06, 0x0f, 0x06, 0x06, 0x0F, 0x00},
            { 0x00, 0x00, 0x6E, 0x33, 0x33, 0x3E, 0x30, 0x1F},
            { 0x07, 0x06, 0x36, 0x6E, 0x66, 0x66, 0x67, 0x00},
            { 0x0C, 0x00, 0x0E, 0x0C, 0x0C, 0x0C, 0x1E, 0x00},
            { 0x30, 0x00, 0x30, 0x30, 0x30, 0x33, 0x33, 0x1E},
            { 0x07, 0x06, 0x66, 0x36, 0x1E, 0x36, 0x67, 0x00},
            { 0x0E, 0x0C, 0x0C, 0x0C, 0x0C, 0x0C, 0x1E, 0x00},
            { 0x00, 0x00, 0x33, 0x7F, 0x7F, 0x6B, 0x63, 0x00},
            { 0x00, 0x00, 0x1F, 0x33, 0x33, 0x33, 0x33, 0x00},
            { 0x00, 0x00, 0x1E, 0x33, 0x33, 0x33, 0x1E, 0x00},
            { 0x00, 0x00, 0x3B, 0x66, 0x66, 0x3E, 0x06, 0x0F},
            { 0x00, 0x00, 0x6E, 0x33, 0x33, 0x3E, 0x30, 0x78},
            { 0x00, 0x00, 0x3B, 0x6E, 0x66, 0x06, 0x0F, 0x00},
            { 0x00, 0x00, 0x3E, 0x03, 0x1E, 0x30, 0x1F, 0x00},
            { 0x08, 0x0C, 0x3E, 0x0C, 0x0C, 0x2C, 0x18, 0x00},
            { 0x00, 0x00, 0x33, 0x33, 0x33, 0x33, 0x6E, 0x00},
            { 0x00, 0x00, 0x33, 0x33, 0x33, 0x1E, 0x0C, 0x00},
            { 0x00, 0x00, 0x63, 0x6B, 0x7F, 0x7F, 0x36, 0x00},
            { 0x00, 0x00, 0x63, 0x36, 0x1C, 0x36, 0x63, 0x00},
            { 0x00, 0x00, 0x33, 0x33, 0x33, 0x3E, 0x30, 0x1F},
            { 0x00, 0x00, 0x3F, 0x19, 0x0C, 0x26, 0x3F, 0x00},
            { 0x38, 0x0C, 0x0C, 0x07, 0x0C, 0x0C, 0x38, 0x00},
            { 0x18, 0x18, 0x18, 0x00, 0x18, 0x18, 0x18, 0x00},
            { 0x07, 0x0C, 0x0C, 0x38, 0x0C, 0x0C, 0x07, 0x00},
            { 0x6E, 0x3B, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00} 
          };
        for(int y=y0;y<y0+char_height;y++){
          for(int x=x0;x<x0+char_width;x++){
            if(y >=0 && y < h &&
               x >=0 && x < w) {
              int dx = x-x0;
              int dy = y-y0;
              int bit = 
                ascii_code > 0x20 && ascii_code < 0x7f ?
                fonts[ascii_code-0x20][dy] & (0x1 << dx) :
                0;
              if (bit) {
                dst[(y*w+x)*channel+0] = r;
                dst[(y*w+x)*channel+1] = g;
                dst[(y*w+x)*channel+2] = b;
              } else {
                dst[(y*w+x)*channel+0] = br;
                dst[(y*w+x)*channel+1] = bg;
                dst[(y*w+x)*channel+2] = bb;
              }
            }
          }
        }
    } |]

resizeRGB8 :: Int -> Int -> Bool -> I.Image I.PixelRGB8 -> I.Image I.PixelRGB8
resizeRGB8 :: Int -> Int -> Bool -> Image PixelRGB8 -> Image PixelRGB8
resizeRGB8 Int
width Int
height Bool
keepAspectRatio Image PixelRGB8
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
  let channel :: Int
channel = Int
3 :: Int
      (I.Image Int
org_w Int
org_h Vector (PixelBaseComponent PixelRGB8)
org_vec) = Image PixelRGB8
input
      img :: Image PixelRGB8
img@(I.Image Int
w Int
h Vector (PixelBaseComponent PixelRGB8)
vec) = forall px. Pixel px => (Int -> Int -> px) -> Int -> Int -> Image px
I.generateImage (\Int
_ Int
_ -> (Pixel8 -> Pixel8 -> Pixel8 -> PixelRGB8
I.PixelRGB8 Pixel8
0 Pixel8
0 Pixel8
0)) Int
width Int
height :: I.Image I.PixelRGB8
      (ForeignPtr Pixel8
org_fptr, Int
org_len) = forall a. Storable a => Vector a -> (ForeignPtr a, Int)
V.unsafeToForeignPtr0 Vector Pixel8
org_vec
      org_whc :: Integer
org_whc = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ Int
org_w forall a. Num a => a -> a -> a
* Int
org_h forall a. Num a => a -> a -> a
* Int
channel
      (ForeignPtr Pixel8
fptr, Int
len) = forall a. Storable a => Vector a -> (ForeignPtr a, Int)
V.unsafeToForeignPtr0 Vector Pixel8
vec
      whc :: Integer
whc = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ Int
w forall a. Num a => a -> a -> a
* Int
h forall a. Num a => a -> a -> a
* Int
channel
  forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
F.withForeignPtr ForeignPtr Pixel8
org_fptr forall a b. (a -> b) -> a -> b
$ \Ptr Pixel8
ptr1 -> forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
F.withForeignPtr ForeignPtr Pixel8
fptr forall a b. (a -> b) -> a -> b
$ \Ptr Pixel8
ptr2 -> do
    let src :: Ptr b
src = forall a b. Ptr a -> Ptr b
F.castPtr Ptr Pixel8
ptr1
        dst :: Ptr b
dst = forall a b. Ptr a -> Ptr b
F.castPtr Ptr Pixel8
ptr2
        iw :: CInt
iw = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
w
        ih :: CInt
ih = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
h
        iorg_w :: CInt
iorg_w = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
org_w
        iorg_h :: CInt
iorg_h = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
org_h
        ichannel :: CInt
ichannel = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
channel
        ckeepAspectRatio :: CInt
ckeepAspectRatio = if Bool
keepAspectRatio then CInt
1 else CInt
0
    [C.block| void {
        uint8_t* src = $(uint8_t* src);
        uint8_t* dst = $(uint8_t* dst);
        int w = $(int iw);
        int h = $(int ih);
        int channel = $(int ichannel);
        int ow = $(int iorg_w);
        int oh = $(int iorg_h);
        int keepAspectRatio = $(int ckeepAspectRatio);
        if(keepAspectRatio){
          int t0h = h;
          int t0w = ow * h / oh;
          int t1h = oh * w / ow;
          int t1w = w;
          if (t0w > w) {
            int offset = (h - (oh * w / ow))/2;
            for(int y=offset;y<h-offset;y++){
              for(int x=0;x<w;x++){
                for(int c=0;c<channel;c++){
                  int sy = (y-offset) * ow / w;
                  int sx = x * ow / w;
                  if(sy >= 0 && sy < oh){
                    dst[(y*w+x)*channel+c] = src[(sy*ow+sx)*channel+c];
                  }
                }
              }
            }
          } else {
            int offset = (w - (ow * h / oh))/2;
            for(int y=0;y<h;y++){
              for(int x=offset;x<w-offset;x++){
                for(int c=0;c<channel;c++){
                  int sy = y * oh / h;
                  int sx = (x-offset) * oh / h;
                  if(sx >= 0 && sx < ow){
                    dst[(y*w+x)*channel+c] = src[(sy*ow+sx)*channel+c];
                  }
                }
              }
            }
          }
        } else {
          for(int y=0;y<h;y++){
            for(int x=0;x<w;x++){
              for(int c=0;c<channel;c++){
                int sy = y * oh / h;
                int sx = x * ow / w;
                dst[(y*w+x)*channel+c] = src[(sy*ow+sx)*channel+c];
              }
            }
          }
        }
    } |]
    forall (m :: * -> *) a. Monad m => a -> m a
return Image PixelRGB8
img

pixelFormat :: I.DynamicImage -> PixelFormat
pixelFormat :: DynamicImage -> PixelFormat
pixelFormat DynamicImage
image = case DynamicImage
image of
  I.ImageY8 Image Pixel8
_ -> PixelFormat
Y8
  I.ImageYF Image Float
_ -> PixelFormat
YF
  I.ImageYA8 Image PixelYA8
_ -> PixelFormat
YA8
  I.ImageRGB8 Image PixelRGB8
_ -> PixelFormat
RGB8
  I.ImageRGBF Image PixelRGBF
_ -> PixelFormat
RGBF
  I.ImageRGBA8 Image PixelRGBA8
_ -> PixelFormat
RGBA8
  I.ImageYCbCr8 Image PixelYCbCr8
_ -> PixelFormat
YCbCr8
  I.ImageCMYK8 Image PixelCMYK8
_ -> PixelFormat
CMYK8
  I.ImageCMYK16 Image PixelCMYK16
_ -> PixelFormat
CMYK16
  I.ImageRGBA16 Image PixelRGBA16
_ -> PixelFormat
RGBA16
  I.ImageRGB16 Image PixelRGB16
_ -> PixelFormat
RGB16
  I.ImageY16 Image Word16
_ -> PixelFormat
Y16
  I.ImageYA16 Image PixelYA16
_ -> PixelFormat
YA16
  I.ImageY32 Image Word32
_ -> PixelFormat
Y32

fromDynImage :: I.DynamicImage -> D.Tensor
fromDynImage :: DynamicImage -> Tensor
fromDynImage DynamicImage
image = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ case DynamicImage
image of
  I.ImageY8 (I.Image Int
width Int
height Vector (PixelBaseComponent Pixel8)
vec) -> forall {a}.
Storable a =>
Int -> Int -> Int -> DType -> Int -> Vector a -> IO Tensor
createTensor Int
width Int
height Int
1 DType
D.UInt8 Int
1 Vector (PixelBaseComponent Pixel8)
vec
  I.ImageYF (I.Image Int
width Int
height Vector (PixelBaseComponent Float)
vec) -> forall {a}.
Storable a =>
Int -> Int -> Int -> DType -> Int -> Vector a -> IO Tensor
createTensor Int
width Int
height Int
1 DType
D.Float Int
4 Vector (PixelBaseComponent Float)
vec
  I.ImageYA8 (I.Image Int
width Int
height Vector (PixelBaseComponent PixelYA8)
vec) -> forall {a}.
Storable a =>
Int -> Int -> Int -> DType -> Int -> Vector a -> IO Tensor
createTensor Int
width Int
height Int
2 DType
D.UInt8 Int
1 Vector (PixelBaseComponent PixelYA8)
vec
  I.ImageRGB8 (I.Image Int
width Int
height Vector (PixelBaseComponent PixelRGB8)
vec) -> forall {a}.
Storable a =>
Int -> Int -> Int -> DType -> Int -> Vector a -> IO Tensor
createTensor Int
width Int
height Int
3 DType
D.UInt8 Int
1 Vector (PixelBaseComponent PixelRGB8)
vec
  I.ImageRGBF (I.Image Int
width Int
height Vector (PixelBaseComponent PixelRGBF)
vec) -> forall {a}.
Storable a =>
Int -> Int -> Int -> DType -> Int -> Vector a -> IO Tensor
createTensor Int
width Int
height Int
3 DType
D.Float Int
4 Vector (PixelBaseComponent PixelRGBF)
vec
  I.ImageRGBA8 (I.Image Int
width Int
height Vector (PixelBaseComponent PixelRGBA8)
vec) -> forall {a}.
Storable a =>
Int -> Int -> Int -> DType -> Int -> Vector a -> IO Tensor
createTensor Int
width Int
height Int
4 DType
D.UInt8 Int
1 Vector (PixelBaseComponent PixelRGBA8)
vec
  I.ImageYCbCr8 (I.Image Int
width Int
height Vector (PixelBaseComponent PixelYCbCr8)
vec) -> forall {a}.
Storable a =>
Int -> Int -> Int -> DType -> Int -> Vector a -> IO Tensor
createTensor Int
width Int
height Int
3 DType
D.UInt8 Int
1 Vector (PixelBaseComponent PixelYCbCr8)
vec
  I.ImageCMYK8 (I.Image Int
width Int
height Vector (PixelBaseComponent PixelCMYK8)
vec) -> forall {a}.
Storable a =>
Int -> Int -> Int -> DType -> Int -> Vector a -> IO Tensor
createTensor Int
width Int
height Int
4 DType
D.UInt8 Int
1 Vector (PixelBaseComponent PixelCMYK8)
vec
  I.ImageCMYK16 (I.Image Int
width Int
height Vector (PixelBaseComponent PixelCMYK16)
vec) -> forall {a}.
Storable a =>
Int -> Int -> Int -> DType -> Vector a -> IO Tensor
createTensorU16to32 Int
width Int
height Int
4 DType
D.Int32 Vector (PixelBaseComponent PixelCMYK16)
vec
  I.ImageRGBA16 (I.Image Int
width Int
height Vector (PixelBaseComponent PixelRGBA16)
vec) -> forall {a}.
Storable a =>
Int -> Int -> Int -> DType -> Vector a -> IO Tensor
createTensorU16to32 Int
width Int
height Int
4 DType
D.Int32 Vector (PixelBaseComponent PixelRGBA16)
vec
  I.ImageRGB16 (I.Image Int
width Int
height Vector (PixelBaseComponent PixelRGB16)
vec) -> forall {a}.
Storable a =>
Int -> Int -> Int -> DType -> Vector a -> IO Tensor
createTensorU16to32 Int
width Int
height Int
3 DType
D.Int32 Vector (PixelBaseComponent PixelRGB16)
vec
  I.ImageY16 (I.Image Int
width Int
height Vector (PixelBaseComponent Word16)
vec) -> forall {a}.
Storable a =>
Int -> Int -> Int -> DType -> Vector a -> IO Tensor
createTensorU16to32 Int
width Int
height Int
1 DType
D.Int32 Vector (PixelBaseComponent Word16)
vec
  I.ImageYA16 (I.Image Int
width Int
height Vector (PixelBaseComponent PixelYA16)
vec) -> forall {a}.
Storable a =>
Int -> Int -> Int -> DType -> Vector a -> IO Tensor
createTensorU16to32 Int
width Int
height Int
2 DType
D.Int32 Vector (PixelBaseComponent PixelYA16)
vec
  I.ImageY32 (I.Image Int
width Int
height Vector (PixelBaseComponent Word32)
vec) -> forall {a}.
Storable a =>
Int -> Int -> Int -> DType -> Vector a -> IO Tensor
createTensorU32to64 Int
width Int
height Int
1 DType
D.Int64 Vector (PixelBaseComponent Word32)
vec
  where
    createTensor :: Int -> Int -> Int -> DType -> Int -> Vector a -> IO Tensor
createTensor Int
width Int
height Int
channel DType
dtype Int
dtype_size Vector a
vec = do
      Tensor
t <- ((forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr IntArray
-> ForeignPtr TensorOptions -> IO (ForeignPtr Tensor)
LibTorch.empty_lo) :: [Int] -> D.TensorOptions -> IO D.Tensor) [Int
1, Int
height, Int
width, Int
channel] forall a b. (a -> b) -> a -> b
$ DType -> TensorOptions -> TensorOptions
D.withDType DType
dtype TensorOptions
D.defaultOpts
      forall a. Tensor -> (Ptr () -> IO a) -> IO a
D.withTensor Tensor
t forall a b. (a -> b) -> a -> b
$ \Ptr ()
ptr1 -> do
        let (ForeignPtr a
fptr, Int
len) = forall a. Storable a => Vector a -> (ForeignPtr a, Int)
V.unsafeToForeignPtr0 Vector a
vec
            whc :: Int
whc = Int
width forall a. Num a => a -> a -> a
* Int
height forall a. Num a => a -> a -> a
* Int
channel forall a. Num a => a -> a -> a
* Int
dtype_size
        forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
F.withForeignPtr ForeignPtr a
fptr forall a b. (a -> b) -> a -> b
$ \Ptr a
ptr2 -> do
          Ptr Pixel8 -> Ptr Pixel8 -> Int -> IO ()
BSI.memcpy (forall a b. Ptr a -> Ptr b
F.castPtr Ptr ()
ptr1) (forall a b. Ptr a -> Ptr b
F.castPtr Ptr a
ptr2) Int
whc
          forall (m :: * -> *) a. Monad m => a -> m a
return Tensor
t
    createTensorU16to32 :: Int -> Int -> Int -> DType -> Vector a -> IO Tensor
createTensorU16to32 Int
width Int
height Int
channel DType
dtype Vector a
vec = do
      Tensor
t <- ((forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr IntArray
-> ForeignPtr TensorOptions -> IO (ForeignPtr Tensor)
LibTorch.empty_lo) :: [Int] -> D.TensorOptions -> IO D.Tensor) [Int
1, Int
height, Int
width, Int
channel] forall a b. (a -> b) -> a -> b
$ DType -> TensorOptions -> TensorOptions
D.withDType DType
dtype TensorOptions
D.defaultOpts
      forall a. Tensor -> (Ptr () -> IO a) -> IO a
D.withTensor Tensor
t forall a b. (a -> b) -> a -> b
$ \Ptr ()
ptr1 -> do
        let (ForeignPtr a
fptr, Int
len) = forall a. Storable a => Vector a -> (ForeignPtr a, Int)
V.unsafeToForeignPtr0 Vector a
vec
            whc :: CInt
whc = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ Int
width forall a. Num a => a -> a -> a
* Int
height forall a. Num a => a -> a -> a
* Int
channel
        forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
F.withForeignPtr ForeignPtr a
fptr forall a b. (a -> b) -> a -> b
$ \Ptr a
ptr2 -> do
          let src :: Ptr b
src = forall a b. Ptr a -> Ptr b
F.castPtr Ptr a
ptr2
              dst :: Ptr b
dst = forall a b. Ptr a -> Ptr b
F.castPtr Ptr ()
ptr1
          [C.block| void {
              uint16_t* src = $(uint16_t* src);
              int32_t* dst = $(int32_t* dst);
              for(int i=0;i<$(int whc);i++){
                 dst[i] = src[i];
              }
          } |]
          forall (m :: * -> *) a. Monad m => a -> m a
return Tensor
t
    createTensorU32to64 :: Int -> Int -> Int -> DType -> Vector a -> IO Tensor
createTensorU32to64 Int
width Int
height Int
channel DType
dtype Vector a
vec = do
      Tensor
t <- ((forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr IntArray
-> ForeignPtr TensorOptions -> IO (ForeignPtr Tensor)
LibTorch.empty_lo) :: [Int] -> D.TensorOptions -> IO D.Tensor) [Int
1, Int
height, Int
width, Int
channel] forall a b. (a -> b) -> a -> b
$ DType -> TensorOptions -> TensorOptions
D.withDType DType
dtype TensorOptions
D.defaultOpts
      forall a. Tensor -> (Ptr () -> IO a) -> IO a
D.withTensor Tensor
t forall a b. (a -> b) -> a -> b
$ \Ptr ()
ptr1 -> do
        let (ForeignPtr a
fptr, Int
len) = forall a. Storable a => Vector a -> (ForeignPtr a, Int)
V.unsafeToForeignPtr0 Vector a
vec
            whc :: CInt
whc = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ Int
width forall a. Num a => a -> a -> a
* Int
height forall a. Num a => a -> a -> a
* Int
channel
        forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
F.withForeignPtr ForeignPtr a
fptr forall a b. (a -> b) -> a -> b
$ \Ptr a
ptr2 -> do
          let src :: Ptr b
src = forall a b. Ptr a -> Ptr b
F.castPtr Ptr a
ptr2
              dst :: Ptr b
dst = forall a b. Ptr a -> Ptr b
F.castPtr Ptr ()
ptr1
          [C.block| void {
              uint32_t* src = $(uint32_t* src);
              int64_t* dst = $(int64_t* dst);
              for(int i=0;i<$(int whc);i++){
                 dst[i] = src[i];
              }
          } |]
          forall (m :: * -> *) a. Monad m => a -> m a
return Tensor
t

fromImages :: [I.Image I.PixelRGB8] -> IO D.Tensor
fromImages :: [Image PixelRGB8] -> IO Tensor
fromImages [Image PixelRGB8]
imgs = do
  let num_imgs :: Int
num_imgs = forall (t :: * -> *) a. Foldable t => t a -> Int
length [Image PixelRGB8]
imgs
      channel :: Int
channel = Int
3
      (I.Image Int
width Int
height Vector (PixelBaseComponent PixelRGB8)
_) = forall a. [a] -> a
head [Image PixelRGB8]
imgs
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
num_imgs forall a. Eq a => a -> a -> Bool
== Int
0) forall a b. (a -> b) -> a -> b
$ do
    forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO forall a b. (a -> b) -> a -> b
$ String -> IOError
userError String
"The number of images should be greater than 0."
  Tensor
t <- ((forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr IntArray
-> ForeignPtr TensorOptions -> IO (ForeignPtr Tensor)
LibTorch.empty_lo) :: [Int] -> D.TensorOptions -> IO D.Tensor) [Int
num_imgs, Int
height, Int
width, Int
channel] forall a b. (a -> b) -> a -> b
$ DType -> TensorOptions -> TensorOptions
D.withDType DType
D.UInt8 TensorOptions
D.defaultOpts
  forall a. Tensor -> (Ptr () -> IO a) -> IO a
D.withTensor Tensor
t forall a b. (a -> b) -> a -> b
$ \Ptr ()
ptr1 -> do
    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0 ..] [Image PixelRGB8]
imgs) forall a b. (a -> b) -> a -> b
$ \(Int
idx, (I.Image Int
width' Int
height' Vector (PixelBaseComponent PixelRGB8)
vec)) -> do
      let (ForeignPtr Pixel8
fptr, Int
len) = forall a. Storable a => Vector a -> (ForeignPtr a, Int)
V.unsafeToForeignPtr0 Vector (PixelBaseComponent PixelRGB8)
vec
          whc :: Int
whc = Int
width forall a. Num a => a -> a -> a
* Int
height forall a. Num a => a -> a -> a
* Int
channel
      forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
len forall a. Eq a => a -> a -> Bool
/= Int
whc) forall a b. (a -> b) -> a -> b
$ do
        forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO forall a b. (a -> b) -> a -> b
$ String -> IOError
userError String
"vector's length is not the same as tensor' one."
      forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
width forall a. Eq a => a -> a -> Bool
/= Int
width') forall a b. (a -> b) -> a -> b
$ do
        forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO forall a b. (a -> b) -> a -> b
$ String -> IOError
userError String
"image's width is not the same as first image's one"
      forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
height forall a. Eq a => a -> a -> Bool
/= Int
height') forall a b. (a -> b) -> a -> b
$ do
        forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO forall a b. (a -> b) -> a -> b
$ String -> IOError
userError String
"image's height is not the same as first image's one"
      forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
F.withForeignPtr ForeignPtr Pixel8
fptr forall a b. (a -> b) -> a -> b
$ \Ptr Pixel8
ptr2 -> do
        Ptr Pixel8 -> Ptr Pixel8 -> Int -> IO ()
BSI.memcpy (forall a b. Ptr a -> Int -> Ptr b
F.plusPtr (forall a b. Ptr a -> Ptr b
F.castPtr Ptr ()
ptr1) (Int
whc forall a. Num a => a -> a -> a
* Int
idx)) Ptr Pixel8
ptr2 Int
len
  forall (m :: * -> *) a. Monad m => a -> m a
return Tensor
t

writeImage :: forall p. I.Pixel p => Int -> Int -> Int -> p -> D.Tensor -> IO (I.Image p)
writeImage :: forall p.
Pixel p =>
Int -> Int -> Int -> p -> Tensor -> IO (Image p)
writeImage Int
width Int
height Int
channel p
pixel Tensor
tensor = do
  let img :: Image p
img@(I.Image Int
w Int
h Vector (PixelBaseComponent p)
vec) = forall px. Pixel px => (Int -> Int -> px) -> Int -> Int -> Image px
I.generateImage (\Int
_ Int
_ -> p
pixel) Int
width Int
height :: I.Image p
  forall a. Tensor -> (Ptr () -> IO a) -> IO a
D.withTensor Tensor
tensor forall a b. (a -> b) -> a -> b
$ \Ptr ()
ptr1 -> do
    let (ForeignPtr (PixelBaseComponent p)
fptr, Int
len) = forall a. Storable a => Vector a -> (ForeignPtr a, Int)
V.unsafeToForeignPtr0 Vector (PixelBaseComponent p)
vec
        whc :: Int
whc = Int
width forall a. Num a => a -> a -> a
* Int
height forall a. Num a => a -> a -> a
* Int
channel
    if (Int
len forall a. Eq a => a -> a -> Bool
/= Int
whc)
      then forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO forall a b. (a -> b) -> a -> b
$ String -> IOError
userError forall a b. (a -> b) -> a -> b
$ String
"vector's length(" forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Int
len forall a. [a] -> [a] -> [a]
++ String
") is not the same as tensor' one."
      else do
        forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
F.withForeignPtr ForeignPtr (PixelBaseComponent p)
fptr forall a b. (a -> b) -> a -> b
$ \Ptr (PixelBaseComponent p)
ptr2 -> do
          Ptr Pixel8 -> Ptr Pixel8 -> Int -> IO ()
BSI.memcpy (forall a b. Ptr a -> Ptr b
F.castPtr Ptr (PixelBaseComponent p)
ptr2) (forall a b. Ptr a -> Ptr b
F.castPtr Ptr ()
ptr1) Int
len
          forall (m :: * -> *) a. Monad m => a -> m a
return Image p
img

writeBitmap :: FilePath -> D.Tensor -> IO ()
writeBitmap :: String -> Tensor -> IO ()
writeBitmap String
file Tensor
tensor = do
  case (Tensor -> [Int]
D.shape Tensor
tensor, Tensor -> DType
D.dtype Tensor
tensor) of
    ([Int
1, Int
height, Int
width, Int
1], DType
D.UInt8) -> forall p.
Pixel p =>
Int -> Int -> Int -> p -> Tensor -> IO (Image p)
writeImage Int
width Int
height Int
1 (Pixel8
0 :: I.Pixel8) Tensor
tensor forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall pixel. BmpEncodable pixel => String -> Image pixel -> IO ()
I.writeBitmap String
file
    ([Int
1, Int
height, Int
width], DType
D.UInt8) -> forall p.
Pixel p =>
Int -> Int -> Int -> p -> Tensor -> IO (Image p)
writeImage Int
width Int
height Int
1 (Pixel8
0 :: I.Pixel8) Tensor
tensor forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall pixel. BmpEncodable pixel => String -> Image pixel -> IO ()
I.writeBitmap String
file
    ([Int
1, Int
height, Int
width, Int
3], DType
D.UInt8) -> forall p.
Pixel p =>
Int -> Int -> Int -> p -> Tensor -> IO (Image p)
writeImage Int
width Int
height Int
3 (Pixel8 -> Pixel8 -> Pixel8 -> PixelRGB8
I.PixelRGB8 Pixel8
0 Pixel8
0 Pixel8
0) Tensor
tensor forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall pixel. BmpEncodable pixel => String -> Image pixel -> IO ()
I.writeBitmap String
file
    ([Int
height, Int
width, Int
1], DType
D.UInt8) -> forall p.
Pixel p =>
Int -> Int -> Int -> p -> Tensor -> IO (Image p)
writeImage Int
width Int
height Int
1 (Pixel8
0 :: I.Pixel8) Tensor
tensor forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall pixel. BmpEncodable pixel => String -> Image pixel -> IO ()
I.writeBitmap String
file
    ([Int
height, Int
width], DType
D.UInt8) -> forall p.
Pixel p =>
Int -> Int -> Int -> p -> Tensor -> IO (Image p)
writeImage Int
width Int
height Int
1 (Pixel8
0 :: I.Pixel8) Tensor
tensor forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall pixel. BmpEncodable pixel => String -> Image pixel -> IO ()
I.writeBitmap String
file
    ([Int
height, Int
width, Int
3], DType
D.UInt8) -> forall p.
Pixel p =>
Int -> Int -> Int -> p -> Tensor -> IO (Image p)
writeImage Int
width Int
height Int
3 (Pixel8 -> Pixel8 -> Pixel8 -> PixelRGB8
I.PixelRGB8 Pixel8
0 Pixel8
0 Pixel8
0) Tensor
tensor forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall pixel. BmpEncodable pixel => String -> Image pixel -> IO ()
I.writeBitmap String
file
    ([Int], DType)
format -> forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO forall a b. (a -> b) -> a -> b
$ String -> IOError
userError forall a b. (a -> b) -> a -> b
$ String
"Can not write a image. " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show ([Int], DType)
format forall a. [a] -> [a] -> [a]
++ String
" is unsupported format."

writePng :: FilePath -> D.Tensor -> IO ()
writePng :: String -> Tensor -> IO ()
writePng String
file Tensor
tensor = do
  case (Tensor -> [Int]
D.shape Tensor
tensor, Tensor -> DType
D.dtype Tensor
tensor) of
    ([Int
1, Int
height, Int
width, Int
1], DType
D.UInt8) -> forall p.
Pixel p =>
Int -> Int -> Int -> p -> Tensor -> IO (Image p)
writeImage Int
width Int
height Int
1 (Pixel8
0 :: I.Pixel8) Tensor
tensor forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall pixel. PngSavable pixel => String -> Image pixel -> IO ()
I.writePng String
file
    ([Int
1, Int
height, Int
width], DType
D.UInt8) -> forall p.
Pixel p =>
Int -> Int -> Int -> p -> Tensor -> IO (Image p)
writeImage Int
width Int
height Int
1 (Pixel8
0 :: I.Pixel8) Tensor
tensor forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall pixel. PngSavable pixel => String -> Image pixel -> IO ()
I.writePng String
file
    ([Int
1, Int
height, Int
width, Int
3], DType
D.UInt8) -> forall p.
Pixel p =>
Int -> Int -> Int -> p -> Tensor -> IO (Image p)
writeImage Int
width Int
height Int
3 (Pixel8 -> Pixel8 -> Pixel8 -> PixelRGB8
I.PixelRGB8 Pixel8
0 Pixel8
0 Pixel8
0) Tensor
tensor forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall pixel. PngSavable pixel => String -> Image pixel -> IO ()
I.writePng String
file
    ([Int
height, Int
width, Int
1], DType
D.UInt8) -> forall p.
Pixel p =>
Int -> Int -> Int -> p -> Tensor -> IO (Image p)
writeImage Int
width Int
height Int
1 (Pixel8
0 :: I.Pixel8) Tensor
tensor forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall pixel. PngSavable pixel => String -> Image pixel -> IO ()
I.writePng String
file
    ([Int
height, Int
width], DType
D.UInt8) -> forall p.
Pixel p =>
Int -> Int -> Int -> p -> Tensor -> IO (Image p)
writeImage Int
width Int
height Int
1 (Pixel8
0 :: I.Pixel8) Tensor
tensor forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall pixel. PngSavable pixel => String -> Image pixel -> IO ()
I.writePng String
file
    ([Int
height, Int
width, Int
3], DType
D.UInt8) -> forall p.
Pixel p =>
Int -> Int -> Int -> p -> Tensor -> IO (Image p)
writeImage Int
width Int
height Int
3 (Pixel8 -> Pixel8 -> Pixel8 -> PixelRGB8
I.PixelRGB8 Pixel8
0 Pixel8
0 Pixel8
0) Tensor
tensor forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall pixel. PngSavable pixel => String -> Image pixel -> IO ()
I.writePng String
file
    ([Int], DType)
format -> forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO forall a b. (a -> b) -> a -> b
$ String -> IOError
userError forall a b. (a -> b) -> a -> b
$ String
"Can not write a image. " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show ([Int], DType)
format forall a. [a] -> [a] -> [a]
++ String
" is unsupported format."

-- [batch, height, width, channel] -> [batch, channel, height, width]
hwc2chw :: D.Tensor -> D.Tensor
hwc2chw :: Tensor -> Tensor
hwc2chw = [Int] -> Tensor -> Tensor
D.permute [Int
0, Int
3, Int
1, Int
2]

-- [batch, channel, height, width] -> [batch, height, width, channel]
chw2hwc :: D.Tensor -> D.Tensor
chw2hwc :: Tensor -> Tensor
chw2hwc = [Int] -> Tensor -> Tensor
D.permute [Int
0, Int
2, Int
3, Int
1]

randomIndexes :: Int -> [Int]
randomIndexes :: Int -> [Int]
randomIndexes Int
size = (forall a. Integral a => a -> a -> a
`mod` Int
size) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a g. (Random a, RandomGen g) => g -> [a]
randoms StdGen
seed where seed :: StdGen
seed = Int -> StdGen
mkStdGen Int
123