当前位置:网站首页>Machine learning perceptron and k-nearest neighbor
Machine learning perceptron and k-nearest neighbor
2022-06-24 10:10:00 【Cpsu】
One 、 perceptron
import numpy as np
from matplotlib import pyplot as plt
np.random.seed(3)
# In fact, we should use x1=np.linspace(0,5,50)
#X Axis dataset
x1=[i for i in np.arange(0,5,0.1)]
# Positive sample data set
x2=np.abs(np.random.randn(50))
# Negative sample data set
x3=np.abs(np.random.randn(50)+8)
plt.scatter(x1,x2)
plt.scatter(x1,x3)

# Use the above data to build a data matrix
# features 1
x_1=x1+x1
# features 2
x_2=list(x2)+list(x3)
# Label column
y=[1]*50+[-1]*50
data=np.zeros((100,3))
data=np.c_[x_1,x_2,y]
data.shape
#(100, 3)
class Perceptron():
""" :param data: ndarray Format data : N x P N individual P D data :param lr: Learning rate :param maxiter: Maximum number of iterations :param w_vect: Initial weight vector """
def __init__(self,data,lr,maxiter,w_vect):
self.data=data
self.w=w_vect
self.lr=lr
self.maxiter=maxiter
def get_wrong(self):
""" :return: Misclassification data set matrix and corresponding index of original matrix """
# Filter data
x=np.c_[self.data[:,:-1],np.ones(self.data.shape[0])]
wrong_index=np.where((x.dot(self.w)*((self.data[:,-1]).reshape(-1,1))<=0))[0]
return data[wrong_index,:],wrong_index
def fit(self):
for j in range(self.maxiter):
error=0
wrong_data,wrong_index=self.get_wrong()
#print(wrong_data)
x=np.c_[wrong_data[:,:-1],np.ones(wrong_data.shape[0])]
# Scrambling data sets to obtain different hyperplane solutions
np.random.shuffle(x)
for i in range(0,wrong_data.shape[0]):
gradient=((-wrong_data[i,-1:])*x[i,:]).reshape(-1,1)
self.w=self.w-self.lr*gradient
error+=1
#print(gradient.shape)
if error==0:
break
#w_vect=np.zeros((data.shape[1],1))
w_vect=np.array([[0],[0],[0]])
a=Perceptron(data,0.01,200,w_vect)
a.fit()
weights=a.w
w1 = weights[0][0]
w2 = weights[1][0]
bias = weights[-1][0]
print(a.w)
x6 = -w1 / w2 * np.array(x1) - bias / w2
plt.scatter(x1,x2)
plt.scatter(x1,x3)
plt.plot(x1,x6)

Two 、KNN
# establish kd Trees
import numpy as np
import matplotlib.pyplot as plt
class kdTree():
def __init__(self, parent_node):
# Node initialization
self.nodedata = None # The data value of the current node , Two dimensional data
self.split = None # Sequence number of the direction axis of the split plane ,0 Represents along x Axis segmentation ,1 Represents along y Axis segmentation
self.range = None # Split threshold
self.left = None # Left subtree node
self.right = None # Right subtree node
self.parent = parent_node # Parent node
self.leftdata = None # Keep all the data of the left node
self.rightdata = None # Keep all the data of the right node
self.isinvted = False # Record whether the current node has been accessed
def print(self):
# Print the current node information
print(self.nodedata, self.split, self.range)
def getSplitAxis(self, all_data):
# Determine the segmentation axis according to the variance
var_all_data = np.var(all_data, axis=0)
if var_all_data[0] > var_all_data[1]:
return 0
else:
return 1
def getRange(self, split_axis, all_data):
# Get the size of the median data value on the corresponding split axis
split_all_data = all_data[:, split_axis]
data_count = split_all_data.shape[0]
med_index = int(data_count/2)
sort_split_all_data = np.sort(split_all_data)
range_data = sort_split_all_data[med_index]
return range_data
def getNodeLeftRigthData(self, all_data):
# Divide the data into the left subtree , Right subtree and get the current node
data_count = all_data.shape[0]
ls_leftdata = []
ls_rightdata = []
for i in range(data_count):
now_data = all_data[i]
if now_data[self.split] < self.range:
ls_leftdata.append(now_data)
elif now_data[self.split] == self.range and self.nodedata == None:
self.nodedata = now_data
else:
ls_rightdata.append(now_data)
self.leftdata = np.array(ls_leftdata)
self.rightdata = np.array(ls_rightdata)
def createNextNode(self,all_data):
# Iteratively create nodes , Generate kd Trees
if all_data.shape[0] == 0:
print("create kd tree finished!")
return None
self.split = self.getSplitAxis(all_data)
self.range = self.getRange(self.split, all_data)
self.getNodeLeftRigthData(all_data)
if self.leftdata.shape[0] != 0:
self.left = kdTree(self)
self.left.createNextNode(self.leftdata)
if self.rightdata.shape[0] != 0:
self.right = kdTree(self)
self.right.createNextNode(self.rightdata)
def plotKdTree(self):
# Draw the recursive iteration process of tree structure on the graph
if self.parent == None:
plt.figure(dpi=300)
plt.xlim([0.0, 10.0])
plt.ylim([0.0, 10.0])
color = np.random.random(3)
if self.left != None:
plt.plot([self.nodedata[0], self.left.nodedata[0]],[self.nodedata[1], self.left.nodedata[1]], '-o', color=color)
plt.arrow(x=self.nodedata[0], y=self.nodedata[1], dx=(self.left.nodedata[0]-self.nodedata[0])/2.0, dy=(self.left.nodedata[1]-self.nodedata[1])/2.0, color=color, head_width=0.2)
self.left.plotKdTree()
if self.right != None:
plt.plot([self.nodedata[0], self.right.nodedata[0]],[self.nodedata[1], self.right.nodedata[1]], '-o', color=color)
plt.arrow(x=self.nodedata[0], y=self.nodedata[1], dx=(self.right.nodedata[0]-self.nodedata[0])/2.0, dy=(self.right.nodedata[1]-self.nodedata[1])/2.0, color=color, head_width=0.2)
self.right.plotKdTree()
# if self.split == 0:
# x = self.range
# plt.vlines(x, 0, 10, color=color, linestyles='--')
# else:
# y = self.range
# plt.hlines(y, 0, 10, color=color, linestyles='--')
test_array = 10.0*np.random.random([30,2])
my_kd_tree = kdTree(None)
my_kd_tree.createNextNode(test_array)
my_kd_tree.plotKdTree()
边栏推荐
猜你喜欢

植物生长h5动画js特效

SQL Server AVG函数取整问题

Arbre binaire partie 1

Cicflowmeter source code analysis and modification to meet requirements

Basic operations on binary tree

canvas管道动画js特效

Use of vim

利用pandas读取SQL Sever数据表

Queue queue

How to solve multi-channel customer communication problems in independent stations? This cross-border e-commerce plug-in must be known!
随机推荐
Groovy obtains Jenkins credentials through withcredentials
Go language development environment setup +goland configuration under the latest Windows
oracle池式连接请求超时问题排查步骤
416 binary tree (first, middle and last order traversal iteration method)
Floating point notation (summarized from cs61c and CMU CSAPP)
解决Deprecated: Methods with the same name as their class will not be constructors in报错方案
操作符详解
Geogebra instance clock
p5.js千纸鹤动画背景js特效
微信小程序學習之 實現列錶渲染和條件渲染.
Troubleshooting steps for Oracle pool connection request timeout
100 GIS practical application cases (XIV) -arcgis attribute connection and using Excel
植物生长h5动画js特效
小程序 rich-text中图片点击放大与自适应大小问题
414-二叉树的递归遍历
tf.errors
Wechat applet learning to achieve list rendering and conditional rendering
Engine localization adaptation & Reconstruction notes
Is there a reliable and low commission futures account opening channel in China? Is it safe to open an account online?
el-table表格的拖拽 sortablejs