Saturday, 14 July 2018

Cross Validation - Model checking, not model building

1st Update on 11th August 2020
2nd Update on 4th August 2023
3rd Update on 25th September 2023


Cross-Validation 

Cross-validation is a popular technique that is used for model evaluation in machine learning. If you skim through any Kaggle Kernel then you're bound to find Cross-validation in the script. It allows us to evaluate how our model will work when faced with unseen data and extends the concept of train-test datasets. It is also used to counter overfitting by examining how well our model generalizes.

CV has 2 main objectives: 
1. Give a good estimate of how our model generalizes to unseen data
2. Compare the performance between different models or a combination of hyperparameters 

I had this question about cross-validation. Suppose we do 5 fold CV, we are splitting our dataset into 5 portions and then training and testing on each portion. Then how are we supposed to choose which of these models is the best?

A similar question was asked on StackExchange which cleared my concern.

Also, this link from Machine Learning Mastery was helpful.

An example of choosing the best type of model. Let's consider a classification problem and we want to try Logistic Regression and Random Forest. Now, in this case, we will compare the CV error of both methods and choose the one that minimizes the error. 
Suppose Logistic Regression is a winner here. We will thus discard all the models that were built during the CV and train the model on the entire dataset. This model will then be utilized for predictions.

An extension to the above problem: We considered Logistic Regression and Random Forest in the above scenario but the question of hyperparameter tuning will now arise.
In logistic regression, we have L1, L2 regularization and in the random forest, we can vary the number of trees, maximum number of features etc. Hence, we are not just trying 2 models but a lot of variations amongst these models as well. Scikit Learn's GridSearchCV can be used here as it will iterate through all the variations and find the optimal model with its hyperparameters.
Now that we are aware of the best model along with its hyperparameters, we can train the model on the entire dataset.

Cross-validation Pseudocode:
Here we divide the training dataset into k-folds and k-1 folds are used for training and the remaining fold for validation
1. For 1 to N hyperparameter combinations:
    i. For 1 to k folds: #(The step repeats holdout approach k times)
        Train the model on k-1 folds and compute the validation error.
    ii. Calculate the average validation error
2. Choose the hyperparameters with the lowest average validation error and evaluate the model on test error

Here is the link to an example jupyter notebook.

Holdout validation:
We understood the concept of cross-validation but there is yet another strategy that can be adopted for choosing the best model or the model with the best hyperparameters. The strategy is called "holdout validation". 

Here, the dataset is divided into 2 sets: training and testing. However, the training dataset is further divided into training and validation sets. Thus, we would train our model on the training data and then evaluate its performance on the validation dataset. We could then compare the validation metrics between different models or if we want to choose the best hyperparameters. 

Once, the best model is selected we can then train it on the entire dataset and compute the test dataset metrics to understand how well our model generalizes. The con of this method is that we would choose only a small portion like 10% of the training data as our validation data and our decision would be based on this validation metric. 

Holdout Validation Pseudocode:
1. For 1 to N hyperparameter combinations:
    Train the model on the training dataset and compute the error on validation data.
2. Choose the hyperparameters that generalize the best and evaluate the model on test error

Here is a great link that explains the difference between holdout validation and k-fold cross-validation.