帶有可視化API的ROC曲線?

Scikit-learn定義了一個簡單的API,用于創建用于機器學習的可視化。 該API的主要功能是無需重新計算即可進行快速繪圖和視覺調整。 在此示例中,我們將通過比較ROC曲線來演示如何使用可視化API。

導入數據并訓練一個支持向量機

首先,我們加載紅酒數據集并將其轉換為二分類分類問題。 然后,我們在訓練數據集上訓練支持向量分類器。

import matplotlib.pyplot as plt
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import plot_roc_curve
from sklearn.datasets import load_wine
from sklearn.model_selection import train_test_split

X, y = load_wine(return_X_y=True)
y = y == 2

X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
svc = SVC(random_state=42)
svc.fit(X_train, y_train)

輸出:

繪制ROC曲線

接下來,我們通過一次調用sklearn.metrics.plot_roc_curve繪制ROC曲線。 返回的svc_disp對象使我們可以在以后的圖中繼續使用已經計算出的SVC ROC曲線。

svc_disp = plot_roc_curve(svc, X_test, y_test)
plt.show()

輸出:

訓練隨機森林并繪制ROC曲線

我們訓練一個隨機森林分類器,并創建一個將其與SVC ROC曲線進行比較的圖。 注意svc_disp如何使用plot來繪制SVC ROC曲線,而無需重新計算roc曲線本身的值。 此外,我們將alpha = 0.8傳遞給繪圖函數以調整曲線的alpha值。

rfc = RandomForestClassifier(n_estimators=10, random_state=42)
rfc.fit(X_train, y_train)
ax = plt.gca()
rfc_disp = plot_roc_curve(rfc, X_test, y_test, ax=ax, alpha=0.8)
svc_disp.plot(ax=ax, alpha=0.8)
plt.show()

輸出:

腳本的總運行時間:0分鐘0.233秒