当前位置:网站首页>Comparison between pytorch and paddlepaddle -- Taking the implementation of dcgan network as an example

Comparison between pytorch and paddlepaddle -- Taking the implementation of dcgan network as an example

2022-07-23 12:02:00 KHB1698

pytorch And paddlepaddle contrast —— With DCGAN Network implementation as an example

This paper takes the realization of handwritten numeral generation as an example to pytorch and paddlepaddle Contrast

Reference material :

One 、pytorch And paddle contrast

PaddlePaddle 2.0 and PyTorch The style is very similar . Use PaddlePaddle You can call Baidu directly AI Studio Some resources in ( Include GPU、 Pre training weight and so on ), And documentation 、 The communities are all in Chinese , friendly ; and PyTorch stay Github There are more code and resources , The combination of the two is more fragrant . Now I've sorted out some PaddlePaddle as well as PyTorch The corresponding function in . Of course , The best way to use it is to know the corresponding relationship , Go to PyTorch、PaddlePaddle See the specific instructions in the data manual on the official website

Two 、DCGA Principle analysis

1、 What is the generation of adversary networks

Generative antagonistic network (GAN), Including generator and discriminator , The two models are trained simultaneously through the confrontation process .

generator , It can be understood as “ The artist 、 The creator ”, It learns to create images that look real .

Judging device , It can be understood as “ Art critic 、 Reviewer ”, It learns to distinguish between true and false images .

During training , The generator becomes more and more convenient in generating realistic images , The discriminator gradually becomes stronger in the ability to distinguish these images .

When the discriminator can no longer distinguish between real pictures and forged pictures , Balance the training process .

2、DCGAN Network architecture

DCGAN It mainly improves the original network architecture GAN,DCGAN The generator and discriminator of both utilize CNN The architecture replaces the original GAN The full connection network of , The main improvements are as follows :

  1. DCGAN The generator and discriminator of are discarded CNN Pooled layer , The discriminator retains CNN The overall structure of , The generator replaces the convolution layer with the deconvolution layer (fractional-strided convolution) Or transposed convolution (Convolution Transpose).
  2. In the discriminator and generator, it is used after each layer Batch Normalization(BN) layer , It is helpful to deal with training problems caused by poor initialization , Speed up model training , Improved training stability .
  3. utilize 1*1 The convolution layer is replaced by all the full connection layers .
  4. Use in generator except output layer Tanh(Sigmoid) Activation function , All other layers are used ReLu Activation function .
  5. Use in all layers of the discriminator LeakyReLU Activation function , Prevent gradient dilution .

DCAGN Through the above improvements, the generator structure is as follows :

3、 ... and 、DCGAN Handwritten numeral generation

Four 、paddle Realization DCGAN

This article only provides paddle Link to version (pytorch The code is my corresponding paddle The version was knocked by hand ), Suggest pytorch Version and paddle Version split screen view , Start by importing the corresponding packages and check them one by one . Actually pytorch and paddle The difference is very small , The specific functions are very similar , I hope it can enlighten you in the comparative study .

5、 ... and 、pytorch Realization DCGAN

import os
import random
import torch 
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision import transforms 
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation

#  Define datasets 
dataset = datasets.MNIST(root='dataset/mnist/', train=True, download=True, 
                            transform = transforms.Compose([
                                # resize -> (32,32)
                                transforms.Resize((32,32)),
                                #  Put the original image PIL Become tensor tensor(H*W*C)
                                transforms.ToTensor(),
                                #  Normalize to  -1~1
                                transforms.Normalize([127.5], [127.5])
                            ]))

dataloader = DataLoader(dataset, shuffle=True, batch_size=32,num_workers=0)


# Look at the dimension of the input image 
for data in dataloader:
    break

data[0].shape
# Parameter initialized module , and paddle Dissimilarity 
def weights_init(m):
    classname = m.__class__.__name__
    if hasattr(m, 'weight') and classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0) # nn.init.constant_() Indicates that the deviation is defined as a constant 0 
# Generator Code
class Generator(nn.Module):
    def __init__(self, ):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            # input is Z, [B, 100, 1, 1] -> [B, 64 * 4, 4, 4]
            nn.ConvTranspose2d(100, 64 * 4, 4, 1, 0, bias=False),  #  Pay attention to this method and paddle The name of is different from bias The difference between 
            nn.BatchNorm2d(64 * 4),
            nn.ReLU(True),
            # state size. [B, 64 * 4, 4, 4] -> [B, 64 * 2, 8, 8]
            nn.ConvTranspose2d(64 * 4, 64 * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 2),
            nn.ReLU(True),
            # state size. [B, 64 * 2, 8, 8] -> [B, 64, 16, 16]
            nn.ConvTranspose2d( 64 * 2, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            # state size. [B, 64, 16, 16] -> [B, 1, 32, 32]
            nn.ConvTranspose2d( 64, 1, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, x):
        return self.gen(x)


netG = Generator()

netG.apply(weights_init)

# netG.load_state_dict(weights_init)
# Print the model
print(netG)
class Discriminator(nn.Module):
    def __init__(self,):
        super(Discriminator, self).__init__()
        self.dis = nn.Sequential(

            # input [B, 1, 32, 32] -> [B, 64, 16, 16]
            nn.Conv2d(1, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2),

            # state size. [B, 64, 16, 16] -> [B, 128, 8, 8]
            nn.Conv2d(64, 64 * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 2),
            nn.LeakyReLU(0.2),

            # state size. [B, 128, 8, 8] -> [B, 256, 4, 4]
            nn.Conv2d(64 * 2, 64 * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 4),
            nn.LeakyReLU(0.2),

            # state size. [B, 256, 4, 4] -> [B, 1, 1, 1] -> [B, 1]
            nn.Conv2d(64 * 4, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.dis(x)

netD = Discriminator()
netD.apply(weights_init)
print(netD)
# Initialize BCELoss function
loss = nn.BCELoss() #  Two classification cross entropy loss 

# Create batch of latent vectors that we will use to visualize
# the progression of the generator
fixed_noise = torch.randn([32, 100, 1, 1])

# Establish convention for real and fake labels during training
real_label = 1.
fake_label = 0.

# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5,0.999))
optimizerG = optim.Adam(netG.parameters(), lr=0.0002,  betas=(0.5,0.999))
losses = [[], []]
#plt.ion()
now = 0
for pass_id in range(100):
    for batch_id, (data, target) in enumerate(dataloader):
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################

        optimizerD.zero_grad()
        real_img = data
        bs_size = real_img.shape[0]
        label = torch.full((bs_size, 1, 1, 1), real_label)
        real_out = netD(real_img)
        errD_real = loss(real_out, label)
        errD_real.backward()

        noise = torch.randn([bs_size, 100, 1, 1])
        fake_img = netG(noise)
        label = torch.full((bs_size, 1, 1, 1), fake_label)
        fake_out = netD(fake_img.detach())
        errD_fake = loss(fake_out,label)
        errD_fake.backward()
        optimizerD.step()
        optimizerD.zero_grad()

        errD = errD_real + errD_fake
        losses[0].append(errD.detach().numpy())

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        optimizerG.zero_grad()
        noise = torch.randn([bs_size, 100, 1, 1])
        fake = netG(noise)
        label = torch.full((bs_size, 1, 1, 1), real_label)
        output = netD(fake)
        errG = loss(output,label)
        errG.backward()
        optimizerG.step()
        optimizerG.zero_grad()

        losses[1].append(errG.detach().numpy())


        ############################
        # visualize
        ###########################
        if batch_id % 100 == 0:
            generated_image = netG(noise).detach().numpy()
            imgs = []
            plt.figure(figsize=(15,15))
            try:
                for i in range(10):
                    image = generated_image[i].transpose()
                    image = np.where(image > 0, image, 0)
                    image = image.transpose((1,0,2))
                    plt.subplot(10, 10, i + 1)
                    
                    plt.imshow(image[...,0], vmin=-1, vmax=1)
                    plt.axis('off')
                    plt.xticks([])
                    plt.yticks([])
                    plt.subplots_adjust(wspace=0.1, hspace=0.1)
                msg = 'Epoch ID={0} Batch ID={1} \n\n D-Loss={2} G-Loss={3}'.format(pass_id, batch_id, errD.detach().numpy(), errG.detach().numpy())
                print(msg)
                plt.suptitle(msg,fontsize=20)
                plt.draw()
                # plt.savefig('{}/{:04d}_{:04d}.png'.format('work', pass_id, batch_id), bbox_inches='tight')
                plt.pause(0.01)
            except IOError:
                print(IOError)
    paddle.save(netG.state_dict(), "generator.pth")
原网站

版权声明
本文为[KHB1698]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/204/202207230538545520.html