sklearn/examples/inspection/plot_permutation_importance...

182 lines
7.2 KiB
Python
Raw Normal View History

2024-08-05 09:32:03 +02:00
"""
=================================================================
Permutation Importance with Multicollinear or Correlated Features
=================================================================
In this example, we compute the
:func:`~sklearn.inspection.permutation_importance` of the features to a trained
:class:`~sklearn.ensemble.RandomForestClassifier` using the
:ref:`breast_cancer_dataset`. The model can easily get about 97% accuracy on a
test dataset. Because this dataset contains multicollinear features, the
permutation importance shows that none of the features are important, in
contradiction with the high test accuracy.
We demo a possible approach to handling multicollinearity, which consists of
hierarchical clustering on the features' Spearman rank-order correlations,
picking a threshold, and keeping a single feature from each cluster.
.. note::
See also
:ref:`sphx_glr_auto_examples_inspection_plot_permutation_importance.py`
"""
# %%
# Random Forest Feature Importance on Breast Cancer Data
# ------------------------------------------------------
#
# First, we define a function to ease the plotting:
from sklearn.inspection import permutation_importance
def plot_permutation_importance(clf, X, y, ax):
result = permutation_importance(clf, X, y, n_repeats=10, random_state=42, n_jobs=2)
perm_sorted_idx = result.importances_mean.argsort()
ax.boxplot(
result.importances[perm_sorted_idx].T,
vert=False,
labels=X.columns[perm_sorted_idx],
)
ax.axvline(x=0, color="k", linestyle="--")
return ax
# %%
# We then train a :class:`~sklearn.ensemble.RandomForestClassifier` on the
# :ref:`breast_cancer_dataset` and evaluate its accuracy on a test set:
from sklearn.datasets import load_breast_cancer
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
X, y = load_breast_cancer(return_X_y=True, as_frame=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
clf = RandomForestClassifier(n_estimators=100, random_state=42)
clf.fit(X_train, y_train)
print(f"Baseline accuracy on test data: {clf.score(X_test, y_test):.2}")
# %%
# Next, we plot the tree based feature importance and the permutation
# importance. The permutation importance is calculated on the training set to
# show how much the model relies on each feature during training.
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
mdi_importances = pd.Series(clf.feature_importances_, index=X_train.columns)
tree_importance_sorted_idx = np.argsort(clf.feature_importances_)
tree_indices = np.arange(0, len(clf.feature_importances_)) + 0.5
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 8))
mdi_importances.sort_values().plot.barh(ax=ax1)
ax1.set_xlabel("Gini importance")
plot_permutation_importance(clf, X_train, y_train, ax2)
ax2.set_xlabel("Decrease in accuracy score")
fig.suptitle(
"Impurity-based vs. permutation importances on multicollinear features (train set)"
)
_ = fig.tight_layout()
# %%
# The plot on the left shows the Gini importance of the model. As the
# scikit-learn implementation of
# :class:`~sklearn.ensemble.RandomForestClassifier` uses a random subsets of
# :math:`\sqrt{n_\text{features}}` features at each split, it is able to dilute
# the dominance of any single correlated feature. As a result, the individual
# feature importance may be distributed more evenly among the correlated
# features. Since the features have large cardinality and the classifier is
# non-overfitted, we can relatively trust those values.
#
# The permutation importance on the right plot shows that permuting a feature
# drops the accuracy by at most `0.012`, which would suggest that none of the
# features are important. This is in contradiction with the high test accuracy
# computed as baseline: some feature must be important.
#
# Similarly, the change in accuracy score computed on the test set appears to be
# driven by chance:
fig, ax = plt.subplots(figsize=(7, 6))
plot_permutation_importance(clf, X_test, y_test, ax)
ax.set_title("Permutation Importances on multicollinear features\n(test set)")
ax.set_xlabel("Decrease in accuracy score")
_ = ax.figure.tight_layout()
# %%
# Nevertheless, one can still compute a meaningful permutation importance in the
# presence of correlated features, as demonstrated in the following section.
#
# Handling Multicollinear Features
# --------------------------------
# When features are collinear, permuting one feature has little effect on the
# models performance because it can get the same information from a correlated
# feature. Note that this is not the case for all predictive models and depends
# on their underlying implementation.
#
# One way to handle multicollinear features is by performing hierarchical
# clustering on the Spearman rank-order correlations, picking a threshold, and
# keeping a single feature from each cluster. First, we plot a heatmap of the
# correlated features:
from scipy.cluster import hierarchy
from scipy.spatial.distance import squareform
from scipy.stats import spearmanr
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 8))
corr = spearmanr(X).correlation
# Ensure the correlation matrix is symmetric
corr = (corr + corr.T) / 2
np.fill_diagonal(corr, 1)
# We convert the correlation matrix to a distance matrix before performing
# hierarchical clustering using Ward's linkage.
distance_matrix = 1 - np.abs(corr)
dist_linkage = hierarchy.ward(squareform(distance_matrix))
dendro = hierarchy.dendrogram(
dist_linkage, labels=X.columns.to_list(), ax=ax1, leaf_rotation=90
)
dendro_idx = np.arange(0, len(dendro["ivl"]))
ax2.imshow(corr[dendro["leaves"], :][:, dendro["leaves"]])
ax2.set_xticks(dendro_idx)
ax2.set_yticks(dendro_idx)
ax2.set_xticklabels(dendro["ivl"], rotation="vertical")
ax2.set_yticklabels(dendro["ivl"])
_ = fig.tight_layout()
# %%
# Next, we manually pick a threshold by visual inspection of the dendrogram to
# group our features into clusters and choose a feature from each cluster to
# keep, select those features from our dataset, and train a new random forest.
# The test accuracy of the new random forest did not change much compared to the
# random forest trained on the complete dataset.
from collections import defaultdict
cluster_ids = hierarchy.fcluster(dist_linkage, 1, criterion="distance")
cluster_id_to_feature_ids = defaultdict(list)
for idx, cluster_id in enumerate(cluster_ids):
cluster_id_to_feature_ids[cluster_id].append(idx)
selected_features = [v[0] for v in cluster_id_to_feature_ids.values()]
selected_features_names = X.columns[selected_features]
X_train_sel = X_train[selected_features_names]
X_test_sel = X_test[selected_features_names]
clf_sel = RandomForestClassifier(n_estimators=100, random_state=42)
clf_sel.fit(X_train_sel, y_train)
print(
"Baseline accuracy on test data with features removed:"
f" {clf_sel.score(X_test_sel, y_test):.2}"
)
# %%
# We can finally explore the permutation importance of the selected subset of
# features:
fig, ax = plt.subplots(figsize=(7, 6))
plot_permutation_importance(clf_sel, X_test_sel, y_test, ax)
ax.set_title("Permutation Importances on selected subset of features\n(test set)")
ax.set_xlabel("Decrease in accuracy score")
ax.figure.tight_layout()
plt.show()