当前位置:网站首页>【sklearn】RF 交叉验证 袋外数据 参数学习曲线 网格搜索
【sklearn】RF 交叉验证 袋外数据 参数学习曲线 网格搜索
2022-07-24 05:15:00 【rejudge】
RF 交叉验证 RF袋外数据 参数学习曲线 网格搜索
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_wine
from sklearn.model_selection import train_test_split
%matplotlib inline
ensemble.RandomForestClassifier
wine = load_wine()
Xtrain, Xtest, Ytrain, Ytest = train_test_split(wine.data, wine.target, test_size=0.3)
rfc = RandomForestClassifier(random_state=0)
rfc = rfc.fit(Xtrain, Ytrain)
score_r = rfc.score(Xtest, Ytest)
print(score_r)
# 0.9444444444444444
划分训练测试集,交叉验证cross_val_score、cross_validate
sklearn.model_selection.cross_val_score
from sklearn.model_selection import cross_val_score
import matplotlib.pyplot as plt
rfc = RandomForestClassifier(n_estimators=100)
rfc_s = cross_val_score(rfc, wine.data, wine.target, cv=10)
print(rfc_s)
''' array([1. , 0.94444444, 0.94444444, 0.94444444, 1. , 1. , 1. , 1. , 1. , 1. ]) '''
plt.plot(range(1, 11), rfc_s, label="RandomForest")
#plt.plot(range(1, 11), clf_s, label="DecisionTree")
plt.legend() #显示label
plt.show()

sklearn.model_selection.cross_validate
输出训练集和测试集上的交叉验证分数,可以观察训练时是否过/欠拟合
from sklearn.model_selection import cross_validate, KFold
# 实例化交叉验证方式 shuffle是否打乱数据
cv = KFold(n_splits=5, shuffle=True, random_state=1412)
result = cross_validate(RandomForestClassifier() # 评估器
,wine.data,wine.target # 数据
,cv=cv # 交叉验证模式
,scoring='neg_mean_squared_error' # 评估指标mse(负值)
# 以上参数同cross_val_score()
# 以下cross_validate()特有参数
,return_train_score=True # 返回训练集交叉验证分数
,verbose=True # 打印进程
,n_jobs=1 # 线程数 -1表示调用全部线程
)
''' [Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers. [Parallel(n_jobs=1)]: Done 5 out of 5 | elapsed: 0.6s finished '''
result
# 在训练集和测试集上的交叉验证分数差异小,则过拟合程度低
''' {'fit_time': array([0.13620615, 0.14765525, 0.13985324, 0.14356041, 0.13722539]), 'score_time': array([0.01013279, 0.01257372, 0.01011801, 0.0101881 , 0.01026726]), 'test_score': array([-0.02777778, -0.02777778, -0.05555556, -0. , -0.02857143]), 'train_score': array([-0., -0., -0., -0., -0.])} '''
# 一般情况下,交叉验证MSE分数数量级太大
# 可使用RMSE评估(MSE开平方)
abs(result['train_score'])**0.5
abs(result['test_score'])**0.5
# 画图观察
plt.figure(figsize=(8,6), dpi=80)
plt.plot(range(1, 6), abs(result['test_score'])**0.5, color='green',label='RandomForestTest')
plt.plot(range(1, 6), abs(result['train_score'])**0.5, color='red',label='RandomForestTrain')
plt.xticks([1,2,3,4,5])
plt.xlabel('CVcounts', fontsize=16)
plt.ylabel('RMSE', fontsize=16)
plt.legend()
plt.show()

无需划分训练测试集,使用boostrap
# oob_score = True
rfc = RandomForestClassifier(n_estimators=25, oob_score=True)
rfc = rfc.fit(wine.data, wine.target)
# 袋外数据作为测试集的得分
rfc.oob_score_
''' 0.9606741573033708 '''
# n_estimators学习曲线
scorel = []
for i in range(0, 200, 10):
rfc = RandomForestClassifier(n_estimators=i+1
,n_jobs=-1
,random_state=0)
score = cross_val_score(rfc, wine.data, wine.target, cv=10).mean()
scorel.append(score)
print(max(scorel), (scorel.index(max(scorel))*10)+1)
plt.figure(figsize=[20,5])
plt.plot(range(1, 201, 10), scorel, label='n_estimators')
plt.legend()
plt.show()
0.9833333333333332 31
参数学习曲线
scorel = []
for i in range(20, 40):
rfc = RandomForestClassifier(n_estimators=i
,n_jobs=-1
,random_state=0)
score = cross_val_score(rfc, wine.data, wine.target, cv=10).mean()
scorel.append(score)
print(max(scorel), (scorel.index(max(scorel))*10)+1)
plt.figure(figsize=[20,5])
plt.plot(range(20, 40), scorel, label='n_estimators')
plt.legend()
plt.show()
0.9833333333333334 21

网格搜索GridSearchCV
from sklearn.model_selection import GridSearchCV
import numpy as np
# 网格搜索 调整max_depth
param_grid = {
'max_depth':np.arange(1, 20, 1), 'criterion':['gini', 'entropy']}
rfc = RandomForestClassifier(n_estimators=31
,random_state=0)
GS = GridSearchCV(rfc, param_grid, cv=10)
GS.fit(wine.data, wine.target)
# 调了后泛化误差上升,就不调整了
GS.best_params_, GS.best_score_
({‘criterion’: ‘gini’, ‘max_depth’: 4}, 0.9833333333333332)
边栏推荐
- 泛型和注解
- Installation and login login
- )的低字节来反馈给应用层或者成多种格式文档:
- 13. Write a program, in which a user-defined function is used to judge whether an integer is a prime number. The main function inputs a number and outputs whether it is a prime number.
- 1. Pedestrian recognition based on incremental occlusion generation and confrontation suppression
- SSH service
- JMeter FAQs
- )To feed back to the application layer or into multiple format documents:
- Knowledge record of College Physics C in advance in summer [update]
- 股票价格走势的行业关联性
猜你喜欢

OSS文件上传

智能指针、左值引用右值引用、lambda表达式

Chapter5 foundation of deep learning

FRP intranet penetration service usage

High performance architecture design of wechat circle of friends

Heavy! The 2022 China open source development blue book was officially released

Read "effective managers - Drucker"

SSH service

Performance test process
![Knowledge record of College Physics C in advance in summer [update]](/img/c4/76b669c3229a365a5e2567f7fb824e.png)
Knowledge record of College Physics C in advance in summer [update]
随机推荐
Markov random field: definition, properties, maximum a posteriori probability problem, energy minimization problem
Jiang Xingqun, senior vice president of BOE: aiot technology enables enterprise IOT transformation
1、基于增量式生成遮挡与对抗抑制的行人再识别
线程
Teach you how to weld CAD design board bottom (for beginners) graphic tutorial
ssm的整合
Tips for using BeanShell built-in variable prev
Hcip-- review the homework for the next day
[database connection] - excerpt from training
Using a* heuristic search to solve maze routing problem
Emqx simple to use
Learning pyramid context encoder network for high quality image painting paper notes
What are the core strengths of a knowledge base that supports customers quickly?
Drools 开发决策表
Embedded system transplantation [3] - uboot burning and use
Optional consistency
Beginners' preparation for the Blue Bridge Cup (University Programming learning history records, topic ideas for students who need to prepare for the Blue Bridge Cup)
Icml2022 | rock: causal reasoning principle on common sense causality
IDEA:SLF4J: Failed to load class “org.slf4j.impl.StaticLoggerBinder“.
Hanoi problem