当前位置:网站首页>pytorch(环境、tensorboard、transforms、torchvision、dataloader)
pytorch(环境、tensorboard、transforms、torchvision、dataloader)
2022-06-26 05:30:00 【月屯】
软件使用anacoder和pycharm
环境配置和安装
解决pip安装时速度慢的问题
安装anaconda
conda create -n pytorch python=3.6
激活版本
conda activate pytorch
查看工具包
安装pytorch
pytorch官网
安装指令
conda install pytorch torchvision cudatoolkit=9.2 -c pytorch -c defaults -c numba/1abe1/dev
根据官网安装合适版本
CommandNotFoundError: No command ‘conda creat’. Did you mean ‘conda create’?
进入python环境

检测本地gpu是否可被pytorch使用
>>> import torch
>>> torch.cuda.is_available()
pycharm安装
配置anaconda


错误Cannot run program “D:…\venv\Scripts\python.exe“ (in directory ): CreateProcess error=2
jupyter

进入Anaconda Prompt终端
conda activate pytorch

jupyter notebook


python学习
python学习两法宝
dir()函数,能让我们知道工具箱以及工具箱中的分隔区有什么东西。
help()函数,能让我们知道每个工具是如何使用的,工具的使用方法。
pycharm

数据

下载数据集
jupyvter使用

pycharm
image的使用


实例
from torch.utils.data import Dataset
from PIL import Image
import os
class MyData(Dataset):
def __init__(self,root_dir,label_dir):
self.root_dir=root_dir
self.label_dir=label_dir
self.path=os.path.join(self.root_dir,self.label_dir)
self.img_path=os.listdir(self.path)
def __getitem__(self, idx):
img_name=self.img_path[idx]
img_item_path=os.path.join(self.root_dir,self.label_dir,img_name)
img=Image.open(img_item_path)
label=self.label_dir
return img,label
def __len__(self):
return len(self.img_path)
ants_dataset=MyData("dataset/train","ants")

数据集相加
上面的基础上
ants_dataset=MyData("dataset/train","ants")
bees_dataset=MyData("dataset/train","bees")
train_dataset=ants_dataset+bees_dataset

创建label

import os
root_dir = "dataset/train"
target_dir = "ants_image"
img_path = os.listdir(os.path.join(root_dir,target_dir))
label=target_dir.split("_")[0]
out_dir = "ants_label"
for i in img_path:
file_name = i.split( "_jpg ")[0]
with open(os.path.join(root_dir,out_dir,"{}.txt".format(file_name)),'w') as f:
f.write(label)
tensorboard
Tensorboard原本是Google TensorFlow的可视化工具,可以用于记录训练数据、评估数据、网络结构、图像等,并且可以在web上展示,对于观察神经网络的过程非常有帮助。
安装
pip install tensorboard
SummaryWriter
add_scalar
from torch.utils.tensorboard import SummaryWriter
writer=SummaryWriter("logs")
for i in range(100):
writer.add_scalar("y=2x",2*i,i)# 依次是tag,y,x
writer.close()

AttributeError:module ‘distutils‘ has no attribute ‘version
add_image
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from PIL import Image
writer=SummaryWriter("logs")
image_path='dataset/train/bees_image/17209602_fe5a5a746f.jpg'
img_PIL=Image.open(image_path)
img_array=np.array(img_PIL)
print(type(img_array))
print(img_array.shape)
writer.add_image("test",img_array,3,dataformats='HWC')# 标记,图片,步长,格式
for i in range(100):
writer.add_scalar("y=2x",2*i,i)# 依次是tag,y,x
writer.close()

transforms
transforms.ToTensor图片格式转换tensor类型
from torchvision import transforms
from PIL import Image
img_path="dataset/train/ants_image/7759525_1363d24e88.jpg"
img=Image.open(img_path)
print(img,"\n")
tensor_trans=transforms.ToTensor()
tensor_img=tensor_trans(img)
print(tensor_img)

安装 opencv-python
pip --default-timeout=300 install opencv-python -i https://pypi.douban.com/simple
pip._vendor.urllib3.exceptions.ReadTimeoutError: HTTPSConnectionPool(host=‘files.pythonhosted.org’, port=443): Read timed out
镜像
torch.Tensor展示add_image
from torchvision import transforms
from PIL import Image
from torch.utils.tensorboard import SummaryWriter
img_path="dataset/train/ants_image/7759525_1363d24e88.jpg"
img=Image.open(img_path)
writer=SummaryWriter("logs")
tensor_trans=transforms.ToTensor()
tensor_img=tensor_trans(img)
writer.add_image("Tensor_img",tensor_img)
writer.close()

补充__call__
class person:
def __call__(self,name):
print( " __call__"+" Hello " + name)
def hello(self,name):
print( "hello"+ name )
person = person( )
person("zhangsan")
person.hello("lisi")
简单使用
from PIL import Image
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
writer=SummaryWriter("logs")
img=Image.open("dataset/train/ants_image/5650366_e22b7e1065.jpg")
print(img)
# 1.ToTensor
trans_totensor=transforms.ToTensor()
img_tensor=trans_totensor(img)
writer.add_image("ToTensor",img_tensor)
# 2.Normalize归一化
print(img_tensor[0][0][0])
# output = (input - mean) / std
# mean:各通道的均值
# std:各通道的标准差 rgb三种
trans_norm=transforms.Normalize([2,6,5],[3,2,1])
img_norm=trans_norm(img_tensor)
writer.add_image("Normalize",img_norm,2)
print(img_norm[0][0][0])
#3.resize
print(img.size)
trans_resize=transforms.Resize((512,512))
#img PIL->resize->img_resize PIL
img_resize=trans_resize(img)
# img_resize PIL ->totensor->img_resize tensor
img_resize=trans_totensor(img_resize)
writer.add_image("Resize",img_resize,0)
print(img_resize)
#4. compose resize-2
trans_resize_2=transforms.Resize(256)
#PIL->PIL->tensor
trans_compose=transforms.Compose([trans_resize_2,trans_totensor])
img_resize_2=trans_compose(img)
writer.add_image("Resize",img_resize_2,1)
#5.随即裁剪
trans_random=transforms.RandomCrop((100,200))
trans_compose_2=transforms.Compose([trans_random,trans_totensor])
for i in range(10):
writer.add_image("RandomCrop",trans_compose_2(img),i)
writer.close()
torchvision的数据集使用
import torchvision
#下载数据集
train_set=torchvision.datasets.CIFAR10(root="./data_set_train",train=True,download=True)
test_set=torchvision.datasets.CIFAR10(root="./data_set_test",train=False,download=True)
# 数据集的某一个类别
print(train_set[0])
# //数据集中的类别名称属性
print(train_set.classes)
#返回图片和目标值
img,target=train_set[0]
print(img)
print(target)
print(train_set.classes[target])
img.show()

将数据转化tensor类型
import torchvision
dataset_tansform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
train_set=torchvision.datasets.CIFAR10(root="./data_set_train",train=True,transform=dataset_tansform,download=True)
test_set=torchvision.datasets.CIFAR10(root="./data_set_test",train=False,transform=dataset_tansform,download=True)
print(test_set[0])

import torchvision
from torch.utils.tensorboard import SummaryWriter
# /类型转换
dataset_tansform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
train_set=torchvision.datasets.CIFAR10(root="./data_set_train",train=True,transform=dataset_tansform,download=True)
test_set=torchvision.datasets.CIFAR10(root="./data_set_test",train=False,transform=dataset_tansform,download=True)
#类型转换结果查看
print(test_set[0])
writer=SummaryWriter("p10")
for i in range(10):
img,target=test_set[i]
writer.add_image("test_set",img,i)
writer.close()
dataloader的使用
import torchvision
from torch.utils.data import DataLoader
# 准备的数据集
test_data=torchvision.datasets.CIFAR10("./data_set_test",train=False,transform=torchvision.transforms.ToTensor())
# 测试数据集中第一张图片和target
test_loader=DataLoader(dataset=test_data,batch_size=4,shuffle=True,num_workers=0,drop_last=False)
img,target= test_data[0]
print(img.shape)
print(target)
for data in test_loader:
imggs,targets=data
print(imggs.shape)
print(targets)

SummaryWriter查看数据集
import torchvision
from torch.utils.data import DataLoader
# 准备的数据集
from torch.utils.tensorboard import SummaryWriter
test_data=torchvision.datasets.CIFAR10("./data_set_test",train=False,transform=torchvision.transforms.ToTensor())
# 测试数据集中第一张图片和target,参数:数据集,每份数量,是否洗牌,0,是否要最后余数
test_loader=DataLoader(dataset=test_data,batch_size=4,shuffle=True,num_workers=0,drop_last=False)
img,target= test_data[0]
print(img.shape)
print(target)
writer=SummaryWriter("dataloader")
step=0
for data in test_loader:
imgs,targets=data
print(imgs.shape)
print(targets)
writer.add_images("test_data",imgs,step)
step+=1
writer.close()
边栏推荐
- Henkel database custom operator '~~‘
- Internship May 29, 2019
- PHP 2D / multidimensional arrays are sorted in ascending and descending order according to the specified key values
- FindControl的源代码
- How to ensure the efficiency and real-time of pushing large-scale group messages in mobile IM?
- 《财富自由之路》读书之一点体会
- cartographer_optimization_problem_2d
- Setting pseudo static under fastadmin Apache
- [leetcode] 713: subarray with product less than k
- Some doubts about ARP deception experiment
猜你喜欢

Tp5.0 framework PDO connection MySQL error: too many connections solution

uniCloud云开发获取小程序用户openid
![C# 40. Byte[] to hexadecimal string](/img/3e/1b8b4e522b28eea4faca26b276a27b.png)
C# 40. Byte[] to hexadecimal string

As promised: Mars, the mobile terminal IM network layer cross platform component library used by wechat, has been officially open source

【ARM】讯为rk3568开发板buildroot添加桌面应用

The State Council issued a document to improve the application of identity authentication and electronic seals, and strengthen the construction of Digital Government
![[arm] add desktop application for buildreoot of rk3568 development board](/img/9a/28015cdea7362261c39ffc7f6e13a9.png)
[arm] add desktop application for buildreoot of rk3568 development board

AutowiredAnnotationBeanPostProcessor什么时候被实例化的?

使用Jenkins执行TestNg+Selenium+Jsoup自动化测试和生成ExtentReport测试报告

Leetcode513. Find the value in the lower left corner of the tree
随机推荐
Mongodb image configuration method
【红队】要想加入红队,需要做好哪些准备?
LeetCode_二叉搜索树_简单_108.将有序数组转换为二叉搜索树
How to make your big file upload stable and fast?
Introduction to alluxio
Vie procédurale
The parameter field of the callback address of the payment interface is "notify_url", and an error occurs after encoding and decoding the signed special character URL (,,,,,)
Pytorch中自己所定义(修改)的模型加载所需部分预训练模型参数并冻结
Replacing domestic image sources in openwrt for soft routing (take Alibaba cloud as an example)
Positioning setting horizontal and vertical center (multiple methods)
Yunqi lab recommends experience scenarios this week, free cloud learning
The difference between get and post in small interview questions
国务院发文,完善身份认证、电子印章等应用,加强数字政府建设
PHP 2D / multidimensional arrays are sorted in ascending and descending order according to the specified key values
Learn cache lines and pseudo sharing of JVM slowly
cartographer_local_trajectory_builder_2d
uni-app吸顶固定样式
睛天霹雳的消息
Excellent learning ability is your only sustainable competitive advantage
cartographer_ pose_ graph_ 2d

