Callback Helpers

Flux.throttleFunction
throttle(f, timeout; leading=true, trailing=false)

Return a function that when invoked, will only be triggered at most once during timeout seconds.

Normally, the throttled function will run as much as it can, without ever going more than once per wait duration; but if you'd like to disable the execution on the leading edge, pass leading=false. To enable execution on the trailing edge, pass trailing=true.

Examples

julia> a = Flux.throttle(() -> println("Flux"), 2);

julia> for i = 1:4  # a called in alternate iterations
           a()
           sleep(1)
       end
Flux
Flux
source

Patience Helpers

Flux provides utilities for controlling your training procedure according to some monitored condition and a maximum patience. For example, you can use early_stopping to stop training when the model is converging or deteriorating, or you can use plateau to check if the model is stagnating.

For example, below we create a pseudo-loss function that decreases, bottoms out, and then increases. The early stopping trigger will break the loop before the loss increases too much.

# create a pseudo-loss that decreases for 4 calls, then starts increasing
# we call this like loss()
loss = let t = 0
  () -> begin
    t += 1
    (t - 4) ^ 2
  end
end

# create an early stopping trigger
# returns true when the loss increases for two consecutive steps
es = early_stopping(loss, 2; init_score = 9)

# this will stop at the 6th (4 decreasing + 2 increasing calls) epoch
for epoch in 1:10
  es() && break
end

The keyword argument distance of early_stopping is a function of the form distance(best_score, score). By default distance is -, which implies that the monitored metric f is expected to be decreasing and minimized. If you use some increasing metric (e.g. accuracy), you can customize the distance function: (best_score, score) -> score - best_score.

# create a pseudo-accuracy that increases by 0.01 each time from 0 to 1
# we call this like acc()
acc = let v = 0
  () -> v = max(1, v + 0.01)
end

# create an early stopping trigger for accuracy
es = early_stopping(acc, 3; delta = (best_score, score) -> score - best_score)

# this will iterate until the 10th epoch
for epoch in 1:10
  es() && break
end

early_stopping and plateau are both built on top of patience. You can use patience to build your own triggers that use a patient counter. For example, if you want to trigger when the loss is below a threshold for several consecutive iterations:

threshold(f, thresh, delay) = patience(delay) do
  f() < thresh
end

Both predicate in patience and f in early_stopping / plateau can accept extra arguments. You can pass such extra arguments to predicate or f through the returned function:

trigger = patience((a; b) -> a > b, 3)

# this will iterate until the 10th epoch
for epoch in 1:10
  trigger(1; b = 2) && break
end

# this will stop at the 3rd epoch
for epoch in 1:10
  trigger(3; b = 2) && break
end
Flux.patienceFunction
patience(predicate, wait)

Return a function that internally counts by one when predicate(...) == true, otherwise the count is reset to zero. If the count is greater than or equal to wait, the function returns true, otherwise it returns false.

Examples

julia> loss() = rand();

julia> trigger = Flux.patience(() -> loss() < 1, 3);


julia> for i in 1:10
         @info "Epoch $i"
         trigger() && break
       end
[ Info: Epoch 1
[ Info: Epoch 2
[ Info: Epoch 3
source
Flux.early_stoppingFunction
early_stopping(f, delay; distance = -, init_score = 0, min_dist = 0)

Return a function that internally counts by one when distance(best_score, f(...)) <= min_dist, where best_score is the last seen best value of f(...). If the count is greater than or equal to delay, the function returns true, otherwise it returns false. The count is reset when distance(best_score, f(...)) > min_dist.

Examples

julia> loss = let l = 0
         () -> l += 1
       end; # pseudo loss function that returns increasing values

julia> es = Flux.early_stopping(loss, 3);


julia> for i in 1:10
         @info "Epoch $i"
         es() && break
       end
[ Info: Epoch 1
[ Info: Epoch 2
[ Info: Epoch 3
source
Flux.plateauFunction
plateau(f, width; distance = -, init_score = 0, min_dist = 1f-6)

Return a function that internally counts by one when abs(distance(last_score, f(...))) <= min_dist, where last_score holds the last value of f(...). If the count is greater than or equal to width, the function returns true, otherwise it returns false. The count is reset when abs(distance(last_score, f(...))) > min_dist.

Examples

julia> f = let v = 10
         () -> v = v / abs(v) - v
       end; # -9, 8, -7, 6, ...

julia> trigger = Flux.plateau(f, 3; init_score=10, min_dist=18);


julia> for i in 1:10
         @info "Epoch $i"
         trigger() && break
       end
[ Info: Epoch 1
[ Info: Epoch 2
[ Info: Epoch 3
[ Info: Epoch 4
source