当前位置:网站首页>Pytorch idea and implementation of keras code conversion for CNN image classification
Pytorch idea and implementation of keras code conversion for CNN image classification
2022-06-22 07:08:00 【zorchp】
tags: Python DL
Write it at the front
I changed a code a few days ago , It is about convolution neural network in deep learning Python Code , Used to solve classification problems . The code uses TensorFlow Of Keras Interface written , Requirements are translated into pytorch Code , Whereas both api Close , It won't be too difficult to cover it , Just some details need to be paid attention to , Record it here , For your reference .
About library function import
First, let's take a look at the difference between the two popular deep learning frameworks in the import of library functions , This requires a brief understanding of their main structures . For the convenience of narration , As mentioned below TF All refer to TensorFlow2.X with Keras, Torch All refer to PyTorch.
model building
First, let's look at the construction of the model , about TF, The model can be built easily through sequential Method to get , This requires the introduction of this method :
from tensorflow.keras.models import Sequential
stay Torch in , Of course, it can also be sequential Build the model , ( However, the government still recommends an object-oriented approach )
We need to introduce :
from torch.nn import Sequential
Speaking of model building , We have to mention several layers that are commonly used in convolutional neural networks : conv layer , maxpool Layer and full connection layer (softmax), These are readily available in both frameworks , Let's see how to call these methods :
stay TF in :
from tensorflow.keras.layers import Conv2D, MaxPooling2D
from tensorflow.keras.layers import Activation, Dropout, Flatten, Dense
And in the Torch in :
from torch.nn import Conv2d, MaxPool2d
from torch.nn import Flatten, Linear, CrossEntropyLoss
from torch.optim import SGD
It can be seen that the two are only slightly different , TF Put the calls of some activation functions in the parameters , and Torch Are given in the form of library functions .
Data read in
Finally, let's take a look at the data import section , stay TF You can easily use the following methods to process data ( picture ) Processing and reading of :
from tensorflow.keras import backend
from tensorflow.keras.preprocessing.image import ImageDataGenerator
stay Torch in , Similar import is required :
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
data fetch / Deal with partial api differences
In the data reading section , I still feel like Keras It's more convenient 1, Torch It mainly uses the modular import method , You need to instantiate a class first , Then the object is used to process the image .
Let's take a look at TF Code for reading picture data :
# Import data
if backend.image_data_format() == 'channels_first':
input_shape = (3, img_width, img_height)
else:
input_shape = (img_width, img_height, 3)
# Training set image enhancement
train_datagen = ImageDataGenerator(
rescale=1. / 255,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True)
# Test set image enhancement (only rescaling)
test_datagen = ImageDataGenerator(rescale=1. / 255)
train_generator = train_datagen.flow_from_directory(
train_data_dir,
target_size=(img_width, img_height),
batch_size=batch_size,
class_mode='categorical') # Many classification
validation_generator = test_datagen.flow_from_directory(
validation_data_dir,
target_size=(img_width, img_height),
batch_size=batch_size,
class_mode='categorical') # Many classification
Next is Torch Code for :
# Import data
input_shape = (img_width, img_height, 3)
# Training set image enhancement
train_datagen = transforms.Compose([
transforms.ToTensor(),
transforms.RandomHorizontalFlip(),
transforms.Resize((img_width, img_height))
])
# Test set image enhancement (only rescaling)
test_datagen = transforms.Compose([ # Perform the following specified operations on the read picture
transforms.ToTensor(), # This step is equivalent to Keras Of rescale by 1/255
transforms.Resize((img_width, img_height))
])
train_generator = datasets.ImageFolder(train_data_dir,
transform=train_datagen)
validation_generator = datasets.ImageFolder(validation_data_dir,
transform=test_datagen)
train_loader = torch.utils.data.DataLoader(train_generator,
batch_size=batch_size,
shuffle=True)
test_loader = torch.utils.data.DataLoader(validation_generator,
batch_size=batch_size,
shuffle=False)
Of model building api differences
Let's talk about the most important , The building part of the model api The difference between calls , stay TF Directly in model.add Call to , You can easily create a CNN Identify the model , Note the correspondence of data flow dimensions , Here is the code . Concise and intuitive .
# Creating models
model = Sequential()
model.add(Conv2D(filters=6,
kernel_size=(5, 5),
padding='valid',
input_shape=input_shape,
activation='tanh'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(filters=16,
kernel_size=(5, 5),
padding='valid',
activation='tanh'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(120, activation='tanh'))
model.add(Dense(84, activation='tanh'))
model.add(Dense(4, activation='softmax'))
# Compile model
model.compile(loss='categorical_crossentropy',
optimizer='sgd',
metrics=['accuracy'])
stay Torch in , There's a similar way , However, there is no need to compile the model , The code is as follows :
# Creating models
model = Sequential(
Conv2d(in_channels=3,
out_channels=6,
kernel_size=(5, 5),
padding='valid'),
MaxPool2d(kernel_size=(2, 2)),
Conv2d(in_channels=6,
out_channels=16,
kernel_size=(5, 5),
padding='valid'),
MaxPool2d(kernel_size=(2, 2)),
Flatten(),
Linear(400, 120),
Linear(120, 84),
Linear(84, 4)
)
# The loss function is set as the cross entropy function
criterion = CrossEntropyLoss()
# Set optimizer to random gradient descent algorithm
optimizer = SGD(model.parameters(), lr=0.001)
Here it is api There are still some differences , For example, the writing method and parameters of the full connection layer , There are also some differences in convolution . Same as , Still pay great attention to the data dimension .
Model training part api differences
stay TF in , By introducing Keras This powerful and grammatically concise api, The training model is also very simple , The code is as follows :
# Training models
history=model.fit_generator(
train_generator,
steps_per_epoch=nb_train_samples // batch_size,
epochs=epochs,
validation_data=validation_generator,
validation_steps=nb_validation_samples // batch_size)
But in Torch in , You also need to build it step by step , A little fussy
n_total_steps = len(train_loader)
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(train_loader):
# Forward pass
outputs = model(images)
loss = criterion(outputs, labels)
# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (i+1) % 5 == 0:
print(f''' Epoch [{
epoch+1}/{
num_epochs}], Step [{
i+1}/{
n_total_steps}], Loss: {
loss.item():.4f} ''')
torch.save(model.state_dict(), './ckpt')
Summary
Make good use of search engine , Both frameworks are detailed in official documents api Usage method .
Main reference
边栏推荐
- 2022-06-21:golang选择题,以下golang代码输出什么?A:3;B:4;C:100;D:编译失败。 package main import ( “fmt“ ) func
- Access to control objects in JS
- cookie的介绍和使用
- Theory and application of naturallanguageprocessing
- 实训渗透靶场02|3星vh-lll靶机|vulnhub靶场Node1
- Py's optbinning: introduction, installation and detailed introduction of optbinning
- Introduction to 51 single chip microcomputer - matrix key
- Data security practice guide - data collection security management
- 自定义实现JS中的bind方法
- Cesium加载3D Tiles模型
猜你喜欢

【GAN】《ENERGY-BASED GENERATIVE ADVERSARIAL NETWORKS》 ICLR‘17

【GCN-RS】UltraGCN: Ultra Simplification of Graph Convolutional Networks for Recommendation (CIKM‘21)

Qt development simple Bluetooth debugging assistant (low power Bluetooth)

Introduction notes to quantum computing (continuously updated)

咖啡供应链是如何被趟平的?

汇编学习《汇编语言(第三版)》王爽著第四章学习
![[fundamentals of machine learning 02] decision tree and random forest](/img/24/28964279ea479476e1a232aec9c599.jpg)
[fundamentals of machine learning 02] decision tree and random forest
![[outside distribution detection] your classifier is secret an energy based model and you head treat it like one ICLR '20](/img/ec/605f5fcaaebd4829ab555626a9a3e9.jpg)
[outside distribution detection] your classifier is secret an energy based model and you head treat it like one ICLR '20

Introduction to 51 single chip microcomputer - matrix key

5G NR PWS系统
随机推荐
Up sampling and down sampling (notes, for personal use)
Rebuild binary tree
An image is worth 16x16 words: translators for image recognition at scale
33岁程序员的年中总结
咖啡供应链是如何被趟平的?
[out of distribution detection] energy based out of distribution detection nips' 20
生成字符串方式
Theory and application of naturallanguageprocessing
Notes on advanced combinatorics -- Conclusion
校招路上的坑
[fundamentals of machine learning 02] decision tree and random forest
[distributed external detection] Odin ICLR '18
Tableau 连接mysql详细教程
CNN model collection | RESNET variants -wideresnet interpretation
Successfully solved raise keyerror (F "none of [{key}] are in the [{axis\u name}]") keyerror: "none of [index (['age.in.y
Buuctf part Title WP
JS中如何阻止事件的传播
Advanced usage of setting breakpoints during keil debugging
vue连接mysql数据库失败
【GAN】SAGAN ICML‘19