当前位置:网站首页>sklearn sklearn中的模型调参利器 gridSearchCV(网格搜索)
sklearn sklearn中的模型调参利器 gridSearchCV(网格搜索)
2022-06-23 04:40:00 【郭庆汝】
sklearn sklearn中的模型调参利器 gridSearchCV(网格搜索)

代码:
import pandas as pd # 数据科学计算工具
import numpy as np # 数值计算工具
import matplotlib.pyplot as plt # 可视化
import seaborn as sns # matplotlib的高级API
from sklearn.model_selection import StratifiedKFold #交叉验证
from sklearn.model_selection import GridSearchCV #网格搜索
from sklearn.model_selection import train_test_split #将数据集分开成训练集和测试集
from xgboost import XGBClassifier #xgboost
pima = pd.read_csv("pima_indians-diabetes.csv")
print(pima.head())
x = pima.iloc[:,0:8]
y = pima.iloc[:,8]
seed = 7 #重现随机生成的训练
test_size = 0.33 #33%测试,67%训练
X_train, X_test, Y_train, Y_test = train_test_split(x, y, test_size=test_size, random_state=seed
model = XGBClassifier()
learning_rate = [0.0001,0.001,0.01,0.1,0.2,0.3] #学习率
gamma = [1, 0.1, 0.01, 0.001]
param_grid = dict(learning_rate = learning_rate,gamma = gamma)#转化为字典格式,网络搜索要求
kflod = StratifiedKFold(n_splits=10, shuffle = True,random_state=7)#将训练/测试数据集划分10个互斥子集,
grid_search = GridSearchCV(model,param_grid,scoring = 'neg_log_loss',n_jobs = -1,cv = kflod)
#scoring指定损失函数类型,n_jobs指定全部cpu跑,cv指定交叉验证
grid_result = grid_search.fit(X_train, Y_train) #运行网格搜索
print("Best: %f using %s" % (grid_result.best_score_,grid_search.best_params_))
#grid_scores_:给出不同参数情况下的评价结果。best_params_:描述了已取得最佳结果的参数的组合
#best_score_:成员提供优化过程期间观察到的最好的评分
#具有键作为列标题和值作为列的dict,可以导入到DataFrame中。
#注意,“params”键用于存储所有参数候选项的参数设置列表。
means = grid_result.cv_results_['mean_test_score']
params = grid_result.cv_results_['params']
for mean,param in zip(means,params):
print("%f with: %r" % (mean,param))

边栏推荐
- Centos7 deploy radius service -freeradius-3.0.13-15 EL7 integrating MySQL
- 求二叉树最宽的层有多少个节点
- mysql读已提交和可重复度区别
- Day_13 传智健康项目-第13章
- 去除防火墙和虚拟机对live555启动IP地址的影响
- WordPress contact form entries cross cross site scripting attack
- WordPress Core 5.8.2 - 'WP_ Query'SQL injection
- 【Leetcode】431. Encode N-ary Tree to Binary Tree(困难)
- 机器学习3-岭回归,Lasso,变量选择技术
- Pat class B 1023 minimum decimals
猜你喜欢

SSM project construction

十一、纺织面料下架功能的实现

100-300 cases of single chip microcomputer program (detailed explanation of notes)

ant使用总结(一):使用ant自动打包apk
![[cocos2d-x] custom ring menu](/img/fd/c18c39ae738f6c1d2b76b6c54dc654.png)
[cocos2d-x] custom ring menu

【Leetcode】431. Encode N-ary Tree to Binary Tree(困难)

微软面试题:打印折纸的折痕

Day_03 传智健康项目-预约管理-检查组管理

内存分析与内存泄漏检测

Runc symbolic link mount and container escape vulnerability alert (cve-2021-30465)
随机推荐
Pat class B 1012 C language
Redis sentry
Remove duplicates from sorted list II of leetcode topic resolution
Newbeecoder. Page animation switching of UI control library
Pyinstaller 打包pyttsx3 出错
mongodb 4. X binding multiple IP startup errors
Redis 哨兵
Day_07 传智健康项目-Freemarker
Day_13 传智健康项目-第13章
Pat class B 1022 d-ary a+b
Pat class B 1013 C language
exe闪退的原因查找方法
ant使用总结(三):批量打包apk
Possible pits in mongodb project
ant使用总结(二):相关命令说明
【开源项目】excel导出lua配置表工具
WordPress contact form entries cross cross site scripting attack
Leetcode topic resolution single number
Memory analysis and memory leak detection
求二叉树最宽的层有多少个节点