一、算法思想
神经网络通常用来进行预测任务,比如给定一张图片预测所属类别,给定一组数据预测未来数值。能不能用其生成图片、文本或者语音。我们的目标:收集大量数据用来训练模型,从而生成与训练数据类似的新样本。GAN的核心思想:采用对抗机制从大量数据中训练模型,学习真实数据的分布,使训练后的模型能够产生真实数据分布中的样本,即能够生成之前不存在,却很真实的样本。
二、算法推导
1、GAN算法原理
我们的目标:给定一组向量生成与真实数据相似的图像。如图1所示,将固定长度的随机向量输入生成网络G1,输出生成图像,通过观察发现生成图像与真实存在差异。从数学上描述:生成图像与真实图像的分布不同。
由图1所示,生成网络生成的图像数据分布与真实数据分布不符。因此,G1网络需优化,通过训练调整参数得到图2所示G2网络,可发现通过调整后的G2网络所生成的图像与真实图像相似,即网络拟合了真实数据的分布。
什么是图像的分布?如图3,对于 32×32×3 大小的生成图像(RGB),所产生的像素组合为 。即生成图像总共有 种类别(像素值为0-255)。我们所看到的任何一张 32×32×3 大小的图像均来自上述组合。我们能够很容易识别“猫”与“狗”是因为两者在像素分布上有差异,也就是如果我们统计大量“猫”的图像以及“狗”的图像会发现他们在像素上呈现不同的分布。
如下图,橙色的椭圆为图像“狗”的真实数据空间,绿色的椭圆为“狗”生成数据空间。如果网络没有训练好,生成的数据分布与真实的数据分布存在差异。GAN训练目的:通过生成网络产生的数据分布逼近真实数据分布。理想状态下图中的绿色椭圆与橙色椭圆重合。
2、网络结构与训练
怎样使网络产生的数据分布逼近真实数据分布?GAN基于生成器与判别器采用对抗机制训练模型,使生成数据匹配真实数据。对抗机制是如何实现的?以真假钞票鉴别为例。
起初,“罪犯”生成了一批“假钞”,但由于生成的质量太低,一下被警方识破。生成器能力不够。
后续,“罪犯”通过观察真钞,不断学习,生成了新的钞票。同时警察也利用真钞与假钞不断学习,提升能力。但由于“罪犯”生成的钞票质量依旧偏低,再次被识破。
经过多轮学习,双方能力不断提升,最终生成的钞票,警方难辨真伪。
即双方在对抗学习中不断提升自身能力,达到的结果的是生成器生成的数据,判别器难辨真伪。
如图4,GAN网络结构由两部分构成:生成器与判别器。
生成器:学习生成合理的数据。对于图像生成来说是给定一个向量,生成一张图片。其生成的数据作为判别器的负样本。
输入: n维随机向量;输出m×m×3大小的图像
判别器:判别输入是生成数据还是真实数据。网络输出越接近于0,生成数据可能性越大;反之,真实数据可能性越大。
输入:m×m×3大小的图像;输出:0-1,预测为真实图像概率。
生成器网络结构:生成器可为普通的卷积神经网络,其输入为一个随机的n维向量,输出为m×m×3大小的图片,如图5所示,典型的DCGAN-生成器为五层结构,分别为1层reshape以及4层卷积。输入为 100×1 ,输出为 64×64×3 。
判别器网络结构:判别器可为普通的卷积神经网络,其输入为m×m×3大小图像,输出为0-1概率值。如图6所示,典型的DCGAN-判别器为3层结构,2层卷积层以及1层全连接层。
判别器损失函数:
其中,y_real=1,y_gen=0。
该损失为二分类交叉熵损失函数,前半部分,D(x)真实值越接近1,损失函数越小; 后半部分1-D(G(z))生成值越接近0,损失函数越小。判别器优化过程为最小化 J。
生成器损失函数:
生成器的优化过程为最小化 J,即最大化 D(G(z)),最大化判别器对生成数据的预测值。
3、算法实现
以手写数字生成为例,实现GAN网络。
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import make_grid
from tqdm import tqdm
# General config
batch_size = 64
# Generator config
sample_size = 100 # Random value sample size
g_alpha = 0.01 # LeakyReLU alpha
g_lr = 1.0e-3 # Learning rate (higher than previous version)
# Discriminator config
d_alpha = 0.01 # LeakyReLU alpha
d_lr = 1.0e-4 # Learning rate
# DataLoader for MNIST
transform = transforms.ToTensor()
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, drop_last=True)
# Reshape helper
class Reshape(nn.Module):
def __init__(self, *shape):
super().__init__()
self.shape = shape
def forward(self, x):
return x.reshape(-1, *self.shape)
# Generator network
class Generator(nn.Module):
def __init__(self, sample_size: int, alpha: float):
super().__init__()
# sample_size => 784
self.fc = nn.Sequential(
nn.Linear(sample_size, 784),
nn.BatchNorm1d(784),
nn.LeakyReLU(alpha))
# 784 => 16 x 7 x 7
self.reshape = Reshape(16, 7, 7)
# 16 x 7 x 7 => 32 x 14 x 14
self.conv1 = nn.Sequential(
nn.ConvTranspose2d(16, 32,
kernel_size=5, stride=2, padding=2,
output_padding=1, bias=False),
nn.BatchNorm2d(32),
nn.LeakyReLU(alpha))
# 32 x 14 x 14 => 1 x 28 x 28
self.conv2 = nn.Sequential(
nn.ConvTranspose2d(32, 1,
kernel_size=5, stride=2, padding=2,
output_padding=1, bias=False),
nn.Sigmoid())
# Random value sample size
self.sample_size = sample_size
def forward(self, batch_size: int):
# Generate random input values
z = torch.randn(batch_size, self.sample_size)
# Use transposed convolutions
x = self.fc(z) # => 784
x = self.reshape(x) # => 16 x 7 x 7
x = self.conv1(x) # => 32 x 14 x 14
x = self.conv2(x) # => 1 x 28 x 28
return x
# Discriminator network
class Discriminator(nn.Module):
def __init__(self, alpha: float):
super().__init__()
# 1 x 28 x 28 => 32 x 14 x 14
self.conv1 = nn.Sequential(
nn.Conv2d(1, 32,
kernel_size=5, stride=2, padding=2, bias=False),
nn.LeakyReLU(alpha))
# 32 x 14 x 14 => 16 x 7 x 7
self.conv2 = nn.Sequential(
nn.Conv2d(32, 16,
kernel_size=5, stride=2, padding=2, bias=False),
nn.BatchNorm2d(16),
nn.LeakyReLU(alpha))
# 16 x 7 x 7 => 784
self.fc = nn.Sequential(
nn.Flatten(),
nn.Linear(784, 784),
nn.BatchNorm1d(784),
nn.LeakyReLU(alpha),
nn.Linear(784, 1))
def forward(self, images: torch.Tensor, targets: torch.Tensor):
# Extract image features using convolutions
x = self.conv1(images) # => 32 x 14 x 14
x = self.conv2(x) # => 16 x 7 x 7
prediction = self.fc(x) # => 1
loss = F.binary_cross_entropy_with_logits(prediction, targets)
return loss
# Save image grid
def save_image_grid(epoch: int, images: torch.Tensor, ncol: int):
image_grid = make_grid(images, ncol) # Images in a grid
image_grid = image_grid.permute(1, 2, 0) # Move channel last
image_grid = image_grid.cpu().numpy() # To Numpy
plt.imshow(image_grid)
plt.xticks([])
plt.yticks([])
plt.savefig(f'generated_{epoch:03d}.jpg')
plt.close()
# Real and fake labels
real_targets = torch.ones(batch_size, 1)
fake_targets = torch.zeros(batch_size, 1)
# Generator and discriminator networks
generator = Generator(sample_size, g_alpha)
discriminator = Discriminator(d_alpha)
# Optimizers
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=d_lr)
g_optimizer = torch.optim.Adam(generator.parameters(), lr=g_lr)
# Training loop
for epoch in range(100):
d_losses = []
g_losses = []
for images, labels in tqdm(dataloader):
# ===============================
# Discriminator training
# ===============================
# Loss with MNIST image inputs and real_targets as labels
discriminator.train()
d_loss = discriminator(images, real_targets)
# Generate images in eval mode
generator.eval()
with torch.no_grad():
generated_images = generator(batch_size)
# Loss with generated image inputs and fake_targets as labels
d_loss += discriminator(generated_images, fake_targets)
# Optimizer updates the discriminator parameters
d_optimizer.zero_grad()
d_loss.backward()
d_optimizer.step()
# ===============================
# Generator Network Training
# ===============================
# Generate images in train mode
generator.train()
generated_images = generator(batch_size)
# batchnorm is unstable in eval due to generated images
# change drastically every epoch. We'll not use the eval here.
# discriminator.eval()
# Loss with generated image inputs and real_targets as labels
g_loss = discriminator(generated_images, real_targets)
# Optimizer updates the generator parameters
g_optimizer.zero_grad()
g_loss.backward()
g_optimizer.step()
# Keep losses for logging
d_losses.append(d_loss.item())
g_losses.append(g_loss.item())
# Print average losses
print(epoch, np.mean(d_losses), np.mean(g_losses))
# Save images
save_image_grid(epoch, generator(batch_size), ncol=8)
第1轮生成值
第50轮生成值
第100轮生成值
随着训练轮数提升,GAN生成数据质量不断提升,与真实数据相似程度逐步提高。