GitHub Repo:
You can find this notebook in my GitHub Repo here:
https://github.com/sam-tritto/banana-quality/blob/main/Banana_Quality.ipynb
OUTLINE
Peeling Back the Insights: Classifying When Explainability is More Important than Accuracy
Import Libraries
The Banana Quality Dataset
EDA
Feature Engineering
Split the Data - Train, Validation, and Test Sets
Manually Tuned Model
Optuna Bayesian Hyperparamter Tuned Model
Final Model
SHAP Explainability
In the world of machine learning, the relentless pursuit of higher accuracy often takes center stage. While a highly performant model is obviously valuable, there are scenarios where understanding why a model makes its predictions and gaining actionable insights can be even more crucial. This tutorial will guide you through a journey of not just building a predictive model, but also dissecting its behavior to extract meaningful knowledge aligned to Business actions, using the Banana Quality dataset.
Imagine a system designed to classify the quality of bananas. While achieving 99% accuracy might seem like a resounding success, what if a subtle change in growing conditions consistently leads to misclassification, costing significant losses? Without understanding which factors the model deems important and how they influence the prediction, we remain in a "black box," unable to diagnose issues or identify opportunities for improvement. Also some or many of the factors might be unactionable from a business standpoint. It's important to know what your controllable factors are and which ones matter most.
This is where the power of Explainable AI (XAI) comes into play. By shedding light on the model's decision-making process, we can:
Build Trust and Transparency: Especially in sensitive applications, understanding why a model makes a certain prediction fosters trust among stakeholders.
Debug and Improve Models: Identifying influential features and their relationships with the target can reveal biases, errors in data, or areas where the model can be further refined.
Gain Actionable Insights: Understanding the drivers of a prediction can lead to valuable insights about the underlying domain. In our banana quality example, this could reveal critical factors in cultivation, handling, or storage that impact the final product.
Identify Potential Issues: Explainability can help detect spurious correlations or unintended biases that might lead to seemingly accurate but ultimately flawed models.
To embark on this journey of insight discovery, we will leverage two powerful tools:
Optuna: This efficient Bayesian hyperparameter optimization framework will allow us to systematically search for the best model configurations, not just for accuracy, but also in a way that facilitates further analysis. Optuna's flexible and automated approach streamlines the often tedious process of tuning model parameters.
SHAP (SHapley Additive exPlanations): SHAP provides a unified framework for explaining the output of any machine learning model. By calculating the contribution of each feature to individual predictions, SHAP allows us to understand feature importance, feature effects, and interactions in a consistent and interpretable manner. These values "additive" meaning that they play well for aggregations especially in dashboards where we might have to drill down deeper or higher into the level of data.
This project will revolve around the Banana Quality dataset. This dataset, while seemingly simple, provides a tangible and relatable context for exploring the nuances of model explainability. By analyzing the features of bananas and their relationship to quality classifications, we can demonstrate how Optuna and SHAP can go beyond simply predicting 'Good' or 'Bad' to revealing the underlying factors that determine quality. We'll use these SHAP values to reveal hidden thresholds in our metrics to better identify when exactly a Banana goes from being Good to Bad.
Through this tutorial, you will learn how to combine the power of automated hyperparameter tuning with the clarity of explainable AI, demonstrating that sometimes, the insights gained from understanding our models are far more valuable than chasing that extra fraction of a percentage point in accuracy. Let's peel back the layers of our model and uncover the knowledge hidden within the data.
Import Libraries
The libraries I'll be using here are my go to for all things classification. They are a flexibe power trio that allows for explainability and sharp insights. LightGBM has both a Regressor and Classifier. Even though this is a classification project, I'll use the Regressor later on to identify hidden thresholds in our data. They're fast, performative, and flexible allowing for categorical data (which we don't have here). I've already mentioned Optuna and SHAP above.
The Banana Quality Dataset
Typically I try to avoid toy data sets but, every once in a while I do a quick search of ones that I might find interesting and I love bananas so this one I couldn't pass by. I also wanted to showcase some more of the technical aspects of my typical classification workflow... I'm trying to pass on my tricks here! This data set I found on Kaggle and it's a very clean and predictive data set with only a handful of features. The data appears to be standardized or at least centered already (which I would have preferred to do myself), but it's clean. There's no missing data and only one categorical column which is the target.
You can find more information and the data set here on Kaggle:
https://www.kaggle.com/datasets/l3llff/banana?resource=download
EDA
One of the most common things I see in any kind of ML project is overengineering in the EDA phase. I'm guilty of that too. They're pretty and pretty easy to code. But more often than not they just sit at the top of a notebook and that's that. They don't inform the analysis or provide any real value.
But what about class imbalance? While that's a valid concern I don't really find methods like Undersampling, Oversampling, or SMOTE helpful unless the data is severely imbalanced, such as in a fraud data set. They can lead to bias very easily and many of the ML packages have parameters that we can utilize to balance things out without altering or creating syntheitc data. I'll show how to do this in the scale_pos_weight parameter later on.
Pie charts? No! ...value_counts(), I actually don't mind pie charts but if your're not going to print it out for leadership then don't burn your calories, let's be pragmatic. You can see that this is a well balanced data set, since it's a toy data set.
Null values and imputations? Yup, they're important. But again let's just see if it's an issue before we start coding. And we can see that since it's a toy data set there's no issue here. I'm a big fan of grouped median imputations, KNN imputations, and MICE if you have to worry about missing values. Some tree based libraries can actually handle nulls now, but a good trick I used to use was to impute with something extreme like -999.
But what about outliers? We don't need to worry too much about that since we're using tree-based models which are robust to outliers. I will look at the tails however to see how wide they go, but that's about it.
So what to do then? There is one type of EDA I never skip though and that's looking at the distribution of the metrics, they can tell me pretty much everything I need to know. And when there's multiple classes you can overlay them to see the differences in your classes very easily. There's a great package for this klib, but sometimes I'll just use seaborn's kdeplot() becasue it's quick and easily customizable.
What to Look For
Mostly I look for Normality and skewness. When I have two classes or groups I look to see which features are farthest apart between the groups. This variance means the classifier will learn much from the feature. You can see this in the plot above for Size. Looking at the data for Softness below though, we can see that there might be something going on with the "Bad" class, it looks multimodal. We should start to think about how we can perform some feature engineering to get our model to learn this. Basically, this alludes that there's more to the story. I have some tricks later on in the Explainability section that can help us figure out what might be going on. For now we can just note it down and circle back later.
Feature Engineering
Not too much to see here. Just going to do some binary encoding of the target variable. A "Good" banana will be the positive class.
I'm also creating an interaction term, based on the distribution plot of Softness above and some insights that will come later on in the analysis. I'll explain how I decided that Acidity would help explain the bimodal nature of Softness later on. A quick note that we can't use the target, Quality, to create any features as this would cause leakage, although it's tempting becasue we can see that this bimodal nature for Softness is only for the "Bad" bananas.
Split the Data - Train, Validation, and Test Sets
Since we are using tree-based models, which are easy to overfit, we are going to need a validation set in addition to our test set so that we can utilize the model's early stopping abilities. First I split the data 80/20 into train and val/test, then further split the val/test 50/50. This results in a 80/10/10 split. If your data is temporal, them you'd need to ensure there's no leakage when splitting your data.
Manually Tuned Model
Now we can build the classifier model. I like to start with a manually tuned model. This let's me get a feel for how the model fits to the data, and also allows me to build any expert knowledge into my paramters. I'll show next how once you find parameters that work well for your data you can then further tune them with Optuna, but use the ones you've found manually as a starting place.
Overfitting: Like I mentioned earlier, all tree-based models are prone to overfitting, all besides Random Forests. The idea here is to set a large number of estimators and then allow the model's early stopping capabilities to stop the model once it hasn't learned anything for a certain number of iterations. It will use the validation set to evaluate itself. You can see below that even though I've set the number of estimators to be 1000, the early stopping was triggered and the best iteration was at 169. Good thing we didn't keep going to 1000.
Class Imbalance: This data set does not suffer from class imbalance, if it did we can use the scale_pos_weight parameter to weight the positive class either above or below 1 to give it more or less weight compared to the minority class. This can help you dial in False Negatives or False Positives. It's just one tool to help balance the model's errors. With a value of 2.0, predicting a "Good" banana is twice as important. But this also just balances the errors as you can see in the confusion matrix below. You also notice that I'm using auc to evaluate the model, that's another tool to balance the errors.
Confusion Matrix & Classification Report
After fitting, we need to understand how our model is predicting each class. It's best to think through from a business standpoint whether we prefer False Negatives or False Positives or a balance of both. Errors are inevitable. If you have a 100% accuracy, that's bad and you probably have data leakage somewhere and need to fix it.
You could easily assign a cost to each one of these if it makes sense in your business context. For instance if a FN costs $0.38 and a FP costs $1.2 you could weight the F1 Score calculation. For this tutorial I'll keep it simple, but we can absolutely optimize for cost here. Something to keep in mind.
Since a "Good" banana is our 1 class, then our errors are as follows:
False Negatives: We predicted a "Bad" banana and it was a "Good" banana.
False Positives: We predicted a "Good" banana and it was a "Bad" banana.
This model seems well balanced in its errors with a 98% F1 Score, thanks to the scale_pos_weight parameter.
Loss Curves
As the model trains itself it keeps record of its evaluation scores for both the training and evaluation set. We should monitor these to look for overfiting and convergence. If the train set keeps improving and the validation set does not, then its overfitting. There are 2 curves, one for each evaluation metric. Notice the number of boosting rounds is about 179. That is 10 more than the best iteration of 169, or 169 rounds + 10 early stopping rounds. If we hadn't set the early_stopping paramter this would have continued up to the n_estimators parameter of 1000, even though it had stopped learning at round 169.
Now that we've got our model set and we're pretty happy with the hyperparamters, we can tune them further and use the ones we've found as a starting place. We will start by defining the objective function where we specify the type and range of each search hyperparameter. We don't have to tune every hyperparameter, we can but it just takes longer. We can use some of the manually chosen ones to save some time. We can also limit the search space for each hyperparameter to save some time.
This isn't a widely used function, but one worth using for sure. If I've seen 10 Optuna tutorials, none of them have covered the enque_trial() function... tell your friends!
Prior Knowledge
Since Optuna is a Bayesian model under the hood, that means it has Priors. And since it has priors it allows for us to set in some prior or expert knowledge about each hyperparameter. There is a method in Optuna called enque_trial() that I don't often see being used, but it can cut down your search time drastically. It simply allows us to pass in the hyperparameter dictionary as a starting place, which saves many iterations of search time.
You can see in the Optimization Plot History that our expertly picked hyperparameters were actually very close to the best Optuna was able to find. We probably could have stopped at 100 iterations or so.
The max_depth hyperparameter was the one that needed the most tuning and effected our model the most.
Once we have run Optuna we need to extract the best hyperparamters into a dictionary for further use.
Now that our hyperparamters are tuned, we can test our final model on the holdout set. First we join the train and validation sets back together since we won't be doing any more evaluations. Then we make any manual adjustments to the hyperparamters. Finally, fit and predict.
Confusion Matrix & Classification Report
Looking at the results we can see that we've tuned our way to +1 percentage point in Accuracy and F1 Score. We can also see it's slightly more unbalanced. 7 FNs and 3 FPs. I'm ok with this. I'd much rather say a banana is "Bad" then it turns out "Good", rather than say its "Good" and it ends up "Bad". Some waste would be expected, but you probably wouldn't want to ship defects to the store. You should consult your business partners for these questions. Like I mentioned earlier, it's also entirely possible to weight these metrics based on cost and then optimize for that here as well.
Now that our model is tuned and we've made some predictions we can focus on something other than accuracy metrics. There are so many insights to be had beyond accuracy. In this section I'll go over how to utilize SHAP values to identify hidden thresholds in the data and I'll show you how to use them for prescriptive insights. To get started we simply need to fit the explainer to our test data. This will give us SHAP values for each record in our test or prediction set. These SHAP values will answer the question, "Why did our model predict this banana would be Good/Bad?".
The first insights I look for are the SHAP Feature Importances. Based on this plot I might want to rework some features, or worse find data leakage (if one feature is overly important). Notice here that the Interaction term is now the 3rd most important metric, when its components were the least important.
Now if you're not looking at that chart and wondering what those little dots represent, then you should be. These dots are SHAP values of each individual prediction. The dependence plots expand these into scatterplots but color code them with the feature that has the most interaction with the current feature. We can look for patterns in these color codings to think about possible new features we could code up.
For instance, you can see here in the SHAP values for Softness that there might be some interplay from Acidity. If we can use this knowledge to create a new feature it might help the model learn the intricacies better. Remember it was the Softness feature that exhibited the multimodal shape in it's "Bad" class. I've already coded this up and rerun the analysis, but the idea for this feature came from here. You can see a red band going from top left to bottom right, and a blue band going in the opposite direction.
For me this is the most exciting part of the tutorial. I haven't seen anyone deconstruct the SHAP values for this purpose yet, but it's really cool. If we build our own scatterplots of the SHAP values for each feature, we can then fit a small LGBMRegressor() on the SHAP values for the purpose of predicting when the values will change sign, or cross the x-axis. This tells us the approximate feature value when the model's predictions start to change from "Good" to "Bad". Read that again. We can now tell at which point each feature starts to turn from favorable to unfavorable.
Not all of the features will be controllable by the business and therefore knowing a threshold might prove invaluable. Here is one of the controllable features where the business could take direct action on. On the x-axis you have the feature's values. The mean average harvest time is -0.8. On the y-axis we have this feature's SHAP values where below 0 indicates a prediction for "Bad" and above 0 indicates a prediction for "Good". The dashed red line indicates our small regression model's predictions for when the fit will equal 0. So it is at this point when the model's predictions start to trun from unfavorable to favorable. You'll notice it's actually not the average start time of -0.8, it's a little to the right of that around -0.4. So this could inform our business that we should actually harvest a bit later than we currently are. The business would need to consider all of the implications of that, but it's an insightful and data driven hypothesis.
When we compare all of the thresholds to the known averages, you'll see that most are quite different. You'll also notice that the two least important features have many thresholds, indicating a poor model fit, but their interaction has only one. This gives us confidence in our feature engineering, and also warns us not to trust the thresholds unless the model fit was good. It is possible however for a metric to turn from favorable to unfavorable many times, you'll need to investigate each feature.
Again, this is something I haven't seen many doing with their Classification models, but something I would encourage more of. It provides thoughtful and data driven guideposts and hypotheses. Please tell you friends!