当前位置:网站首页>Pytorch的旅程一:线性模型
Pytorch的旅程一:线性模型
2022-07-24 18:15:00 【Kang|King】
代码解析
内层循环:核心计算内容:
从数据集中,按数据对儿取出自变量x_val和真实值y_val;先调用forward函数,计算预测值 w*x(y_hat);调用loss函数,计算单个数据的损失数值;累加损失,并记下来(此处要提前初始化一个值为0的变量,后面才能不报错);随意打印想要看到的内容,一般是打印x_val、y_val、loss_val;在外层循环中(也就是每一个数据对儿计算的时候),都要把计算的结果,放进之前的空列表,用于绘图。
全部代码
import numpy as np
import matplotlib.pyplot as plt
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
def forward(x):
return x * w
# loss function 是 均方根误差 loss = (y_hat - y) ** 2
def loss(x, y):
y_pred = forward(x)
return (y_pred - y) * (y_pred - y)
w_list = []
mse_list = []
for w in np.arange(0.0, 4.0, 0.1):
print('w=', w)
l_sum = 0
for x_val, y_val in zip(x_data, y_data):
# A zip object yielding tuples until an input is exhausted;
y_pred_val = forward(x_val)
loss_val = loss(x_val, y_val)
# 传入的是x_val,但是经过loss中的forward计算后,已经是y_hat(估计值)了;
l_sum += loss_val
print('\t', x_val, y_val, y_pred_val, loss_val)
print('MSE=', l_sum / len(x_data))
# 求一下 损失的均值
w_list.append(w)
mse_list.append(l_sum / len(x_data))
plt.plot(w_list, mse_list)
plt.ylabel('Loss')
plt.xlabel('w')
plt.show()
边栏推荐
- undefined reference to H5PTopen
- How to read "STL source code analysis"?
- Common methods of array (2)
- Pay close attention! List of the latest agenda of 2022 open atom open source Summit
- IO multiplexing
- Ship new idea 2022.2 was officially released, and the new features are really fragrant!
- 6126. 设计食物评分系统
- jmeter --静默运行
- 运维小白成长记——架构第8周
- 【“码”力全开,“章”显实力】2022年第1季Task挑战赛贡献者榜单
猜你喜欢

0625~<config>-<bus>

Shanghai Jiaotong University team used joint deep learning to optimize metabonomics research

Bib | mol2context vec: context aware deep network model learning molecular representation for drug discovery

Handwritten blog platform ~ the next day

Brats18 - Multimodal MR image brain tumor segmentation challenge continued
Go to bed capacity exchange

Codeforces Round #794 (Div. 2)(A.B.C)

Interview assault 66: what is the difference between request forwarding and request redirection?

【刷题记录】20. 有效的括号

Mozilla foundation released 2022 Internet health report: AI will contribute 15.7 trillion yuan to the global economy in 2030, and the investment in AI in the United States last year was nearly three t
随机推荐
【OpenCV】—阈值化
Interview assault 66: what is the difference between request forwarding and request redirection?
In depth analysis of the famous Alibaba cloud log4j vulnerability
【校验】只能输入数字(正负数)
2022 the latest short video de watermarking analysis API interface sharing
Ship new idea 2022.2 was officially released, and the new features are really fragrant!
odoo中的bom理解
0627~ holiday knowledge summary
Go to bed capacity exchange
Pay close attention! List of the latest agenda of 2022 open atom open source Summit
再见收费的Navicat!这款开源的数据库管理工具界面更炫酷!
【obs】依赖库: x264 vs 构建
[verification] only numbers (positive and negative numbers) can be entered
BOM understanding in odoo
JumpServer的使用
What are the pitfalls from single architecture to distributed architecture?
《STL源码剖析》应该怎样读?
The ability to detect movement in vivo and build a safe and reliable payment level "face brushing" experience
Is header file required? Follow the compilation process~~~
Go language file operation