泽兴芝士网

一站式 IT 编程学习资源平台

AI算法:生成对抗网络(GAN)原理与实现

一、算法思想

神经网络通常用来进行预测任务,比如给定一张图片预测所属类别,给定一组数据预测未来数值。能不能用其生成图片、文本或者语音。我们的目标:收集大量数据用来训练模型,从而生成与训练数据类似的新样本。GAN的核心思想:采用对抗机制从大量数据中训练模型,学习真实数据的分布,使训练后的模型能够产生真实数据分布中的样本,即能够生成之前不存在,却很真实的样本。

二、算法推导

1、GAN算法原理

我们的目标:给定一组向量生成与真实数据相似的图像。如图1所示,将固定长度的随机向量输入生成网络G1,输出生成图像,通过观察发现生成图像与真实存在差异。从数学上描述:生成图像与真实图像的分布不同。

由图1所示,生成网络生成的图像数据分布与真实数据分布不符。因此,G1网络需优化,通过训练调整参数得到图2所示G2网络,可发现通过调整后的G2网络所生成的图像与真实图像相似,即网络拟合了真实数据的分布。

什么是图像的分布?如图3,对于 32×32×3 大小的生成图像(RGB),所产生的像素组合为 。即生成图像总共有 种类别(像素值为0-255)。我们所看到的任何一张 32×32×3 大小的图像均来自上述组合。我们能够很容易识别“猫”与“狗”是因为两者在像素分布上有差异,也就是如果我们统计大量“猫”的图像以及“狗”的图像会发现他们在像素上呈现不同的分布。

如下图,橙色的椭圆为图像“狗”的真实数据空间,绿色的椭圆为“狗”生成数据空间。如果网络没有训练好,生成的数据分布与真实的数据分布存在差异。GAN训练目的:通过生成网络产生的数据分布逼近真实数据分布。理想状态下图中的绿色椭圆与橙色椭圆重合。

2、网络结构与训练

怎样使网络产生的数据分布逼近真实数据分布?GAN基于生成器与判别器采用对抗机制训练模型,使生成数据匹配真实数据。对抗机制是如何实现的?以真假钞票鉴别为例。

起初,“罪犯”生成了一批“假钞”,但由于生成的质量太低,一下被警方识破。生成器能力不够。

后续,“罪犯”通过观察真钞,不断学习,生成了新的钞票。同时警察也利用真钞与假钞不断学习,提升能力。但由于“罪犯”生成的钞票质量依旧偏低,再次被识破。

经过多轮学习,双方能力不断提升,最终生成的钞票,警方难辨真伪。

即双方在对抗学习中不断提升自身能力,达到的结果的是生成器生成的数据,判别器难辨真伪。

如图4,GAN网络结构由两部分构成:生成器与判别器

生成器:学习生成合理的数据。对于图像生成来说是给定一个向量,生成一张图片。其生成的数据作为判别器的负样本。

输入: n维随机向量;输出m×m×3大小的图像

判别器:判别输入是生成数据还是真实数据。网络输出越接近于0,生成数据可能性越大;反之,真实数据可能性越大。

输入:m×m×3大小的图像;输出:0-1,预测为真实图像概率。

生成器网络结构:生成器可为普通的卷积神经网络,其输入为一个随机的n维向量,输出为m×m×3大小的图片,如图5所示,典型的DCGAN-生成器为五层结构,分别为1层reshape以及4层卷积。输入为 100×1 ,输出为 64×64×3 。

判别器网络结构:判别器可为普通的卷积神经网络,其输入为m×m×3大小图像,输出为0-1概率值。如图6所示,典型的DCGAN-判别器为3层结构,2层卷积层以及1层全连接层。

判别器损失函数

其中,y_real=1,y_gen=0。

该损失为二分类交叉熵损失函数,前半部分,D(x)真实值越接近1,损失函数越小; 后半部分1-D(G(z))生成值越接近0,损失函数越小。判别器优化过程为最小化 J。

生成器损失函数:

生成器的优化过程为最小化 J,即最大化 D(G(z)),最大化判别器对生成数据的预测值。

3、算法实现

以手写数字生成为例,实现GAN网络。

import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import make_grid
from tqdm import tqdm

# General config
batch_size = 64

# Generator config
sample_size = 100  # Random value sample size
g_alpha = 0.01  # LeakyReLU alpha
g_lr = 1.0e-3  # Learning rate (higher than previous version)

# Discriminator config
d_alpha = 0.01  # LeakyReLU alpha
d_lr = 1.0e-4  # Learning rate

# DataLoader for MNIST
transform = transforms.ToTensor()
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, drop_last=True)


# Reshape helper
class Reshape(nn.Module):
    def __init__(self, *shape):
        super().__init__()
        self.shape = shape

    def forward(self, x):
        return x.reshape(-1, *self.shape)


# Generator network
class Generator(nn.Module):
    def __init__(self, sample_size: int, alpha: float):
        super().__init__()

        # sample_size => 784
        self.fc = nn.Sequential(
            nn.Linear(sample_size, 784),
            nn.BatchNorm1d(784),
            nn.LeakyReLU(alpha))

        # 784 => 16 x 7 x 7
        self.reshape = Reshape(16, 7, 7)

        # 16 x 7 x 7 => 32 x 14 x 14
        self.conv1 = nn.Sequential(
            nn.ConvTranspose2d(16, 32,
                               kernel_size=5, stride=2, padding=2,
                               output_padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(alpha))

        # 32 x 14 x 14 => 1 x 28 x 28
        self.conv2 = nn.Sequential(
            nn.ConvTranspose2d(32, 1,
                               kernel_size=5, stride=2, padding=2,
                               output_padding=1, bias=False),
            nn.Sigmoid())

        # Random value sample size
        self.sample_size = sample_size

    def forward(self, batch_size: int):
        # Generate random input values
        z = torch.randn(batch_size, self.sample_size)

        # Use transposed convolutions
        x = self.fc(z)  # => 784
        x = self.reshape(x)  # => 16 x 7 x 7
        x = self.conv1(x)  # => 32 x 14 x 14
        x = self.conv2(x)  # => 1 x 28 x 28
        return x


# Discriminator network
class Discriminator(nn.Module):
    def __init__(self, alpha: float):
        super().__init__()

        # 1 x 28 x 28 => 32 x 14 x 14
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 32,
                      kernel_size=5, stride=2, padding=2, bias=False),
            nn.LeakyReLU(alpha))

        # 32 x 14 x 14 => 16 x 7 x 7
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 16,
                      kernel_size=5, stride=2, padding=2, bias=False),
            nn.BatchNorm2d(16),
            nn.LeakyReLU(alpha))

        # 16 x 7 x 7 => 784
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(784, 784),
            nn.BatchNorm1d(784),
            nn.LeakyReLU(alpha),
            nn.Linear(784, 1))

    def forward(self, images: torch.Tensor, targets: torch.Tensor):
        # Extract image features using convolutions
        x = self.conv1(images)  # => 32 x 14 x 14
        x = self.conv2(x)  # => 16 x 7 x 7
        prediction = self.fc(x)  # => 1

        loss = F.binary_cross_entropy_with_logits(prediction, targets)
        return loss


# Save image grid
def save_image_grid(epoch: int, images: torch.Tensor, ncol: int):
    image_grid = make_grid(images, ncol)  # Images in a grid
    image_grid = image_grid.permute(1, 2, 0)  # Move channel last
    image_grid = image_grid.cpu().numpy()  # To Numpy

    plt.imshow(image_grid)
    plt.xticks([])
    plt.yticks([])
    plt.savefig(f'generated_{epoch:03d}.jpg')
    plt.close()


# Real and fake labels
real_targets = torch.ones(batch_size, 1)
fake_targets = torch.zeros(batch_size, 1)

# Generator and discriminator networks
generator = Generator(sample_size, g_alpha)
discriminator = Discriminator(d_alpha)

# Optimizers
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=d_lr)
g_optimizer = torch.optim.Adam(generator.parameters(), lr=g_lr)

# Training loop
for epoch in range(100):

    d_losses = []
    g_losses = []

    for images, labels in tqdm(dataloader):
        # ===============================
        # Discriminator training
        # ===============================

        # Loss with MNIST image inputs and real_targets as labels
        discriminator.train()
        d_loss = discriminator(images, real_targets)

        # Generate images in eval mode
        generator.eval()
        with torch.no_grad():
            generated_images = generator(batch_size)

        # Loss with generated image inputs and fake_targets as labels
        d_loss += discriminator(generated_images, fake_targets)

        # Optimizer updates the discriminator parameters
        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()

        # ===============================
        # Generator Network Training
        # ===============================

        # Generate images in train mode
        generator.train()
        generated_images = generator(batch_size)

        # batchnorm is unstable in eval due to generated images
        # change drastically every epoch. We'll not use the eval here.
        # discriminator.eval()

        # Loss with generated image inputs and real_targets as labels
        g_loss = discriminator(generated_images, real_targets)

        # Optimizer updates the generator parameters
        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()

        # Keep losses for logging
        d_losses.append(d_loss.item())
        g_losses.append(g_loss.item())

    # Print average losses
    print(epoch, np.mean(d_losses), np.mean(g_losses))

    # Save images
    save_image_grid(epoch, generator(batch_size), ncol=8)

第1轮生成值

第50轮生成值

第100轮生成值

随着训练轮数提升,GAN生成数据质量不断提升,与真实数据相似程度逐步提高。

控制面板
您好,欢迎到访网站!
  查看权限
网站分类
最新留言