当前位置:网站首页>Sklearn.metrics module model evaluation function
Sklearn.metrics module model evaluation function
2022-07-24 15:17:00 【qq_ twenty-seven million three hundred and ninety thousand and 】
sklearn There is 3 Different species API Used to evaluate the prediction quality of the model .
Estimator scoring method : The estimator has a scoring method , Provide a default evaluation criteria for the problems they are designed to solve .
Scoring parameters : Model evaluation tools using cross validation ( Such as model_selection.cross_val_score and model_selection.GridSearchCV) Rely on an internal scoring strategy .
Metric function :sklearn.metrics The module realizes the function of evaluating the prediction error for a specific purpose .
Example of model evaluation function :
from sklearn import metrics
# Check the function of the module
dir(metrics)
### 1.Accuracy score
import numpy as np
from sklearn.metrics import accuracy_score
y_pred = [0, 2, 1, 3]
y_true = [0, 1, 2, 3]
print(accuracy_score(y_true, y_pred))
print(accuracy_score(y_true, y_pred, normalize=False)) # The number of correct predictions
### 2.Top-k accuracy score
# top_k_accuracy_score The function is accuracy_score Generalization .
# The difference lies in , As long as the real label is consistent with k Associated with one of the highest prediction scores , The prediction is considered to be correct .
# Accuracy _ The score is k=1 In special circumstances .
import numpy as np
from sklearn.metrics import top_k_accuracy_score
y_true = np.array([0, 1, 2, 2])
y_score = np.array([[0.5, 0.2, 0.2],
[0.3, 0.4, 0.2],
[0.2, 0.4, 0.3],
[0.7, 0.2, 0.1]])
top_k_accuracy_score(y_true, y_score, k=2)
# Not normalizing gives the number of "correctly" classified samples
top_k_accuracy_score(y_true, y_score, k=2, normalize=False)
### 3.confusion_matrix
from sklearn import datasets
from sklearn.svm import LinearSVC
from sklearn.model_selection import cross_validate
from sklearn.metrics import confusion_matrix,ConfusionMatrixDisplay
import matplotlib.pyplot as plt
y_true = [2, 0, 2, 2, 0, 1]
y_pred = [0, 0, 2, 2, 0, 2]
cm=confusion_matrix(y_true, y_pred)
print(confusion_matrix(y_true, y_pred))
print(confusion_matrix(y_true, y_pred,normalize='all'))
disp = ConfusionMatrixDisplay(confusion_matrix=cm)
disp.plot()
plt.show()
# Two classification
y_true = [0, 0, 0, 1, 1, 1, 1, 1]
y_pred = [0, 1, 0, 1, 0, 1, 0, 1]
tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
print(tn, fp, fn, tp)
# A sample toy binary classification dataset
X, y = datasets.make_classification(n_classes=2, random_state=0)
svm = LinearSVC(random_state=0)
def confusion_matrix_scorer(clf, X, y):
y_pred = clf.predict(X)
cm = confusion_matrix(y, y_pred)
return {'tn': cm[0, 0], 'fp': cm[0, 1],
'fn': cm[1, 0], 'tp': cm[1, 1]}
cv_results = cross_validate(svm, X, y, cv=5,
scoring=confusion_matrix_scorer)
# Getting the test set true positive scores
print(cv_results['test_tp'])
# Getting the test set false negative scores
print(cv_results['test_fn'])
print(cv_results['test_tn'])
print(cv_results['test_fp'])
### 4.classification_report
from sklearn.metrics import classification_report
y_true = [0, 1, 2, 2, 0]
y_pred = [0, 0, 2, 1, 0]
target_names = ['class 0', 'class 1', 'class 2']
print(classification_report(y_true, y_pred, target_names=target_names))
### 5. hamming_loss
from sklearn.metrics import hamming_loss
y_pred = [1, 2, 3, 4]
y_true = [2, 2, 3, 4]
hamming_loss(y_true, y_pred)
### 6. Precision, recall and F-measures
from sklearn import metrics
y_pred = [0, 1, 0, 0]
y_true = [0, 1, 0, 1]
metrics.precision_score(y_true, y_pred)
metrics.recall_score(y_true, y_pred)
metrics.f1_score(y_true, y_pred)
metrics.fbeta_score(y_true, y_pred, beta=0.5)
metrics.fbeta_score(y_true, y_pred, beta=1)
metrics.fbeta_score(y_true, y_pred, beta=2)
metrics.precision_recall_fscore_support(y_true, y_pred, beta=0.5)
import numpy as np
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import average_precision_score
y_true = np.array([0, 0, 1, 1])
y_scores = np.array([0.1, 0.4, 0.35, 0.8])
precision, recall, threshold = precision_recall_curve(y_true, y_scores)
print(precision)
print(recall)
print(threshold)
print( average_precision_score(y_true, y_scores))
## Many classification
from sklearn import metrics
y_true = [0, 1, 2, 0, 1, 2]
y_pred = [0, 2, 1, 0, 0, 1]
print(metrics.precision_score(y_true, y_pred, average='macro'))
print(metrics.recall_score(y_true, y_pred, average='micro'))
print(metrics.f1_score(y_true, y_pred, average='weighted'))
print(metrics.fbeta_score(y_true, y_pred, average='macro', beta=0.5))
print(metrics.precision_recall_fscore_support(y_true, y_pred, beta=0.5, average=None))
### 7. Regression prediction r2_score
# r2_score Function to calculate the coefficient of determination , Usually expressed as R².
# It represents the variance explained by the independent variables in the model (Y) The proportion of . It provides an indication of the degree of fit ,
# Therefore, by explaining the proportion of variance, we can measure the degree to which unseen samples may be predicted by the model .
from sklearn.metrics import r2_score
y_true = [3, -0.5, 2, 7]
y_pred = [2.5, 0.0, 2, 8]
r2_score(y_true, y_pred)
### 8. Regression prediction mean_absolute_error
from sklearn.metrics import mean_absolute_error
y_true = [3, -0.5, 2, 7]
y_pred = [2.5, 0.0, 2, 8]
mean_absolute_error(y_true, y_pred)
### 9. Regression prediction mean_squared_error
from sklearn.metrics import mean_squared_error
y_true = [3, -0.5, 2, 7]
y_pred = [2.5, 0.0, 2, 8]
mean_squared_error(y_true, y_pred)
### 10. Regression prediction mean_squared_log_error
from sklearn.metrics import mean_squared_log_error
y_true = [3, 5, 2.5, 7]
y_pred = [2.5, 5, 4, 8]
mean_squared_log_error(y_true, y_pred)
### 11. Unsupervised clustering Silhouette Coefficient
from sklearn import metrics
from sklearn import datasets
import numpy as np
X, y = datasets.load_iris(return_X_y=True)
from sklearn.cluster import KMeans
kmeans_model = KMeans(n_clusters=3, random_state=1).fit(X)
labels = kmeans_model.labels_
metrics.silhouette_score(X, labels, metric='euclidean')Reference resources :
https://scikit-learn.org/stable/modules/model_evaluation.html#multilabel-ranking-metrics
https://scikit-learn.org/stable/modules/clustering.html#clustering-evaluation
边栏推荐
- Performance test - analyze requirements
- Performance test - Preparation of test plan
- DS diagram - the shortest path of the diagram (excluding the code framework)
- Comparison of traversal speed between map and list
- Sword finger offer II 001. integer division
- The first n rows sorted after dataframe grouping nlargest argmax idmax tail!!!!
- spark:获取日志中每个时间段的访问量(入门级-简单实现)
- Discussion and legitimacy of the order of entering and leaving the stack
- 2022 RoboCom 世界机器人开发者大赛-本科组(省赛)-- 第三题 跑团机器人 (已完结)
- kali简洁转换语言方法(图解)
猜你喜欢

The accuracy of yolov7 in cracking down on counterfeits, not all papers are authentic

Learning rate adjustment strategy in deep learning (1)

Mysql库的操作
![[matlab] matlab drawing Series II 1. Cell and array conversion 2. Attribute cell 3. delete Nan value 4. Merge multiple figs into the same Fig 5. Merge multiple figs into the same axes](/img/4d/b0ba599a732d1390c5eeb1aea6e83c.png)
[matlab] matlab drawing Series II 1. Cell and array conversion 2. Attribute cell 3. delete Nan value 4. Merge multiple figs into the same Fig 5. Merge multiple figs into the same axes

Detailed explanation of document operation

2022 RoboCom 世界机器人开发者大赛-本科组(省赛)-- 第五题 树与二分图 (已完结)

Unity uses NVIDIA flex for unity plug-in to realize the effects of making software, water, fluid, cloth, etc. learning tutorial

ZABBIX administrator forgot login password

Sword finger offer II 001. integer division

JMeter - call the interface for uploading files or pictures
随机推荐
Overall testing framework for performance testing
[matlab] matlab drawing Series II 1. Cell and array conversion 2. Attribute cell 3. delete Nan value 4. Merge multiple figs into the same Fig 5. Merge multiple figs into the same axes
kali简洁转换语言方法(图解)
Explain the edge cloud in simple terms | 2. architecture
云开发单机版图片九宫格流量主源码
C# 无操作则退出登陆
Decrypt "sea Lotus" organization (domain control detection and defense)
【OpenCV 例程300篇】238. OpenCV 中的 Harris 角点检测
C. Recover an RBS
2022 RoboCom 世界机器人开发者大赛-本科组(省赛)RC-u4 攻略分队 (已完结)
[machine learning basics] - another perspective to explain SVM
2022 RoboCom 世界机器人开发者大赛-本科组(省赛) CAIP 完整版题解
onBlur和onChange冲突解决方法
Which brokerage has the lowest commission? I want to open an account. Is it safe to open an account on my mobile phone
各种Normalization的直观理解
Extjs4 instance address and Chinese document address
Huawei camera capability
C. Recover an RBS
Mysql库的操作
Istio1.12:安装和快速入门