{-# LANGUAGE ScopedTypeVariables #-}
module Torch.GraduallyTyped.LearningRateSchedules where
singleCycleLearningRateSchedule ::
Double ->
Double ->
Int ->
Int ->
Int ->
Int ->
Double
singleCycleLearningRateSchedule :: Double -> Double -> Int -> Int -> Int -> Int -> Double
singleCycleLearningRateSchedule Double
maxLearningRate Double
finalLearningRate Int
numEpochs Int
numWarmupEpochs Int
numCooldownEpochs Int
epoch
| Int
epoch forall a. Ord a => a -> a -> Bool
<= Int
0 = Double
0.0
| Int
0 forall a. Ord a => a -> a -> Bool
< Int
epoch Bool -> Bool -> Bool
&& Int
epoch forall a. Ord a => a -> a -> Bool
<= Int
numWarmupEpochs =
let Double
a :: Double = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
epoch forall a. Fractional a => a -> a -> a
/ forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
numWarmupEpochs
in Double
a forall a. Num a => a -> a -> a
* Double
maxLearningRate
| Int
numWarmupEpochs forall a. Ord a => a -> a -> Bool
< Int
epoch Bool -> Bool -> Bool
&& Int
epoch forall a. Ord a => a -> a -> Bool
< Int
numEpochs forall a. Num a => a -> a -> a
- Int
numCooldownEpochs =
let Double
a :: Double =
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
numEpochs forall a. Num a => a -> a -> a
- Int
numCooldownEpochs forall a. Num a => a -> a -> a
- Int
epoch)
forall a. Fractional a => a -> a -> a
/ forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
numEpochs forall a. Num a => a -> a -> a
- Int
numCooldownEpochs forall a. Num a => a -> a -> a
- Int
numWarmupEpochs)
in Double
a forall a. Num a => a -> a -> a
* Double
maxLearningRate forall a. Num a => a -> a -> a
+ (Double
1 forall a. Num a => a -> a -> a
- Double
a) forall a. Num a => a -> a -> a
* Double
finalLearningRate
| Bool
otherwise = Double
finalLearningRate