Pytorch学习笔记DCGAN极简入门教程

作者:xz1308579340 时间:2022-05-28 17:29:02 

1.图片分类网络

这是一个二分类网络,可以是alxnet ,vgg,resnet任何一个,负责对图片进行二分类,区分图片是真实图片还是生成的图片

2.图片生成网络

输入是一个随机噪声,输出是一张图片,使用的是反卷积层

相信学过深度学习的都能写出这两个网络,当然如果你写不出来,没关系,有人替你写好了

首先是图片分类网络:

简单来说就是cnn+relu+sogmid,可以换成任何一个分类网络,比如bgg,resnet等


class Discriminator(nn.Module):
   def __init__(self, ngpu):
       super(Discriminator, self).__init__()
       self.ngpu = ngpu
       self.main = nn.Sequential(
           # input is (nc) x 64 x 64
           nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
           nn.LeakyReLU(0.2, inplace=True),
           # state size. (ndf) x 32 x 32
           nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
           nn.BatchNorm2d(ndf * 2),
           nn.LeakyReLU(0.2, inplace=True),
           # state size. (ndf*2) x 16 x 16
           nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
           nn.BatchNorm2d(ndf * 4),
           nn.LeakyReLU(0.2, inplace=True),
           # state size. (ndf*4) x 8 x 8
           nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
           nn.BatchNorm2d(ndf * 8),
           nn.LeakyReLU(0.2, inplace=True),
           # state size. (ndf*8) x 4 x 4
           nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
           nn.Sigmoid()
       )
   def forward(self, input):
       return self.main(input)

重点是生成网络

代码如下,其实就是反卷积+bn+relu


class Generator(nn.Module):
   def __init__(self, ngpu):
       super(Generator, self).__init__()
       self.ngpu = ngpu
       self.main = nn.Sequential(
           # input is Z, going into a convolution
           nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
           nn.BatchNorm2d(ngf * 8),
           nn.ReLU(True),
           # state size. (ngf*8) x 4 x 4
           nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
           nn.BatchNorm2d(ngf * 4),
           nn.ReLU(True),
           # state size. (ngf*4) x 8 x 8
           nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
           nn.BatchNorm2d(ngf * 2),
           nn.ReLU(True),
           # state size. (ngf*2) x 16 x 16
           nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
           nn.BatchNorm2d(ngf),
           nn.ReLU(True),
           # state size. (ngf) x 32 x 32
           nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
           nn.Tanh()
           # state size. (nc) x 64 x 64
       )
   def forward(self, input):
       return self.main(input)

讲道理,以上两个网络都挺简单。

真正的重点到了,怎么训练

每一个step分为三个步骤:

  • 训练二分类网络
       1.输入真实图片,经过二分类,希望判定为真实图片,更新二分类网络
       2.输入噪声,进过生成网络,生成一张图片,输入二分类网络,希望判定为虚假图片,更新二分类网络

  • 训练生成网络
       3.输入噪声,进过生成网络,生成一张图片,输入二分类网络,希望判定为真实图片,更新生成网络

不多说直接上代码


for epoch in range(num_epochs):
   # For each batch in the dataloader
   for i, data in enumerate(dataloader, 0):
       ############################
       # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
       ###########################
       ## Train with all-real batch
       netD.zero_grad()
       # Format batch
       real_cpu = data[0].to(device)
       b_size = real_cpu.size(0)
       label = torch.full((b_size,), real_label, device=device)
       # Forward pass real batch through D
       output = netD(real_cpu).view(-1)
       # Calculate loss on all-real batch
       errD_real = criterion(output, label)
       # Calculate gradients for D in backward pass
       errD_real.backward()
       D_x = output.mean().item()
       ## Train with all-fake batch
       # Generate batch of latent vectors
       noise = torch.randn(b_size, nz, 1, 1, device=device)
       # Generate fake image batch with G
       fake = netG(noise)
       label.fill_(fake_label)
       # Classify all fake batch with D
       output = netD(fake.detach()).view(-1)
       # Calculate D's loss on the all-fake batch
       errD_fake = criterion(output, label)
       # Calculate the gradients for this batch
       errD_fake.backward()
       D_G_z1 = output.mean().item()
       # Add the gradients from the all-real and all-fake batches
       errD = errD_real + errD_fake
       # Update D
       optimizerD.step()
       ############################
       # (2) Update G network: maximize log(D(G(z)))
       ###########################
       netG.zero_grad()
       label.fill_(real_label)  # fake labels are real for generator cost
       # Since we just updated D, perform another forward pass of all-fake batch through D
       output = netD(fake).view(-1)
       # Calculate G's loss based on this output
       errG = criterion(output, label)
       # Calculate gradients for G
       errG.backward()
       D_G_z2 = output.mean().item()
       # Update G
       optimizerG.step()
       # Output training stats
       if i % 50 == 0:
           print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                 % (epoch, num_epochs, i, len(dataloader),
                    errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
       # Save Losses for plotting later
       G_losses.append(errG.item())
       D_losses.append(errD.item())
       # Check how the generator is doing by saving G's output on fixed_noise
       if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
           with torch.no_grad():
               fake = netG(fixed_noise).detach().cpu()
           img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
       iters += 1

来源:https://blog.csdn.net/xz1308579340/article/details/105883090

标签:Pytorch,DCGAN
0
投稿

猜你喜欢

  • Python3逻辑运算符与成员运算符

    2021-03-29 18:59:14
  • 解决Dreamweaver不支持中文文件名方法

    2008-01-09 12:52:00
  • python实现rsa加密实例详解

    2021-08-24 03:32:51
  • ASP长文章分页代码实例

    2007-10-02 17:04:00
  • PyCharm创建Django项目的简单步骤记录

    2023-08-28 11:03:37
  • FSO组件之文件操作(中)

    2010-05-03 11:05:00
  • Mysql 自动增加设定基值的语句

    2024-01-21 09:17:18
  • 把pandas转换int型为str型的方法

    2022-02-16 15:45:03
  • mysql 5.7.14 安装配置简单教程

    2024-01-13 04:41:48
  • Python API自动化框架总结

    2022-08-25 15:37:44
  • Go学习笔记之Zap日志的使用

    2023-09-19 01:21:36
  • python 爬取京东指定商品评论并进行情感分析

    2021-03-02 19:56:53
  • 数据库Oracle数据的异地的自动备份

    2010-07-27 13:28:00
  • Python实现爬虫设置代理IP和伪装成浏览器的方法分享

    2021-05-26 19:42:29
  • Flask项目的部署的实现步骤

    2023-08-11 17:59:58
  • Python数据类型之List列表实例详解

    2021-01-15 17:06:45
  • 各个系统下的Python解释器相关安装方法

    2022-10-09 00:24:34
  • 一篇文章教你掌握python数据类型的底层实现

    2023-06-01 03:41:14
  • 如何把IP表存到SQL数据库里去?

    2009-11-02 20:21:00
  • python docx的超链接网址和链接文本操作

    2021-06-05 22:10:55
  • asp之家 网络编程 m.aspxhome.com