当前位置:网站首页>pytorch学习09:矩阵基本运算
pytorch学习09:矩阵基本运算
2022-06-21 23:47:00 【HMTT】
四则运算
import torch
a = torch.tensor([
[1,2],
[3,4]
])
b = torch.tensor([
[10, 20]
])
# 加
print("torch.all(torch.eq(a+b, torch.add(a,b))):",
torch.all(torch.eq(a+b, torch.add(a,b))))
print("a+b:\n{}\n".format(a+b))
# 减
print("torch.all(torch.eq(a-b, torch.sub(a,b))):",
torch.all(torch.eq(a-b, torch.sub(a,b))))
print("a*b:\n{}\n".format(a-b))
# 乘(是点乘)
print("torch.all(torch.eq(a*b, torch.mul(a,b))):",
torch.all(torch.eq(a*b, torch.mul(a,b))))
print("a*b:\n{}\n".format(a*b))
# 除
print("torch.all(torch.eq(a/b, torch.div(a,b))):",
torch.all(torch.eq(a/b, torch.div(a,b))))
print("a*b:\n{}\n".format(a/b))

矩阵相乘
import torch
a = torch.tensor([
[1],
[3]
])
b = torch.tensor([
[10, 20]
])
# mm只能运算至多二维矩阵
print("torch.mm(a, b):\n{}\n".format(torch.mm(a, b)))
# matmul可运算更高维矩阵
print("torch.matmul(a, b):\n{}\n".format(torch.matmul(a, b)))
print("[email protected]:\n{}\n".format([email protected]))

大于2维的矩阵相乘
import torch
a1 = torch.rand(4, 3, 28, 64)
b1 = torch.rand(4, 3, 64, 32)
c1 = torch.matmul(a1, b1)
# 对最后两维进行乘法运算
# 可以理解为多个矩阵并行相乘
print("c1.shape: ", c1.shape)
a2 = torch.rand(4, 1, 28, 64)
b2 = torch.rand(4, 3, 64, 32)
c2 = torch.matmul(a2, b2)
# 这里用到了广播机制
print("c2.shape: ", c2.shape)

幂运算
import torch
a = torch.tensor([
[1, 2],
[3, 4]
])
print("a.pow(2):\n{}\n".format(a.pow(2)))
print("a**2:\n{}\n".format(a**2))
print("a.pow(0.5):\n{}\n".format(a.pow(0.5)))
print("a.sqrt():\n{}\n".format(a.sqrt()))
# 平方根的倒数
print("a.rsqrt():\n{}\n".format(a.rsqrt()))
print("a**0.5:\n{}\n".format(a**0.5))

exp log
import torch
a = torch.tensor([
[1, 2],
[3, 4]
])
a_exp = torch.exp(a)
# e^x
print("torch.exp(a):\n{}\n".format(a_exp))
# ln x
# 以2为底:log2
# 以10为底:log10
print("torch.log(a_exp):\n{}\n".format(torch.log(a_exp)))

近似值
import torch
a = torch.tensor(1.67)
# 向下取整
print("a.floor():", a.floor())
# 向上取整
print("a.ceil():", a.ceil())
# 取整数部分
print("a.trunc():", a.trunc())
# 取小数部分
print("a.frac():", a.frac())
# 四舍五入
print("a.round():", a.round())

最大值、最小值、中位数
import torch
a = torch.rand(2,3)*20
print("a:\n{}\n".format(a))
# 最大值
print("a.max(): ", a.max())
# 中位数,偶数时不取平均,取从小到大第 length/2 个
print("a.median(): ", a.median())
# 最小值
print("a.min(): ", a.min())

限制区间
import torch
a = torch.rand(2,3)*20
print("a:\n{}\n".format(a))
# clamp(min),当有值小于 min 时,用 min 替换
print("a.clamp(10):\n{}\n".format(a.clamp(10)))
# clamp(min, max),当有值小于 min 时,用 min 替换
# 当有值大于 max 时,用 max 替换
print("a.clamp(5, 10):\n{}\n".format(a.clamp(5, 10)))

边栏推荐
猜你喜欢

MySQL 8.0 新特性梳理汇总

Meetup03期回顾:Linkis新版本介绍以及DSS的应用实践

How to use through-hole conductive slip ring

eslint:错误

Bit operation bit and

对面积的曲面积分中dS与dxdy的转换

Lecture 3 of Data Engineering Series: characteristic engineering of data centric AI

Hongmeng OS learning (rotation chart, list, icon)

NS32F103VBT6软硬件替代STM32F103VBT6

过孔式导电滑环怎么用
随机推荐
How to gracefully count code time
[2023 approval in advance] China Singapore SECCO
旋转框目标检测————关于旋转框定义和解决方案
Mathematical knowledge: greatest common divisor divisor
Record a small JSP bug
程序员坐牢了,会被安排去写代码吗?
Acwing game 56
[sword finger offer] 43 Number of occurrences of 1 in integers 1 to n
note
How to use through-hole conductive slip ring
[set static route] "WiFi for private internal network and external network“
The importance of rational selection of seal clearance of hydraulic slip ring
[wechat applet] 40029 invalid code solution set
Web应用系统开发的两种流行架构
Introduction and use of pytest fixture, confitest and mark
QT qmediaplayer get audio playback end status
导电滑环是如何工作的
Brief idea and simple case of JVM tuning - space allocation guarantee mechanism in the old age
[wechat applet] wechat applet uses pop-up box for multi-level linkage (example)
Use of MySQL performance analysis tools