在鳶尾花數據集中繪制不同的SVM分類器?

在鳶尾花數據集的二維投影上比較不同的線性支持向量機分類器。我們僅考慮此數據集的前兩個特征:

  • 花萼長度

  • 花萼寬度

此案例會說明如何繪制具有不同核函數的四個SVM分類器的決策平面。

線性模型LinearSVC()和SVC(kernel ='linear')得出的決策邊界略有不同。 這可能是由于以下差異造成的:

  • LinearSVC()最小化平方hinge損失,而SVC最小化普通的hinge損失。

  • LinearSVC()使用“一對全(One-vs-All)”(也稱為“一對多 One-vs-Rest”)多類歸約,而SVC使用“一對一(One-vs-One)”多類歸約。

兩種線性模型都有線性決策邊界(對多分類而言,相交的超平面),而非線性內核模型(多項式或高斯RBF)則具有更靈活的非線性決策邊界,其形狀取決于內核的種類及其參數。

注意:

雖然為玩具(toy)二維數據集繪制分類器的決策函數可以幫助直觀了解各個核函數的表達能力,但請注意,這些直覺并不總是會推廣到更現實的高維問題。

輸入:

import numpy as np
import matplotlib.pyplot as plt
from sklearn import svm, datasets


def make_meshgrid(x, y, h=.02):
    """創建要繪制的點網格

    參數
    ----------
    x: 創建網格x軸所需要的數據
    y: 創建網格y軸所需要的數據
    h: 網格大小的可選大小,可選填

    返回
    -------
    xx, yy : n維數組
    """

    x_min, x_max = x.min() - 1, x.max() + 1
    y_min, y_max = y.min() - 1, y.max() + 1
    xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
                         np.arange(y_min, y_max, h))
    return xx, yy


def plot_contours(ax, clf, xx, yy, **params):
    """繪制分類器的決策邊界。

    參數
    ----------
    ax: matplotlib子圖對象
    clf: 一個分類器
    xx: 網狀網格meshgrid的n維數組
    yy: 網狀網格meshgrid的n維數組
    params: 傳遞給contourf的參數字典,可選填
    """

    Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
    Z = Z.reshape(xx.shape)
    out = ax.contourf(xx, yy, Z, **params)
    return out


# 導入數據以便后續使用
iris = datasets.load_iris()
# 采用前兩個特征。我們可以通過使用二維數據集來避免使用切片。
X = iris.data[:, :2]
y = iris.target

# 我們創建一個SVM實例并擬合數據。由于要繪制支持向量,因此我們不縮放數據
C = 1.0  # SVM正則化參數
models = (svm.SVC(kernel='linear', C=C),
          svm.LinearSVC(C=C, max_iter=10000),
          svm.SVC(kernel='rbf', gamma=0.7, C=C),
          svm.SVC(kernel='poly', degree=3, gamma='auto', C=C))
models = (clf.fit(X, y) for clf in models)

# 為圖像設置標題
titles = ('SVC with linear kernel',
          'LinearSVC (linear kernel)',
          'SVC with RBF kernel',
          'SVC with polynomial (degree 3) kernel')

# 設置一個2x2結構的畫布
fig, sub = plt.subplots(22)
plt.subplots_adjust(wspace=0.4, hspace=0.4)

X0, X1 = X[:, 0], X[:, 1]
xx, yy = make_meshgrid(X0, X1)

for clf, title, ax in zip(models, titles, sub.flatten()):
    plot_contours(ax, clf, xx, yy,
                  cmap=plt.cm.coolwarm, alpha=0.8)
    ax.scatter(X0, X1, c=y, cmap=plt.cm.coolwarm, s=20, edgecolors='k')
    ax.set_xlim(xx.min(), xx.max())
    ax.set_ylim(yy.min(), yy.max())
    ax.set_xlabel('Sepal length')
    ax.set_ylabel('Sepal width')
    ax.set_xticks(())
    ax.set_yticks(())
    ax.set_title(title)

plt.show()

腳本的總運行時間:(0分鐘0.570秒)