当前位置:网站首页>Pytorch的旅程一:线性模型
Pytorch的旅程一:线性模型
2022-07-24 18:15:00 【Kang|King】
代码解析
内层循环:核心计算内容:
从数据集中,按数据对儿取出自变量x_val和真实值y_val;先调用forward函数,计算预测值 w*x(y_hat);调用loss函数,计算单个数据的损失数值;累加损失,并记下来(此处要提前初始化一个值为0的变量,后面才能不报错);随意打印想要看到的内容,一般是打印x_val、y_val、loss_val;在外层循环中(也就是每一个数据对儿计算的时候),都要把计算的结果,放进之前的空列表,用于绘图。
全部代码
import numpy as np
import matplotlib.pyplot as plt
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
def forward(x):
return x * w
# loss function 是 均方根误差 loss = (y_hat - y) ** 2
def loss(x, y):
y_pred = forward(x)
return (y_pred - y) * (y_pred - y)
w_list = []
mse_list = []
for w in np.arange(0.0, 4.0, 0.1):
print('w=', w)
l_sum = 0
for x_val, y_val in zip(x_data, y_data):
# A zip object yielding tuples until an input is exhausted;
y_pred_val = forward(x_val)
loss_val = loss(x_val, y_val)
# 传入的是x_val,但是经过loss中的forward计算后,已经是y_hat(估计值)了;
l_sum += loss_val
print('\t', x_val, y_val, y_pred_val, loss_val)
print('MSE=', l_sum / len(x_data))
# 求一下 损失的均值
w_list.append(w)
mse_list.append(l_sum / len(x_data))
plt.plot(w_list, mse_list)
plt.ylabel('Loss')
plt.xlabel('w')
plt.show()
边栏推荐
- New can also create objects. Why do you need factory mode?
- Has polardb for PostgreSQL entered the list of Xinchuang database?
- Alibaba /166 obtains the API instructions for all products in the store
- Custom web framework
- Go language interface and type
- Blackmagic Fusion Studio 18
- 0611~自习课
- Go language file operation
- 【刷题记录】20. 有效的括号
- 0612~quartz定时器框架
猜你喜欢

Mozilla foundation released 2022 Internet health report: AI will contribute 15.7 trillion yuan to the global economy in 2030, and the investment in AI in the United States last year was nearly three t

Inheritance and Derive

Brats18 - Multimodal MR image brain tumor segmentation challenge continued

T245982 "kdoi-01" drunken flower Yin

【OpenCV】—阈值化
![[opencv] - thresholding](/img/4e/88c8c8063de7cb10e44e76e77dbb8e.png)
[opencv] - thresholding

Number of times a number appears in an ascending array

pycharm配置opencv库

Laravel notes - RSA encryption of user login password (improve system security)

SSM framework learning
随机推荐
0701~ holiday summary
Goodbye Navicat! This open source database management tool has a cooler interface!
Shengxin commonly used analysis graphics drawing 02 -- unlock the essence of volcano map!
Alibaba /166 obtains the API instructions for all products in the store
How does win11 enhance the microphone? Win11 enhanced microphone settings
Go to bed capacity exchange
Go language interface and type
[OBS] cooperation between video and audio coding and RTMP transmission
0616项目二结束~~总总结
【OpenCV】—阈值化
Use of jumpserver
0623~ holiday self study
How to render millions of 2D objects smoothly with webgpu?
In depth analysis of the famous Alibaba cloud log4j vulnerability
Stream, file, IO
运维小白成长记——架构第8周
字符串常用方法(2)
头文件是必须的吗?跟一跟编译过程~~~
排序的几种方式for while 还有sort
JMeter -- silent operation