Getting started
All schedules types in ParameterSchedulers.jl behave as callable iterators. For example, we can call the simple exponential decay schedule (Exp
) below at a specific iteration:
s = Exp(start = 0.1, decay = 0.8)
println("s(1): ", s(1))
println("s(5): ", s(5))
s(1): 0.1
s(5): 0.04096000000000001
The iterations are unitless. So, if you index a schedule every epoch, then the s(i)
is parameter value at epoch i
.
We can also use the schedule in an iterable context like a for
-loop:
for (i, param) in enumerate(s)
(i > 10) && break
println("s($i): ", param)
end
s(1): 0.1
s(2): 0.08000000000000002
s(3): 0.06400000000000002
s(4): 0.051200000000000016
s(5): 0.04096000000000001
s(6): 0.03276800000000001
s(7): 0.026214400000000013
s(8): 0.020971520000000007
s(9): 0.016777216000000008
s(10): 0.013421772800000007
Many schedules such as Exp
are infinite iterators, so iterating over them will result in an infinite loop. You can use Base.IteratorSize
to check if a schedule has infinite length.
Notice that the value of s(1)
and s(5)
is unchanged even though we accessed the schedule once by calling them and again in the for
-loop. This is because all schedules in ParameterSchedulers.jl are immutable. If you want a stateful (mutable) schedule, then you can use ParameterSchedulers.Stateful
:
using ParameterSchedulers: Stateful, next!
stateful_s = Stateful(s)
println("s: ", next!(stateful_s))
println("s: ", next!(stateful_s))
println(stateful_s)
s: 0.1
s: 0.08000000000000002
ParameterSchedulers.Stateful{Exp{Float64}, Int64, ParameterSchedulers.var"#15#17"}(Exp{Float64}(0.1, 0.8), 3, ParameterSchedulers.var"#15#17"())
We used ParameterSchedulers.next!
to advance the stateful iterator. Notice that stateful_s
stores a reference to s
and the current iteration state (which is 3
since we advanced the iterator twice). We can reset the mutable iteration state too:
using ParameterSchedulers: reset!
reset!(stateful_s)
println("s: ", next!(stateful_s))
s: 0.1
Also note that Stateful
cannot be called (or iterated with Base.iterate
):
try
stateful_s(1)
catch e
println(e)
end
MethodError(ParameterSchedulers.Stateful{Exp{Float64}, Int64, ParameterSchedulers.var"#15#17"}(Exp{Float64}(0.1, 0.8), 2, ParameterSchedulers.var"#15#17"()), (1,), 0x0000000000006902)