{-# LANGUAGE ScopedTypeVariables #-}

module Torch.GraduallyTyped.LearningRateSchedules where

-- | Single-cycle learning rate schedule.
-- See, for instance, https://arxiv.org/abs/1803.09820.
--
-- This is a simple schedule that is a stepwise linear interpolation
-- between the initial, maximum, and final learning rates.
-- The initial learning rate is zero.
singleCycleLearningRateSchedule ::
  -- | peak learning rate after warmup
  Double ->
  -- | learning rate at the end of the schedule
  Double ->
  -- | total number of epochs
  Int ->
  -- | number of warm-up epochs
  Int ->
  -- | number of cool-down epochs
  Int ->
  -- | current epoch
  Int ->
  -- | current learning rate
  Double
singleCycleLearningRateSchedule :: Double -> Double -> Int -> Int -> Int -> Int -> Double
singleCycleLearningRateSchedule Double
maxLearningRate Double
finalLearningRate Int
numEpochs Int
numWarmupEpochs Int
numCooldownEpochs Int
epoch
  | Int
epoch forall a. Ord a => a -> a -> Bool
<= Int
0 = Double
0.0
  | Int
0 forall a. Ord a => a -> a -> Bool
< Int
epoch Bool -> Bool -> Bool
&& Int
epoch forall a. Ord a => a -> a -> Bool
<= Int
numWarmupEpochs =
    let Double
a :: Double = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
epoch forall a. Fractional a => a -> a -> a
/ forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
numWarmupEpochs
     in Double
a forall a. Num a => a -> a -> a
* Double
maxLearningRate
  | Int
numWarmupEpochs forall a. Ord a => a -> a -> Bool
< Int
epoch Bool -> Bool -> Bool
&& Int
epoch forall a. Ord a => a -> a -> Bool
< Int
numEpochs forall a. Num a => a -> a -> a
- Int
numCooldownEpochs =
    let Double
a :: Double =
          forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
numEpochs forall a. Num a => a -> a -> a
- Int
numCooldownEpochs forall a. Num a => a -> a -> a
- Int
epoch)
            forall a. Fractional a => a -> a -> a
/ forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
numEpochs forall a. Num a => a -> a -> a
- Int
numCooldownEpochs forall a. Num a => a -> a -> a
- Int
numWarmupEpochs)
     in Double
a forall a. Num a => a -> a -> a
* Double
maxLearningRate forall a. Num a => a -> a -> a
+ (Double
1 forall a. Num a => a -> a -> a
- Double
a) forall a. Num a => a -> a -> a
* Double
finalLearningRate
  | Bool
otherwise = Double
finalLearningRate