127 lines
3.5 KiB
Python
127 lines
3.5 KiB
Python
"""
|
|
==================================================
|
|
Balance model complexity and cross-validated score
|
|
==================================================
|
|
|
|
This example balances model complexity and cross-validated score by
|
|
finding a decent accuracy within 1 standard deviation of the best accuracy
|
|
score while minimising the number of PCA components [1].
|
|
|
|
The figure shows the trade-off between cross-validated score and the number
|
|
of PCA components. The balanced case is when n_components=10 and accuracy=0.88,
|
|
which falls into the range within 1 standard deviation of the best accuracy
|
|
score.
|
|
|
|
[1] Hastie, T., Tibshirani, R.,, Friedman, J. (2001). Model Assessment and
|
|
Selection. The Elements of Statistical Learning (pp. 219-260). New York,
|
|
NY, USA: Springer New York Inc..
|
|
|
|
"""
|
|
|
|
# Author: Wenhao Zhang <wenhaoz@ucla.edu>
|
|
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
|
|
from sklearn.datasets import load_digits
|
|
from sklearn.decomposition import PCA
|
|
from sklearn.model_selection import GridSearchCV
|
|
from sklearn.pipeline import Pipeline
|
|
from sklearn.svm import LinearSVC
|
|
|
|
|
|
def lower_bound(cv_results):
|
|
"""
|
|
Calculate the lower bound within 1 standard deviation
|
|
of the best `mean_test_scores`.
|
|
|
|
Parameters
|
|
----------
|
|
cv_results : dict of numpy(masked) ndarrays
|
|
See attribute cv_results_ of `GridSearchCV`
|
|
|
|
Returns
|
|
-------
|
|
float
|
|
Lower bound within 1 standard deviation of the
|
|
best `mean_test_score`.
|
|
"""
|
|
best_score_idx = np.argmax(cv_results["mean_test_score"])
|
|
|
|
return (
|
|
cv_results["mean_test_score"][best_score_idx]
|
|
- cv_results["std_test_score"][best_score_idx]
|
|
)
|
|
|
|
|
|
def best_low_complexity(cv_results):
|
|
"""
|
|
Balance model complexity with cross-validated score.
|
|
|
|
Parameters
|
|
----------
|
|
cv_results : dict of numpy(masked) ndarrays
|
|
See attribute cv_results_ of `GridSearchCV`.
|
|
|
|
Return
|
|
------
|
|
int
|
|
Index of a model that has the fewest PCA components
|
|
while has its test score within 1 standard deviation of the best
|
|
`mean_test_score`.
|
|
"""
|
|
threshold = lower_bound(cv_results)
|
|
candidate_idx = np.flatnonzero(cv_results["mean_test_score"] >= threshold)
|
|
best_idx = candidate_idx[
|
|
cv_results["param_reduce_dim__n_components"][candidate_idx].argmin()
|
|
]
|
|
return best_idx
|
|
|
|
|
|
pipe = Pipeline(
|
|
[
|
|
("reduce_dim", PCA(random_state=42)),
|
|
("classify", LinearSVC(random_state=42, C=0.01)),
|
|
]
|
|
)
|
|
|
|
param_grid = {"reduce_dim__n_components": [6, 8, 10, 12, 14]}
|
|
|
|
grid = GridSearchCV(
|
|
pipe,
|
|
cv=10,
|
|
n_jobs=1,
|
|
param_grid=param_grid,
|
|
scoring="accuracy",
|
|
refit=best_low_complexity,
|
|
)
|
|
X, y = load_digits(return_X_y=True)
|
|
grid.fit(X, y)
|
|
|
|
n_components = grid.cv_results_["param_reduce_dim__n_components"]
|
|
test_scores = grid.cv_results_["mean_test_score"]
|
|
|
|
plt.figure()
|
|
plt.bar(n_components, test_scores, width=1.3, color="b")
|
|
|
|
lower = lower_bound(grid.cv_results_)
|
|
plt.axhline(np.max(test_scores), linestyle="--", color="y", label="Best score")
|
|
plt.axhline(lower, linestyle="--", color=".5", label="Best score - 1 std")
|
|
|
|
plt.title("Balance model complexity and cross-validated score")
|
|
plt.xlabel("Number of PCA components used")
|
|
plt.ylabel("Digit classification accuracy")
|
|
plt.xticks(n_components.tolist())
|
|
plt.ylim((0, 1.0))
|
|
plt.legend(loc="upper left")
|
|
|
|
best_index_ = grid.best_index_
|
|
|
|
print("The best_index_ is %d" % best_index_)
|
|
print("The n_components selected is %d" % n_components[best_index_])
|
|
print(
|
|
"The corresponding accuracy score is %.2f"
|
|
% grid.cv_results_["mean_test_score"][best_index_]
|
|
)
|
|
plt.show()
|