My favorite airline!
GitHub Repo:
You can find this notebook in my GitHub Repo here:
https://github.com/sam-tritto/airline-passenger-satisfaction
In this tutorial, we'll explore how to leverage Google's ADK (Agent Development Kit) to create a powerful Explainable AI (XAI) agent for your machine learning models.
Understanding why a model makes a certain prediction is crucial, especially when stakeholders and business leaders ask, "Why did the model output this result?" Traditional AI agents often struggle when faced with complex, multi-dataset environments, becoming bogged down in trying to connect disparate data sources.
Our approach sidesteps this common pitfall by focusing on a single, well-structured table containing both the model's inputs and outputs. This makes the agent's task much simpler: it needs only to analyze a single, unified data source to provide clear, concise explanations. This not only makes the agent more effective but also provides an excellent and highly relevant use case for stakeholders who need to understand the reasoning behind a model's decisions.
Import Libraries
The libraries I'll be using here are my go to for all things classification. They are a flexibe power trio that allows for explainability and sharp insights. LightGBM has both a Regressor and Classifier. They're fast, performative, and flexible allowing for categorical data.
The Airline Passenger Satisfaction Dataset
Here's an interesting dataset from Kaggle, the Airline Passenger Satisfaction dataset. It contains responses from a passenger survey on the Likert scale (1 - 5) with a label denoting if the passenger was satisfied with their flight, which makes it great for binary classification. There are also many categorical columns describing the type of passenger or flight which is fantastic for a cohort analysis. The dataset comes in two parts and is relatively large with about 100k records.
You can find more information and the data set here on Kaggle:
https://www.kaggle.com/datasets/teejmahal20/airline-passenger-satisfaction
Pre-Processing
This dataset happends to be very clean so there really isn't too much to do here in terms of pre-processing the data. I will however make the assumption that some nulls in the Arival Delay column are really supposed to be 0 and fill them with that. The target class is a string field so I'll convert it to Binary. The rest of the categorical data I will explicitly cast as from type object pandas category which will be necessary for LightGBM to process later on. There's an oportunity here to create some NPS (Net Promoter Score) features since these are all on the Likert scale, but I'll leave that for another day. The goal of my project is to simply stand up a binary classifier, so that I can play with its SHAP values later on.
Finally I'll separate the features from the target varaible which will help out downsteam to simplify things.
Split the Data - Train, Validation, and Test Sets
Since we are using tree-based models, which are easy to overfit, we are going to need a validation set in addition to our test set so that we can utilize the model's early stopping abilities. First I split the train data 80/20 into train and val/test, then further split the val/test 50/50. This results in a 80/10/10 split. The original test data will be used as a holdout set. If your data is temporal, them you'd need to ensure there's no leakage when splitting your data.
Manually Tuned Model
Now we can build the classifier model. I like to start with a manually tuned model. This let's me get a feel for how the model fits to the data, and also allows me to build any expert knowledge into my paramters. I'll show next how once you find parameters that work well for your data you can then further tune them with Optuna while using the ones you've found manually as a starting place.
Class Imbalance: This data set does not suffer from class imbalance, if it did we can use the scale_pos_weight parameter to weight the positive class either above or below 1 to give it more or less weight compared to the minority class. This can help you dial in False Negatives or False Positives. It's just one tool to help balance the model's errors. With a value of 0.5 predicting a satisfied customer is half as important as predicting an unsatisfied one. You can also notice that I'm using AUC to evaluate the model, that's another tool to balance the errors.
Parameters: I'm starting with some pretty standard parameters, but setting the number of estimators high and the learning rate low. I like to call this "low and slow".
Overfitting: Like I mentioned earlier, all tree-based models are prone to overfitting, all besides Random Forests. The idea here is to set a large number of estimators and then allow the model's early stopping capabilities to stop the model once it hasn't learned anything for a certain number of iterations. It will use the validation set to allow the model to evaluate itself. You can see below that even though I've set the number of estimators to be 1000, the early stopping was triggered and the best iteration was at 366. Good thing we didn't keep going to 1000.
Confusion Matrix & Classification Report
After fitting, we need to understand how our model is predicting for each class. It's best to think through from a business standpoint whether we prefer False Negatives or False Positives or a balance of both. Errors are inevitable. If you have 100% accuracy, that's bad and you probably have data leakage somewhere and need to fix it.
You could easily assign a cost to each one of these if it makes sense in your business context. For instance if a FN costs $0.38 and a FP costs $1.2 you could weight the F1 Score calculation. For this tutorial I'll keep it simple, but we can absolutely optimize for cost here. Something to keep in mind.
Since a "satisfied" customer is our 1 class, then our errors are as follows:
False Negatives: We predicted an "unsatisfied" customer and it they were a "satisfied" customer.
False Positives: We predicted a "satisfied" customer and they were an "unsatisfied" customer.
This model seems well balanced in its errors with a 96% F1 Score, and leaning more toward False Negatives thanks to the scale_pos_weight parameter.
Loss Curves
As the model trains itself it keeps record of its evaluation scores for both the training and evaluation set. We should monitor these to look for overfitting and convergence. If the train set keeps improving and the validation set does not, then its overfitting. There are 2 curves, one for each evaluation metric. Notice the number of boosting rounds is about 376. That is 10 more than the best iteration of 366, or 366 rounds + 10 early stopping rounds. If we hadn't set the early_stopping parameter this would have continued up to the n_estimators parameter of 1000, even though it had stopped learning at round 366.
Typically at this point I might try to tune the hyperparameters using Optuna or possibly a grid search, but since the model fit very well and I'm really interested in the SHAP values I'm going to move on. Here I'll re-join the validation and test set and use them as a hold out set after fitting on the training set. I'm doing this after I've found the best paramters and want a more accurate picture of the model's accuracy metrics. Notice I'm using the best_n_estimators I found from overfitting, all other parameters are the same.
Calibrate Predicted Probabilities
The predicted probabilites returned from the model are not true probabilities. I'm imagining here that my stakeholders want to communicate these proabilites around and so I'm going to calibrate them closer to true probabilities.
Now that our model is tuned and we've made some predictions we can focus on something other than accuracy metrics. There are so many insights to be had beyond accuracy. The whole purpose of this project has been to generate SHAP values for our predictions so that I can build a Data Science Agent that will explain our ML model to stakeholders, and providing it with these SHAP values will help it greatly. To get started we simply need to fit the explainer to our test data. This will give us SHAP values for each record in our test or prediction set. These SHAP values will answer the question, "Why did our model predict this customer would be satisfied?".
Typically I'd run through some of the visualizations provided by the SHAP library (dependence plots, feature importance, etc.) and even do some tricks of my own to get some further insights, but I'm going to skip past all of that. If you're interested in that you can navigate to my GitHub repo or a similar project tutorial.
GitHub Repo: https://github.com/sam-tritto/airline-passenger-satisfaction/blob/main/Airline%20Passenger%20Satisfaction.ipynb
Another Classification Project: Banana Quality - Classifying for Explainability & Insight
What I'm going to do next is to put the metrics and the SHAP values for each metric in one dataframe along with the predictions and predicted probabilities. This should be all the Data Science Agent needs to answer any "Why" question a user can ask.
Now that we have the general classification structure down, we are going to imagine this model running each day/week/month/year with new data. These predictions would likely end up stacking on top on each other along with their SHAP values. We are going to loop through 12 Months of the year to mimic this temporal structure making predictions on synthetic data.
The SDV library, or Synthetic Data Vault, is an excellent choice to create synthetic data to mimic an existing dataset. I'll use their GaussianCopulaSynthesiser to learn from the existing data and produce similar Likert data for each cohort group. They even have some functions available to evaluate the quality of this new data.
With a start date, we can iterate through any number of months backwards.
Generating between 20k and 30k records each month.
Then evaluating the synthetic data.
Now that we have Monthly data, we can make predictions and assign SHAP values similar to how we've done above. I'll skip showing the code here as the loop is long, but if you're interested please navigate to this notebook in my GitHub Repo:
Since our data will be queried by our DS Agent and will exist as string prompts, we have an opportunity to save on token cost if we round all of our numerical data. You can see below that our Agent's thoughts will contain numerical data with 14 decimals.
Rounding to 3 will not harm our insights but will save greatly on costs.
The last thing we'll need to do is push the data to BigQuery, don't forget to authenticate.
Now that our data mimics a true business problem with historical data we can get to the fun stuff. Google's ADK documentation is amazing so rather than walk through the steps here, I'll just make some call outs and point to their docs. Before we get started though we're going to need a project-id and we're going to have to have billing enabled. Then navigate to APIs & Services -> Credentials and create an API Key, noting it down for later. Once that's all set up, we can navigate to the sample DS Agent on their GitHub below.
The sample DS Agent can be found here: https://github.com/google/adk-samples/tree/main/python/agents/data-science
Official Documentation: https://google.github.io/adk-docs/
Once there you can copy their command to clone thier GitHub structure. It will clone the entire adk-samples repo so what I did was simply manually delete the folders locally that I didn't need. I only kept the data-science folder. This also assumes you have git installed, if not you'll need to set that up locally.
Once cloned, you'll need to use poetry to install the environment and then create and edit a .env file with local environment variables. Add your API Key that you created earlier to this file. Don't forget to add this .env to your .gitignore file to prevent you from pushing your secrets to GitHub. And once your Code Interpreter permissions are enabled automatically, copy their address into this file as well, otherwise it will keep creating new ones.
Running the agent is easy, or atleast it's easy to demo rather than deploy. There are two options, either from the terminal or web. The web option is really nice it allows us to avoid spinning up a Streamlit app just to demo our agent. To officially deploy the agent Cloud Run functions are reccommended.
Before we begin it's going to be important to understand what we've set up. This is going to be a DS Agent with access to one BQ table for the sole purpose of explaining our model's predictions. The table has all of the columns needed for a basic analysis and the Agent has all of the tools needed to deliver that including BQ access, python data visualization capabilities, and memory. Let's see how it delivers.
First, let's see if we can connect to our data in BQ.
Now right to the good stuff... let's see if we can visualize a simple line chart for each month broken down by one category and our target variable.
Looks good, I wouldn't use this in a slide deck but it's informative. It shows the percentage of satisfied passengers, which is the correct way to compare unequal groups. I do like how it shows not only its explanation, but also a summary of the chart. This is great for notes or bullet points to accompany the chart.
Now for something more challenging, a follow up question. Let's try to see if it can interpret the SHAP values already present in our data. Wow, great insights! These would be almost perfect for stakeholders if they didn't reference SHAP values. That is something we can take care of later on in Context Engineering where we can give it explicit instruction to not mention SHAP values directly. The other thing that I would have to say here is that it considered all of the historical data. I was kind of hoping it would infer that the most recent month would be more of interest. Again, something we can take care of in Context Engineering, where we can specify that most users are interested on what has changed this current month.
Now another follow up question, and let's be specific about the month and also about actual metric values rather than SHAP values. Overall a great job here. A bit brief, but we can build suggestions into our Agent later on in Context Engineering. We can feed it a simple python fucntion or even examples to compute average values along with the standard deviation and maybe month over month statistics.
A really tough one. Now let's see how it can group the data by it's many categorical columns while focusing on this current month. It focused only on one categorical value but I would have prefered it looked at many. Again something to consider for future Context Engineering, we can tell the Agent common groupings the users generally like to see.
Maybe if we're more specific.
I feel like Age is too continuous, let's use the pre-calculated Age Cohort. I wish the DS Agent had suggested something like this. Any pre known bins can be specified in Context Engineering as well.
A more vague question here. It doesn't reference any time so it's looking at the entire historical dataframe.
Again it doesn't reference time or the current month by default.
Please forgive my spelling, it was very late. A nice take on some line charts and an unhelpful categorical horizontal bar chart. As a future improvement we can customize our data visualization sub agent with some custom examples of charts that we love, even custom colors. The textual summaries here are very nice and would make excellent call outs or notes for a presentation.
A step in the right direction, but far from a presentable chart, let's be more specific.
More specific instructions seem to help the Agent.
I've seen worse in a slide deck.
Wow! Very impressed here. I'm sure this is a vanilla forecast, but something we can edit in our sub agent. It's totally possible to give it some powerful libraries (there are only a handful more allowed on the Code Executer but statsmodels and sci-kit learn are available) and instruct it to use them for ML tasks.
Very impressed with the output and the legend.
Coming Soon!
Currently I'm exploring how to give the model more context so that I can improve the results that I've seen so far. Specifically, I think giving the Agent more knowledge of the data, knowledge if how how users prefer to categorize and visualize things, more tools like web searching, more advanced ML libraries, and fine tuning each agent's prompts would be a great start. Out-of-the-box I think the Agent did a great vanilla analysis. If it could understand more about user preferences it seems like it could easily deliver some advanced analyses and insights. I'm going to follow this tutorial up with another focused on Context Engineering.
Coming Soon-ish!
Currently this only runs on the Agent ADK web interface. The next steps for a model like this would include deploying on Streamlit or using Cloud Functions something I would love to explore more of when I have time.