当前位置:网站首页>Pytorch 通过 Tensor 某一维的值将 Tensor 分开的方法(简易)
Pytorch 通过 Tensor 某一维的值将 Tensor 分开的方法(简易)
2022-07-25 09:27:00 【Haulyn5】
需求与方法
场景是我现在有一个 4 * 2 的数据(Tensor),每一行是一个样本,第一列代表分类置信度,第二列是真实标签,我想根据真实标签的值将这个数据分成两个 Tensor,一个只包括 Label=1的样本,一个只包括 Label=0 的样本。
假设 Tensor 内容如下
Tensor a = [[92,1], [23,0],[67,1],[33,0]]
我们预期的结果是两个 Tensor 分别是 [[92, 1], [67, 1]] 与 [[23, 0], [33, 0]]。
代码如下:
a = torch.Tensor([[92,1], [23,0],[67,1],[33,0]])
tensor_label_1 = a[a[:,1]== 1, :]
tensor_label_0 = a[a[:,1]== 0, :]
print(tensor_label_1)
print(tensor_label_0)
# 执行结果如下
# tensor([[92., 1.],
# [67., 1.]])
# tensor([[23., 0.],
# [33., 0.]])理解
其实这里是用到了 PyTorch 的一些功能,为了理解逻辑,可以分开执行代码的部分。
单独执行 `a[:,1]== 1` 可以得到一个 Bool 类型的 Tensor:
tensor([ True, False, True, False])
而使用 Bool 类型的 Tensor 可以作为 Tensor 的索引使用,所以 `a[a[:,1]== 1, :]` 即表示,对于 Tensor a,按照 True, False, True, False 的顺序选取 行(row),而列的索引是 ";" ,表示选中行的每一列都要放入结果中返回。
边栏推荐
- 腾讯云之错误[100007] this env is not enable anonymous login
- [tensorflow2 installation] tensorflow2.3-cpu installation pit avoidance guide!!!
- 数据库MySQL详解
- 记录一些JS工具函数
- Summary of most consistency problems
- 字符串切片的用法
- 多线程——死锁和synchronized
- CCF 201509-2 date calculation
- nodejs版本升级或切换的常用方式
- Probabilistic robot learning notes Chapter 2
猜你喜欢
随机推荐
Reflection 反射
TCP传输
鼠标监听,画笔
CentOS install redis
NPM details
小程序调起微信支付
struct2的原理
Round to the nearest
Subtotal of rospy odometry sinkhole
链表相关(设计链表及环链表问题)
Swing组件之单选与多选按钮
DHCP的配置(以华为eNSP为例)
几个常用的网络诊断命令
CCF 201503-3 Festival
ThreadLocal&Fork/Join
Pow(x,n)
UE4 LoadingScreen动态加载启动动画
vscode插件开发
线程池的设计和原理
修改mysql的分组报错Expression #1 of SELECT list is not in GROUP









