当前位置:网站首页>通过tidymodels使用XGBOOST
通过tidymodels使用XGBOOST
2022-06-27 19:56:00 【王诗翔呀】
原文:https://www.r-bloggers.com/2020/05/using-xgboost-with-tidymodels/
XGBoost是一个最初用C++编写的机器学习库,通过XGBoost R包中移植到R。在过去的几年里,XGBoost在Kaggle竞赛中的有效性让它大受欢迎。在Tychobra, XGBoost是我们的首选机器学习库。
在2016年和2017年,Kaggle被两种方法所主导:梯度升压机和深度学习。具体来说,梯度增强用于结构化数据可用的问题,而深度学习用于图像分类等感知问题。前者的实践者几乎总是使用优秀的XGBoost库。
Max Kuhn和Rstudio的其他人最近将他们的注意力从caret转向了 tidymodels (caret的继承者)。“tidymodels”是一个R包的集合,它们一起工作来简化和加强模型训练和优化。随着最近发布的tidymodels.org[1],我们觉得是时候给tidymodels R包一个机会了。
概览
这篇文章中我们使用tidymodels包训练和优化XGBoost模型。我们使用的AmesHousing[2]数据集,其中包含来自艾奥瓦州艾姆斯的住房数据。我们的模型将预测房屋销售价格。
加载包:
# data
library(AmesHousing)
# data cleaning
library(janitor)
# data prep
library(dplyr)
# tidymodels
library(rsample)
library(recipes)
library(parsnip)
library(tune)
library(dials)
library(workflows)
library(yardstick)
# speed up computation with parrallel processing (optional)
library(doParallel)
all_cores <- parallel::detectCores(logical = FALSE)
registerDoParallel(cores = all_cores)
加载数据:
# set the random seed so we can reproduce any simulated results.
set.seed(1234)
# load the housing data and clean names
ames_data <- make_ames() %>%
janitor::clean_names()
Step 0:探索性数据分析
在这一点上,我们通常会对数据做一些简单的图表和总结,以获得对数据的高层次理解。为了简单起见,我们将从这篇文章中删除EDA过程,但是,在实际分析中,理解业务问题和执行有效的EDA通常是分析中最耗时和最关键的方面。
Step 1:初始数据划分
现在我们将数据分解为训练和测试数据。训练数据用于模型训练和超参数调优。训练后,可以根据测试数据对模型进行评估,以评估其准确性。
# split into training and testing datasets. Stratify by Sale price
ames_split <- rsample::initial_split(
ames_data,
prop = 0.2,
strata = sale_price
)
Step 2:预处理
预处理改变数据,使我们的模型更具预测性,训练过程的计算量更少。许多模型需要仔细和广泛的变量预处理来产生准确的预测。然而,XGBoost对于高度倾斜和/或相关的数据是稳健的,因此XGBoost所需的预处理量是最小的。尽管如此,我们仍然可以从一些预处理中获益。
在tidymodels中,我们使用recipes包来定义这些预处理步骤,也就是所谓的“recipe”。
# preprocessing "recipe"
preprocessing_recipe <-
recipes::recipe(sale_price ~ ., data = training(ames_split)) %>%
# convert categorical variables to factors
recipes::step_string2factor(all_nominal()) %>%
# combine low frequency factor levels
recipes::step_other(all_nominal(), threshold = 0.01) %>%
# remove no variance predictors which provide no predictive information
recipes::step_nzv(all_nominal()) %>%
prep()
Step 3:划分交叉验证
我们使用bake()
应用前面定义的预处理配方。然后我们使用交叉验证将训练数据随机分割成进一步的训练和测试集。在后面的步骤中,我们将使用这些额外的交叉验证折叠来调优超参数。
ames_cv_folds <-
recipes::bake(
preprocessing_recipe,
new_data = training(ames_split)
) %>%
rsample::vfold_cv(v = 5)
Step 4:XGBoost 模型制定
我们使用parsnip包来定义XGBoost模范。下面我们使用boost_tree()
和tune()
来定义超参数,以便在后续步骤中进行调优。
# XGBoost model specification
xgboost_model <-
parsnip::boost_tree(
mode = "regression",
trees = 1000,
min_n = tune(),
tree_depth = tune(),
learn_rate = tune(),
loss_reduction = tune()
) %>%
set_engine("xgboost", objective = "reg:squarederror")
Step 5:网格搜索
我们使用tidymodels的dials包指定参数集合。
# grid specification
xgboost_params <-
dials::parameters(
min_n(),
tree_depth(),
learn_rate(),
loss_reduction()
)
接下来我们设置网格空间。grid_*
函数支持几种定义网格空间的方法。我们使用的是dails::grid_max_entropy()
函数,它覆盖了超参数空间,这样空间的任何部分都有一个观察到的组合,相隔不是很远。
xgboost_grid <-
dials::grid_max_entropy(
xgboost_params,
size = 60
)
knitr::kable(head(xgboost_grid))
min_n | tree_depth | learn_rate | loss_reduction |
---|---|---|---|
22 | 4 | 0.0000002 | 0.0004020 |
28 | 1 | 0.0753898 | 0.4706658 |
31 | 5 | 0.0000000 | 0.0000001 |
21 | 3 | 0.0012828 | 0.0000837 |
28 | 9 | 0.0000043 | 0.0000000 |
12 | 8 | 0.0004554 | 0.0000024 |
为了优化我们的模型,我们在xgboost_grid的网格空间上执行网格搜索,以确定具有最低预测误差的超参数值。
Step 6:定义workflow
xgboost_wf <-
workflows::workflow() %>%
add_model(xgboost_model) %>%
add_formula(sale_price ~ .)
Step 7:优化模型
调优是包的tidymodels生态系统真正结合在一起的地方。下面是传递给我们调用tune_grid()
的前4个参数的对象的快速说明:
- “object”: xgboost_wf,它是我们在parsnip和workflows包中定义的工作流。
- “resamples”: ames_cv_folds 通过 rsample 和 recipes 包定义。
- “grid”: xgboost_grid 通过dials包定义的网格空间。
- “metric”: yardstick包定义的指标集合用于评估模型性能。
# hyperparameter tuning
xgboost_tuned <- tune::tune_grid(
object = xgboost_wf,
resamples = ames_cv_folds,
grid = xgboost_grid,
metrics = yardstick::metric_set(rmse, rsq, mae),
control = tune::control_grid(verbose = TRUE)
)
在上面的代码块中,tune_grid()
对我们在xgboost_grid中定义的所有60个网格参数组合执行网格搜索,并使用5倍交叉验证以及rmse(均方根误差)、rsq (R Squared)和mae(平均绝对误差)来测量预测精度。因此,我们的tidymodels优化执行构建60 X 5 = 300 XGBoost模型,每个模型都有1000棵树,都是为了寻找最佳的超参数。
这些是在最小化RMSE时表现最好的超参数值:
xgboost_tuned %>%
tune::show_best(metric = "rmse") %>%
knitr::kable()
min_n | tree_depth | learn_rate | loss_reduction | .metric | .estimator | mean | n | std_err | .config |
---|---|---|---|---|---|---|---|---|---|
3 | 2 | 0.0277572 | 0.0016397 | rmse | standard | 33249.41 | 5 | 5605.092 | Preprocessor1_Model15 |
24 | 13 | 0.0228652 | 0.0000177 | rmse | standard | 34754.87 | 5 | 4207.207 | Preprocessor1_Model27 |
28 | 1 | 0.0753898 | 0.4706658 | rmse | standard | 35182.62 | 5 | 3975.320 | Preprocessor1_Model02 |
8 | 2 | 0.0059440 | 1.7544081 | rmse | standard | 35269.42 | 5 | 4710.518 | Preprocessor1_Model57 |
40 | 6 | 0.0121354 | 0.0029648 | rmse | standard | 35362.95 | 5 | 3652.746 | Preprocessor1_Model44 |
下一步分离最佳的超参数值:
xgboost_best_params <- xgboost_tuned %>%
tune::select_best("rmse")
knitr::kable(xgboost_best_params)
min_n | tree_depth | learn_rate | loss_reduction | .config |
---|---|---|---|---|
3 | 2 | 0.0277572 | 0.0016397 | Preprocessor1_Model15 |
最终确定XGBoost模型以使用最佳的调优参数。
xgboost_model_final <- xgboost_model %>%
finalize_model(xgboost_best_params)
Step 8:在测试数据上评估性能
现在我们已经训练了我们的模型,我们需要评估模型的性能。我们使用第1步中的测试数据(模型训练中没有使用的数据)来评估性能。
我们使用rmse(均方根误差),rsq (R平方),和mae(平均绝对值)度量从尺度包在我们的模型评估。
首先,让我们评估训练数据的指标:
train_processed <- bake(preprocessing_recipe, new_data = training(ames_split))
train_prediction <- xgboost_model_final %>%
# fit the model on all the training data
fit(
formula = sale_price ~ .,
data = train_processed
) %>%
# predict the sale prices for the training data
predict(new_data = train_processed) %>%
bind_cols(training(ames_split))
xgboost_score_train <-
train_prediction %>%
yardstick::metrics(sale_price, .pred) %>%
mutate(.estimate = format(round(.estimate, 2), big.mark = ","))
knitr::kable(xgboost_score_train)
.metric | .estimator | .estimate |
---|---|---|
rmse | standard | 11,936.67 |
rsq | standard | 0.98 |
mae | standard | 8,762.44 |
接下来是测试数据:
test_processed <- bake(preprocessing_recipe, new_data = testing(ames_split))
test_prediction <- xgboost_model_final %>%
# fit the model on all the training data
fit(
formula = sale_price ~ .,
data = train_processed
) %>%
# use the training model fit to predict the test data
predict(new_data = test_processed) %>%
bind_cols(testing(ames_split))
# measure the accuracy of our model using `yardstick`
xgboost_score <-
test_prediction %>%
yardstick::metrics(sale_price, .pred) %>%
mutate(.estimate = format(round(.estimate, 2), big.mark = ","))
knitr::kable(xgboost_score)
.metric | .estimator | .estimate |
---|---|---|
rmse | standard | 28,459.15 |
rsq | standard | 0.88 |
mae | standard | 17,525.88 |
测试数据上的上述度量明显比我们的训练数据度量差,所以我们知道在我们的模型中存在一些过拟合。这突出了使用测试数据而不是训练数据来评估模型性能的重要性。
为了快速检查我们的模型预测是否存在明显的问题,让我们绘制测试数据的残差。
library(ggplot2)
house_prediction_residual <- test_prediction %>%
arrange(.pred) %>%
mutate(residual_pct = (sale_price - .pred) / .pred) %>%
select(.pred, residual_pct)
ggplot(house_prediction_residual, aes(x = .pred, y = residual_pct)) +
geom_point() +
xlab("Predicted Sale Price") +
ylab("Residual (%)") +
scale_x_continuous(labels = scales::dollar_format()) +
scale_y_continuous(labels = scales::percent)
上面的图表并没有显示出任何非常明显的残差趋势。这表明,在一个非常高的水平上,我们的模型并没有系统地对具有特定预期销售价格的房屋做出不准确的预测。我们会在这里做更多的模型验证以进行真实世界的分析,但是,为了这篇文章的目的,上面的图表对我们来说已经足够了。
总结
在这篇文章中,我们并没有过分关注我们模型的性能。我们的目标是简单地通过使用tidymodels训练XGBoost模型的过程,并学习tidymodels的基础知识。
Tidymodels为我们提供了一个标准的流程和词汇表来处理重采样(rsample)、数据预处理(recipes)、模型规范(parsnip)、调优(tune)和模型验证(yardstick)。tidymodels团队“整理”机器学习过程的工作是对R中机器学习可接近性的一步改进。使用tidymodels包,训练和(更重要的是)理解模型训练过程比以往任何时候都更容易。谢谢tidymodels团队!
参考资料
[1]
tidymodels.org: https://www.tidymodels.org/
[2]
AmesHousing: https://shixiangwang.github.io/blog/using-xgboost-with-tidymodels/
边栏推荐
- 爬虫笔记(2)- 解析
- Which method is called for OSS upload
- netERR_CONNECTION_REFUSED 解决大全
- Flask application case
- 结构化机器学习项目(一)- 机器学习策略
- 99 multiplication table - C language
- \W and [a-za-z0-9_], \Are D and [0-9] equivalent?
- Introduce you to ldbc SNB, a powerful tool for database performance and scenario testing
- Vue+MySQL实现登录注册案例
- Macro task and micro task understanding
猜你喜欢
Workflow automation low code is the key
YOLOv6:又快又准的目标检测框架开源啦
Use Fiddler to simulate weak network test (2g/3g)
7 jours d'apprentissage de la programmation simultanée go 7 jours de programmation simultanée go Language Atomic Atomic Atomic actual Operation contains ABA Problems
99 multiplication table - C language
Introduction to MySQL operation (IV) -- data sorting (ascending, descending, and multi field sorting)
"I make the world cooler" 2022 Huaqing vision R & D product launch was a complete success
Penetration learning - shooting range chapter -dvwa shooting range detailed introduction (continuous updating - currently only the SQL injection part is updated)
Structured machine learning project (I) - machine learning strategy
6G显卡显存不足出现CUDA Error:out of memory解决办法
随机推荐
Basic data type and complex data type
深度学习又有新坑了!悉尼大学提出全新跨模态任务,用文本指导图像进行抠图
Exclusive interview with millions of annual salary. What should developers do if they don't fix bugs?
Crontab scheduled task common commands
Flask application case
Vue+mysql login registration case
YOLOv6:又快又准的目标检测框架开源啦
Management system itclub (medium)
Crawler notes (1) - urllib
MONTHS_ Between function use
月薪3万的狗德培训,是不是一门好生意?
Penetration learning - shooting range chapter - detailed introduction to Pikachu shooting range (under continuous update - currently only the SQL injection part is updated)
Do280openshift access control -- Security Policy and chapter experiment
Penetration learning - shooting range chapter -dvwa shooting range detailed introduction (continuous updating - currently only the SQL injection part is updated)
PCIe knowledge point -008: structure of PCIe switch
改善深层神经网络:超参数调试、正则化以及优化(三)- 超参数调试、Batch正则化和程序框架
爬虫笔记(2)- 解析
CUDA error:out of memory caused by insufficient video memory of 6G graphics card
Acwing weekly contest 57- digital operation - (thinking + decomposition of prime factor)
Improving deep neural networks: hyperparametric debugging, regularization and optimization (III) - hyperparametric debugging, batch regularization and program framework