One-Hot Encoding
It's common to encode categorical variables (like true
, false
or cat
, dog
) in "one-of-k" or "one-hot" form. Flux provides the onehot
function to make this easy.
julia> using Flux: onehot, onecold
julia> onehot(:b, [:a, :b, :c])
3-element Flux.OneHotVector{3,UInt32}:
0
1
0
julia> onehot(:c, [:a, :b, :c])
3-element Flux.OneHotVector{3,UInt32}:
0
0
1
The inverse is onecold
(which can take a general probability distribution, as well as just booleans).
julia> onecold(ans, [:a, :b, :c])
:c
julia> onecold([true, false, false], [:a, :b, :c])
:a
julia> onecold([0.3, 0.2, 0.5], [:a, :b, :c])
:c
Flux.onehot
— Functiononehot(l, labels[, unk])
Return a OneHotVector
where only first occourence of l
in labels
is 1
and all other elements are 0
.
If l
is not found in labels and unk
is present, the function returns onehot(unk, labels)
; otherwise the function raises an error.
Examples
julia> Flux.onehot(:b, [:a, :b, :c])
3-element Flux.OneHotVector{3,UInt32}:
0
1
0
julia> Flux.onehot(:c, [:a, :b, :c])
3-element Flux.OneHotVector{3,UInt32}:
0
0
1
Flux.onecold
— Functiononecold(y[, labels = 1:length(y)])
Inverse operations of onehot
.
Examples
julia> Flux.onecold([true, false, false], [:a, :b, :c])
:a
julia> Flux.onecold([0.3, 0.2, 0.5], [:a, :b, :c])
:c
Batches
onehotbatch
creates a batch (matrix) of one-hot vectors, and onecold
treats matrices as batches.
julia> using Flux: onehotbatch
julia> onehotbatch([:b, :a, :b], [:a, :b, :c])
3×3 Flux.OneHotArray{3,2,Vector{UInt32}}:
0 1 0
1 0 1
0 0 0
julia> onecold(ans, [:a, :b, :c])
3-element Vector{Symbol}:
:b
:a
:b
Note that these operations returned OneHotVector
and OneHotMatrix
rather than Array
s. OneHotVector
s behave like normal vectors but avoid any unnecessary cost compared to using an integer index directly. For example, multiplying a matrix with a one-hot vector simply slices out the relevant row of the matrix under the hood.
Flux.onehotbatch
— Functiononehotbatch(ls, labels[, unk...])
Return a OneHotMatrix
where k
th column of the matrix is onehot(ls[k], labels)
.
If one of the input labels ls
is not found in labels
and unk
is given, return onehot(unk, labels)
; otherwise the function will raise an error.
Examples
julia> Flux.onehotbatch([:b, :a, :b], [:a, :b, :c])
3×3 Flux.OneHotArray{3,2,Vector{UInt32}}:
0 1 0
1 0 1
0 0 0