当前位置:网站首页>Machine learning perceptron and k-nearest neighbor
Machine learning perceptron and k-nearest neighbor
2022-06-24 10:10:00 【Cpsu】
One 、 perceptron
import numpy as np
from matplotlib import pyplot as plt
np.random.seed(3)
# In fact, we should use x1=np.linspace(0,5,50)
#X Axis dataset
x1=[i for i in np.arange(0,5,0.1)]
# Positive sample data set
x2=np.abs(np.random.randn(50))
# Negative sample data set
x3=np.abs(np.random.randn(50)+8)
plt.scatter(x1,x2)
plt.scatter(x1,x3)

# Use the above data to build a data matrix
# features 1
x_1=x1+x1
# features 2
x_2=list(x2)+list(x3)
# Label column
y=[1]*50+[-1]*50
data=np.zeros((100,3))
data=np.c_[x_1,x_2,y]
data.shape
#(100, 3)
class Perceptron():
""" :param data: ndarray Format data : N x P N individual P D data :param lr: Learning rate :param maxiter: Maximum number of iterations :param w_vect: Initial weight vector """
def __init__(self,data,lr,maxiter,w_vect):
self.data=data
self.w=w_vect
self.lr=lr
self.maxiter=maxiter
def get_wrong(self):
""" :return: Misclassification data set matrix and corresponding index of original matrix """
# Filter data
x=np.c_[self.data[:,:-1],np.ones(self.data.shape[0])]
wrong_index=np.where((x.dot(self.w)*((self.data[:,-1]).reshape(-1,1))<=0))[0]
return data[wrong_index,:],wrong_index
def fit(self):
for j in range(self.maxiter):
error=0
wrong_data,wrong_index=self.get_wrong()
#print(wrong_data)
x=np.c_[wrong_data[:,:-1],np.ones(wrong_data.shape[0])]
# Scrambling data sets to obtain different hyperplane solutions
np.random.shuffle(x)
for i in range(0,wrong_data.shape[0]):
gradient=((-wrong_data[i,-1:])*x[i,:]).reshape(-1,1)
self.w=self.w-self.lr*gradient
error+=1
#print(gradient.shape)
if error==0:
break
#w_vect=np.zeros((data.shape[1],1))
w_vect=np.array([[0],[0],[0]])
a=Perceptron(data,0.01,200,w_vect)
a.fit()
weights=a.w
w1 = weights[0][0]
w2 = weights[1][0]
bias = weights[-1][0]
print(a.w)
x6 = -w1 / w2 * np.array(x1) - bias / w2
plt.scatter(x1,x2)
plt.scatter(x1,x3)
plt.plot(x1,x6)

Two 、KNN
# establish kd Trees
import numpy as np
import matplotlib.pyplot as plt
class kdTree():
def __init__(self, parent_node):
# Node initialization
self.nodedata = None # The data value of the current node , Two dimensional data
self.split = None # Sequence number of the direction axis of the split plane ,0 Represents along x Axis segmentation ,1 Represents along y Axis segmentation
self.range = None # Split threshold
self.left = None # Left subtree node
self.right = None # Right subtree node
self.parent = parent_node # Parent node
self.leftdata = None # Keep all the data of the left node
self.rightdata = None # Keep all the data of the right node
self.isinvted = False # Record whether the current node has been accessed
def print(self):
# Print the current node information
print(self.nodedata, self.split, self.range)
def getSplitAxis(self, all_data):
# Determine the segmentation axis according to the variance
var_all_data = np.var(all_data, axis=0)
if var_all_data[0] > var_all_data[1]:
return 0
else:
return 1
def getRange(self, split_axis, all_data):
# Get the size of the median data value on the corresponding split axis
split_all_data = all_data[:, split_axis]
data_count = split_all_data.shape[0]
med_index = int(data_count/2)
sort_split_all_data = np.sort(split_all_data)
range_data = sort_split_all_data[med_index]
return range_data
def getNodeLeftRigthData(self, all_data):
# Divide the data into the left subtree , Right subtree and get the current node
data_count = all_data.shape[0]
ls_leftdata = []
ls_rightdata = []
for i in range(data_count):
now_data = all_data[i]
if now_data[self.split] < self.range:
ls_leftdata.append(now_data)
elif now_data[self.split] == self.range and self.nodedata == None:
self.nodedata = now_data
else:
ls_rightdata.append(now_data)
self.leftdata = np.array(ls_leftdata)
self.rightdata = np.array(ls_rightdata)
def createNextNode(self,all_data):
# Iteratively create nodes , Generate kd Trees
if all_data.shape[0] == 0:
print("create kd tree finished!")
return None
self.split = self.getSplitAxis(all_data)
self.range = self.getRange(self.split, all_data)
self.getNodeLeftRigthData(all_data)
if self.leftdata.shape[0] != 0:
self.left = kdTree(self)
self.left.createNextNode(self.leftdata)
if self.rightdata.shape[0] != 0:
self.right = kdTree(self)
self.right.createNextNode(self.rightdata)
def plotKdTree(self):
# Draw the recursive iteration process of tree structure on the graph
if self.parent == None:
plt.figure(dpi=300)
plt.xlim([0.0, 10.0])
plt.ylim([0.0, 10.0])
color = np.random.random(3)
if self.left != None:
plt.plot([self.nodedata[0], self.left.nodedata[0]],[self.nodedata[1], self.left.nodedata[1]], '-o', color=color)
plt.arrow(x=self.nodedata[0], y=self.nodedata[1], dx=(self.left.nodedata[0]-self.nodedata[0])/2.0, dy=(self.left.nodedata[1]-self.nodedata[1])/2.0, color=color, head_width=0.2)
self.left.plotKdTree()
if self.right != None:
plt.plot([self.nodedata[0], self.right.nodedata[0]],[self.nodedata[1], self.right.nodedata[1]], '-o', color=color)
plt.arrow(x=self.nodedata[0], y=self.nodedata[1], dx=(self.right.nodedata[0]-self.nodedata[0])/2.0, dy=(self.right.nodedata[1]-self.nodedata[1])/2.0, color=color, head_width=0.2)
self.right.plotKdTree()
# if self.split == 0:
# x = self.range
# plt.vlines(x, 0, 10, color=color, linestyles='--')
# else:
# y = self.range
# plt.hlines(y, 0, 10, color=color, linestyles='--')
test_array = 10.0*np.random.random([30,2])
my_kd_tree = kdTree(None)
my_kd_tree.createNextNode(test_array)
my_kd_tree.plotKdTree()
边栏推荐
- Can the long-term financial products you buy be shortened?
- 机器学习——感知机及K近邻
- 一群骷髅在飞canvas动画js特效
- Three ways to use applicationcontextinitializer
- 416 binary tree (first, middle and last order traversal iteration method)
- Record the range of data that MySQL update will lock
- Queue queue
- vim的使用
- 学习整理在php中使用KindEditor富文本编辑器
- 被困英西中学的师生安全和食物有保障
猜你喜欢
随机推荐
Floating point notation (summarized from cs61c and CMU CSAPP)
读取csv(tsv)文件出错
物联网?快来看 Arduino 上云啦
np.float32()
Nvisual digital infrastructure operation management software platform
为什么 JSX 语法这么香?
小程序 rich-text中图片点击放大与自适应大小问题
JS proxy mode
美国电子烟巨头 Juul 遭遇灭顶之灾,所有产品强制下架
静态链接库和动态链接库的区别
How large and medium-sized enterprises build their own monitoring system
SQL sever基本数据类型详解
[input method] so far, there are so many Chinese character input methods!
NVIDIA's CVPR 2022 oral is on fire! 2D images become realistic 3D objects in seconds! Here comes the virtual jazz band!
linux下oracle服务器打开允许远程连接
MySQL data advanced
JS singleton mode
js代理模式
Groovy obtains Jenkins credentials through withcredentials
整理接口性能优化技巧,干掉慢代码









