from sklearn import datasets
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.naive_bayes import MultinomialNB
from sklearn.cross_validation import train_test_split
from sklearn.metrics import classification_report

# 垃圾邮件过滤

def spamTest_skl():
    
    # 加载email文件夹下的数据
    base_data = datasets.load_files('email/')
    print(len(base_data))
#     print(base_data.data)
    print(len(base_data.target))
    print(base_data.target)
    
    # 交叉验证选择 训练集和测试集
    train_data, test_data, train_y, test_y = train_test_split(base_data.data, base_data.target, 
                                                              test_size=0.2, train_size=0.8)
    
    # 生成文本的词频矩阵
    vectorizer = CountVectorizer(stop_words='english', decode_error='ignore')
    wordX = vectorizer.fit_transform(train_data)
    
    # 训练分类器
    clf = MultinomialNB().fit(wordX, train_y)
    
    # 预测测试集的分类结果
    test_wordX = vectorizer.transform(test_data).toarray()
#     newDoc_tfidf = transformer.transform(newDoc_wordX) # 得到新文档每个词的TF-IDF值
    predicted = clf.predict(test_wordX)
    print(predicted)
    
    # 在测试集上的性能评估
    print(classification_report(test_y, predicted, target_names=base_data.target_names))
    
    
    
    
    
spamTest_skl()

输出:

5
50
[1 0 0 1 0 1 1 1 0 0 1 1 1 1 0 0 0 1 1 1 0 1 1 0 1 0 1 0 0 1 0 0 1 1 0 0 1
 0 0 0 1 0 0 0 1 1 0 0 1 1]
[0 1 1 1 1 0 1 0 0 0]
             precision    recall  f1-score   support

     noSpam       1.00      1.00      1.00         5
       spam       1.00      1.00      1.00         5

avg / total       1.00      1.00      1.00        10

Logo

汇聚全球AI编程工具,助力开发者即刻编程。

更多推荐