{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

module Torch.Jit where

import Torch.Script
import Torch.Tensor
import Torch.NN
import Control.Concurrent.STM.TVar
import Control.Concurrent.STM (atomically)
import System.IO.Unsafe (unsafePerformIO)

newtype ScriptCache = ScriptCache { ScriptCache -> TVar (Maybe ScriptModule)
unScriptCache :: TVar (Maybe ScriptModule) }

newScriptCache :: IO ScriptCache
newScriptCache :: IO ScriptCache
newScriptCache = TVar (Maybe ScriptModule) -> ScriptCache
ScriptCache forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. a -> IO (TVar a)
newTVarIO forall a. Maybe a
Nothing

jitIO :: ScriptCache -> ([Tensor] -> IO [Tensor]) -> [Tensor] -> IO [Tensor]
jitIO :: ScriptCache -> ([Tensor] -> IO [Tensor]) -> [Tensor] -> IO [Tensor]
jitIO (ScriptCache TVar (Maybe ScriptModule)
cache) [Tensor] -> IO [Tensor]
func [Tensor]
input = do
  Maybe ScriptModule
v <- forall a. TVar a -> IO a
readTVarIO TVar (Maybe ScriptModule)
cache
  ScriptModule
script <- case Maybe ScriptModule
v of
    Just ScriptModule
script' -> forall (m :: * -> *) a. Monad m => a -> m a
return ScriptModule
script'
    Maybe ScriptModule
Nothing -> do
      RawModule
m <- String
-> String -> ([Tensor] -> IO [Tensor]) -> [Tensor] -> IO RawModule
trace String
"MyModule" String
"forward" [Tensor] -> IO [Tensor]
func [Tensor]
input
      ScriptModule
script' <- RawModule -> IO ScriptModule
toScriptModule RawModule
m
      forall a. STM a -> IO a
atomically forall a b. (a -> b) -> a -> b
$ forall a. TVar a -> a -> STM ()
writeTVar TVar (Maybe ScriptModule)
cache (forall a. a -> Maybe a
Just ScriptModule
script')
      forall (m :: * -> *) a. Monad m => a -> m a
return ScriptModule
script'
  IVTensorList [Tensor]
r0 <- forall f a b. HasForward f a b => f -> a -> IO b
forwardStoch ScriptModule
script (forall a b. (a -> b) -> [a] -> [b]
map Tensor -> IValue
IVTensor [Tensor]
input)
  forall (m :: * -> *) a. Monad m => a -> m a
return [Tensor]
r0

jit :: ScriptCache -> ([Tensor] -> [Tensor]) -> [Tensor] -> [Tensor]
jit :: ScriptCache -> ([Tensor] -> [Tensor]) -> [Tensor] -> [Tensor]
jit ScriptCache
cache [Tensor] -> [Tensor]
func [Tensor]
input = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ ScriptCache -> ([Tensor] -> IO [Tensor]) -> [Tensor] -> IO [Tensor]
jitIO ScriptCache
cache (forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Tensor] -> [Tensor]
func) [Tensor]
input