Pytorch实现WGAN用于动漫头像生成

作者:不佛 时间:2023-07-24 22:31:26 

WGAN与GAN的不同

  • 去除sigmoid

  • 使用具有动量的优化方法,比如使用RMSProp

  • 要对Discriminator的权重做修整限制以确保lipschitz连续约

WGAN实战卷积生成动漫头像 


import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import os
from anime_face_generator.dataset import ImageDataset

batch_size = 32
num_epoch = 100
z_dimension = 100
dir_path = './wgan_img'

# 创建文件夹
if not os.path.exists(dir_path):
 os.mkdir(dir_path)

def to_img(x):
 """因为我们在生成器里面用了tanh"""
 out = 0.5 * (x + 1)
 return out

dataset = ImageDataset()
dataloader = DataLoader(dataset, batch_size=32, shuffle=False)

class Generator(nn.Module):
 def __init__(self):
   super().__init__()

self.gen = nn.Sequential(
     # 输入是一个nz维度的噪声,我们可以认为它是一个1*1*nz的feature map
     nn.ConvTranspose2d(100, 512, 4, 1, 0, bias=False),
     nn.BatchNorm2d(512),
     nn.ReLU(True),
     # 上一步的输出形状:(512) x 4 x 4
     nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
     nn.BatchNorm2d(256),
     nn.ReLU(True),
     # 上一步的输出形状: (256) x 8 x 8
     nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
     nn.BatchNorm2d(128),
     nn.ReLU(True),
     # 上一步的输出形状: (256) x 16 x 16
     nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
     nn.BatchNorm2d(64),
     nn.ReLU(True),
     # 上一步的输出形状:(256) x 32 x 32
     nn.ConvTranspose2d(64, 3, 5, 3, 1, bias=False),
     nn.Tanh() # 输出范围 -1~1 故而采用Tanh
     # nn.Sigmoid()
     # 输出形状:3 x 96 x 96
   )

def forward(self, x):
   x = self.gen(x)
   return x

def weight_init(m):
   # weight_initialization: important for wgan
   class_name = m.__class__.__name__
   if class_name.find('Conv') != -1:
     m.weight.data.normal_(0, 0.02)
   elif class_name.find('Norm') != -1:
     m.weight.data.normal_(1.0, 0.02)

class Discriminator(nn.Module):
 def __init__(self):
   super().__init__()
   self.dis = nn.Sequential(
     nn.Conv2d(3, 64, 5, 3, 1, bias=False),
     nn.LeakyReLU(0.2, inplace=True),
     # 输出 (64) x 32 x 32

nn.Conv2d(64, 128, 4, 2, 1, bias=False),
     nn.BatchNorm2d(128),
     nn.LeakyReLU(0.2, inplace=True),
     # 输出 (128) x 16 x 16

nn.Conv2d(128, 256, 4, 2, 1, bias=False),
     nn.BatchNorm2d(256),
     nn.LeakyReLU(0.2, inplace=True),
     # 输出 (256) x 8 x 8

nn.Conv2d(256, 512, 4, 2, 1, bias=False),
     nn.BatchNorm2d(512),
     nn.LeakyReLU(0.2, inplace=True),
     # 输出 (512) x 4 x 4

nn.Conv2d(512, 1, 4, 1, 0, bias=False),
     nn.Flatten(),
     # nn.Sigmoid() # 输出一个数(概率)
   )

def forward(self, x):
   x = self.dis(x)
   return x

def weight_init(m):
   # weight_initialization: important for wgan
   class_name = m.__class__.__name__
   if class_name.find('Conv') != -1:
     m.weight.data.normal_(0, 0.02)
   elif class_name.find('Norm') != -1:
     m.weight.data.normal_(1.0, 0.02)

def save(model, filename="model.pt", out_dir="out/"):
 if model is not None:
   if not os.path.exists(out_dir):
     os.mkdir(out_dir)
   torch.save({'model': model.state_dict()}, out_dir + filename)
 else:
   print("[ERROR]:Please build a model!!!")

import QuickModelBuilder as builder

if __name__ == '__main__':
 one = torch.FloatTensor([1]).cuda()
 mone = -1 * one

is_print = True
 # 创建对象
 D = Discriminator()
 G = Generator()
 D.weight_init()
 G.weight_init()

if torch.cuda.is_available():
   D = D.cuda()
   G = G.cuda()

lr = 2e-4
 d_optimizer = torch.optim.RMSprop(D.parameters(), lr=lr, )
 g_optimizer = torch.optim.RMSprop(G.parameters(), lr=lr, )
 d_scheduler = torch.optim.lr_scheduler.ExponentialLR(d_optimizer, gamma=0.99)
 g_scheduler = torch.optim.lr_scheduler.ExponentialLR(g_optimizer, gamma=0.99)

fake_img = None

# ##########################进入训练##判别器的判断过程#####################
 for epoch in range(num_epoch): # 进行多个epoch的训练
   pbar = builder.MyTqdm(epoch=epoch, maxval=len(dataloader))
   for i, img in enumerate(dataloader):
     num_img = img.size(0)
     real_img = img.cuda() # 将tensor变成Variable放入计算图中
     # 这里的优化器是D的优化器
     for param in D.parameters():
       param.requires_grad = True
     # ########判别器训练train#####################
     # 分为两部分:1、真的图像判别为真;2、假的图像判别为假

# 计算真实图片的损失
     d_optimizer.zero_grad() # 在反向传播之前,先将梯度归0
     real_out = D(real_img) # 将真实图片放入判别器中
     d_loss_real = real_out.mean(0).view(1)
     d_loss_real.backward(one)

# 计算生成图片的损失
     z = torch.randn(num_img, z_dimension).cuda() # 随机生成一些噪声
     z = z.reshape(num_img, z_dimension, 1, 1)
     fake_img = G(z).detach() # 随机噪声放入生成网络中,生成一张假的图片。 # 避免梯度传到G,因为G不用更新, detach分离
     fake_out = D(fake_img) # 判别器判断假的图片,
     d_loss_fake = fake_out.mean(0).view(1)
     d_loss_fake.backward(mone)

d_loss = d_loss_fake - d_loss_real
     d_optimizer.step() # 更新参数

# 每次更新判别器的参数之后把它们的绝对值截断到不超过一个固定常数c=0.01
     for parm in D.parameters():
       parm.data.clamp_(-0.01, 0.01)

# ==================训练生成器============================
     # ###############################生成网络的训练###############################
     for param in D.parameters():
       param.requires_grad = False

# 这里的优化器是G的优化器,所以不需要冻结D的梯度,因为不是D的优化器,不会更新D
     g_optimizer.zero_grad() # 梯度归0

z = torch.randn(num_img, z_dimension).cuda()
     z = z.reshape(num_img, z_dimension, 1, 1)
     fake_img = G(z) # 随机噪声输入到生成器中,得到一副假的图片
     output = D(fake_img) # 经过判别器得到的结果
     # g_loss = criterion(output, real_label) # 得到的假的图片与真实的图片的label的loss
     g_loss = torch.mean(output).view(1)
     # bp and optimize
     g_loss.backward(one) # 进行反向传播
     g_optimizer.step() # .step()一般用在反向传播后面,用于更新生成网络的参数

# 打印中间的损失
     pbar.set_right_info(d_loss=d_loss.data.item(),
               g_loss=g_loss.data.item(),
               real_scores=real_out.data.mean().item(),
               fake_scores=fake_out.data.mean().item(),
               )
     pbar.update()
     try:
       fake_images = to_img(fake_img.cpu())
       save_image(fake_images, dir_path + '/fake_images-{}.png'.format(epoch + 1))
     except:
       pass
     if is_print:
       is_print = False
       real_images = to_img(real_img.cpu())
       save_image(real_images, dir_path + '/real_images.png')
   pbar.finish()
   d_scheduler.step()
   g_scheduler.step()
   save(D, "wgan_D.pt")
   save(G, "wgan_G.pt")

来源:https://blog.csdn.net/bu_fo/article/details/109808354

标签:Pytorch,WGAN,动漫头像
0
投稿

猜你喜欢

  • 详解PHP如何更好的利用PHPstorm的自动提示

    2024-05-22 10:05:30
  • 深入解析Python的Tornado框架中内置的模板引擎

    2022-08-03 12:43:52
  • 详解Laravel服务容器的优势

    2023-10-31 03:36:04
  • python使用KNN算法识别手写数字

    2022-02-20 10:48:23
  • 你知道吗实现炫酷可视化只要1行python代码

    2022-06-10 13:36:16
  • Python实现图片查找轮廓、多边形拟合、最小外接矩形代码

    2021-03-27 05:34:56
  • django的model操作汇整详解

    2022-05-16 03:59:46
  • Django分页查询并返回jsons数据(中文乱码解决方法)

    2022-12-02 22:44:20
  • Python之关于类变量的两种赋值区别详解

    2021-09-08 08:05:26
  • MySQL拼接字符串函数GROUP_CONCAT详解

    2024-01-27 18:21:56
  • PHP无限分类代码,支持数组格式化、直接输出菜单两种方式

    2024-05-13 09:24:51
  • 浅析JS原始值和引用值问题

    2024-04-28 09:33:17
  • PHP生成饼图的示例代码

    2023-05-25 10:24:09
  • 认识Javascript数组

    2009-08-27 15:26:00
  • 解决el-tree节点过滤不显示下级的问题

    2024-05-29 22:24:03
  • Django框架静态文件处理、中间件、上传文件操作实例详解

    2021-03-07 14:31:04
  • php通过隐藏表单控件获取到前两个页面的url

    2023-11-16 04:00:08
  • Go语言sync.Cond基本使用及原理示例详解

    2023-06-28 07:09:01
  • pytorch Dataset,DataLoader产生自定义的训练数据案例

    2022-12-05 06:41:23
  • Python调用工具包实现发送邮件服务

    2023-08-30 02:25:16
  • asp之家 网络编程 m.aspxhome.com