{-# LANGUAGE CPP #-}
{-# LANGUAGE EmptyDataDecls #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ForeignFunctionInterface #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeSynonymInstances #-}
module Torch.Internal.GC where
import Control.Concurrent (threadDelay)
import Control.Concurrent.Async
import Control.Exception.Safe (Exception, MonadThrow, Typeable, catch, throwIO, throwM)
import Control.Monad (when)
import Data.List (isPrefixOf)
import Foreign.C.Types
import GHC.ExecutionStack
import Language.C.Inline.Cpp.Exceptions
import System.Environment (lookupEnv)
import System.IO (hPutStrLn, stderr)
import System.IO.Unsafe (unsafePerformIO)
import System.Mem (performGC)
import System.SysInfo
foreign import ccall unsafe "hasktorch_finalizer.h showWeakPtrList"
c_showWeakPtrList :: CInt -> IO ()
#ifdef ENABLE_DUMMY_MALLOC_TRIM
mallocTrim :: CInt -> IO ()
mallocTrim _ = return ()
#else
foreign import ccall unsafe "malloc.h malloc_trim"
mallocTrim :: CInt -> IO ()
#endif
dumpLibtorchObjects ::
Int ->
IO ()
dumpLibtorchObjects :: Int -> IO ()
dumpLibtorchObjects Int
age = CInt -> IO ()
c_showWeakPtrList (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
age)
newtype HasktorchException = HasktorchException String
deriving (Int -> HasktorchException -> ShowS
[HasktorchException] -> ShowS
HasktorchException -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [HasktorchException] -> ShowS
$cshowList :: [HasktorchException] -> ShowS
show :: HasktorchException -> String
$cshow :: HasktorchException -> String
showsPrec :: Int -> HasktorchException -> ShowS
$cshowsPrec :: Int -> HasktorchException -> ShowS
Show)
instance Exception HasktorchException
unsafeThrowableIO :: forall a m. MonadThrow m => IO a -> m a
unsafeThrowableIO :: forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO IO a
a = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO a
a) forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> (e -> m a) -> m a
`catch` (\(CppStdException String
msg) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM forall a b. (a -> b) -> a -> b
$ String -> HasktorchException
HasktorchException String
msg)
prettyException :: IO a -> IO a
prettyException :: forall a. IO a -> IO a
prettyException IO a
func =
IO a
func forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> (e -> m a) -> m a
`catch` \a :: CppException
a@(CppStdException String
message) -> do
Maybe String
flag <- String -> IO (Maybe String)
lookupEnv String
"HASKTORCH_DEBUG"
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Maybe String
flag forall a. Eq a => a -> a -> Bool
/= forall a. a -> Maybe a
Just String
"0") forall a b. (a -> b) -> a -> b
$ do
Maybe String
mst <- IO (Maybe String)
showStackTrace
case Maybe String
mst of
Just String
st -> Handle -> String -> IO ()
hPutStrLn Handle
stderr String
st
Maybe String
Nothing -> Handle -> String -> IO ()
hPutStrLn Handle
stderr String
"Cannot show stacktrace"
Handle -> String -> IO ()
hPutStrLn Handle
stderr String
message
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO CppException
a
{-# INLINE prettyException #-}
retryWithGC' :: Int -> IO a -> IO a
retryWithGC' :: forall a. Int -> IO a -> IO a
retryWithGC' Int
count IO a
func =
IO a
func forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> (e -> m a) -> m a
`catch` \a :: CppException
a@(CppStdException String
message) ->
if forall a. Eq a => [a] -> [a] -> Bool
isPrefixOf String
msgOutOfMemory String
message
then
if Int
count forall a. Ord a => a -> a -> Bool
<= Int
0
then forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO forall a b. (a -> b) -> a -> b
$ String -> IOError
userError forall a b. (a -> b) -> a -> b
$ String
"Too many calls to performGC, " forall a. [a] -> [a] -> [a]
++ String
message
else do
IO ()
performGC
CInt -> IO ()
mallocTrim CInt
0
Int -> IO ()
threadDelay Int
1000
forall a. Int -> IO a -> IO a
retryWithGC' (Int
count forall a. Num a => a -> a -> a
-Int
1) IO a
func
else forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO CppException
a
where
msgOutOfMemory :: String
msgOutOfMemory :: String
msgOutOfMemory = String
"Exception: CUDA out of memory."
{-# INLINE retryWithGC' #-}
retryWithGC :: IO a -> IO a
retryWithGC :: forall a. IO a -> IO a
retryWithGC IO a
func = forall a. IO a -> IO a
prettyException forall a b. (a -> b) -> a -> b
$ forall a. Int -> IO a -> IO a
retryWithGC' Int
10 IO a
func
{-# INLINE retryWithGC #-}
checkOSMemoryWithGC :: IO ()
checkOSMemoryWithGC :: IO ()
checkOSMemoryWithGC = do
Either Errno SysInfo
v <- IO (Either Errno SysInfo)
sysInfo
case Either Errno SysInfo
v of
Right SysInfo
stat -> do
let rate :: Double
rate = (forall a b. (Integral a, Num b) => a -> b
fromIntegral (SysInfo -> CULong
freeram SysInfo
stat) forall a. Fractional a => a -> a -> a
/ forall a b. (Integral a, Num b) => a -> b
fromIntegral (SysInfo -> CULong
totalram SysInfo
stat))
if Double
rate forall a. Ord a => a -> a -> Bool
<= Double
0.5
then do
IO ()
performGC
CInt -> IO ()
mallocTrim CInt
0
else forall (m :: * -> *) a. Monad m => a -> m a
return ()
Left Errno
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
Int -> IO ()
threadDelay (Int
500 forall a. Num a => a -> a -> a
* Int
1000)
IO ()
checkOSMemoryWithGC
monitorMemory :: IO () -> IO ()
monitorMemory :: IO () -> IO ()
monitorMemory IO ()
func = do
IO ()
func forall a b. IO a -> IO b -> IO (Either a b)
`race` IO ()
checkOSMemoryWithGC
forall (m :: * -> *) a. Monad m => a -> m a
return ()