当前位置:网站首页>机器学习笔记 六:逻辑回归中的多分类问题之数字识别
机器学习笔记 六:逻辑回归中的多分类问题之数字识别
2022-06-22 05:30:00 【Amyniez】
使用逻辑回归来识别手写数字(0到9)
1. 需要用到的函数:
1. sigmoid函数:
2. 逻辑回归的代价函数:
3. 梯度下降算法:
4. 正则化的逻辑回归模型的代价函数:
5. 数字识别原理(向量化标签,1表示真,0表示假):
2.代码实现
import numpy as np
from scipy.io import loadmat
from scipy.optimize import minimize
""" date: 2022/6/1 author: amyniez name: machine learning one vs all """
# sigmoid函数
def sigmoid(z):
return 1 / (1 + np.exp(-z))
# 正则化的逻辑回归模型的代价函数J(theta)
def cost(cost_theta, X, y, learningRate):
cost_theta = np.matrix(cost_theta)
X = np.matrix(X)
y = np.matrix(y)
first = np.multiply(-y, np.log(sigmoid(X * cost_theta.T)))
second = np.multiply(1 - y, np.log(1 - sigmoid(X * cost_theta.T)))
reg = (learningRate / (2 * len(X))) * np.sum(np.power(cost_theta[:, 1:cost_theta.shape[1]], 2))
return np.sum(first - second) / len(X) + reg
# gradient descent function
# 方法一:for循环的梯度函数
def gradient_with_loop(theta, X, y, learningRate):
theta = np.matrix(theta)
X = np.matrix(X)
y = np.matrix(y)
parameters = int(theta.reval().shape[1])
gradients = np.zero(parameters)
error = sigmoid(1 + np.exp(X * theta.T)) - y
for i in range(parameters):
term = np.multiply(error, X[:, i])
if (i == 0):
gradients[i] = np.sum(term) / len(X)
else:
gradients[i] = (np.sum(term) / len(X)) + ((learningRate / len(X)) * theta[:, i])
return gradients
# 方法二:向量化的梯度函数
def gradient(theta, X, y, learningRate):
theta = np.matrix(theta)
X = np.matrix(X)
y = np.matrix(y)
parameters = int(theta.ravel().shape[1])
error = sigmoid(X * theta.T) - y
grad = ((X.T * error) / len(X)).T + ((learningRate / len(X)) * theta)
# intercept gradient is not regularized
grad[0, 0] = np.sum(np.multiply(error, X[:, 0])) / len(X)
return np.array(grad).ravel()
# 多分类器函数
def one_vs_all(X, y, num_labels, learning_rate):
rows = X.shape[0]
params = X.shape[1]
# k个分类器,n为参数
# k X (n + 1) array for the parameters of each of the k classifiers
all_theta = np.zeros((num_labels, params + 1))
# insert a column of ones at the beginning for the intercept term
# arr原始数组,可一可多,obj插入元素位置,values是插入内容,axis是按行按列插入(1为列)
X = np.insert(X, 0, values=np.ones(rows), axis=1)
# labels are 1-indexed instead of 0-indexed
for i in range(1, num_labels + 1):
theta = np.zeros(params + 1)
y_i = np.array([1 if label == i else 0 for label in y])
y_i = np.reshape(y_i, (rows, 1))
# minimize the objective function
# 高级优化算法,直接计算得到theta参数值
# fun:传入cost函数,自己写的代价函数(def cost)
# x0:传入theta的初始值,初始化后的theta
# args:梯度下降函数的参数(除了theta之外的按顺序写)(def gradient)
# method:所要使用的优化方法,如TNC、BFGS等
# jac:传入gradient方法,即计算梯度的方法
# options:有两个选项,{‘maxiter’:100}可以控制迭代次数;{‘disp’:True}可以打印一些运行细节
fmin = minimize(fun=cost, x0=theta, args=(X, y_i, learning_rate), method='TNC', jac=gradient)
# 将优化程序找到的参数分配给参数数组
all_theta[i - 1, :] = fmin.x
return all_theta
# 预测函数
def predict_all(X, all_theta):
rows = X.shape[0]
params = X.shape[1]
num_labels = all_theta.shape[0]
# same as before, insert ones to match the shape
X = np.insert(X, 0, values=np.ones(rows), axis=1)
# convert to matrices
X = np.matrix(X)
all_theta = np.matrix(all_theta)
# compute the class probability for each class on each training instance
# sigmoid函数计算预测概率
h = sigmoid(X * all_theta.T)
# create array of the index with the maximum probability
# 找出最大数的索引,即找出1的索引
# axis[0]:列方向上搜索,[1]:行方向上搜索
h_argmax = np.argmax(h, axis=1)
# because our array was zero-indexed we need to add one for the true label prediction
# 将0索引转换为1索引
h_argmax = h_argmax + 1
return h_argmax
data = loadmat("E:\\Download\\ex3data1.mat")
print(data)
print(data['X'].shape, data['y'].shape)
rows = data['X'].shape[0]
params = data['X'].shape[1]
all_theta = np.zeros((10, params + 1))
X = np.insert(data['X'], 0, values=np.ones(rows), axis=1)
theta = np.zeros(params + 1)
y_0 = np.array([1 if label == 0 else 0 for label in data['y']])
y_0 = np.reshape(y_0, (rows, 1))
print(X.shape, y_0.shape, theta.shape, all_theta.shape)
# 查看有几类标签
print(np.unique(data['y']))
all_theta = one_vs_all(data['X'], data['y'], 10, 1)
print(all_theta)
y_pred = predict_all(data['X'], all_theta)
correct = [1 if a == b else 0 for (a, b) in zip(y_pred, data['y'])]
accuracy = (sum(map(int, correct)) / float(len(correct)))
print ('accuracy = {0}%'.format(accuracy * 100))

边栏推荐
- Independent station optimization list - how to effectively improve the conversion rate in the station?
- Research Report on demand and investment opportunities in key areas of global and Chinese strontium tungstate industry 2022-2027
- The prediction made ten years ago by the glacier has now been realized by Ali, which is very shocking
- nacos server 源码运行实现
- The benefits of implementing the standard of intellectual property in Miyun District, Beijing, with a subsidy of 50000-100000 yuan
- 关于背包问题的总结
- Shenzhen Nanshan District specialized special new small giant enterprise declaration index, with a subsidy of 500000 yuan
- Squoosh - Google's free open source image compression tool, reducing the image size by 90%! Support API development calls
- Zhiyuan OA vulnerability analysis, utilization and protection collection
- 删除弹窗组件的封装使用
猜你喜欢

Data storage (Advanced)

SCM future employment development direction, learn SCM must know some entry-level knowledge and industry prospects, read the benefit

數據的存儲(進階)

基于WebUploader实现大文件分片上传

Gerrit Code Review Setup

记本地项目启动报错:无效的源发行版: 8

毕业季 | 新的开始,不说再见

Analysis of 43 cases of MATLAB neural network: Chapter 28 Application Research of decision tree classifier - breast cancer diagnosis

Want to put Facebook ads but don't know where to start? This article takes you to know more about

Does the air valve perform the en 1634-1 fire resistance test for 4 hours?
随机推荐
Squoosh - 谷歌出品的免费开源图片压缩工具,图片大小减少90%!支持 API 开发调用
C语言指针(进阶)
innosetup判断程序已经运行方法
Jedispool tool class
大厂晋升学习方法三:链式学习法
A simple method to implement deep cloning and encapsulation of objects JS
sourcetree报错ssh失败
printf becomes puts
独立站优化清单丨如何有效提升站内转化率?
企业如何把ERP项目自动施行?
Current market situation analysis and investment analysis prospect report of global and Chinese ceramic capacitor industry 2022-2027
Cookie setting and reading in C #
毕业季 | 新的开始,不说再见
Rambbmitmq Push Party
为什么说“ CPS联盟营销 ” 是性价比最高的推广方式?
Summary of knapsack problem
Mobile terminal layout adaptation
Link a static library‘s all sections
我不建议你工作太拼命
數據的存儲(進階)