Variational autoencoders

So far, we've covered many examples of how to train models in a supervised fashion. However, there are many applications of neural networks outside the supervised regime. In this tutorial, we'll implement and train a variational autoencoder (VAE) to embed and generate images from the MNIST dataset.

You'll learn how to

  • implement custom Flux.jl models

  • write a custom training loop

  • generate new images and visualize them


			
			
			
			using
			
			 

	
			FastAI
			,
			
			 
			FastVision
			,
			
			 
			FastMakie
			,
			
			 

	
			FastAI
			.

	
			FluxTraining
			,
			
			 

	
			FastAI
			.

	
			Flux
			

			
			
			using
			
			 
			FastVision
			:
			
			 
			SVector
			,
			
			 
			Gray
			

			
			import
			
			 
			CairoMakie
			  
			# run if in an interactive environment like a notebook

Setting up the data

First we load the MNIST dataset. Since we're not using it for supervised learning, we only need the input images, and don't load the labels. We get a data container where every observation is an image:


			
			
			
			path
			 
			=
			 
			
			load
			(
			
			

	
			datasets
			(
			)
			[
			
			"
			mnist_tiny
			"
			]
			)
			

			
			data
			 
			=
			 
			
			

	
			Datasets
			.
			

	
			loadfolderdata
			(
			
    
			path
			,
			
    
			
			filterfn
			 
			=
			 
			
			FastVision
			.
			

	
			isimagefile
			,
			
    
			
			loadfn
			 
			=
			 

	
			loadfile
			,
			

			)
			

			

	
			showblock
			(
			
			

	
			Image
			{
			2
			}
			(
			)
			,
			 
			

	
			getobs
			(
			data
			,
			 
			1
			)
			)

Next we need to define a learning task that will handle data encoding and decoding as well as visualization for us. So far, we've used SupervisedTask a lot which assumes there is an input that is fed to the model and a corresponding target output. Since we want to do unsupervised learning, we'll instead create a custom learning task using BlockTask. It defines what kind of data we'll have at each step in the data pipeline for example x is a model input and a model output. See AbstractBlockTask for more info.


			
			
			
			
			using
			
			 

	
			FastAI
			:
			
			 

	
			encodedblockfilled
			,
			
			 

	
			decodedblockfilled
			

			

			
			function
			 
			
			EmbeddingTask
			(
			block
			,
			 
			encodings
			)
			
			
    
			
			sample
			 
			=
			 
			block
			
    
			
			encodedsample
			 
			=
			
			 
			x
			 
			=
			
			 
			y
			 
			=
			
			 
			 
			=
			 
			

	
			encodedblockfilled
			(
			encodings
			,
			 
			sample
			)
			
    
			
			blocks
			 
			=
			 
			
			(
			
			;
			 
			sample
			,
			 
			x
			,
			 
			y
			,
			 
			,
			 
			encodedsample
			)
			
    
			

	
			BlockTask
			(
			blocks
			,
			 
			encodings
			)
			

			end

			EmbeddingTask (generic function with 1 method)

With this helper defined, we can create a learning task for our task: Image{2}() is the kind of data we want to learn with and ImagePreprocessing makes sure to encode and decode these images so they can be used to train a model.


			
			
			
			task
			 
			=
			 
			
			EmbeddingTask
			(
			
    
			
			

	
			Image
			{
			2
			}
			(
			)
			,
			
    
			
			(
			

	
			ImagePreprocessing
			(
			
			means
			 
			=
			 
			
			SVector
			(
			0.0
			)
			,
			 
			
			stds
			 
			=
			 
			
			SVector
			(
			1.0
			)
			,
			 
			
			C
			 
			=
			 
			
			Gray
			{
			Float32
			}
			,
			 
			
			buffered
			=
			false
			)
			,
			)
			,
			

			)

			BlockTask(blocks=(:sample, :x, :y, :ŷ, :encodedsample))

With the learning task set up, we can use encode to get samples ready to be input to a model, and all show* functions to visualize data at various points of the pipeline:


			
			
			
			x
			 
			=
			 
			

	
			encodesample
			(
			task
			,
			 
			

	
			Training
			(
			)
			,
			 
			

	
			getobs
			(
			data
			,
			 
			1
			)
			)
			

			

	
			showencodedsample
			(
			task
			,
			 
			x
			)

For later training, the last thing we need to do with the data is to create a data iterator over batches of encoded samples. Since the dataset comfortably fits into memory, we preload it all at once by using collect on the DataLoader. This saves us having to reload each image again every epoch.


			
			
			
			td
			 
			=
			 
			

	
			taskdataset
			(
			data
			,
			 
			task
			,
			 
			

	
			Training
			(
			)
			)
			

			

	
			getobs
			(
			td
			,
			 
			6
			)

			28×28×1 Array{Float32, 3}:
[:, :, 1] =
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 ⋮                        ⋮              ⋱                      ⋮         
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0

			
			
			
			BATCHSIZE
			 
			=
			 
			512
			
			

			
			
			dl
			,
			 
			_
			 
			=
			 
			

	
			taskdataloaders
			(
			data
			,
			 
			task
			,
			 
			BATCHSIZE
			,
			 
			
			pctgval
			 
			=
			 
			0.1
			)
			;
			

			
			dataiter
			 
			=
			 
			
			collect
			(
			dl
			)

			3-element Vector{Array{Float32, 4}}:
 [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;;; … ;;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]
 [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;;; … ;;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]
 [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;;; … ;;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]

With that, we have a data iterator of batches that we can use in a training loop just by iterating over it:


			
			
			
			for
			
			 
			xs
			 
			in
			 
			dataiter
			
			
    
			
			print
			(
			
			size
			(
			xs
			)
			)
			
    
			break
			

			end

			(28, 28, 1, 1428)

Next, we need to construct a model and define a loss function so it can be optimized.

Modelling

The variational autoencoder consists of two parts: an encoder, and a decoder. The encoder takes in data and outputs parameters of a probability distribution. These are then used to sample latent vectors which are fed into the decoder to produce new samples. A loss function (ELBO) rewards the model if the outputs are similar to the inputs. The challenge is that the latent space is of much lower dimensionality than the data space, so the model needs to learn to compress the information contained in the data. If you're interested in more mathematical background on VAEs and the loss function, Lilian Weng has written a great write-up on autoencoders . The architecture looks like this:

Diagram of VAE architecture

We define the Variational Autoencoder model as a new type that wraps an encoder and decoder model and define the forward pass and loss function as regular Julia functions.


			
			
			
			
			struct
			 
			
			VAE
			{
			E
			,
			 
			D
			}
			
			
    
			
			encoder
			::
			E
			
    
			
			decoder
			::
			D
			

			end
			

			

			
			

	
			Flux
			.
			
			@
			functor
			 
			VAE
			

			

			
			function
			 
			
			(
			
			vae
			::
			VAE
			)
			(
			xs
			)
			
			
    
			
			
			μ
			,
			 
			logσ²
			 
			=
			 
			
			
			vae
			.
			
			encoder
			(
			xs
			)
			
    
			
			zs
			 
			=
			 
			
			sample_latent
			(
			μ
			,
			 
			logσ²
			)
			
    
			
			x̄s
			 
			=
			 
			
			
			vae
			.
			
			decoder
			(
			zs
			)
			
    
			
			return
			
			 
			x̄s
			,
			 
			
			(
			
			;
			 
			μ
			,
			 
			logσ²
			)
			

			end
			

			

			

			
			
			using
			
			 
			Random
			:
			
			 
			randn!
			

			
			
			using
			
			 
			Statistics
			:
			
			 
			mean
			

			

			
			
			
			sample_latent
			(
			
			μ
			::
			
			AbstractArray
			{
			T
			}
			,
			 
			
			logσ²
			::
			
			AbstractArray
			{
			T
			}
			)
			 
			where
			 
			{
			T
			}
			 
			=
			
			
    
			μ
			 
			.+
			
			 
			
			exp
			.
			
			(
			
			logσ²
			 
			./
			 
			2
			)
			 
			.*
			 
			
			randn!
			(
			
			similar
			(
			logσ²
			)
			)
			

			

			
			function
			 
			
			βELBO
			(
			x
			,
			 
			,
			 
			μ
			,
			 
			logσ²
			
			;
			 
			
			β
			 
			=
			 
			1
			)
			
			
    
			
			reconstruction_error
			 
			=
			 
			
			mean
			(
			
			sum
			(
			
			@
			.
			(
			
			(
			
			 
			-
			 
			x
			)
			^
			2
			)
			
			;
			 
			
			dims
			 
			=
			 
			1
			)
			)
			
    
			# D(N(μ, Σ)||N(0, I)) = 1/2 * (μᵀμ + tr(Σ) - length(μ) - log(|Σ|))
			
    
			
			kl_divergence
			 
			=
			 
			
			mean
			(
			
			sum
			(
			
			@
			.
			(
			
			(
			
			
			
			
			μ
			^
			2
			 
			+
			 
			
			exp
			(
			logσ²
			)
			 
			-
			 
			1
			 
			-
			 
			logσ²
			)
			 
			/
			 
			2
			)
			
			;
			 
			
			dims
			 
			=
			 
			1
			)
			)
			

			
    
			
			return
			
			 
			reconstruction_error
			 
			+
			
			 
			β
			 
			*
			 
			kl_divergence
			

			end

			βELBO (generic function with 1 method)

Next we define the encoder and decoder models by composing basic Flux.jl layers. Dlatent is the size of the latent space and controls how much the model has to compress the information. Feel free to try out smaller or larger numbers and see how the quality of the generated images changes.


			
			
			
			SIZE
			 
			=
			 
			
			(
			28
			,
			 
			28
			,
			 
			1
			)
			

			
			Din
			 
			=
			 
			
			prod
			(
			SIZE
			)
			

			
			Dhidden
			 
			=
			 
			512
			

			
			Dlatent
			 
			=
			 
			4
			

			

			
			encoder
			 
			=
			
			
    
			

	
			Chain
			(
			
        
			

	
			Flux
			.
			

	
			flatten
			,
			
        
			

	
			Dense
			(
			Din
			,
			 
			Dhidden
			,
			 
			relu
			)
			,
			 
			# backbone
			
        
			

	
			Parallel
			(
			
            
			tuple
			,
			
            
			

	
			Dense
			(
			Dhidden
			,
			 
			Dlatent
			)
			,
			 
			# μ
			
            
			

	
			Dense
			(
			Dhidden
			,
			 
			Dlatent
			)
			,
			 
			# logσ²
			
        
			)
			,
			
    
			)
			 
			|>
			 

	
			gpu
			

			

			
			decoder
			 
			=
			
			 
			

	
			Chain
			(
			

	
			Dense
			(
			Dlatent
			,
			 
			Dhidden
			,
			 
			relu
			)
			,
			 
			

	
			Dense
			(
			Dhidden
			,
			 
			Din
			,
			 
			sigmoid
			)
			,
			 
			
			xs
			 
			->
			 
			
			reshape
			(
			xs
			,
			 
			
			SIZE
			...
			,
			 
			:
			)
			)
			 
			|>
			 

	
			gpu
			
			

			

			
			model
			 
			=
			 
			
			VAE
			(
			encoder
			,
			 
			decoder
			)
			;

Custom training loop

When dealing with a unconvential learning scheme, we usually need to write a custom training loop. FastAI.jl is build on top of FluxTraining.jl which allows you to write custom training loops with very little boilerplate while retaining compatibility with its extensive callback system. In fact, the built-in training loops for supervised learning are defined in just the same way as we will.


			
			
			
			
			struct
			
			 
			VAETrainingPhase
			 
			<:
			 
			

	
			FluxTraining
			.
			
			AbstractTrainingPhase
			
			 
			end

We just defined our own training phase. All that is required to take advantage of the FastAI.jl framework is to define the FluxTraining.step! function. We'll use a utility, FluxTraining.runstep, to reduce the boilerplate involved in handling callback events and state. runstep's first argument is a function with inputs (handle, state). handle can be used to dispatch events which callbacks can react to. state holds data generated on each call to step! like the batch, gradients, and loss. These are also accessible to callbacks, for example to calculate metrics. We also give runstep the initial step state which just contains our batch.


			
			
			
			
			struct
			
			 
			VAETrainingPhase
			 
			<:
			 
			

	
			FluxTraining
			.
			
			AbstractTrainingPhase
			
			 
			end
			

			

			
			function
			 
			
			

	
			FluxTraining
			.
			

	
			step!
			(
			learner
			,
			 
			
			phase
			::
			VAETrainingPhase
			,
			 

	
			batch
			)
			
			
    
			
			
			

	
			FluxTraining
			.
			

	
			runstep
			(
			learner
			,
			 
			phase
			,
			 
			
			(
			
			xs
			 
			=
			 

	
			batch
			,
			)
			)
			 
			do
			
			 

	
			handle
			,
			 

	
			state
			
			
        
			
			gs
			 
			=
			 
			
			
			gradient
			(
			
			learner
			.
			

	
			params
			)
			 
			do
			
			
            
			
			
			
			μ
			,
			 
			logσ²
			 
			=
			 
			
			
			
			learner
			.
			
			model
			.
			
			encoder
			(
			

	
			state
			.
			
			xs
			)
			
            
			
			

	
			state
			.
			
			zs
			 
			=
			 
			
			sample_latent
			(
			μ
			,
			 
			logσ²
			)
			
            
			
			

	
			state
			.
			
			x̄s
			 
			=
			 
			
			
			
			learner
			.
			
			model
			.
			
			decoder
			(
			

	
			state
			.
			
			zs
			)
			

			
            
			

	
			handle
			(
			
			

	
			FluxTraining
			.
			
			LossBegin
			(
			)
			)
			
            
			
			

	
			state
			.
			
			loss
			 
			=
			 
			
			
			learner
			.
			
			lossfn
			(
			
			

	
			Flux
			.
			

	
			flatten
			(
			

	
			state
			.
			
			xs
			)
			,
			 
			
			

	
			Flux
			.
			

	
			flatten
			(
			

	
			state
			.
			
			x̄s
			)
			,
			 
			μ
			,
			 
			logσ²
			)
			

			
            
			

	
			handle
			(
			
			

	
			FluxTraining
			.
			
			BackwardBegin
			(
			)
			)
			
            
			
			return
			 
			

	
			state
			.
			
			loss
			
        
			end
			
        
			

	
			handle
			(
			
			

	
			FluxTraining
			.
			
			BackwardEnd
			(
			)
			)
			
        
			
			
			

	
			Flux
			.
			

	
			Optimise
			.
			
			update!
			(
			
			learner
			.
			
			optimizer
			,
			 
			
			learner
			.
			

	
			params
			,
			 
			gs
			)
			
    
			end
			

			end

Since the step state is a little different from the supervised case (there are no targets ys), we also overwrite the default task for the ToDevice callback for our training phase:


			
			
			
			function
			 
			
			

	
			FluxTraining
			.
			

	
			on
			(
			
    
			
			::
			

	
			FluxTraining
			.
			
			StepBegin
			,
			
    
			
			::
			VAETrainingPhase
			,
			
    
			
			cb
			::

	
			ToDevice
			,
			
    
			learner
			,
			

			)
			
			
    
			
			
			
			learner
			.
			
			step
			.
			
			xs
			 
			=
			 
			
			
			cb
			.
			
			movedatafn
			(
			
			
			learner
			.
			
			step
			.
			
			xs
			)
			

			end

Training

Now we can put the pieces together, creating a Learner.


			
			
			
			learner
			 
			=
			 
			

	
			Learner
			(
			model
			,
			 
			βELBO
			,
			 
			
			callbacks
			=
			
			[
			

	
			ToGPU
			(
			)
			]
			)
			
			

			
			

	
			FluxTraining
			.
			

	
			removecallback!
			(
			learner
			,
			 

	
			ProgressPrinter
			)
			;

To override the default supervised training loops, we pass our custom training phase and the data iterator we want to run it on to fitonecycle!:


			
			
			

	
			fitonecycle!
			(
			
    
			learner
			,
			
    
			50
			,
			
    
			0.01
			
			;
			
    
			
			phases
			 
			=
			 
			
			(
			
			
			VAETrainingPhase
			(
			)
			 
			=>
			 
			dataiter
			,
			)
			,
			

			)

			┌──────────────────┬───────┬─────────┐
│            Phase  Epoch     Loss │
├──────────────────┼───────┼─────────┤
│ VAETrainingPhase │   1.0 │ 32.3354 │
└──────────────────┴───────┴─────────┘
┌──────────────────┬───────┬─────────┐
│            Phase  Epoch     Loss │
├──────────────────┼───────┼─────────┤
│ VAETrainingPhase │   2.0 │ 32.2528 │
└──────────────────┴───────┴─────────┘
┌──────────────────┬───────┬────────┐
│            Phase  Epoch    Loss │
├──────────────────┼───────┼────────┤
│ VAETrainingPhase │   3.0 │ 32.208 │
└──────────────────┴───────┴────────┘
┌──────────────────┬───────┬─────────┐
│            Phase  Epoch     Loss │
├──────────────────┼───────┼─────────┤
│ VAETrainingPhase │   4.0 │ 31.7685 │
└──────────────────┴───────┴─────────┘
┌──────────────────┬───────┬─────────┐
│            Phase  Epoch     Loss │
├──────────────────┼───────┼─────────┤
│ VAETrainingPhase │   5.0 │ 31.6154 │
└──────────────────┴───────┴─────────┘
┌──────────────────┬───────┬─────────┐
│            Phase  Epoch     Loss │
├──────────────────┼───────┼─────────┤
│ VAETrainingPhase │   6.0 │ 31.4664 │
└──────────────────┴───────┴─────────┘
┌──────────────────┬───────┬─────────┐
│            Phase  Epoch     Loss │
├──────────────────┼───────┼─────────┤
│ VAETrainingPhase │   7.0 │ 31.1981 │
└──────────────────┴───────┴─────────┘
┌──────────────────┬───────┬─────────┐
│            Phase  Epoch     Loss │
├──────────────────┼───────┼─────────┤
│ VAETrainingPhase │   8.0 │ 31.1152 │
└──────────────────┴───────┴─────────┘
┌──────────────────┬───────┬─────────┐
│            Phase  Epoch     Loss │
├──────────────────┼───────┼─────────┤
│ VAETrainingPhase │   9.0 │ 30.8114 │
└──────────────────┴───────┴─────────┘
┌──────────────────┬───────┬─────────┐
│            Phase  Epoch     Loss │
├──────────────────┼───────┼─────────┤
│ VAETrainingPhase │  10.0 │ 30.6412 │
└──────────────────┴───────┴─────────┘
┌──────────────────┬───────┬─────────┐
│            Phase  Epoch     Loss │
├──────────────────┼───────┼─────────┤
│ VAETrainingPhase │  11.0 │ 30.5843 │
└──────────────────┴───────┴─────────┘
┌──────────────────┬───────┬─────────┐
│            Phase  Epoch     Loss │
├──────────────────┼───────┼─────────┤
│ VAETrainingPhase │  12.0 │ 30.2381 │
└──────────────────┴───────┴─────────┘
┌──────────────────┬───────┬─────────┐
│            Phase  Epoch     Loss │
├──────────────────┼───────┼─────────┤
│ VAETrainingPhase │  13.0 │ 29.9814 │
└──────────────────┴───────┴─────────┘
┌──────────────────┬───────┬─────────┐
│            Phase  Epoch     Loss │
├──────────────────┼───────┼─────────┤
│ VAETrainingPhase │  14.0 │ 30.0446 │
└──────────────────┴───────┴─────────┘
┌──────────────────┬───────┬─────────┐
│            Phase  Epoch     Loss │
├──────────────────┼───────┼─────────┤
│ VAETrainingPhase │  15.0 │ 29.8863 │
└──────────────────┴───────┴─────────┘
┌──────────────────┬───────┬─────────┐
│            Phase  Epoch     Loss │
├──────────────────┼───────┼─────────┤
│ VAETrainingPhase │  16.0 │ 29.6784 │
└──────────────────┴───────┴─────────┘
┌──────────────────┬───────┬─────────┐
│            Phase  Epoch     Loss │
├──────────────────┼───────┼─────────┤
│ VAETrainingPhase │  17.0 │ 29.6538 │
└──────────────────┴───────┴─────────┘
┌──────────────────┬───────┬─────────┐
│            Phase  Epoch     Loss │
├──────────────────┼───────┼─────────┤
│ VAETrainingPhase │  18.0 │ 29.5483 │
└──────────────────┴───────┴─────────┘
┌──────────────────┬───────┬─────────┐
│            Phase  Epoch     Loss │
├──────────────────┼───────┼─────────┤
│ VAETrainingPhase │  19.0 │ 29.4535 │
└──────────────────┴───────┴─────────┘
┌──────────────────┬───────┬─────────┐
│            Phase  Epoch     Loss │
├──────────────────┼───────┼─────────┤
│ VAETrainingPhase │  20.0 │ 29.4254 │
└──────────────────┴───────┴─────────┘
┌──────────────────┬───────┬─────────┐
│            Phase  Epoch     Loss │
├──────────────────┼───────┼─────────┤
│ VAETrainingPhase │  21.0 │ 29.3104 │
└──────────────────┴───────┴─────────┘
┌──────────────────┬───────┬─────────┐
│            Phase  Epoch     Loss │
├──────────────────┼───────┼─────────┤
│ VAETrainingPhase │  22.0 │ 29.3043 │
└──────────────────┴───────┴─────────┘
┌──────────────────┬───────┬─────────┐
│            Phase  Epoch     Loss │
├──────────────────┼───────┼─────────┤
│ VAETrainingPhase │  23.0 │ 29.2092 │
└──────────────────┴───────┴─────────┘
┌──────────────────┬───────┬─────────┐
│            Phase  Epoch     Loss │
├──────────────────┼───────┼─────────┤
│ VAETrainingPhase │  24.0 │ 29.1536 │
└──────────────────┴───────┴─────────┘
┌──────────────────┬───────┬─────────┐
│            Phase  Epoch     Loss │
├──────────────────┼───────┼─────────┤
│ VAETrainingPhase │  25.0 │ 28.9276 │
└──────────────────┴───────┴─────────┘
┌──────────────────┬───────┬─────────┐
│            Phase  Epoch     Loss │
├──────────────────┼───────┼─────────┤
│ VAETrainingPhase │  26.0 │ 28.8557 │
└──────────────────┴───────┴─────────┘
┌──────────────────┬───────┬────────┐
│            Phase  Epoch    Loss │
├──────────────────┼───────┼────────┤
│ VAETrainingPhase │  27.0 │ 28.818 │
└──────────────────┴───────┴────────┘
┌──────────────────┬───────┬─────────┐
│            Phase  Epoch     Loss │
├──────────────────┼───────┼─────────┤
│ VAETrainingPhase │  28.0 │ 28.6366 │
└──────────────────┴───────┴─────────┘
┌──────────────────┬───────┬─────────┐
│            Phase  Epoch     Loss │
├──────────────────┼───────┼─────────┤
│ VAETrainingPhase │  29.0 │ 28.6398 │
└──────────────────┴───────┴─────────┘
┌──────────────────┬───────┬─────────┐
│            Phase  Epoch     Loss │
├──────────────────┼───────┼─────────┤
│ VAETrainingPhase │  30.0 │ 28.5331 │
└──────────────────┴───────┴─────────┘
┌──────────────────┬───────┬─────────┐
│            Phase  Epoch     Loss │
├──────────────────┼───────┼─────────┤
│ VAETrainingPhase │  31.0 │ 28.5456 │
└──────────────────┴───────┴─────────┘
┌──────────────────┬───────┬─────────┐
│            Phase  Epoch     Loss │
├──────────────────┼───────┼─────────┤
│ VAETrainingPhase │  32.0 │ 28.3021 │
└──────────────────┴───────┴─────────┘
┌──────────────────┬───────┬─────────┐
│            Phase  Epoch     Loss │
├──────────────────┼───────┼─────────┤
│ VAETrainingPhase │  33.0 │ 28.4016 │
└──────────────────┴───────┴─────────┘
┌──────────────────┬───────┬─────────┐
│            Phase  Epoch     Loss │
├──────────────────┼───────┼─────────┤
│ VAETrainingPhase │  34.0 │ 28.4184 │
└──────────────────┴───────┴─────────┘
┌──────────────────┬───────┬────────┐
│            Phase  Epoch    Loss │
├──────────────────┼───────┼────────┤
│ VAETrainingPhase │  35.0 │ 28.248 │
└──────────────────┴───────┴────────┘
┌──────────────────┬───────┬─────────┐
│            Phase  Epoch     Loss │
├──────────────────┼───────┼─────────┤
│ VAETrainingPhase │  36.0 │ 28.1684 │
└──────────────────┴───────┴─────────┘
┌──────────────────┬───────┬─────────┐
│            Phase  Epoch     Loss │
├──────────────────┼───────┼─────────┤
│ VAETrainingPhase │  37.0 │ 28.1215 │
└──────────────────┴───────┴─────────┘
┌──────────────────┬───────┬─────────┐
│            Phase  Epoch     Loss │
├──────────────────┼───────┼─────────┤
│ VAETrainingPhase │  38.0 │ 28.0628 │
└──────────────────┴───────┴─────────┘
┌──────────────────┬───────┬─────────┐
│            Phase  Epoch     Loss │
├──────────────────┼───────┼─────────┤
│ VAETrainingPhase │  39.0 │ 28.1443 │
└──────────────────┴───────┴─────────┘
┌──────────────────┬───────┬─────────┐
│            Phase  Epoch     Loss │
├──────────────────┼───────┼─────────┤
│ VAETrainingPhase │  40.0 │ 28.0746 │
└──────────────────┴───────┴─────────┘
┌──────────────────┬───────┬─────────┐
│            Phase  Epoch     Loss │
├──────────────────┼───────┼─────────┤
│ VAETrainingPhase │  41.0 │ 27.9412 │
└──────────────────┴───────┴─────────┘
┌──────────────────┬───────┬─────────┐
│            Phase  Epoch     Loss │
├──────────────────┼───────┼─────────┤
│ VAETrainingPhase │  42.0 │ 27.9764 │
└──────────────────┴───────┴─────────┘
┌──────────────────┬───────┬────────┐
│            Phase  Epoch    Loss │
├──────────────────┼───────┼────────┤
│ VAETrainingPhase │  43.0 │ 27.828 │
└──────────────────┴───────┴────────┘
┌──────────────────┬───────┬─────────┐
│            Phase  Epoch     Loss │
├──────────────────┼───────┼─────────┤
│ VAETrainingPhase │  44.0 │ 27.7843 │
└──────────────────┴───────┴─────────┘
┌──────────────────┬───────┬─────────┐
│            Phase  Epoch     Loss │
├──────────────────┼───────┼─────────┤
│ VAETrainingPhase │  45.0 │ 27.8039 │
└──────────────────┴───────┴─────────┘
┌──────────────────┬───────┬────────┐
│            Phase  Epoch    Loss │
├──────────────────┼───────┼────────┤
│ VAETrainingPhase │  46.0 │ 27.711 │
└──────────────────┴───────┴────────┘
┌──────────────────┬───────┬─────────┐
│            Phase  Epoch     Loss │
├──────────────────┼───────┼─────────┤
│ VAETrainingPhase │  47.0 │ 27.6714 │
└──────────────────┴───────┴─────────┘
┌──────────────────┬───────┬─────────┐
│            Phase  Epoch     Loss │
├──────────────────┼───────┼─────────┤
│ VAETrainingPhase │  48.0 │ 27.6664 │
└──────────────────┴───────┴─────────┘
┌──────────────────┬───────┬─────────┐
│            Phase  Epoch     Loss │
├──────────────────┼───────┼─────────┤
│ VAETrainingPhase │  49.0 │ 27.6659 │
└──────────────────┴───────┴─────────┘
┌──────────────────┬───────┬─────────┐
│            Phase  Epoch     Loss │
├──────────────────┼───────┼─────────┤
│ VAETrainingPhase │  50.0 │ 27.6021 │
└──────────────────┴───────┴─────────┘

Finally, we can visualize how well inputs are reconstructed:


			
			
			
			xs
			 
			=
			
			 
			

	
			makebatch
			(
			task
			,
			 
			data
			,
			 
			
			rand
			(
			
			1
			:
			

	
			numobs
			(
			data
			)
			,
			 
			4
			)
			)
			 
			|>
			 

	
			gpu
			

			
			
			ypreds
			,
			 
			_
			 
			=
			 
			
			model
			(
			xs
			)
			

			

	
			showoutputbatch
			(
			task
			,
			 
			

	
			cpu
			(
			xs
			)
			,
			 
			

	
			cpu
			(
			ypreds
			)
			)