Tabular Classification

Tabular Classification involves having a categorical column as the target. Here, we'll use the adult sample dataset from fastai and try to predict whether the salary is above 50K or not, making this a binary classification task.


			
			
			
			using
			
			 

	
			FastAI
			,
			
			 
			FastTabular
			,
			
			 
			FastTabular
			.
			Tables
			,
			
			 

	
			FluxTraining
			,
			
			 

	
			Flux
			

			
			import
			
			 
			FastTabular
			.

	
			DataAugmentation

We can quickly download and get the path of any dataset from fastai by using datasets. Once we have the path, we'll load the data in a TableDataset. By default, if we pass in just the path to TableDataset, the data is loaded in a DataFrame, but we can use any package for accessing our data, and pass an object satisfying the Tables.jl interface to it.


			
			
			
			path
			 
			=
			 
			
			joinpath
			(
			
			load
			(
			
			

	
			datasets
			(
			)
			[
			
			"
			adult_sample
			"
			]
			)
			,
			 
			
			"
			adult.csv
			"
			)
			

			
			data
			 
			=
			 
			

	
			TableDataset
			(
			

	
			loadfile
			(
			path
			)
			)

			TableDataset{DataFrames.DataFrame}(32561×15 DataFrame
   Row  age    workclass          fnlwgt  education      education-num  marit        Int64  String31           Int64   String15       Float64?       Strin ⋯
───────┼────────────────────────────────────────────────────────────────────────
     1 │    49   Private           101320   Assoc-acdm             12.0   Marr ⋯
     2 │    44   Private           236746   Masters                14.0   Divo
     3 │    38   Private            96185   HS-grad           missing     Divo
     4 │    38   Self-emp-inc      112847   Prof-school            15.0   Marr
     5 │    42   Self-emp-not-inc   82297   7th-8th           missing     Marr ⋯
     6 │    20   Private            63210   HS-grad                 9.0   Neve
     7 │    49   Private            44434   Some-college           10.0   Divo
     8 │    37   Private           138940   11th                    7.0   Marr
     9 │    46   Private           328216   HS-grad                 9.0   Marr ⋯
    10 │    36   Self-emp-inc      216711   HS-grad           missing     Marr
    11 │    23   Private           529223   Bachelors              13.0   Neve
   ⋮   │   ⋮            ⋮            ⋮           ⋮              ⋮              ⋱
 32552 │    60   Private           230545   7th-8th                 4.0   Divo
 32553 │    39   Private           139743   HS-grad                 9.0   Sepa ⋯
 32554 │    35   Self-emp-inc      135436   Prof-school            15.0   Marr
 32555 │    53   Private            35102   Some-college           10.0   Divo
 32556 │    48   Private           355320   Bachelors              13.0   Marr
 32557 │    36   Private           297449   Bachelors              13.0   Divo ⋯
 32558 │    23   ?                 123983   Bachelors              13.0   Neve
 32559 │    53   Private           157069   Assoc-acdm             12.0   Marr
 32560 │    32   Local-gov         217296   HS-grad                 9.0   Marr
 32561 │    26   Private           182308   Some-college           10.0   Marr ⋯
                                               10 columns and 32540 rows omitted)

In case our data was present in a different format for eg. parquet, it could be loaded into a data container as follows:


			
			
			
			using
			
			 
			Parquet
			
			

			

			

	
			TableDataset
			(
			
			read_parquet
			(
			parquet_path
			)
			)
			;

mapobs is used here to split our target column from the rest of the row in a lazy manner, so that each observation consists of a row of inputs and a target variable.


			
			
			
			
			splitdata
			 
			=
			 
			

	
			mapobs
			(
			
			row
			 
			->
			 
			
			(
			row
			,
			 
			
			row
			[
			
			:
			salary
			]
			)
			,
			 
			data
			)
			;

To create a learning task for tabular classification task, we need an input block, an output block, and the encodings to be performed on the data.

The input block here is a TableRow which contains information about the nature of the columns (ie. categorical or continuous) along with an indexable collection mapping categorical column names to a collection with distinct classes in that column. We can get this mapping by using the gettransformationdict task with DataAugmentation.Categorify.

The outblock block used is Label for single column classification and the unique classes have to passed to it.

This is followed by the encodings which needs to be applied on our input and output blocks. For the input block, we have used the gettransforms function here to get a standard bunch of transformations to apply, but this can be easily customized by passing in any tabular transformation from DataAugmentation.jl or a composition of those, to TabularPreprocessing. In addition to this, we have just one-hot encoded the outblock.


			
			
			
			
			cat
			,
			 
			cont
			 
			=
			 
			
			
			FastTabular
			.
			

	
			getcoltypes
			(
			data
			)
			

			
			target
			 
			=
			 
			
			:
			salary
			

			
			cat
			 
			=
			 
			
			filter
			(
			
			!
			
			isequal
			(
			target
			)
			,
			 
			cat
			)
			
			

			
			catdict
			 
			=
			 
			
			
			FastTabular
			.
			

	
			gettransformdict
			(
			data
			,
			 
			

	
			DataAugmentation
			.
			

	
			Categorify
			,
			 
			cat
			)
			;

			
			
			
			inputblock
			 
			=
			 
			

	
			TableRow
			(
			cat
			,
			 
			cont
			,
			 
			catdict
			)
			

			
			targetblock
			 
			=
			 
			

	
			Label
			(
			
			unique
			(
			
			
			data
			.
			
			table
			[
			:
			,
			 
			target
			]
			)
			)
			

			

			
			task
			 
			=
			 
			

	
			BlockTask
			(
			
    
			
			(
			inputblock
			,
			 
			targetblock
			)
			,
			
    
			
			(
			
        
			

	
			setup
			(

	
			TabularPreprocessing
			,
			 
			inputblock
			,
			 
			data
			)
			,
			
        
			
			

	
			FastAI
			.
			

	
			OneHot
			(
			)
			
    
			)
			

			)

			┌ Warning: There is a missing value present for category 'occupation' which will be removed from Categorify dict
└ @ DataAugmentation /home/lorenz/.julia/packages/DataAugmentation/D6ODQ/src/rowtransforms.jl:108

			SupervisedTask(TableRow{8, 6, Dict{Any, Any}} -> Label{InlineStrings.String7})

In case our initial problem wasn't a classification task, and we had a continuous target column, we would need to perform tabular regression. To create a learning task suitable for regression, we use a Continuous block for representing our target column. This can be done even with multiple continuous target columns by just passing the number of columns in Continuous. For example, the task here could be used for 3 targets.


			
			
			
			task2
			 
			=
			 
			

	
			BlockTask
			(
			

			
    
			
			(
			

			
        
			

	
			TableRow
			(
			cat
			,
			 
			cont
			,
			 
			catdict
			)
			,
			 

			
        
			

	
			Continuous
			(
			3
			)
			

			
    
			)
			,
			

			
    
			(
			
			(
			
			

	
			FastAI
			.
			

	
			TabularPreprocessing
			(
			data
			)
			,
			)
			)
			,
			

			
    
			
			outputblock
			 
			=
			 
			

	
			Continuous
			(
			3
			)
			

			

			)

To get an overview of the learning task created, and as a sanity test, we can use describetask. This shows us what encodings will be applied to which blocks, and how the predicted ŷ values are decoded.


			
			
			

	
			describetask
			(
			task
			)

getobs gets us a row of data from the TableDataset, which we encode here. This gives us a tuple with the input and target. The input here is again a tuple, containing the categorical values (which have been label encoded or "categorified") and the continuous values (which have been normalized and any missing values have been filled).


			
			
			

	
			getobs
			(
			splitdata
			,
			 
			1000
			)

			(DataFrameRow
  Row  age    workclass   fnlwgt  education  education-num  marital-status          Int64  String31    Int64   String15   Float64?       String31          ⋯
──────┼─────────────────────────────────────────────────────────────────────────
 1000 │    61   State-gov  162678   5th-6th             3.0   Married-civ-spou ⋯
                                                              10 columns omitted, "<50k")

			
			
			
			x
			 
			=
			 
			

	
			encodesample
			(
			task
			,
			 
			

	
			Training
			(
			)
			,
			 
			

	
			getobs
			(
			splitdata
			,
			 
			1000
			)
			)

			(([5, 16, 2, 10, 5, 2, 3, 2], [1.6435221651965317, -0.2567538819371021, -2.751580937680526, -0.14591824281680102, -0.21665620002803673, -0.035428902921319616]), Bool[0, 1])

To get a model suitable for our learning task, we can use taskmodel which constructs a suitable model based on the target block.


			
			
			
			model
			 
			=
			 
			

	
			taskmodel
			(
			task
			)

			Chain(
  Parallel(
    vcat,
    Chain(
      FastTabular.var"#47#49"(),
      Parallel(
        vcat,
        Embedding(10 => 6),             # 60 parameters
        Embedding(17 => 8),             # 136 parameters
        Embedding(8 => 5),              # 40 parameters
        Embedding(17 => 8),             # 136 parameters
        Embedding(7 => 5),              # 35 parameters
        Embedding(6 => 4),              # 24 parameters
        Embedding(3 => 3),              # 9 parameters
        Embedding(43 => 13),            # 559 parameters
      ),
      identity,
    ),
    BatchNorm(6),                       # 12 parameters, plus 12
  ),
  Chain(
    Dense(58 => 200, relu; bias=false),  # 11_600 parameters
    BatchNorm(200),                     # 400 parameters, plus 400
    identity,
  ),
  Chain(
    Dense(200 => 100, relu; bias=false),  # 20_000 parameters
    BatchNorm(100),                     # 200 parameters, plus 200
    identity,
  ),
  Dense(100 => 2),                      # 202 parameters
)         # Total: 18 trainable arrays, 33_413 parameters,
          # plus 6 non-trainable, 612 parameters, summarysize 134.871 KiB.

Of course you can also create a custom backbone using the functions present in FastAI.Models.


			
			
			
			cardinalities
			 
			=
			 
			
			collect
			(
			
			map
			(
			
			col
			 
			->
			 
			
			length
			(
			
			catdict
			[
			col
			]
			)
			,
			 
			cat
			)
			)
			

			

			
			ovdict
			 
			=
			 
			
			Dict
			(
			
			
			:
			workclass
			 
			=>
			 
			10
			,
			 
			
			
			:
			education
			 
			=>
			 
			12
			,
			 
			
			
			Symbol
			(
			
			"
			native-country
			"
			)
			 
			=>
			 
			16
			)
			

			
			overrides
			 
			=
			 
			
			collect
			(
			
			map
			(
			
			col
			 
			->
			
			
			 
			col
			 
			in
			 
			
			keys
			(
			ovdict
			)
			 
			?
			 
			
			ovdict
			[
			col
			]
			 
			:
			 
			nothing
			,
			 
			cat
			)
			)
			

			

			
			embedszs
			 
			=
			 
			
			
			FastTabular
			.
			
			_get_emb_sz
			(
			cardinalities
			,
			 
			overrides
			)
			
			

			
			catback
			 
			=
			 
			
			
			FastTabular
			.
			

	
			tabular_embedding_backbone
			(
			embedszs
			,
			 
			0.2
			)
			;

We can then pass a named tuple (categorical = ..., continuous = ...) to taskmodel to replace the default backbone.


			
			
			
			backbone
			 
			=
			 
			
			(
			
			categorical
			 
			=
			 
			catback
			,
			 
			
			continuous
			 
			=
			  
			

	
			BatchNorm
			(
			
			length
			(
			cont
			)
			)
			)
			

			
			model
			 
			=
			 
			

	
			taskmodel
			(
			task
			,
			 
			backbone
			)

			Chain(
  Parallel(
    vcat,
    Chain(
      FastTabular.var"#47#49"(),
      Parallel(
        vcat,
        Embedding(10 => 10),            # 100 parameters
        Embedding(17 => 12),            # 204 parameters
        Embedding(8 => 5),              # 40 parameters
        Embedding(17 => 8),             # 136 parameters
        Embedding(7 => 5),              # 35 parameters
        Embedding(6 => 4),              # 24 parameters
        Embedding(3 => 3),              # 9 parameters
        Embedding(43 => 16),            # 688 parameters
      ),
      Dropout(0.2),
    ),
    BatchNorm(6),                       # 12 parameters, plus 12
  ),
  Chain(
    Dense(69 => 200, relu; bias=false),  # 13_800 parameters
    BatchNorm(200),                     # 400 parameters, plus 400
    identity,
  ),
  Chain(
    Dense(200 => 100, relu; bias=false),  # 20_000 parameters
    BatchNorm(100),                     # 200 parameters, plus 200
    identity,
  ),
  Dense(100 => 2),                      # 202 parameters
)         # Total: 18 trainable arrays, 35_850 parameters,
          # plus 6 non-trainable, 612 parameters, summarysize 144.453 KiB.

To directly get a Learner suitable for our task and data, we can use the tasklearner function. This creates both batched data loaders and a model for us.


			
			
			
			learner
			 
			=
			 
			

	
			tasklearner
			(
			task
			,
			 
			splitdata
			
			;
			
    
			
			backbone
			=
			backbone
			,
			 
			
			callbacks
			=
			
			[
			

	
			Metrics
			(

	
			accuracy
			)
			]
			,
			
    
			

	
			batchsize
			=
			128
			)

			Learner()

Once we have a Learner, we can call fitonecycle! on it to train it for the desired number of epochs:


			
			
			

	
			fitonecycle!
			(
			learner
			,
			 
			5
			,
			 
			0.1
			)

			Epoch 1 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:55
┌───────────────┬───────┬─────────┬──────────┐
│         Phase  Epoch     Loss  Accuracy │
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │   1.0 │ 0.38405 │  0.82442 │
└───────────────┴───────┴─────────┴──────────┘
Epoch 1 ValidationPhase(): 100%|████████████████████████| Time: 0:00:02
┌─────────────────┬───────┬─────────┬──────────┐
│           Phase  Epoch     Loss  Accuracy │
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │   1.0 │ 0.32253 │  0.85324 │
└─────────────────┴───────┴─────────┴──────────┘
Epoch 2 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:01
┌───────────────┬───────┬─────────┬──────────┐
│         Phase  Epoch     Loss  Accuracy │
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │   2.0 │ 0.34583 │  0.83736 │
└───────────────┴───────┴─────────┴──────────┘
Epoch 2 ValidationPhase(): 100%|████████████████████████| Time: 0:00:00
┌─────────────────┬───────┬─────────┬──────────┐
│           Phase  Epoch     Loss  Accuracy │
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │   2.0 │ 0.31372 │  0.85753 │
└─────────────────┴───────┴─────────┴──────────┘
Epoch 3 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:01
┌───────────────┬───────┬─────────┬──────────┐
│         Phase  Epoch     Loss  Accuracy │
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │   3.0 │ 0.33348 │  0.84085 │
└───────────────┴───────┴─────────┴──────────┘
Epoch 3 ValidationPhase(): 100%|████████████████████████| Time: 0:00:00
┌─────────────────┬───────┬─────────┬──────────┐
│           Phase  Epoch     Loss  Accuracy │
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │   3.0 │ 0.30714 │  0.85545 │
└─────────────────┴───────┴─────────┴──────────┘
Epoch 4 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:01
┌───────────────┬───────┬─────────┬──────────┐
│         Phase  Epoch     Loss  Accuracy │
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │   4.0 │ 0.32585 │  0.84717 │
└───────────────┴───────┴─────────┴──────────┘
Epoch 4 ValidationPhase(): 100%|████████████████████████| Time: 0:00:00
┌─────────────────┬───────┬─────────┬──────────┐
│           Phase  Epoch     Loss  Accuracy │
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │   4.0 │ 0.31087 │  0.85935 │
└─────────────────┴───────┴─────────┴──────────┘
Epoch 5 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:01
┌───────────────┬───────┬─────────┬──────────┐
│         Phase  Epoch     Loss  Accuracy │
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │   5.0 │ 0.32124 │    0.851 │
└───────────────┴───────┴─────────┴──────────┘
Epoch 5 ValidationPhase(): 100%|████████████████████████| Time: 0:00:00
┌─────────────────┬───────┬─────────┬──────────┐
│           Phase  Epoch     Loss  Accuracy │
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │   5.0 │ 0.30543 │   0.8586 │
└─────────────────┴───────┴─────────┴──────────┘