GAN系列之普通GAN

基础GAN详解

1.原理解析

在深度学习的大家庭中,可以将所有的模型粗略地分成两个派系,即生成式模型以及判别式模型,所谓的生成式模型,即是你输入一个sample,模型就会生成另外一个东西,比如最近大火的ChatGPT,Stable Diffusion,以及今天我们要讲的GAN等,都是生成式模型。
而所谓的判别式模型,就是你输入一个sample,模型会给你一个预测概率,表示该sample是某个label的可能性。

那么今天的重点内容,就是讲解GAN的原理。
简而言之,GAN是由两个部分组成的,一个部分叫做生成器,专门生成假的数据,而另外一个是判别器,专门来识别你给的数据是来自真实样本还是生成器的假样本。

他们两者相互博弈,直到达到所谓的纳什均衡

注: 什么是纳什均衡?
定义:博弈的所有参与人都为了满足自己的个人利益而选择牺牲集体利益而导致的全体参与人都吃亏的均衡状态。
在GAN中的意思就是,判别器既不能发现生成器生成的假数据是真数据还是假数据,而生成器也不能奈何判别器。

2. 算法流程

讲解完了原理,相信很多人还是一头雾水,那么我们再来过一遍算法流程。
请明确一点:生成器和判别器本质上是深度神经网络,所以不用关注它的模型架构,后面会有涉及,现在只需要知道它是神经网络

  1. 首先随机采样噪声z,输入到生成器G
  2. 生成器获得采样的噪声,然后就会生成一张图片
  3. 生成器为了愚弄判别器,要求判别器判断该张生成出的图片与真实图像的误差越大越好。越大,那么我就能够越成功愚弄判别器了。
  4. 判别器判断该张图片是真图片还是假图片,如果是真图片,就输出1,否则输出0
  5. 判别器为了不被愚弄,要判别生成器生成出的图片是假图片,真实的样本输出是真图片,让真实的误差越小越好。
  6. 直到判别器无法判别生成器生成的图片是真图片还是假图片为止,就达到了所谓的纳什均衡

公式如下:
GAN的误差函数

3. 代码解释

首先给出的是生成器的代码和判别器的代码:

#生成器代码
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(opt.latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *img_shape)
        return img

而判别器的代码是:

  #判别器
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)

        return validity

可以看到,两者都是最普通的神经网络结构,然后我们给出Loss定义,这里使用的是BCELoss。可以回看算法流程第三、第五步!

adversarial_loss = torch.nn.BCELoss()

然后我们来声明一下两个优化器,分别是生成器优化器,负责更新参数,让生成的图片越来越像真实的图像(即让真实图像和生成图像的误差越大越好)

optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

而判别器优化器,负责判别一张图像是生成的图像还是真实的图像,所以需要让判别器判别的结果和真实的标签误差越小越好。

optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

好了,接下来就到了重头戏了,训练过程!

for epoch in range(opt.n_epochs):
    for i, (imgs, _) in enumerate(dataloader):

        # Adversarial ground truths
        # 真实的数据标签,我们从真实的dataloader中拿数据,那么肯定都是真实的数据啦!
        # [[1],[1],...,[1]]
        valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)  #label
        # 那么这个是假标签,都是0
        # [[0],[0],...,[0]]
        fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)

        # Configure input
        # 我们获得一个真的图像
        real_imgs = Variable(imgs.type(Tensor))

        # -----------------
        #  Train Generator
        # -----------------
		# 梯度置空
        optimizer_G.zero_grad()

        # Sample noise as generator input(随机采样)
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))

        # Generate a batch of images
        # 生成器根据采样结果生成一张图像
        gen_imgs = generator(z)

        # Loss measures generator's ability to fool the discriminator
        # 我们为了愚弄判别器,要让判别器的结果和0越大越好,也就是说不能让判别器发现我们嘛
        # 因此反过来,就是要和1越小越好嘛
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)
		
		# 梯度更新
        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Measure discriminator's ability to classify real from generated samples
        # 这个很简单,判别器就是要发现哪些是真图像,哪些是假图像啦
        # 所以,要让真图像和1的loss越小越好,让假图像和0的loss越小越好
        real_loss = adversarial_loss(discriminator(real_imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()

The End