当前位置:网站首页>Pytoch learning notes -- Summary of common functions of pytoch 1
Pytoch learning notes -- Summary of common functions of pytoch 1
2022-07-25 15:41:00 【whut_ L】
Catalog
2-set() Functions and sorted() function
3-DataLoader() Functions and Dataset class
5- Maximum pooling (max_pool2d) And average pooling (avg_pool2d) function
1-torch.randn() function
import torch
batch_size = 1
seq_len = 3
input_size = 4
inputs = torch.randn(seq_len, batch_size, input_size) torch.randn() The function is used to generate a group with an average of 0, The variance of 1( Standard normal distribution ) The random number . Examples are as follows :
import torch
print(torch.randn(3, 2, 3, 3))
torch.randn(seq_len, batch_size, input_size): The first parameter seq_len Represents sequence length , In the example, the sequence length is 3; The second parameter batch_size Indicates the batch size , The batch size in the example is 2; The third parameter input_size Is the dimension of the input vector , The example is (3, 3).( stay RNN Can be understood as : Example , share 3 A sequence of , Each sequence is divided into 2 batch , The dimension of each batch is 3*3.)
#####################################
#####################################
2-set() Functions and sorted() function
self.country_list = list(sorted(set(self.countries))) # set() duplicate removal , Delete duplicate data ; sorted() Sort set() Function is used to delete duplicate data elements ;sorted() Used for sorting elements , Examples are as follows :
a = ['china', 'china', 'japan']
print(list(set(a)))
print(list(sorted(set(a))))
because ‘c’ < 'j', therefore ‘china’ be ranked at ‘japan’ front .
#####################################
#####################################
3-DataLoader() Functions and Dataset class
from torchvision import datasets
from torch.utils.data import DataLoader, Dataset
batch_size = 64
transform = transforms.Compose([
transforms.ToTensor(), # take shape by (H, W, C) Of img To shape by (C, H, W) Of tensor, Normalize each value to [0,1]
transforms.Normalize((0.1307, ), (0.3081, )) # Data standardization by channel
])
train_dataset = datasets.MNIST(root = '../dataset/mnist/', train = True, download = True, transform = transform)
train_loader = DataLoader(train_dataset, shuffle = True, batch_size = batch_size)
test_dataset = datasets.MNIST(root = '../dataset/mnist/', train = False, download = True, transform = transform)
test_loader = DataLoader(test_dataset, shuffle = False, batch_size = batch_size)DataLoader() The data set imported by the function is Dataset type ,shuffle Indicates whether the data set is disturbed .
#####################################
#####################################
4-.t() function
.t() The delta function is going to Tensor Transposition , Examples are as follows :
import torch
input = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
print(input)
print(input.t())
#####################################
#####################################
5- Maximum pooling (max_pool2d) And average pooling (avg_pool2d) function
import torch
import torch.nn.functional as F
input = torch.tensor([[[1, 2, 3, 1], [4, 5, 6, 1], [7, 8, 9, 1]]]).unsqueeze(0).float() # unsqueeze(0) In the 0 Add a dimension before the dimension
print(input.size())
output = F.max_pool2d(input, kernel_size = (1, 4))
print(output)
max_pool2d(): Maximum pooling operation . According to the set core size , Select the largest element value . Example , The nuclear size is (1,4), It can be understood as picking out the maximum element value of each line .
It should be noted that :unsqueeze(0) The function of is in the 0 Expand a dimension before the dimension , therefore input Of size by (1, 1, 3, 4).
###
import torch
import torch.nn.functional as F
input = torch.randn(1, 1, 4, 4)
print(input.size())
print(input)
output = F.avg_pool2d(input, kernel_size = (2, 2))
print(output)
avg_pool2d(): Average pooling operations . According to the set core size , Calculate the average value of the elements in the core .
The role of pooling : Dimension reduction ; Suppress noise , Reduce information redundancy ; Improve the scale invariance of the model 、 Rotation does not deform ; Reduce the amount of model calculation ; Prevent over fitting .
边栏推荐
- No tracked branch configured for branch xxx or the branch doesn‘t exist. To make your branch trac
- 2021上海市赛-H-二分答案
- Pytorch学习笔记--常用函数总结3
- The difference between VaR, let and Const
- 2016CCPC网络选拔赛C-换根dp好题
- BPSK调制系统MATLAB仿真实现(1)
- Leetcode - 641 design cycle double ended queue (Design)*
- LeetCode - 677 键值映射(设计)*
- GAMES101复习:线性代数
- PAT甲级题目目录
猜你喜欢

Solve the vender-base.66c6fc1c0b393478adf7.js:6 typeerror: cannot read property 'validate' of undefined problem

Geogle Colab笔记1--运行Geogle云端硬盘上的.py文件

LeetCode - 303 区域和检索 - 数组不可变 (设计 前缀和数组)

谷歌博客:采用多重游戏决策Transformer训练通用智能体

Leetcode - 303 area and retrieval - array immutable (design prefix and array)

Leetcode - 362 knock counter (Design)

你准备好脱离“内卷化怪圈”了吗?

GAMES101复习:线性代数

活动回顾|7月6日安远AI x 机器之心系列讲座第2期|麻省理工教授Max Tegmark分享「人类与AI的共生演化 」

解决vender-base.66c6fc1c0b393478adf7.js:6 TypeError: Cannot read property ‘validate‘ of undefined问题
随机推荐
Cf750f1 thinking DP
Understanding the difference between wait() and sleep()
BPSK调制系统MATLAB仿真实现(1)
哪里有搭建flink cdc抽mysql数的demo?
LeetCode - 225 用队列实现栈
带你详细认识JS基础语法(建议收藏)
Idea eye care settings
Xcode added mobileprovision certificate file error: Xcode encoded an error
2021 Shanghai match-h-two point answer
Matlab randInt, matlab randInt function usage "recommended collection"
Understanding of this object
Flex 布局
Leetcode - 622 design cycle queue (Design)
2021HNCPC-E-差分,思维
Take you to learn more about JS basic grammar (recommended Collection)
Window system black window redis error 20creating server TCP listening socket *: 6379: listen: unknown error19-07-28
Leetcode - 380 o (1) time to insert, delete and get random elements (design hash table + array)
C#精挑整理知识要点11 委托和事件(建议收藏)
Pat grade a 1153 decode registration card of PAT (25 points)
Gary Marcus: 学习语言比你想象的更难