当前位置:网站首页>[pytoch] calculate the derivative of sin (x) by automatic differentiation
[pytoch] calculate the derivative of sin (x) by automatic differentiation
2022-06-23 04:37:00 【Sin (Hao)】
1. Problem analysis
1.1 Problem description
In fact, the problem is very simple , This blog is a memo . Our need is :

picture source : Li Mu :《 Hands-on deep learning PyTorch edition 》
1.2 Solutions
We use pytorch The automatic differentiation mechanism of , That's what I wrote for the first time :
import torch
x = torch.arange(-5, 5, 0.01, requires_grad=True)
y = torch.sin(x)
y.backward()
print("f(x)=sin(x) The derivative of is :{}".format(x.grad))
It turned out to be wrong :


The reason is that y It's a vector , But in the use of y.backward() This method , if backward No incoming from gradient When this parameter , Only y It's scalar . The fact was x and y When it's all vectors ,y Yes x The derivative is a matrix ( Jacobian matrix ), I have some introductions in this blog : Vector to vector derivation , Get the Jacobian matrix .
The solution is quite simple , The essential idea is : Vector y Convert to scalar , Then use scalar y Conduct y.backward().
So how to switch ? In fact, we just need to do the vector y Sum each item of the , A scalar derivative of a vector is a vector , A vector x=[x1, x1, …, xn], vector y=[sin(x1), sin(x2), …, sin(xn)], We're going to vector y Sum each term of , We get scalars y1=sin(x1)+sin(x2)+…+sin(xn), Finally, use scalar y1 For vectors x Derivation , You can get the vector d(y1)=[cos(x1), cos(x2), …, cos(xn)], The code is as follows :
import torch
x = torch.arange(-5, 5, 0.1, requires_grad=True)
y = torch.sin(x)
y1 = y.sum()
y1.backward()
print("f(x)=sin(x) The derivative of is :\n{}".format(x.grad))

2. Code implementation
def get_derivative_tensor(x):
x = torch.tensor(x, requires_grad=True)
y1 = torch.sin(x)
y1.sum().backward()
return x.grad
x = np.arange(-5, 5, 0.01)
y = np.sin(x)
plt.plot(x, y, label='sin(x)')
dx = get_derivative_tensor(x)
plt.plot(x, dx.numpy(), label='dx')
plt.xlabel('x')
plt.ylabel('f(x)')
plt.legend()
plt.show()

END :)
In fact, I have been thinking about whether to write a summary for such a simple thing , Later, I felt that I should not despise the problems I didn't know .“ My teacher also said , I know that my husband was born in my year ?” Empathy , I am pursuing progress , It doesn't matter whether the problem itself is difficult , If you disdain to sum up because the problem is simple , Have been muddling along , In the long run, it is difficult to make progress .
边栏推荐
猜你喜欢

Review the SQL row column conversion, and the performance has been improved

Black horse PostgreSQL, why is it black in the end

众昂矿业:新能源新材料产业链对萤石需求大增

给你的AppImage创建桌面快捷方式

一篇文章学会er图绘制

什么是元数据

The spring recruitment in 2022 begins, and a collection of interview questions will help you

x24Cxx系列EEPROM芯片C语言通用读写程序

Pytoch --- use pytoch's pre training model to realize four weather classification problems

Getting started with tensorflow
随机推荐
Tiktok x-bogus and_ Signature parameter analysis
离线数仓建模中常见的概念-术语
Can MySQL be used in Linux
在线JSON转CSharp(C#)Class工具
虫子 日期类 下 太子语言
leetcode 91. Decode Ways 解码方法(中等)
Prince language on insect date class
P1347 排序(topo)
Tables de recherche statiques et tables de recherche statiques
[从零开始学习FPGA编程-40]:进阶篇 - 设计-竞争与风险Risk或冒险
What is metadata
Flutter series: wrap in flutter
A summary of PostgreSQL data types. All the people are here
京东云分布式数据库StarDB荣获中国信通院 “稳定性实践先锋”
What is the open source database under Linux
PTA:7-60 宠物的生长
Latest programming language rankings
Compilation, installation and global configuration section description of haproxy
Halcon知识:binocular_disparity 知识
QMainWindow