One-Hot Encoding with OneHotArrays.jl
It's common to encode categorical variables (like true
, false
or cat
, dog
) in "one-of-k" or "one-hot" form. OneHotArrays.jl provides the onehot
function to make this easy.
julia> using OneHotArrays
julia> onehot(:b, [:a, :b, :c])
3-element OneHotVector(::UInt32) with eltype Bool:
โ
1
โ
julia> onehot(:c, [:a, :b, :c])
3-element OneHotVector(::UInt32) with eltype Bool:
โ
โ
1
The inverse is onecold
(which can take a general probability distribution, as well as just booleans).
julia> onecold(ans, [:a, :b, :c])
:c
julia> onecold([true, false, false], [:a, :b, :c])
:a
julia> onecold([0.3, 0.2, 0.5], [:a, :b, :c])
:c
For multiple samples at once, onehotbatch
creates a batch (matrix) of one-hot vectors, and onecold
treats matrices as batches.
julia> using OneHotArrays
julia> onehotbatch([:b, :a, :b], [:a, :b, :c])
3ร3 OneHotMatrix(::Vector{UInt32}) with eltype Bool:
โ
1 โ
1 โ
1
โ
โ
โ
julia> onecold(ans, [:a, :b, :c])
3-element Vector{Symbol}:
:b
:a
:b
Note that these operations returned OneHotVector
and OneHotMatrix
rather than Array
s. OneHotVector
s behave like normal vectors but avoid any unnecessary cost compared to using an integer index directly. For example, multiplying a matrix with a one-hot vector simply slices out the relevant row of the matrix under the hood.
Function listing
OneHotArrays.onehot
โ Functiononehot(x, labels, [default])
Returns a OneHotVector
which is roughly a sparse representation of x .== labels
.
Instead of storing say Vector{Bool}
, it stores the index of the first occurrence of x
in labels
. If x
is not found in labels, then it either returns onehot(default, labels)
, or gives an error if no default is given.
See also onehotbatch
to apply this to many x
s, and onecold
to reverse either of these, as well as to generalise argmax
.
Examples
julia> ฮฒ = onehot(:b, (:a, :b, :c))
3-element OneHotVector(::UInt32) with eltype Bool:
โ
1
โ
julia> ฮฑฮฒฮณ = (onehot(0, 0:2), ฮฒ, onehot(:z, [:a, :b, :c], :c)) # uses default
(Bool[1, 0, 0], Bool[0, 1, 0], Bool[0, 0, 1])
julia> hcat(ฮฑฮฒฮณ...) # preserves sparsity
3ร3 OneHotMatrix(::Vector{UInt32}) with eltype Bool:
1 โ
โ
โ
1 โ
โ
โ
1
OneHotArrays.onecold
โ Functiononecold(y::AbstractArray, labels = 1:size(y,1))
Roughly the inverse operation of onehot
or onehotbatch
: This finds the index of the largest element of y
, or each column of y
, and looks them up in labels
.
If labels
are not specified, the default is integers 1:size(y,1)
โ the same operation as argmax(y, dims=1)
but sometimes a different return type.
Examples
julia> onecold([false, true, false])
2
julia> onecold([0.3, 0.2, 0.5], (:a, :b, :c))
:c
julia> onecold([ 1 0 0 1 0 1 0 1 0 0 1
0 1 0 0 0 0 0 0 1 0 0
0 0 0 0 1 0 0 0 0 0 0
0 0 0 0 0 0 1 0 0 0 0
0 0 1 0 0 0 0 0 0 1 0 ], 'a':'e') |> String
"abeacadabea"
OneHotArrays.onehotbatch
โ Functiononehotbatch(xs, labels, [default])
Returns a OneHotMatrix
where k
th column of the matrix is onehot(xs[k], labels)
. This is a sparse matrix, which stores just a Vector{UInt32}
containing the indices of the nonzero elements.
If one of the inputs in xs
is not found in labels
, that column is onehot(default, labels)
if default
is given, else an error.
If xs
has more dimensions, N = ndims(xs) > 1
, then the result is an AbstractArray{Bool, N+1}
which is one-hot along the first dimension, i.e. result[:, k...] == onehot(xs[k...], labels)
.
Note that xs
can be any iterable, such as a string. And that using a tuple for labels
will often speed up construction, certainly for less than 32 classes.
Examples
julia> oh = onehotbatch("abracadabra", 'a':'e', 'e')
5ร11 OneHotMatrix(::Vector{UInt32}) with eltype Bool:
1 โ
โ
1 โ
1 โ
1 โ
โ
1
โ
1 โ
โ
โ
โ
โ
โ
1 โ
โ
โ
โ
โ
โ
1 โ
โ
โ
โ
โ
โ
โ
โ
โ
โ
โ
โ
1 โ
โ
โ
โ
โ
โ
1 โ
โ
โ
โ
โ
โ
1 โ
julia> reshape(1:15, 3, 5) * oh # this matrix multiplication is done efficiently
3ร11 Matrix{Int64}:
1 4 13 1 7 1 10 1 4 13 1
2 5 14 2 8 2 11 2 5 14 2
3 6 15 3 9 3 12 3 6 15 3
OneHotArrays.OneHotVector
โ TypeOneHotVector{T} = OneHotArray{T, 0, 1, T}
OneHotVector(indices, L)
A one-hot vector with L
labels (i.e. length(A) == L
and count(A) == 1
) typically constructed by onehot
. Stored efficiently as a single index of type T
, usually UInt32
.
OneHotArrays.OneHotMatrix
โ TypeOneHotMatrix{T, I} = OneHotArray{T, 1, 2, I}
OneHotMatrix(indices, L)
A one-hot matrix (with L
labels) typically constructed using onehotbatch
. Stored efficiently as a vector of indices with type I
and eltype T
.