当前位置:网站首页>Comparison between pytorch and paddlepaddle -- Taking the implementation of dcgan network as an example
Comparison between pytorch and paddlepaddle -- Taking the implementation of dcgan network as an example
2022-07-23 12:02:00 【KHB1698】
pytorch And paddlepaddle contrast —— With DCGAN Network implementation as an example
This paper takes the realization of handwritten numeral generation as an example to pytorch and paddlepaddle Contrast
Reference material :
- DCGAN Principle analysis and pytorch Realization
- DCGAN Detailed explanation of the paper
- PaddlePaddle And PyTorch Transformation
One 、pytorch And paddle contrast
PaddlePaddle 2.0 and PyTorch The style is very similar . Use PaddlePaddle You can call Baidu directly AI Studio Some resources in ( Include GPU、 Pre training weight and so on ), And documentation 、 The communities are all in Chinese , friendly ; and PyTorch stay Github There are more code and resources , The combination of the two is more fragrant . Now I've sorted out some PaddlePaddle as well as PyTorch The corresponding function in . Of course , The best way to use it is to know the corresponding relationship , Go to PyTorch、PaddlePaddle See the specific instructions in the data manual on the official website
- Official comparison link :PyTorch-PaddlePaddle API The mapping table
Two 、DCGA Principle analysis
1、 What is the generation of adversary networks
Generative antagonistic network (GAN), Including generator and discriminator , The two models are trained simultaneously through the confrontation process .
generator , It can be understood as “ The artist 、 The creator ”, It learns to create images that look real .
Judging device , It can be understood as “ Art critic 、 Reviewer ”, It learns to distinguish between true and false images .
During training , The generator becomes more and more convenient in generating realistic images , The discriminator gradually becomes stronger in the ability to distinguish these images .
When the discriminator can no longer distinguish between real pictures and forged pictures , Balance the training process .
2、DCGAN Network architecture
DCGAN It mainly improves the original network architecture GAN,DCGAN The generator and discriminator of both utilize CNN The architecture replaces the original GAN The full connection network of , The main improvements are as follows :
- DCGAN The generator and discriminator of are discarded CNN Pooled layer , The discriminator retains CNN The overall structure of , The generator replaces the convolution layer with the deconvolution layer (fractional-strided convolution) Or transposed convolution (Convolution Transpose).
- In the discriminator and generator, it is used after each layer Batch Normalization(BN) layer , It is helpful to deal with training problems caused by poor initialization , Speed up model training , Improved training stability .
- utilize 1*1 The convolution layer is replaced by all the full connection layers .
- Use in generator except output layer Tanh(Sigmoid) Activation function , All other layers are used ReLu Activation function .
- Use in all layers of the discriminator LeakyReLU Activation function , Prevent gradient dilution .
DCAGN Through the above improvements, the generator structure is as follows :

3、 ... and 、DCGAN Handwritten numeral generation

Four 、paddle Realization DCGAN
This article only provides paddle Link to version (pytorch The code is my corresponding paddle The version was knocked by hand ), Suggest pytorch Version and paddle Version split screen view , Start by importing the corresponding packages and check them one by one . Actually pytorch and paddle The difference is very small , The specific functions are very similar , I hope it can enlighten you in the comparative study .
5、 ... and 、pytorch Realization DCGAN
import os
import random
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
# Define datasets
dataset = datasets.MNIST(root='dataset/mnist/', train=True, download=True,
transform = transforms.Compose([
# resize -> (32,32)
transforms.Resize((32,32)),
# Put the original image PIL Become tensor tensor(H*W*C)
transforms.ToTensor(),
# Normalize to -1~1
transforms.Normalize([127.5], [127.5])
]))
dataloader = DataLoader(dataset, shuffle=True, batch_size=32,num_workers=0)
# Look at the dimension of the input image
for data in dataloader:
break
data[0].shape
# Parameter initialized module , and paddle Dissimilarity
def weights_init(m):
classname = m.__class__.__name__
if hasattr(m, 'weight') and classname.find('Conv') != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0) # nn.init.constant_() Indicates that the deviation is defined as a constant 0
# Generator Code
class Generator(nn.Module):
def __init__(self, ):
super(Generator, self).__init__()
self.gen = nn.Sequential(
# input is Z, [B, 100, 1, 1] -> [B, 64 * 4, 4, 4]
nn.ConvTranspose2d(100, 64 * 4, 4, 1, 0, bias=False), # Pay attention to this method and paddle The name of is different from bias The difference between
nn.BatchNorm2d(64 * 4),
nn.ReLU(True),
# state size. [B, 64 * 4, 4, 4] -> [B, 64 * 2, 8, 8]
nn.ConvTranspose2d(64 * 4, 64 * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(64 * 2),
nn.ReLU(True),
# state size. [B, 64 * 2, 8, 8] -> [B, 64, 16, 16]
nn.ConvTranspose2d( 64 * 2, 64, 4, 2, 1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(True),
# state size. [B, 64, 16, 16] -> [B, 1, 32, 32]
nn.ConvTranspose2d( 64, 1, 4, 2, 1, bias=False),
nn.Tanh()
)
def forward(self, x):
return self.gen(x)
netG = Generator()
netG.apply(weights_init)
# netG.load_state_dict(weights_init)
# Print the model
print(netG)
class Discriminator(nn.Module):
def __init__(self,):
super(Discriminator, self).__init__()
self.dis = nn.Sequential(
# input [B, 1, 32, 32] -> [B, 64, 16, 16]
nn.Conv2d(1, 64, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2),
# state size. [B, 64, 16, 16] -> [B, 128, 8, 8]
nn.Conv2d(64, 64 * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(64 * 2),
nn.LeakyReLU(0.2),
# state size. [B, 128, 8, 8] -> [B, 256, 4, 4]
nn.Conv2d(64 * 2, 64 * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(64 * 4),
nn.LeakyReLU(0.2),
# state size. [B, 256, 4, 4] -> [B, 1, 1, 1] -> [B, 1]
nn.Conv2d(64 * 4, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, x):
return self.dis(x)
netD = Discriminator()
netD.apply(weights_init)
print(netD)
# Initialize BCELoss function
loss = nn.BCELoss() # Two classification cross entropy loss
# Create batch of latent vectors that we will use to visualize
# the progression of the generator
fixed_noise = torch.randn([32, 100, 1, 1])
# Establish convention for real and fake labels during training
real_label = 1.
fake_label = 0.
# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5,0.999))
optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5,0.999))
losses = [[], []]
#plt.ion()
now = 0
for pass_id in range(100):
for batch_id, (data, target) in enumerate(dataloader):
############################
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
###########################
optimizerD.zero_grad()
real_img = data
bs_size = real_img.shape[0]
label = torch.full((bs_size, 1, 1, 1), real_label)
real_out = netD(real_img)
errD_real = loss(real_out, label)
errD_real.backward()
noise = torch.randn([bs_size, 100, 1, 1])
fake_img = netG(noise)
label = torch.full((bs_size, 1, 1, 1), fake_label)
fake_out = netD(fake_img.detach())
errD_fake = loss(fake_out,label)
errD_fake.backward()
optimizerD.step()
optimizerD.zero_grad()
errD = errD_real + errD_fake
losses[0].append(errD.detach().numpy())
############################
# (2) Update G network: maximize log(D(G(z)))
###########################
optimizerG.zero_grad()
noise = torch.randn([bs_size, 100, 1, 1])
fake = netG(noise)
label = torch.full((bs_size, 1, 1, 1), real_label)
output = netD(fake)
errG = loss(output,label)
errG.backward()
optimizerG.step()
optimizerG.zero_grad()
losses[1].append(errG.detach().numpy())
############################
# visualize
###########################
if batch_id % 100 == 0:
generated_image = netG(noise).detach().numpy()
imgs = []
plt.figure(figsize=(15,15))
try:
for i in range(10):
image = generated_image[i].transpose()
image = np.where(image > 0, image, 0)
image = image.transpose((1,0,2))
plt.subplot(10, 10, i + 1)
plt.imshow(image[...,0], vmin=-1, vmax=1)
plt.axis('off')
plt.xticks([])
plt.yticks([])
plt.subplots_adjust(wspace=0.1, hspace=0.1)
msg = 'Epoch ID={0} Batch ID={1} \n\n D-Loss={2} G-Loss={3}'.format(pass_id, batch_id, errD.detach().numpy(), errG.detach().numpy())
print(msg)
plt.suptitle(msg,fontsize=20)
plt.draw()
# plt.savefig('{}/{:04d}_{:04d}.png'.format('work', pass_id, batch_id), bbox_inches='tight')
plt.pause(0.01)
except IOError:
print(IOError)
paddle.save(netG.state_dict(), "generator.pth")
边栏推荐
猜你喜欢
随机推荐
Print right angle triangle, isosceles triangle, diamond
APP自动化测试工具-appium的安装及使用
New BPMN file used by activiti workflow
Review of knowledge points
Websocket long connection
Object类
Tcp/ip protocol
MySQL数据库
Internet communication
Iterative display of.H5 files, h5py data operation
Service服务
Installation and process creation of activiti app used by activiti workflow
百变冰冰!使用飞桨的PaddleGAN实现妆容迁移
DBA command
MySQL事务
Project instances used by activiti workflow
Eigen多版本库安装
使用飞桨的paddleX-yoloV3对钢材缺陷检测开发和部署
1. Know the database
Lecturer solicitation order | Apache dolphin scheduler meetup sharing guests, looking forward to your topic and voice!









