71 lines
2.3 KiB
Python
71 lines
2.3 KiB
Python
|
"""
|
||
|
====================================================
|
||
|
Plot multinomial and One-vs-Rest Logistic Regression
|
||
|
====================================================
|
||
|
|
||
|
Plot decision surface of multinomial and One-vs-Rest Logistic Regression.
|
||
|
The hyperplanes corresponding to the three One-vs-Rest (OVR) classifiers
|
||
|
are represented by the dashed lines.
|
||
|
|
||
|
"""
|
||
|
|
||
|
# Authors: Tom Dupre la Tour <tom.dupre-la-tour@m4x.org>
|
||
|
# License: BSD 3 clause
|
||
|
|
||
|
import matplotlib.pyplot as plt
|
||
|
import numpy as np
|
||
|
|
||
|
from sklearn.datasets import make_blobs
|
||
|
from sklearn.inspection import DecisionBoundaryDisplay
|
||
|
from sklearn.linear_model import LogisticRegression
|
||
|
from sklearn.multiclass import OneVsRestClassifier
|
||
|
|
||
|
# make 3-class dataset for classification
|
||
|
centers = [[-5, 0], [0, 1.5], [5, -1]]
|
||
|
X, y = make_blobs(n_samples=1000, centers=centers, random_state=40)
|
||
|
transformation = [[0.4, 0.2], [-0.4, 1.2]]
|
||
|
X = np.dot(X, transformation)
|
||
|
|
||
|
for multi_class in ("multinomial", "ovr"):
|
||
|
clf = LogisticRegression(solver="sag", max_iter=100, random_state=42)
|
||
|
if multi_class == "ovr":
|
||
|
clf = OneVsRestClassifier(clf)
|
||
|
clf.fit(X, y)
|
||
|
|
||
|
# print the training scores
|
||
|
print("training score : %.3f (%s)" % (clf.score(X, y), multi_class))
|
||
|
|
||
|
_, ax = plt.subplots()
|
||
|
DecisionBoundaryDisplay.from_estimator(
|
||
|
clf, X, response_method="predict", cmap=plt.cm.Paired, ax=ax
|
||
|
)
|
||
|
plt.title("Decision surface of LogisticRegression (%s)" % multi_class)
|
||
|
plt.axis("tight")
|
||
|
|
||
|
# Plot also the training points
|
||
|
colors = "bry"
|
||
|
for i, color in zip(clf.classes_, colors):
|
||
|
idx = np.where(y == i)
|
||
|
plt.scatter(X[idx, 0], X[idx, 1], c=color, edgecolor="black", s=20)
|
||
|
|
||
|
# Plot the three one-against-all classifiers
|
||
|
xmin, xmax = plt.xlim()
|
||
|
ymin, ymax = plt.ylim()
|
||
|
if multi_class == "ovr":
|
||
|
coef = np.concatenate([est.coef_ for est in clf.estimators_])
|
||
|
intercept = np.concatenate([est.intercept_ for est in clf.estimators_])
|
||
|
else:
|
||
|
coef = clf.coef_
|
||
|
intercept = clf.intercept_
|
||
|
|
||
|
def plot_hyperplane(c, color):
|
||
|
def line(x0):
|
||
|
return (-(x0 * coef[c, 0]) - intercept[c]) / coef[c, 1]
|
||
|
|
||
|
plt.plot([xmin, xmax], [line(xmin), line(xmax)], ls="--", color=color)
|
||
|
|
||
|
for i, color in zip(clf.classes_, colors):
|
||
|
plot_hyperplane(i, color)
|
||
|
|
||
|
plt.show()
|