Callback Helpers
Flux.throttle
— Functionthrottle(f, timeout; leading=true, trailing=false)
Return a function that when invoked, will only be triggered at most once during timeout
seconds.
Normally, the throttled function will run as much as it can, without ever going more than once per wait
duration; but if you'd like to disable the execution on the leading edge, pass leading=false
. To enable execution on the trailing edge, pass trailing=true
.
Examples
julia> a = Flux.throttle(() -> println("Flux"), 2);
julia> for i = 1:4 # a called in alternate iterations
a()
sleep(1)
end
Flux
Flux
Patience Helpers
Flux provides utilities for controlling your training procedure according to some monitored condition and a maximum patience
. For example, you can use early_stopping
to stop training when the model is converging or deteriorating, or you can use plateau
to check if the model is stagnating.
For example, below we create a pseudo-loss function that decreases, bottoms out, and then increases. The early stopping trigger will break the loop before the loss increases too much.
# create a pseudo-loss that decreases for 4 calls, then starts increasing
# we call this like loss()
loss = let t = 0
() -> begin
t += 1
(t - 4) ^ 2
end
end
# create an early stopping trigger
# returns true when the loss increases for two consecutive steps
es = early_stopping(loss, 2; init_score = 9)
# this will stop at the 6th (4 decreasing + 2 increasing calls) epoch
for epoch in 1:10
es() && break
end
The keyword argument distance
of early_stopping
is a function of the form distance(best_score, score)
. By default distance
is -
, which implies that the monitored metric f
is expected to be decreasing and minimized. If you use some increasing metric (e.g. accuracy), you can customize the distance
function: (best_score, score) -> score - best_score
.
# create a pseudo-accuracy that increases by 0.01 each time from 0 to 1
# we call this like acc()
acc = let v = 0
() -> v = max(1, v + 0.01)
end
# create an early stopping trigger for accuracy
es = early_stopping(acc, 3; delta = (best_score, score) -> score - best_score)
# this will iterate until the 10th epoch
for epoch in 1:10
es() && break
end
early_stopping
and plateau
are both built on top of patience
. You can use patience
to build your own triggers that use a patient counter. For example, if you want to trigger when the loss is below a threshold for several consecutive iterations:
threshold(f, thresh, delay) = patience(delay) do
f() < thresh
end
Both predicate
in patience
and f
in early_stopping
/ plateau
can accept extra arguments. You can pass such extra arguments to predicate
or f
through the returned function:
trigger = patience((a; b) -> a > b, 3)
# this will iterate until the 10th epoch
for epoch in 1:10
trigger(1; b = 2) && break
end
# this will stop at the 3rd epoch
for epoch in 1:10
trigger(3; b = 2) && break
end
Flux.patience
— Functionpatience(predicate, wait)
Return a function that internally counts by one when predicate(...) == true
, otherwise the count is reset to zero. If the count is greater than or equal to wait
, the function returns true
, otherwise it returns false
.
Examples
julia> loss() = rand();
julia> trigger = Flux.patience(() -> loss() < 1, 3);
julia> for i in 1:10
@info "Epoch $i"
trigger() && break
end
[ Info: Epoch 1
[ Info: Epoch 2
[ Info: Epoch 3
Flux.early_stopping
— Functionearly_stopping(f, delay; distance = -, init_score = 0, min_dist = 0)
Return a function that internally counts by one when distance(best_score, f(...)) <= min_dist
, where best_score
is the last seen best value of f(...)
. If the count is greater than or equal to delay
, the function returns true
, otherwise it returns false
. The count is reset when distance(best_score, f(...)) > min_dist
.
Examples
julia> loss = let l = 0
() -> l += 1
end; # pseudo loss function that returns increasing values
julia> es = Flux.early_stopping(loss, 3);
julia> for i in 1:10
@info "Epoch $i"
es() && break
end
[ Info: Epoch 1
[ Info: Epoch 2
[ Info: Epoch 3
Flux.plateau
— Functionplateau(f, width; distance = -, init_score = 0, min_dist = 1f-6)
Return a function that internally counts by one when abs(distance(last_score, f(...))) <= min_dist
, where last_score
holds the last value of f(...)
. If the count is greater than or equal to width
, the function returns true
, otherwise it returns false
. The count is reset when abs(distance(last_score, f(...))) > min_dist
.
Examples
julia> f = let v = 10
() -> v = v / abs(v) - v
end; # -9, 8, -7, 6, ...
julia> trigger = Flux.plateau(f, 3; init_score=10, min_dist=18);
julia> for i in 1:10
@info "Epoch $i"
trigger() && break
end
[ Info: Epoch 1
[ Info: Epoch 2
[ Info: Epoch 3
[ Info: Epoch 4