当前位置:网站首页>Machine learning perceptron and k-nearest neighbor
Machine learning perceptron and k-nearest neighbor
2022-06-24 10:10:00 【Cpsu】
One 、 perceptron
import numpy as np
from matplotlib import pyplot as plt
np.random.seed(3)
# In fact, we should use x1=np.linspace(0,5,50)
#X Axis dataset
x1=[i for i in np.arange(0,5,0.1)]
# Positive sample data set
x2=np.abs(np.random.randn(50))
# Negative sample data set
x3=np.abs(np.random.randn(50)+8)
plt.scatter(x1,x2)
plt.scatter(x1,x3)

# Use the above data to build a data matrix
# features 1
x_1=x1+x1
# features 2
x_2=list(x2)+list(x3)
# Label column
y=[1]*50+[-1]*50
data=np.zeros((100,3))
data=np.c_[x_1,x_2,y]
data.shape
#(100, 3)
class Perceptron():
""" :param data: ndarray Format data : N x P N individual P D data :param lr: Learning rate :param maxiter: Maximum number of iterations :param w_vect: Initial weight vector """
def __init__(self,data,lr,maxiter,w_vect):
self.data=data
self.w=w_vect
self.lr=lr
self.maxiter=maxiter
def get_wrong(self):
""" :return: Misclassification data set matrix and corresponding index of original matrix """
# Filter data
x=np.c_[self.data[:,:-1],np.ones(self.data.shape[0])]
wrong_index=np.where((x.dot(self.w)*((self.data[:,-1]).reshape(-1,1))<=0))[0]
return data[wrong_index,:],wrong_index
def fit(self):
for j in range(self.maxiter):
error=0
wrong_data,wrong_index=self.get_wrong()
#print(wrong_data)
x=np.c_[wrong_data[:,:-1],np.ones(wrong_data.shape[0])]
# Scrambling data sets to obtain different hyperplane solutions
np.random.shuffle(x)
for i in range(0,wrong_data.shape[0]):
gradient=((-wrong_data[i,-1:])*x[i,:]).reshape(-1,1)
self.w=self.w-self.lr*gradient
error+=1
#print(gradient.shape)
if error==0:
break
#w_vect=np.zeros((data.shape[1],1))
w_vect=np.array([[0],[0],[0]])
a=Perceptron(data,0.01,200,w_vect)
a.fit()
weights=a.w
w1 = weights[0][0]
w2 = weights[1][0]
bias = weights[-1][0]
print(a.w)
x6 = -w1 / w2 * np.array(x1) - bias / w2
plt.scatter(x1,x2)
plt.scatter(x1,x3)
plt.plot(x1,x6)

Two 、KNN
# establish kd Trees
import numpy as np
import matplotlib.pyplot as plt
class kdTree():
def __init__(self, parent_node):
# Node initialization
self.nodedata = None # The data value of the current node , Two dimensional data
self.split = None # Sequence number of the direction axis of the split plane ,0 Represents along x Axis segmentation ,1 Represents along y Axis segmentation
self.range = None # Split threshold
self.left = None # Left subtree node
self.right = None # Right subtree node
self.parent = parent_node # Parent node
self.leftdata = None # Keep all the data of the left node
self.rightdata = None # Keep all the data of the right node
self.isinvted = False # Record whether the current node has been accessed
def print(self):
# Print the current node information
print(self.nodedata, self.split, self.range)
def getSplitAxis(self, all_data):
# Determine the segmentation axis according to the variance
var_all_data = np.var(all_data, axis=0)
if var_all_data[0] > var_all_data[1]:
return 0
else:
return 1
def getRange(self, split_axis, all_data):
# Get the size of the median data value on the corresponding split axis
split_all_data = all_data[:, split_axis]
data_count = split_all_data.shape[0]
med_index = int(data_count/2)
sort_split_all_data = np.sort(split_all_data)
range_data = sort_split_all_data[med_index]
return range_data
def getNodeLeftRigthData(self, all_data):
# Divide the data into the left subtree , Right subtree and get the current node
data_count = all_data.shape[0]
ls_leftdata = []
ls_rightdata = []
for i in range(data_count):
now_data = all_data[i]
if now_data[self.split] < self.range:
ls_leftdata.append(now_data)
elif now_data[self.split] == self.range and self.nodedata == None:
self.nodedata = now_data
else:
ls_rightdata.append(now_data)
self.leftdata = np.array(ls_leftdata)
self.rightdata = np.array(ls_rightdata)
def createNextNode(self,all_data):
# Iteratively create nodes , Generate kd Trees
if all_data.shape[0] == 0:
print("create kd tree finished!")
return None
self.split = self.getSplitAxis(all_data)
self.range = self.getRange(self.split, all_data)
self.getNodeLeftRigthData(all_data)
if self.leftdata.shape[0] != 0:
self.left = kdTree(self)
self.left.createNextNode(self.leftdata)
if self.rightdata.shape[0] != 0:
self.right = kdTree(self)
self.right.createNextNode(self.rightdata)
def plotKdTree(self):
# Draw the recursive iteration process of tree structure on the graph
if self.parent == None:
plt.figure(dpi=300)
plt.xlim([0.0, 10.0])
plt.ylim([0.0, 10.0])
color = np.random.random(3)
if self.left != None:
plt.plot([self.nodedata[0], self.left.nodedata[0]],[self.nodedata[1], self.left.nodedata[1]], '-o', color=color)
plt.arrow(x=self.nodedata[0], y=self.nodedata[1], dx=(self.left.nodedata[0]-self.nodedata[0])/2.0, dy=(self.left.nodedata[1]-self.nodedata[1])/2.0, color=color, head_width=0.2)
self.left.plotKdTree()
if self.right != None:
plt.plot([self.nodedata[0], self.right.nodedata[0]],[self.nodedata[1], self.right.nodedata[1]], '-o', color=color)
plt.arrow(x=self.nodedata[0], y=self.nodedata[1], dx=(self.right.nodedata[0]-self.nodedata[0])/2.0, dy=(self.right.nodedata[1]-self.nodedata[1])/2.0, color=color, head_width=0.2)
self.right.plotKdTree()
# if self.split == 0:
# x = self.range
# plt.vlines(x, 0, 10, color=color, linestyles='--')
# else:
# y = self.range
# plt.hlines(y, 0, 10, color=color, linestyles='--')
test_array = 10.0*np.random.random([30,2])
my_kd_tree = kdTree(None)
my_kd_tree.createNextNode(test_array)
my_kd_tree.plotKdTree()
边栏推荐
- El table Click to add row style
- 有关二叉树 的基本操作
- [db2] sql0805n solution and thinking
- Yolov6: the fast and accurate target detection framework is open source
- 时尚的弹出模态登录注册窗口
- Thinkphp5 multi language switching project practice
- Endgame P.O.O
- Arbre binaire partie 1
- 学习使用phpstripslashe函数去除反斜杠
- 411 stack and queue (20. valid parentheses, 1047. delete all adjacent duplicates in the string, 150. inverse Polish expression evaluation, 239. sliding window maximum, 347. the first k high-frequency
猜你喜欢

大中型企业如何构建自己的监控体系

JS singleton mode

How large and medium-sized enterprises build their own monitoring system

SVG+js拖拽滑块圆形进度条

Queue queue

numpy.linspace()

Tutorial (5.0) 08 Fortinet security architecture integration and fortixdr * fortiedr * Fortinet network security expert NSE 5

NVIDIA's CVPR 2022 oral is on fire! 2D images become realistic 3D objects in seconds! Here comes the virtual jazz band!

p5.js千纸鹤动画背景js特效

Impdp leading schema message ora-31625 exception handling
随机推荐
PostgreSQL DBA quick start - source compilation and installation
有关二叉树 的基本操作
JS singleton mode
被困英西中学的师生安全和食物有保障
记录一下MySql update会锁定哪些范围的数据
微信小程序rich-text图片宽高自适应的方法介绍(rich-text富文本)
How to solve multi-channel customer communication problems in independent stations? This cross-border e-commerce plug-in must be known!
买的长期理财产品,可以转短吗?
JS proxy mode
SQL sever基本数据类型详解
How to standardize data center infrastructure management process
Is there a reliable and low commission futures account opening channel in China? Is it safe to open an account online?
Arbre binaire partie 1
CVPR 2022 Oral | 英伟达提出自适应token的高效视觉Transformer网络A-ViT,不重要的token可以提前停止计算
p5.js千纸鹤动画背景js特效
Impdp leading schema message ora-31625 exception handling
涂鸦智能携多款重磅智能照明解决方案,亮相2022美国国际照明展
上升的气泡canvas破碎动画js特效
413-二叉树基础
Tutorial (5.0) 08 Fortinet security architecture integration and fortixdr * fortiedr * Fortinet network security expert NSE 5