当前位置:网站首页>【sklearn】tree.DecisionTreeClassifier
【sklearn】tree.DecisionTreeClassifier
2022-07-24 07:34:00 【rejudge】
from sklearn import tree
from sklearn.datasets import load_wine
from sklearn.model_selection import train_test_split
wine = load_wine()
wine.data
''' array([[1.423e+01, 1.710e+00, 2.430e+00, ..., 1.040e+00, 3.920e+00, 1.065e+03], [1.320e+01, 1.780e+00, 2.140e+00, ..., 1.050e+00, 3.400e+00, 1.050e+03], [1.316e+01, 2.360e+00, 2.670e+00, ..., 1.030e+00, 3.170e+00, 1.185e+03], ..., [1.327e+01, 4.280e+00, 2.260e+00, ..., 5.900e-01, 1.560e+00, 8.350e+02], [1.317e+01, 2.590e+00, 2.370e+00, ..., 6.000e-01, 1.620e+00, 8.400e+02], [1.413e+01, 4.100e+00, 2.740e+00, ..., 6.100e-01, 1.600e+00, 5.600e+02]]) '''
# pd see
import pandas as pd
pd.concat([pd.DataFrame(wine.data),pd.DataFrame(wine.target)], axis=1)
| 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 0 | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 14.23 | 1.71 | 2.43 | 15.6 | 127.0 | 2.80 | 3.06 | 0.28 | 2.29 | 5.64 | 1.04 | 3.92 | 1065.0 | 0 |
| 1 | 13.20 | 1.78 | 2.14 | 11.2 | 100.0 | 2.65 | 2.76 | 0.26 | 1.28 | 4.38 | 1.05 | 3.40 | 1050.0 | 0 |
| 2 | 13.16 | 2.36 | 2.67 | 18.6 | 101.0 | 2.80 | 3.24 | 0.30 | 2.81 | 5.68 | 1.03 | 3.17 | 1185.0 | 0 |
| 3 | 14.37 | 1.95 | 2.50 | 16.8 | 113.0 | 3.85 | 3.49 | 0.24 | 2.18 | 7.80 | 0.86 | 3.45 | 1480.0 | 0 |
| 4 | 13.24 | 2.59 | 2.87 | 21.0 | 118.0 | 2.80 | 2.69 | 0.39 | 1.82 | 4.32 | 1.04 | 2.93 | 735.0 | 0 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 173 | 13.71 | 5.65 | 2.45 | 20.5 | 95.0 | 1.68 | 0.61 | 0.52 | 1.06 | 7.70 | 0.64 | 1.74 | 740.0 | 2 |
| 174 | 13.40 | 3.91 | 2.48 | 23.0 | 102.0 | 1.80 | 0.75 | 0.43 | 1.41 | 7.30 | 0.70 | 1.56 | 750.0 | 2 |
| 175 | 13.27 | 4.28 | 2.26 | 20.0 | 120.0 | 1.59 | 0.69 | 0.43 | 1.35 | 10.20 | 0.59 | 1.56 | 835.0 | 2 |
| 176 | 13.17 | 2.59 | 2.37 | 20.0 | 120.0 | 1.65 | 0.68 | 0.53 | 1.46 | 9.30 | 0.60 | 1.62 | 840.0 | 2 |
| 177 | 14.13 | 4.10 | 2.74 | 24.5 | 96.0 | 2.05 | 0.76 | 0.56 | 1.35 | 9.20 | 0.61 | 1.60 | 560.0 | 2 |
178 rows × 14 columns
# Divide the training test set
Xtrain, Xtest, Ytrain, Ytest = train_test_split(wine.data, wine.target, test_size=0.3)
''' Decision tree model random_state splitter=‘best’ Select the most important features to branch , if random Then random branches , Reduce overfitting Pruning prevents over fitting max_depth Limit the maximum depth min_samples_leaf=0.05 If there are less than nodes after branching 0.05* Total , It doesn't Branch min_samples_split=5 If the node contains less than 5 Samples are unbranched Reduce dimension and prevent over fitting max_feature Consider the number of features in place min_impurity_decrease Limit information gain '''
clf = tree.DecisionTreeClassifier(criterion='gini'
,random_state=3
,max_depth=3
,min_samples_leaf=0.05
,min_samples_split=2
)
clf = clf.fit(Xtrain, Ytrain)
score = clf.score(Xtest, Ytest)
score
''' 0.8888888888888888 '''
# clf.feature_importances_ View feature weights *zip() After aggregation, return
[*zip(wine.feature_names, clf.feature_importances_)]
''' [('alcohol', 0.44232962250700664), ('malic_acid', 0.0), ('ash', 0.0), ('alcalinity_of_ash', 0.0), ('magnesium', 0.0), ('total_phenols', 0.0), ('flavanoids', 0.4126765806025175), ('nonflavanoid_phenols', 0.0), ('proanthocyanins', 0.0), ('color_intensity', 0.0), ('hue', 0.0030484226577791127), ('od280/od315_of_diluted_wines', 0.1419453742326969), ('proline', 0.0)] '''
# Hyperparametric learning curve
import matplotlib.pyplot as plt
test = []
for i in range(10):
clf = tree.DecisionTreeClassifier(max_depth=i+1
,criterion='entropy'
,random_state=30
,splitter='random'
)
clf = clf.fit(Xtrain, Ytrain)
score = clf.score(Xtest, Ytest)
test.append(score)
plt.plot(range(1, 11), test, color='red', label='max_depth')
plt.legend() # Add legend
plt.show

边栏推荐
- Influxdb unauthorized access & CouchDB permission bypass
- JS_实现多行文本根据换行分隔成数组
- FPGA realizes reading and writing of axi4 bus
- [information system project manager] Chapter VII recheck cost management knowledge structure
- MySQL queries all parents of the current node
- Wild pointer, null pointer, invalid pointer
- China trichlorosilane Market Forecast and Strategic Research Report (2022 Edition)
- 我的创作纪念日
- php链路日志方案
- XSS漏洞学习
猜你喜欢
随机推荐
Introduction to C language I. branch and loop statements
觉维设计响应式布局
FPGA realizes reading and writing of axi4 bus
Li Kou, niuke.com - > linked list related topics (Article 1) (C language)
Single Gmv has increased 100 times. What is the "general rule" behind the rise of popular brands?
UNI-APP_ Playback and pause of background music of applet or H5 page
学习笔记-分布式事务理论
Jackson parsing JSON detailed tutorial
全国职业院校技能大赛网络安全B模块 缓冲区溢出漏洞
Feature Selective Anchor-Free Module for Single-Shot Object Detection
[steering wheel] code review ability of idea to ensure code quality
FlinkSQL-UDF自定义数据源
Injectfix principle learning (to realize the heat of repair addition)
Learning strategies of 2D target detection overview (final chapter)
There are two tables in Oracle, a and B. these two tables need to be associated with the third table C. how to update the field MJ1 in table a to the value MJ2 in table B
中国三氯氢硅市场预测及战略研究报告(2022版)
【云原生】MySql索引分析及查询优化
File "manage.py", line 14) from exc ^ syntaxerror: cause and solution of invalid syntax error
JS_实现多行文本根据换行分隔成数组
China trichlorosilane Market Forecast and Strategic Research Report (2022 Edition)








![[line test] Figure finding regular questions](/img/61/d1c2cd399cf0d808e4fa25cd5fe681.png)
