隨機梯度下降的早期停止?

隨機梯度下降是一種優化技術,它以隨機的方式將損失函數降到最小,逐個樣本進行梯度下降。特別是對線性模型進行擬合是一種非常有效的方法。

作為一種隨機方法,損失函數在每一次迭代中都不一定減少,只有在期望的情況下才能保證收斂性。因此,對損失函數的收斂性進行監測是很困難的。

另一種方法是監視驗證分數的收斂性。在這種情況下,輸入數據被分成訓練集和驗證集。然后在訓練集上對模型進行擬合,停止準則基于在驗證集上計算的預測分數。這使我們能夠找到最少的迭代次數,這足以建立一個模型,該模型可以很好的地泛化到未見數據,并減少了過度擬合訓練數據的機會。

如果 early_stopping=True則早期停止策略被激活。否則,停止準則只對整個輸入數據使用訓練損失。為了更好地控制早期停止策略,我們可以指定一個參數validation_fraction,它設置我們保留的用于計算輸入數據集的驗證分數。優化將持續到驗證分數在最后一次迭代中( n_iter_no_change)不再提高(通過toy)為止。實際迭代次數可在屬性n_iter_中找到。

此示例演示了在sklearn.linear_model.SGDClassifier 模型中如何使用早期停止來實現與構建和不需要早期停止幾乎相同的精度的模型。這可以大大縮短訓練時間。請注意,使與早期迭代相比,停止標準之間的分數也有差異,因為一些訓練數據是使用驗證停止標準保存的。

No stopping criterion: .................................................
Training loss: .................................................
Validation score: .................................................
# Authors: Tom Dupre la Tour
#
# License: BSD 3 clause
import time
import sys

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from sklearn import linear_model
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.utils._testing import ignore_warnings
from sklearn.exceptions import ConvergenceWarning
from sklearn.utils import shuffle

print(__doc__)


def load_mnist(n_samples=None, class_0='0', class_1='8'):
    """Load MNIST, select two classes, shuffle and return only n_samples."""
    # Load data from http://openml.org/d/554
    mnist = fetch_openml('mnist_784', version=1)

    # take only two classes for binary classification
    mask = np.logical_or(mnist.target == class_0, mnist.target == class_1)

    X, y = shuffle(mnist.data[mask], mnist.target[mask], random_state=42)
    if n_samples is not None:
        X, y = X[:n_samples], y[:n_samples]
    return X, y


@ignore_warnings(category=ConvergenceWarning)
def fit_and_score(estimator, max_iter, X_train, X_test, y_train, y_test):
    """Fit the estimator on the train set and score it on both sets"""
    estimator.set_params(max_iter=max_iter)
    estimator.set_params(random_state=0)

    start = time.time()
    estimator.fit(X_train, y_train)

    fit_time = time.time() - start
    n_iter = estimator.n_iter_
    train_score = estimator.score(X_train, y_train)
    test_score = estimator.score(X_test, y_test)

    return fit_time, n_iter, train_score, test_score


# Define the estimators to compare
estimator_dict = {
    'No stopping criterion':
    linear_model.SGDClassifier(n_iter_no_change=3),
    'Training loss':
    linear_model.SGDClassifier(early_stopping=False, n_iter_no_change=3,
                               tol=0.1),
    'Validation score':
    linear_model.SGDClassifier(early_stopping=True, n_iter_no_change=3,
                               tol=0.0001, validation_fraction=0.2)
}

# Load the dataset
X, y = load_mnist(n_samples=10000)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5,
                                                    random_state=0)

results = []
for estimator_name, estimator in estimator_dict.items():
    print(estimator_name + ': ', end='')
    for max_iter in range(150):
        print('.', end='')
        sys.stdout.flush()

        fit_time, n_iter, train_score, test_score = fit_and_score(
            estimator, max_iter, X_train, X_test, y_train, y_test)

        results.append((estimator_name, max_iter, fit_time, n_iter,
                        train_score, test_score))
    print('')

# Transform the results in a pandas dataframe for easy plotting
columns = [
    'Stopping criterion''max_iter''Fit time (sec)''n_iter_',
    'Train score''Test score'
]
results_df = pd.DataFrame(results, columns=columns)

# Define what to plot (x_axis, y_axis)
lines = 'Stopping criterion'
plot_list = [
    ('max_iter''Train score'),
    ('max_iter''Test score'),
    ('max_iter''n_iter_'),
    ('max_iter''Fit time (sec)'),
]

nrows = 2
ncols = int(np.ceil(len(plot_list) / 2.))
fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(6 * ncols,
                                                            4 * nrows))
axes[00].get_shared_y_axes().join(axes[00], axes[01])

for ax, (x_axis, y_axis) in zip(axes.ravel(), plot_list):
    for criterion, group_df in results_df.groupby(lines):
        group_df.plot(x=x_axis, y=y_axis, label=criterion, ax=ax)
    ax.set_title(y_axis)
    ax.legend(title=lines)

fig.tight_layout()
plt.show()

腳本的總運行時間:(0分43.797秒)

Download Python source code: plot_sgd_early_stopping.py

Download Jupyter notebook: plot_sgd_early_stopping.ipynb