在scikit-learn中可視化交叉驗證行為?
選擇正確的交叉驗證對象是正確擬合模型的關鍵部分。有很多方法可以將數據分為訓練集和測試集,從而避免模型過度擬合,例如標準化測試集中的組數等。
本示例將幾個常見的scikit學習對象的行為可視化以進行比較。
from sklearn.model_selection import (TimeSeriesSplit, KFold, ShuffleSplit,
StratifiedKFold, GroupShuffleSplit,
GroupKFold, StratifiedShuffleSplit)
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
np.random.seed(1338)
cmap_data = plt.cm.Paired
cmap_cv = plt.cm.coolwarm
n_splits = 4
可視化我們的數據
首先,我們必須了解數據的結構。它包含100個隨機生成的輸入數據點,數據點之間標簽被不均勻地劃分為三類,同時我們均勻劃分了10個“組”。
正如我們將看到的,一些交叉驗證對象對帶有標簽的數據執行特定的操作,另一些對分組數據的處理方式有所不同,而另一些則不使用此信息。
首先,我們將可視化數據。
# 生成類別/組數據
n_points = 100
X = np.random.randn(100, 10)
percentiles_classes = [.1, .3, .6]
y = np.hstack([[ii] * int(100 * perc)
for ii, perc in enumerate(percentiles_classes)])
# 間隔均勻的組重復一次
groups = np.hstack([[ii] * 10 for ii in range(10)])
def visualize_groups(classes, groups, name):
# 可視化數據集組
fig, ax = plt.subplots()
ax.scatter(range(len(groups)), [.5] * len(groups), c=groups, marker='_',
lw=50, cmap=cmap_data)
ax.scatter(range(len(groups)), [3.5] * len(groups), c=classes, marker='_',
lw=50, cmap=cmap_data)
ax.set(ylim=[-1, 5], yticks=[.5, 3.5],
yticklabels=['Data\ngroup', 'Data\nclass'], xlabel="Sample index")
visualize_groups(y, groups, 'no groups')

定義一個函數以可視化交叉驗證行為
我們將定義一個函數,使我們可以可視化每個交叉驗證對象的行為。 我們將對數據進行4次拆分。在每個分組中,我們將為訓練集(藍色)和測試集(紅色)可視化選擇的索引。
def plot_cv_indices(cv, X, y, group, ax, n_splits, lw=10):
"""為交叉驗證對象的索引創建樣本圖."""
# 為每個交叉驗證分組生成訓練/測試可視化圖像
for ii, (tr, tt) in enumerate(cv.split(X=X, y=y, groups=group)):
# 與訓練/測試組一起填寫索引
indices = np.array([np.nan] * len(X))
indices[tt] = 1
indices[tr] = 0
# 可視化結果
ax.scatter(range(len(indices)), [ii + .5] * len(indices),
c=indices, marker='_', lw=lw, cmap=cmap_cv,
vmin=-.2, vmax=1.2)
# 將數據的分組情況和標簽情況放入圖像
ax.scatter(range(len(X)), [ii + 1.5] * len(X),
c=y, marker='_', lw=lw, cmap=cmap_data)
ax.scatter(range(len(X)), [ii + 2.5] * len(X),
c=group, marker='_', lw=lw, cmap=cmap_data)
# 調整格式
yticklabels = list(range(n_splits)) + ['class', 'group']
ax.set(yticks=np.arange(n_splits+2) + .5, yticklabels=yticklabels,
xlabel='Sample index', ylabel="CV iteration",
ylim=[n_splits+2.2, -.2], xlim=[0, 100])
ax.set_title('{}'.format(type(cv).__name__), fontsize=15)
return ax
現在看看K折交叉驗證對象可視化后效果如何:
fig, ax = plt.subplots()
cv = KFold(n_splits)
plot_cv_indices(cv, X, y, groups, ax, n_splits)
輸出:

<matplotlib.axes._subplots.AxesSubplot object at 0x7f96064f9190>
如您所見,默認情況下,K折交叉驗證迭代器不考慮數據點類或組。我們可以像這樣使用StratifiedKFold來改變它。
fig, ax = plt.subplots()
cv = StratifiedKFold(n_splits)
plot_cv_indices(cv, X, y, groups, ax, n_splits)

<matplotlib.axes._subplots.AxesSubplot object at 0x7f96042325b0>
在這種情況下,交叉驗證在每個CV劃分中保留相同的類比例。 接下來,我們將可視化許多CV迭代器的行為。
可視化許多CV對象的交叉驗證索引
讓我們直觀地比較許多scikit-learn交叉驗證對象的交叉驗證行為。下面,我們將循環瀏覽幾個常見的交叉驗證對象,以可視化每個對象的行為。
注意有些交叉驗證如何使用組/類信息,而有些交叉驗證則不使用。
cvs = [KFold, GroupKFold, ShuffleSplit, StratifiedKFold,
GroupShuffleSplit, StratifiedShuffleSplit,
TimeSeriesSplit]
for cv in cvs:
this_cv = cv(n_splits=n_splits)
fig, ax = plt.subplots(figsize=(6, 3))
plot_cv_indices(this_cv, X, y, groups, ax, n_splits)
ax.legend([Patch(color=cmap_cv(.8)), Patch(color=cmap_cv(.02))],
['Testing set', 'Training set'], loc=(1.02, .8))
# Make the legend fit
plt.tight_layout()
fig.subplots_adjust(right=.7)
plt.show()








腳本的總運行時間:(0分鐘0.937秒)