Dynamic-Graph CNN Classification
Classification of PointCloud structure using PointNet model
For visualization purpose we will require to install Makie and compatible backend (GLMakie or WGLMakie). To install it simply run ] add Makie
in the julia prompt.
using Flux3D, Flux, Makie, CUDA
using Flux: onehotbatch, onecold, onehot, crossentropy
using Statistics: mean
using Base.Iterators: partition
Makie.inline!(false)
Makie.set_theme!(show_axis = false)
Defining arguments for use in training.
- batch_size - batch size of of training data to be passed while training.
- lr - learing rate for the optimization.
- epochs - number of episodes for training the classification model.
- K - k nearest neighbors used in DGCNN model.
- num_classes - number of classes in labels of dataset.
- npoints - number of points in each PointCloud to be returned by dataset.
batch_size = 32
lr = 3e-4
epochs = 5
K = 10
num_classes = 10
npoints = 1024
ModelNet10 Dataset
This package has dataset wrapper for ModelNet10/40 which makes it easy to load and preprocess ModelNet dataset. In this example we will using ModelNet10 but we can also use ModelNet40 with minor tweak in num_classes args.
We can construct ModelNet10 dataset by passing:
mode=:pointcloud
- for returning PointCloud variant of datasetnpoints
- no. of points in each PointCloud.transforms
- Transforms to be applied before return specified PointCloud.train
- Bool to indicate training or testing split.
Detailed list of available arguments can be found in ModelNet section.
dset = ModelNet10.dataset(;
mode = :pointcloud,
npoints = npoints,
transform = NormalizePointCloud(),
)
val_dset = ModelNet10.dataset(;
mode = :pointcloud,
train = false,
npoints = npoints,
transform = NormalizePointCloud(),
)
Visualizing the dataset
we can access the dataset by correspond index like dset[1]
which will return a ModelNet10 DataPoint
and following information idx: 1, data: PointCloud{Float32}, ground_truth: 1 (bathtub)
.
For the visulizing the corresponding datapoint we can use visulize
visualize(dset[11], markersize = 0.1)
Preparing Dataloader for training.
data = [dset[i].data.points for i = 1:length(dset)]
labels =
onehotbatch([dset[i].ground_truth for i = 1:length(dset)], 1:num_classes)
valX = cat([val_dset[i].data.points for i = 1:length(val_dset)]..., dims = 3)
valY = onehotbatch(
[val_dset[i].ground_truth for i = 1:length(val_dset)],
1:num_classes,
)
TRAIN = [
(cat(data[i]..., dims = 3), labels[:, i])
for i in partition(1:length(data), batch_size)
]
VAL = (valX, valY)
Defining 3D model
Flux3D has predefined DGCNN classification model which can be used to train PointCloud dataset
m = DGCNN(num_classes, K, npoints)
Defining loss and validating objectives
loss(x, y) = crossentropy(m(x), y)
accuracy(x, y) =
mean(onecold(cpu(m(x)), 1:num_classes) .== onecold(cpu(y), 1:num_classes))
Defining learning rate and optimizer
opt = Flux.ADAM(lr)
Training the 3D model
ps = params(m)
for epoch = 1:epochs
running_loss = 0
for d in TRAIN
gs = gradient(ps) do
training_loss = loss(d...)
running_loss += training_loss
return training_loss
end
Flux.update!(opt, ps, gs)
end
print("Epoch: $(epoch), epoch_loss: $(running_loss), accuracy: $(accuracy(VAL...))\n")
end
@show accuracy(VAL...)
This page was generated using Literate.jl.