{-# LANGUAGE CPP #-}
{-# LANGUAGE EmptyDataDecls #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ForeignFunctionInterface #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeSynonymInstances #-}

module Torch.Internal.GC where

import Control.Concurrent (threadDelay)
import Control.Concurrent.Async
import Control.Exception.Safe (Exception, MonadThrow, Typeable, catch, throwIO, throwM)
import Control.Monad (when)
import Data.List (isPrefixOf)
import Foreign.C.Types
import GHC.ExecutionStack
import Language.C.Inline.Cpp.Exceptions
import System.Environment (lookupEnv)
import System.IO (hPutStrLn, stderr)
import System.IO.Unsafe (unsafePerformIO)
import System.Mem (performGC)
import System.SysInfo

foreign import ccall unsafe "hasktorch_finalizer.h showWeakPtrList"
  c_showWeakPtrList :: CInt -> IO ()

-- malloc_trim is a glibc function. It doesn't exist on macos.
#ifdef ENABLE_DUMMY_MALLOC_TRIM
mallocTrim :: CInt -> IO ()
mallocTrim _ = return ()
#else
foreign import ccall unsafe "malloc.h malloc_trim"
  mallocTrim :: CInt -> IO ()
#endif

-- | Returns all objects of libtorch.
-- Each time it is called, the age of the object increases by one.
-- Dumps objects that are greater than or equal to the argument of age.
dumpLibtorchObjects ::
  -- | age
  Int ->
  -- | output
  IO ()
dumpLibtorchObjects :: Int -> IO ()
dumpLibtorchObjects Int
age = CInt -> IO ()
c_showWeakPtrList (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
age)

newtype HasktorchException = HasktorchException String
  deriving (Int -> HasktorchException -> ShowS
[HasktorchException] -> ShowS
HasktorchException -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [HasktorchException] -> ShowS
$cshowList :: [HasktorchException] -> ShowS
show :: HasktorchException -> String
$cshow :: HasktorchException -> String
showsPrec :: Int -> HasktorchException -> ShowS
$cshowsPrec :: Int -> HasktorchException -> ShowS
Show)

instance Exception HasktorchException

unsafeThrowableIO :: forall a m. MonadThrow m => IO a -> m a
unsafeThrowableIO :: forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO IO a
a = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO a
a) forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> (e -> m a) -> m a
`catch` (\(CppStdException String
msg) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM forall a b. (a -> b) -> a -> b
$ String -> HasktorchException
HasktorchException String
msg)

prettyException :: IO a -> IO a
prettyException :: forall a. IO a -> IO a
prettyException IO a
func =
  IO a
func forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> (e -> m a) -> m a
`catch` \a :: CppException
a@(CppStdException String
message) -> do
    Maybe String
flag <- String -> IO (Maybe String)
lookupEnv String
"HASKTORCH_DEBUG"
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Maybe String
flag forall a. Eq a => a -> a -> Bool
/= forall a. a -> Maybe a
Just String
"0") forall a b. (a -> b) -> a -> b
$ do
      Maybe String
mst <- IO (Maybe String)
showStackTrace
      case Maybe String
mst of
        Just String
st -> Handle -> String -> IO ()
hPutStrLn Handle
stderr String
st
        Maybe String
Nothing -> Handle -> String -> IO ()
hPutStrLn Handle
stderr String
"Cannot show stacktrace"
      Handle -> String -> IO ()
hPutStrLn Handle
stderr String
message
    forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO CppException
a
{-# INLINE prettyException #-}

retryWithGC' :: Int -> IO a -> IO a
retryWithGC' :: forall a. Int -> IO a -> IO a
retryWithGC' Int
count IO a
func =
  IO a
func forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> (e -> m a) -> m a
`catch` \a :: CppException
a@(CppStdException String
message) ->
    if forall a. Eq a => [a] -> [a] -> Bool
isPrefixOf String
msgOutOfMemory String
message
      then
        if Int
count forall a. Ord a => a -> a -> Bool
<= Int
0
          then forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO forall a b. (a -> b) -> a -> b
$ String -> IOError
userError forall a b. (a -> b) -> a -> b
$ String
"Too many calls to performGC, " forall a. [a] -> [a] -> [a]
++ String
message
          else do
            IO ()
performGC
            CInt -> IO ()
mallocTrim CInt
0
            Int -> IO ()
threadDelay Int
1000 -- We need delta delay(1ms) to wait GC.
            forall a. Int -> IO a -> IO a
retryWithGC' (Int
count forall a. Num a => a -> a -> a
-Int
1) IO a
func
      else forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO CppException
a
  where
    msgOutOfMemory :: String
    msgOutOfMemory :: String
msgOutOfMemory = String
"Exception: CUDA out of memory."
{-# INLINE retryWithGC' #-}

retryWithGC :: IO a -> IO a
retryWithGC :: forall a. IO a -> IO a
retryWithGC IO a
func = forall a. IO a -> IO a
prettyException forall a b. (a -> b) -> a -> b
$ forall a. Int -> IO a -> IO a
retryWithGC' Int
10 IO a
func
{-# INLINE retryWithGC #-}

checkOSMemoryWithGC :: IO ()
checkOSMemoryWithGC :: IO ()
checkOSMemoryWithGC = do
  Either Errno SysInfo
v <- IO (Either Errno SysInfo)
sysInfo
  case Either Errno SysInfo
v of
    Right SysInfo
stat -> do
      let rate :: Double
rate = (forall a b. (Integral a, Num b) => a -> b
fromIntegral (SysInfo -> CULong
freeram SysInfo
stat) forall a. Fractional a => a -> a -> a
/ forall a b. (Integral a, Num b) => a -> b
fromIntegral (SysInfo -> CULong
totalram SysInfo
stat))
      if Double
rate forall a. Ord a => a -> a -> Bool
<= Double
0.5
        then do
          IO ()
performGC
          CInt -> IO ()
mallocTrim CInt
0
        else forall (m :: * -> *) a. Monad m => a -> m a
return ()
    Left Errno
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
  Int -> IO ()
threadDelay (Int
500 forall a. Num a => a -> a -> a
* Int
1000) -- wait 500msec
  IO ()
checkOSMemoryWithGC

monitorMemory :: IO () -> IO ()
monitorMemory :: IO () -> IO ()
monitorMemory IO ()
func = do
  IO ()
func forall a b. IO a -> IO b -> IO (Either a b)
`race` IO ()
checkOSMemoryWithGC
  forall (m :: * -> *) a. Monad m => a -> m a
return ()