在scikit-learn中可視化交叉驗證行為?

選擇正確的交叉驗證對象是正確擬合模型的關鍵部分。有很多方法可以將數據分為訓練集和測試集,從而避免模型過度擬合,例如標準化測試集中的組數等。

本示例將幾個常見的scikit學習對象的行為可視化以進行比較。

from sklearn.model_selection import (TimeSeriesSplit, KFold, ShuffleSplit,
                                     StratifiedKFold, GroupShuffleSplit,
                                     GroupKFold, StratifiedShuffleSplit)
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
np.random.seed(1338)
cmap_data = plt.cm.Paired
cmap_cv = plt.cm.coolwarm
n_splits = 4

可視化我們的數據

首先,我們必須了解數據的結構。它包含100個隨機生成的輸入數據點,數據點之間標簽被不均勻地劃分為三類,同時我們均勻劃分了10個“組”。

正如我們將看到的,一些交叉驗證對象對帶有標簽的數據執行特定的操作,另一些對分組數據的處理方式有所不同,而另一些則不使用此信息。

首先,我們將可視化數據。

# 生成類別/組數據
n_points = 100
X = np.random.randn(10010)

percentiles_classes = [.1.3.6]
y = np.hstack([[ii] * int(100 * perc)
               for ii, perc in enumerate(percentiles_classes)])

# 間隔均勻的組重復一次
groups = np.hstack([[ii] * 10 for ii in range(10)])


def visualize_groups(classes, groups, name):
    # 可視化數據集組
    fig, ax = plt.subplots()
    ax.scatter(range(len(groups)),  [.5] * len(groups), c=groups, marker='_',
               lw=50, cmap=cmap_data)
    ax.scatter(range(len(groups)),  [3.5] * len(groups), c=classes, marker='_',
               lw=50, cmap=cmap_data)
    ax.set(ylim=[-15], yticks=[.53.5],
           yticklabels=['Data\ngroup''Data\nclass'], xlabel="Sample index")


visualize_groups(y, groups, 'no groups')

定義一個函數以可視化交叉驗證行為

我們將定義一個函數,使我們可以可視化每個交叉驗證對象的行為。 我們將對數據進行4次拆分。在每個分組中,我們將為訓練集(藍色)和測試集(紅色)可視化選擇的索引。

def plot_cv_indices(cv, X, y, group, ax, n_splits, lw=10):
    """為交叉驗證對象的索引創建樣本圖."""

    # 為每個交叉驗證分組生成訓練/測試可視化圖像
    for ii, (tr, tt) in enumerate(cv.split(X=X, y=y, groups=group)):
        # 與訓練/測試組一起填寫索引
        indices = np.array([np.nan] * len(X))
        indices[tt] = 1
        indices[tr] = 0

        # 可視化結果
        ax.scatter(range(len(indices)), [ii + .5] * len(indices),
                   c=indices, marker='_', lw=lw, cmap=cmap_cv,
                   vmin=-.2, vmax=1.2)

    # 將數據的分組情況和標簽情況放入圖像
    ax.scatter(range(len(X)), [ii + 1.5] * len(X),
               c=y, marker='_', lw=lw, cmap=cmap_data)

    ax.scatter(range(len(X)), [ii + 2.5] * len(X),
               c=group, marker='_', lw=lw, cmap=cmap_data)

    # 調整格式
    yticklabels = list(range(n_splits)) + ['class''group']
    ax.set(yticks=np.arange(n_splits+2) + .5, yticklabels=yticklabels,
           xlabel='Sample index', ylabel="CV iteration",
           ylim=[n_splits+2.2-.2], xlim=[0100])
    ax.set_title('{}'.format(type(cv).__name__), fontsize=15)
    return ax

現在看看K折交叉驗證對象可視化后效果如何:

fig, ax = plt.subplots()
cv = KFold(n_splits)
plot_cv_indices(cv, X, y, groups, ax, n_splits)

輸出:

<matplotlib.axes._subplots.AxesSubplot object at 0x7f96064f9190>

如您所見,默認情況下,K折交叉驗證迭代器不考慮數據點類或組。我們可以像這樣使用StratifiedKFold來改變它。

fig, ax = plt.subplots()
cv = StratifiedKFold(n_splits)
plot_cv_indices(cv, X, y, groups, ax, n_splits)
<matplotlib.axes._subplots.AxesSubplot object at 0x7f96042325b0>

在這種情況下,交叉驗證在每個CV劃分中保留相同的類比例。 接下來,我們將可視化許多CV迭代器的行為。

可視化許多CV對象的交叉驗證索引

讓我們直觀地比較許多scikit-learn交叉驗證對象的交叉驗證行為。下面,我們將循環瀏覽幾個常見的交叉驗證對象,以可視化每個對象的行為。

注意有些交叉驗證如何使用組/類信息,而有些交叉驗證則不使用。

cvs = [KFold, GroupKFold, ShuffleSplit, StratifiedKFold,
       GroupShuffleSplit, StratifiedShuffleSplit,
       TimeSeriesSplit]

for cv in cvs:
    this_cv = cv(n_splits=n_splits)
    fig, ax = plt.subplots(figsize=(63))
    plot_cv_indices(this_cv, X, y, groups, ax, n_splits)

    ax.legend([Patch(color=cmap_cv(.8)), Patch(color=cmap_cv(.02))],
              ['Testing set''Training set'], loc=(1.02.8))
    # Make the legend fit
    plt.tight_layout()
    fig.subplots_adjust(right=.7)
plt.show()

腳本的總運行時間:(0分鐘0.937秒)