当前位置:网站首页>Pytorch 通过 Tensor 某一维的值将 Tensor 分开的方法(简易)
Pytorch 通过 Tensor 某一维的值将 Tensor 分开的方法(简易)
2022-07-25 09:27:00 【Haulyn5】
需求与方法
场景是我现在有一个 4 * 2 的数据(Tensor),每一行是一个样本,第一列代表分类置信度,第二列是真实标签,我想根据真实标签的值将这个数据分成两个 Tensor,一个只包括 Label=1的样本,一个只包括 Label=0 的样本。
假设 Tensor 内容如下
Tensor a = [[92,1], [23,0],[67,1],[33,0]]
我们预期的结果是两个 Tensor 分别是 [[92, 1], [67, 1]] 与 [[23, 0], [33, 0]]。
代码如下:
a = torch.Tensor([[92,1], [23,0],[67,1],[33,0]])
tensor_label_1 = a[a[:,1]== 1, :]
tensor_label_0 = a[a[:,1]== 0, :]
print(tensor_label_1)
print(tensor_label_0)
# 执行结果如下
# tensor([[92., 1.],
# [67., 1.]])
# tensor([[23., 0.],
# [33., 0.]])理解
其实这里是用到了 PyTorch 的一些功能,为了理解逻辑,可以分开执行代码的部分。
单独执行 `a[:,1]== 1` 可以得到一个 Bool 类型的 Tensor:
tensor([ True, False, True, False])
而使用 Bool 类型的 Tensor 可以作为 Tensor 的索引使用,所以 `a[a[:,1]== 1, :]` 即表示,对于 Tensor a,按照 True, False, True, False 的顺序选取 行(row),而列的索引是 ";" ,表示选中行的每一列都要放入结果中返回。
边栏推荐
- Usage of string slicing
- 1、 Initial mysql, MySQL installation, environment configuration, initialization
- Download and installation of QT 6.2
- oracle 解析同名xml 问题
- [tensorflow2 installation] tensorflow2.3-cpu installation pit avoidance guide!!!
- mysql 解决不支持中文的问题
- 一文学会,三款黑客必备的抓包工具教学
- SSM整合(简单的图书管理系统来整合SSM)
- 线程池的设计和原理
- 腾讯云之错误[100007] this env is not enable anonymous login
猜你喜欢
随机推荐
Copy the old project into a web project
Common methods of JS digital thousand bit segmentation
IO流中的输入流
将 conda 虚拟环境 env 加入 jupyter kernel
js数字千位分割的常用方法
Filter filter details (listeners and their applications)
C3D模型pytorch源码逐句详析(一)
RedisUtil
SQL 题目整理
Record of deep learning segment error (segment core/ exit code 139)
ES6 detailed explanation
oracle 解析同名xml 问题
[recommended collection] with these learning methods, I joined the world's top 500 - the "fantastic skills and extravagance" in the Internet age
字典树的使用
【无标题】
集合的创建,及常用方法
CentOs安装redis
GUI窗口
NPM详解
OSPF协议的配置(以华为eNSP为例)

![[deployment of deep learning model] deploy the deep learning model using tensorflow serving + tornado](/img/62/78abf16bb6c66726c6e394c9fb4f81.png)







