{-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE UndecidableInstances #-} module Torch.Jit where import Torch.Script import Torch.Tensor import Torch.NN import Control.Concurrent.STM.TVar import Control.Concurrent.STM (atomically) import System.IO.Unsafe (unsafePerformIO) newtype ScriptCache = ScriptCache { ScriptCache -> TVar (Maybe ScriptModule) unScriptCache :: TVar (Maybe ScriptModule) } newScriptCache :: IO ScriptCache newScriptCache :: IO ScriptCache newScriptCache = TVar (Maybe ScriptModule) -> ScriptCache ScriptCache forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b <$> forall a. a -> IO (TVar a) newTVarIO forall a. Maybe a Nothing jitIO :: ScriptCache -> ([Tensor] -> IO [Tensor]) -> [Tensor] -> IO [Tensor] jitIO :: ScriptCache -> ([Tensor] -> IO [Tensor]) -> [Tensor] -> IO [Tensor] jitIO (ScriptCache TVar (Maybe ScriptModule) cache) [Tensor] -> IO [Tensor] func [Tensor] input = do Maybe ScriptModule v <- forall a. TVar a -> IO a readTVarIO TVar (Maybe ScriptModule) cache ScriptModule script <- case Maybe ScriptModule v of Just ScriptModule script' -> forall (m :: * -> *) a. Monad m => a -> m a return ScriptModule script' Maybe ScriptModule Nothing -> do RawModule m <- String -> String -> ([Tensor] -> IO [Tensor]) -> [Tensor] -> IO RawModule trace String "MyModule" String "forward" [Tensor] -> IO [Tensor] func [Tensor] input ScriptModule script' <- RawModule -> IO ScriptModule toScriptModule RawModule m forall a. STM a -> IO a atomically forall a b. (a -> b) -> a -> b $ forall a. TVar a -> a -> STM () writeTVar TVar (Maybe ScriptModule) cache (forall a. a -> Maybe a Just ScriptModule script') forall (m :: * -> *) a. Monad m => a -> m a return ScriptModule script' IVTensorList [Tensor] r0 <- forall f a b. HasForward f a b => f -> a -> IO b forwardStoch ScriptModule script (forall a b. (a -> b) -> [a] -> [b] map Tensor -> IValue IVTensor [Tensor] input) forall (m :: * -> *) a. Monad m => a -> m a return [Tensor] r0 jit :: ScriptCache -> ([Tensor] -> [Tensor]) -> [Tensor] -> [Tensor] jit :: ScriptCache -> ([Tensor] -> [Tensor]) -> [Tensor] -> [Tensor] jit ScriptCache cache [Tensor] -> [Tensor] func [Tensor] input = forall a. IO a -> a unsafePerformIO forall a b. (a -> b) -> a -> b $ ScriptCache -> ([Tensor] -> IO [Tensor]) -> [Tensor] -> IO [Tensor] jitIO ScriptCache cache (forall (m :: * -> *) a. Monad m => a -> m a return forall b c a. (b -> c) -> (a -> b) -> a -> c . [Tensor] -> [Tensor] func) [Tensor] input