79 lines
2.3 KiB
Python
79 lines
2.3 KiB
Python
|
"""
|
||
|
===============================================================
|
||
|
Factor Analysis (with rotation) to visualize patterns
|
||
|
===============================================================
|
||
|
|
||
|
Investigating the Iris dataset, we see that sepal length, petal
|
||
|
length and petal width are highly correlated. Sepal width is
|
||
|
less redundant. Matrix decomposition techniques can uncover
|
||
|
these latent patterns. Applying rotations to the resulting
|
||
|
components does not inherently improve the predictive value
|
||
|
of the derived latent space, but can help visualise their
|
||
|
structure; here, for example, the varimax rotation, which
|
||
|
is found by maximizing the squared variances of the weights,
|
||
|
finds a structure where the second component only loads
|
||
|
positively on sepal width.
|
||
|
|
||
|
"""
|
||
|
|
||
|
# Authors: Jona Sassenhagen
|
||
|
# License: BSD 3 clause
|
||
|
|
||
|
import matplotlib.pyplot as plt
|
||
|
import numpy as np
|
||
|
|
||
|
from sklearn.datasets import load_iris
|
||
|
from sklearn.decomposition import PCA, FactorAnalysis
|
||
|
from sklearn.preprocessing import StandardScaler
|
||
|
|
||
|
# %%
|
||
|
# Load Iris data
|
||
|
data = load_iris()
|
||
|
X = StandardScaler().fit_transform(data["data"])
|
||
|
feature_names = data["feature_names"]
|
||
|
|
||
|
# %%
|
||
|
# Plot covariance of Iris features
|
||
|
ax = plt.axes()
|
||
|
|
||
|
im = ax.imshow(np.corrcoef(X.T), cmap="RdBu_r", vmin=-1, vmax=1)
|
||
|
|
||
|
ax.set_xticks([0, 1, 2, 3])
|
||
|
ax.set_xticklabels(list(feature_names), rotation=90)
|
||
|
ax.set_yticks([0, 1, 2, 3])
|
||
|
ax.set_yticklabels(list(feature_names))
|
||
|
|
||
|
plt.colorbar(im).ax.set_ylabel("$r$", rotation=0)
|
||
|
ax.set_title("Iris feature correlation matrix")
|
||
|
plt.tight_layout()
|
||
|
|
||
|
# %%
|
||
|
# Run factor analysis with Varimax rotation
|
||
|
n_comps = 2
|
||
|
|
||
|
methods = [
|
||
|
("PCA", PCA()),
|
||
|
("Unrotated FA", FactorAnalysis()),
|
||
|
("Varimax FA", FactorAnalysis(rotation="varimax")),
|
||
|
]
|
||
|
fig, axes = plt.subplots(ncols=len(methods), figsize=(10, 8), sharey=True)
|
||
|
|
||
|
for ax, (method, fa) in zip(axes, methods):
|
||
|
fa.set_params(n_components=n_comps)
|
||
|
fa.fit(X)
|
||
|
|
||
|
components = fa.components_.T
|
||
|
print("\n\n %s :\n" % method)
|
||
|
print(components)
|
||
|
|
||
|
vmax = np.abs(components).max()
|
||
|
ax.imshow(components, cmap="RdBu_r", vmax=vmax, vmin=-vmax)
|
||
|
ax.set_yticks(np.arange(len(feature_names)))
|
||
|
ax.set_yticklabels(feature_names)
|
||
|
ax.set_title(str(method))
|
||
|
ax.set_xticks([0, 1])
|
||
|
ax.set_xticklabels(["Comp. 1", "Comp. 2"])
|
||
|
fig.suptitle("Factors")
|
||
|
plt.tight_layout()
|
||
|
plt.show()
|