Visualization of MLP weights on MNISTΒΆ

Sometimes looking at the learned coefficients of a neural network can provide insight into the learning behavior. For example if weights look unstructured, maybe some were not used at all, or if very large coefficients exist, maybe regularization was too low or the learning rate too high.

This example shows how to plot some of the first layer weights in a MLPClassifier trained on the MNIST dataset.

The input data consists of 28x28 pixel handwritten digits, leading to 784 features in the dataset. Therefore the first layer weight matrix have the shape (784, hidden_layer_sizes[0]). We can therefore visualize a single column of the weight matrix as a 28x28 pixel image.

To make the example run faster, we use very few hidden units, and train only for a very short time. Training longer would result in weights with a much smoother spatial appearance.

Traceback (most recent call last):
  File "/build/scikit-learn-6l_zuy/scikit-learn-0.20.2+dfsg/examples/neural_networks/plot_mnist_filters.py", line 30, in <module>
    X, y = fetch_openml('mnist_784', version=1, return_X_y=True)
  File "/build/scikit-learn-6l_zuy/scikit-learn-0.20.2+dfsg/.pybuild/cpython3_3.7/build/sklearn/datasets/openml.py", line 526, in fetch_openml
    data_info = _get_data_info_by_name(name, version, data_home)
  File "/build/scikit-learn-6l_zuy/scikit-learn-0.20.2+dfsg/.pybuild/cpython3_3.7/build/sklearn/datasets/openml.py", line 314, in _get_data_info_by_name
    data_home)
  File "/build/scikit-learn-6l_zuy/scikit-learn-0.20.2+dfsg/.pybuild/cpython3_3.7/build/sklearn/datasets/openml.py", line 164, in _get_json_content_from_openml_api
    return _load_json()
  File "/build/scikit-learn-6l_zuy/scikit-learn-0.20.2+dfsg/.pybuild/cpython3_3.7/build/sklearn/datasets/openml.py", line 62, in wrapper
    return f()
  File "/build/scikit-learn-6l_zuy/scikit-learn-0.20.2+dfsg/.pybuild/cpython3_3.7/build/sklearn/datasets/openml.py", line 160, in _load_json
    with closing(_open_openml_url(url, data_home)) as response:
  File "/build/scikit-learn-6l_zuy/scikit-learn-0.20.2+dfsg/.pybuild/cpython3_3.7/build/sklearn/datasets/openml.py", line 109, in _open_openml_url
    with closing(urlopen(req)) as fsrc:
  File "/usr/lib/python3.7/urllib/request.py", line 222, in urlopen
    return opener.open(url, data, timeout)
  File "/usr/lib/python3.7/urllib/request.py", line 525, in open
    response = self._open(req, data)
  File "/usr/lib/python3.7/urllib/request.py", line 543, in _open
    '_open', req)
  File "/usr/lib/python3.7/urllib/request.py", line 503, in _call_chain
    result = func(*args)
  File "/usr/lib/python3.7/urllib/request.py", line 1360, in https_open
    context=self._context, check_hostname=self._check_hostname)
  File "/usr/lib/python3.7/urllib/request.py", line 1319, in do_open
    raise URLError(err)
urllib.error.URLError: <urlopen error [Errno 111] Connection refused>
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_openml
from sklearn.neural_network import MLPClassifier

print(__doc__)

# Load data from https://www.openml.org/d/554
X, y = fetch_openml('mnist_784', version=1, return_X_y=True)
X = X / 255.

# rescale the data, use the traditional train/test split
X_train, X_test = X[:60000], X[60000:]
y_train, y_test = y[:60000], y[60000:]

# mlp = MLPClassifier(hidden_layer_sizes=(100, 100), max_iter=400, alpha=1e-4,
#                     solver='sgd', verbose=10, tol=1e-4, random_state=1)
mlp = MLPClassifier(hidden_layer_sizes=(50,), max_iter=10, alpha=1e-4,
                    solver='sgd', verbose=10, tol=1e-4, random_state=1,
                    learning_rate_init=.1)

mlp.fit(X_train, y_train)
print("Training set score: %f" % mlp.score(X_train, y_train))
print("Test set score: %f" % mlp.score(X_test, y_test))

fig, axes = plt.subplots(4, 4)
# use global min / max to ensure all weights are shown on the same scale
vmin, vmax = mlp.coefs_[0].min(), mlp.coefs_[0].max()
for coef, ax in zip(mlp.coefs_[0].T, axes.ravel()):
    ax.matshow(coef.reshape(28, 28), cmap=plt.cm.gray, vmin=.5 * vmin,
               vmax=.5 * vmax)
    ax.set_xticks(())
    ax.set_yticks(())

plt.show()

Total running time of the script: ( 0 minutes 0.000 seconds)

Gallery generated by Sphinx-Gallery