Add GCM example notebook: auditing CNN predictions for spurious correlations using chest X-ray data#1524
Add GCM example notebook: auditing CNN predictions for spurious correlations using chest X-ray data#1524sanzits wants to merge 3 commits into
Conversation
…lations using chest X-ray data Signed-off-by: Sanchit <sanzit.s@gmail.com>
|
🤖 This is an automated response from Repo Assist. Welcome, A few notes for review: External data dependency: The notebook requires downloading the NIH Chest X-ray Dataset from Kaggle. This means it cannot be executed in automated CI. Please add a clear note at the top of the notebook about the data download requirement, and the PR should ensure it's marked so CI skips it. Check whether it should carry the Docs indexing: Please ensure the notebook is referenced in the appropriate RST index file under CNN training time: If the notebook includes full model training from scratch, runtime may be very long for anyone following along. Consider noting approximate wall-clock time and whether a pre-trained weights option is feasible for the demo. Despite these practical considerations, the analytical story (train model → construct causal graph → apply GCM intrinsic influence / arrow strength / interventional samples → interpret) is clean and instructive. The
|
|
Hello, thanks for the detailed feedback! I've addressed all three points: Added a data download warning at the top of the notebook with instructions for the NIH Chest X-ray dataset Happy to make any further changes. Thanks! |
…ing time note Signed-off-by: Sanchit <sanzit.s@gmail.com>
…ing time note Signed-off-by: Sanchit <sanzit.s@gmail.com>
d2ee462 to
dba5fe8
Compare
What the notebook demonstrates
A CNN trained on chest X-rays can achieve decent AUC and still be learning the wrong things. The notebook shows that roughly 28% of what drives this model's predictions comes from image brightness — a property of the scanner hardware, not the patient's lungs. We use DoWhy's GCM framework to open the black box and separate the legitimate clinical signal from the spurious scanner artifact. The core message is simple: good predictive accuracy does not mean the model learned the right causal relationships.
Which DoWhy features it uses
Three GCM tools, each answering a different question:
gcm.intrinsic_causal_influence — "of everything driving model predictions, how much does each variable causally own?" This is the headline analysis. It tells you that image brightness owns 4.8x more of the model's predictive variance than the actual clinical signal (opacity score).
gcm.arrow_strength — "how load-bearing is each individual edge in the causal graph?" Measured by KL divergence — if you removed that edge, how much would the prediction distribution change? The brightness→prediction edge (0.014) is 6x stronger than the opacity→prediction edge (0.002).
gcm.interventional_samples — "what actually happens if we swap the scanner?" This runs a do-calculus intervention: force every image to have the brightness of scanner group 1, then propagate that change through the causal graph. Mean predictions shift from 0.44 to 0.56 — a 12 percentage point change with no change to any patient's clinical condition.
The dataset and how to get it
NIH Chest X-ray Dataset, published by Wang et al. at CVPR 2017. It contains 112,120 frontal-view chest X-rays from 30,805 unique patients, with 14 disease labels. The notebook uses a 15,000-image stratified subset (first 2 folders) targeting the Infiltration label (about 15% prevalence).
How to get it: it's publicly available on Kaggle at https://www.kaggle.com/datasets/nih-chest-xrays/data. You need a free Kaggle account. The notebook's first cell has the exact download and setup instructions including how to create the stratified subset and generate the feature CSV that DoWhy consumes.
Where it should be listed
Real world-inspired examples section, alongside the microservice latency notebook and the counterfactual medical case. Level: Advanced. Task: Intrinsic causal influence, intervention, and root cause analysis via GCM.
It fits here rather than "Examples on benchmark datasets" because the point isn't to benchmark DoWhy on a standard dataset — it's to show a real-world use case where causal auditing of a deployed model reveals something that standard ML evaluation would miss