具有異構數據源的列變形器?

數據集通常可以包含需要不同特征提取和處理管道的組件。 在以下情況下可能會發生這種情況:

  1. 您的數據集包含異構數據類型(例如,光柵圖像和文字標題),

  2. 您的數據集存儲在pandas.DataFrame中,不同的列需要不同的處理管道。

本示例演示如何在包含不同類型要素的數據集上使用ColumnTransformer。 功能的選擇并不是特別有幫助,但是可以用來說明該技術。

# Author: Matt Terry <matt.terry@gmail.com>
#
# License: BSD 3 clause

import numpy as np

from sklearn.preprocessing import FunctionTransformer
from sklearn.datasets import fetch_20newsgroups
from sklearn.decomposition import TruncatedSVD
from sklearn.feature_extraction import DictVectorizer
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics import classification_report
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer
from sklearn.svm import LinearSVC

20個新聞組數據集

我們將使用20個新聞組數據集,其中包含來自20個主題的新聞組中的帖子。 該數據集根據特定日期之前和之后發布的消息分為訓練和測試子集。 我們將只使用2個類別的帖子來加快運行時間。

categories = ['sci.med''sci.space']
X_train, y_train = fetch_20newsgroups(random_state=1,
                                      subset='train',
                                      categories=categories,
                                      remove=('footers''quotes'),
                                      return_X_y=True)
X_test, y_test = fetch_20newsgroups(random_state=1,
                                    subset='test',
                                    categories=categories,
                                    remove=('footers''quotes'),
                                    return_X_y=True)

print(X_train[0])

每個功能都包含有關該帖子的元信息,例如主題和新聞帖子的主體。

輸出:

Subject: Re: Metric vs English
Article-I.D.: mksol.1993Apr6.131900.8407
Organization: Texas Instruments Inc
Lines: 31




American, perhaps, but nothing military about it.  I learned (mostly)
slugs when we talked English units in high school physics and while
the teacher was an ex-Navy fighter jock the book certainly wasn't
produced by the military.

[Poundals were just too flinking small and made the math come out
funny; sort of the same reason proponents of SI give for using that.]

--
"Insisting on perfect safety is for people who don'
t have the balls to live in the real world."   -- Mary Shafer, NASA Ames Dryden

(譯者注:這里輸出的是英文稿件中的內容,若翻譯成中文用戶將感到非常困惑,故保留英文內容,不予翻譯。)

創建轉換器

首先,我們需要一個轉換器來提取每個帖子的主題和正文。 由于這是無狀態轉換(不需要訓練數據中的狀態信息),因此我們可以定義一個執行數據轉換的函數,然后使用FunctionTransformer創建scikit-learn轉換器。

def subject_body_extractor(posts):
 #用兩列構造對象dtype數組
    #第一列=“主題”,第二列=“主體”
    features = np.empty(shape=(len(posts), 2), dtype=object)
    for i, text in enumerate(posts):
        # 臨時變量“ _”存儲“ \ n \ n”
        headers, _, body = text.partition('\n\n')
        # 將正文存儲在第二欄中
        features[i, 1] = body

        prefix = 'Subject:'
        sub = ''
        # 在第一欄中的“主題:”之后保存文本
        for line in headers.split('\n'):
            if line.startswith(prefix):
                sub = line[len(prefix):]
                break
        features[i, 0] = sub

    return features


subject_body_transformer = FunctionTransformer(subject_body_extractor)

我們還將創建一個轉換器,以提取文本的長度和句子的數量。

def text_stats(posts):
    return [{'length': len(text),
             'num_sentences': text.count('.')}
            for text in posts]


text_stats_transformer = FunctionTransformer(text_stats)

分類管道

下面的管道使用SubjectBodyExtractor從每個帖子中提取主題和正文,生成(n_samples,2)數組。 然后,使用ColumnTransformer,將此數組用于計算主題和正文的標準詞袋特征以及正文的文本長度和句子數。 我們將它們與權重結合在一起,然后根據結合的特征集訓練分類器。

pipeline = Pipeline([
    # 提取標題和文字內容主體
    ('subjectbody', subject_body_transformer),
    # 使用ColumnTransformer組合標題和主體特征
    ('union', ColumnTransformer(
        [
            # 標題詞袋(col 0)
            ('subject', TfidfVectorizer(min_df=50), 0),
            # 文章主體分解的詞袋(col 1)
            ('body_bow', Pipeline([
                ('tfidf', TfidfVectorizer()),
                ('best', TruncatedSVD(n_components=50)),
            ]), 1),
            # 從帖子的正文中提取文本統計信息的管道
            ('body_stats', Pipeline([
                ('stats', text_stats_transformer),  # 返回字典列表
                ('vect', DictVectorizer()),  # 字典列表->特征矩陣
            ]), 1),
        ],
        # ColumnTransformer功能上的權重
        transformer_weights={
            'subject'0.8,
            'body_bow'0.5,
            'body_stats'1.0,
        }
    )),
    # 在組合功能上使用SVC分類器
    ('svc', LinearSVC(dual=False)),
], verbose=True)

最后,我們將培訓數據擬合到管道中,并使用它來預測X_test的主題。 然后打印我們管道的性能指標。

pipeline.fit(X_train, y_train)
y_pred = pipeline.predict(X_test)
print('Classification report:\n\n{}'.format(
    classification_report(y_test, y_pred))
)

輸出:

[Pipeline] ....... (step 1 of 3) Processing subjectbody, total=   0.0s
[Pipeline] ............. (step 2 of 3) Processing union, total=   0.6s
[Pipeline] ............... (step 3 of 3) Processing svc, total=   0.0s
Classification report:

              precision    recall  f1-score   support

           0       0.84      0.88      0.86       396
           1       0.87      0.83      0.85       394

    accuracy                           0.85       790
   macro avg       0.85      0.85      0.85       790
weighted avg       0.85      0.85      0.85       790

腳本的總運行時間:(0分鐘3.380秒)