Getting started

All schedules types in ParameterSchedulers.jl behave as callable iterators. For example, we can call the simple exponential decay schedule (Exp) below at a specific iteration:

s = Exp(start = 0.1, decay = 0.8)
println("s(1): ", s(1))
println("s(5): ", s(5))
s(1): 0.1
s(5): 0.04096000000000001
Info

The iterations are unitless. So, if you index a schedule every epoch, then the s(i) is parameter value at epoch i.

We can also use the schedule in an iterable context like a for-loop:

for (i, param) in enumerate(s)
    (i > 10) && break
    println("s($i): ", param)
end
s(1): 0.1
s(2): 0.08000000000000002
s(3): 0.06400000000000002
s(4): 0.051200000000000016
s(5): 0.04096000000000001
s(6): 0.03276800000000001
s(7): 0.026214400000000013
s(8): 0.020971520000000007
s(9): 0.016777216000000008
s(10): 0.013421772800000007
Warning

Many schedules such as Exp are infinite iterators, so iterating over them will result in an infinite loop. You can use Base.IteratorSize to check if a schedule has infinite length.

Notice that the value of s(1) and s(5) is unchanged even though we accessed the schedule once by calling them and again in the for-loop. This is because all schedules in ParameterSchedulers.jl are immutable. If you want a stateful (mutable) schedule, then you can use ParameterSchedulers.Stateful:

using ParameterSchedulers: Stateful, next!

stateful_s = Stateful(s)
println("s: ", next!(stateful_s))
println("s: ", next!(stateful_s))
println(stateful_s)
s: 0.1
s: 0.08000000000000002
ParameterSchedulers.Stateful{Exp{Float64}, Int64, ParameterSchedulers.var"#15#17"}(Exp{Float64}(0.1, 0.8), 3, ParameterSchedulers.var"#15#17"())

We used ParameterSchedulers.next! to advance the stateful iterator. Notice that stateful_s stores a reference to s and the current iteration state (which is 3 since we advanced the iterator twice). We can reset the mutable iteration state too:

using ParameterSchedulers: reset!

reset!(stateful_s)
println("s: ", next!(stateful_s))
s: 0.1

Also note that Stateful cannot be called (or iterated with Base.iterate):

try
    stateful_s(1)
catch e
    println(e)
end
MethodError(ParameterSchedulers.Stateful{Exp{Float64}, Int64, ParameterSchedulers.var"#15#17"}(Exp{Float64}(0.1, 0.8), 2, ParameterSchedulers.var"#15#17"()), (1,), 0x0000000000006902)