当前位置:网站首页>Pytorch learning 09: basic matrix operations

Pytorch learning 09: basic matrix operations

2022-06-22 00:58:00 HMTT

arithmetic

import torch

a = torch.tensor([
    [1,2],
    [3,4]
])

b = torch.tensor([
    [10, 20]
])

#  Add 
print("torch.all(torch.eq(a+b, torch.add(a,b))):",
      torch.all(torch.eq(a+b, torch.add(a,b))))
print("a+b:\n{}\n".format(a+b))

#  reduce 
print("torch.all(torch.eq(a-b, torch.sub(a,b))):",
      torch.all(torch.eq(a-b, torch.sub(a,b))))
print("a*b:\n{}\n".format(a-b))

#  ride ( It's point by point )
print("torch.all(torch.eq(a*b, torch.mul(a,b))):",
      torch.all(torch.eq(a*b, torch.mul(a,b))))
print("a*b:\n{}\n".format(a*b))

#  except 
print("torch.all(torch.eq(a/b, torch.div(a,b))):",
      torch.all(torch.eq(a/b, torch.div(a,b))))
print("a*b:\n{}\n".format(a/b))

 Please add a picture description

matrix multiplication

import torch

a = torch.tensor([
    [1],
    [3]
])

b = torch.tensor([
    [10, 20]
])

# mm Can only operate on at most two-dimensional matrices 
print("torch.mm(a, b):\n{}\n".format(torch.mm(a, b)))
# matmul Computable higher dimensional matrices 
print("torch.matmul(a, b):\n{}\n".format(torch.matmul(a, b)))
print("[email protected]:\n{}\n".format([email protected]))

 Please add a picture description

Greater than 2 Matrix multiplication of dimensions

import torch

a1 = torch.rand(4, 3, 28, 64)
b1 = torch.rand(4, 3, 64, 32)

c1 = torch.matmul(a1, b1)
#  Multiply the last two dimensions 
#  It can be understood as the parallel multiplication of multiple matrices 
print("c1.shape: ", c1.shape)

a2 = torch.rand(4, 1, 28, 64)
b2 = torch.rand(4, 3, 64, 32)
c2 = torch.matmul(a2, b2)
#  Here's the broadcast mechanism 
print("c2.shape: ", c2.shape)

 Please add a picture description

Power operation

import torch

a = torch.tensor([
    [1, 2],
    [3, 4]
])

print("a.pow(2):\n{}\n".format(a.pow(2)))
print("a**2:\n{}\n".format(a**2))

print("a.pow(0.5):\n{}\n".format(a.pow(0.5)))
print("a.sqrt():\n{}\n".format(a.sqrt()))
#  The reciprocal of the square root 
print("a.rsqrt():\n{}\n".format(a.rsqrt()))
print("a**0.5:\n{}\n".format(a**0.5))

 Please add a picture description

exp log

import torch

a = torch.tensor([
    [1, 2],
    [3, 4]
])

a_exp = torch.exp(a)
# e^x
print("torch.exp(a):\n{}\n".format(a_exp))
# ln x
#  With 2 Bottom :log2
#  With 10 Bottom :log10
print("torch.log(a_exp):\n{}\n".format(torch.log(a_exp)))

 Please add a picture description

Approximate value

import torch

a = torch.tensor(1.67)

#  Rounding down 
print("a.floor():", a.floor())
#  Rounding up 
print("a.ceil():", a.ceil())
#  Take the whole part 
print("a.trunc():", a.trunc())
#  Take the decimal part 
print("a.frac():", a.frac())
#  rounding 
print("a.round():", a.round())

 Please add a picture description

Maximum 、 minimum value 、 Median

import torch

a = torch.rand(2,3)*20

print("a:\n{}\n".format(a))

#  Maximum 
print("a.max(): ", a.max())
#  Median , Even numbers are not averaged , From small to large  length/2  individual 
print("a.median(): ", a.median())
#  minimum value 
print("a.min(): ", a.min())

 Please add a picture description

Restricted interval

import torch

a = torch.rand(2,3)*20

print("a:\n{}\n".format(a))

# clamp(min), When there is a value less than  min  when , use  min  Replace 
print("a.clamp(10):\n{}\n".format(a.clamp(10)))
# clamp(min, max), When there is a value less than  min  when , use  min  Replace 
#  When there is a value greater than  max  when , use  max  Replace 
print("a.clamp(5, 10):\n{}\n".format(a.clamp(5, 10)))

 Please add a picture description

原网站

版权声明
本文为[HMTT]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/173/202206212347164104.html