2014年,蒙特利尔大学的Ian Goodfellow和他的同事们发表了一篇令人惊叹的论文,向全世界介绍了GAN(Generative Adversarial Networks)。
通过创新地将计算图谱和博弈论组合,他们展示,如果有足够的建模能力,两个相互对抗的模型将能够通过普通的反向传播进行共同训练。
自2014年Ian Goodfellow提出了GAN以来,业界对GAN的研究可谓如火如荼。各种GAN的变体不断涌现,下图是GAN相关论文的发表情况:
大牛Yann LeCun甚至评价GAN为 “Adversarial Training is the coolest thing since sliced bread”。
GAN的基本结构

GAN的两个Model分别扮演不同的角色(也就是对抗)。
给定一些真实的数据集R,G是
生成器(Generator)
,试图创建看起来像真实数据的假数据,而D是
判别器(Discriminator)
,从真实的集合或G中获取数据并标记差异。

Goodfellow举了一个生动的例子来说明GAN:G就像一群伪造者,试图将真实的绘画与他们的仿制的画输出相匹配,而D则是试图分辨出来的侦探团队。(除非在这种情况下,伪造者G永远不会看到原始数据,只有鉴别器D可以看到和判断——G就像是盲人伪造者。)
在理想的情况下,D和G都会随着时间的推移而变得更好,直到G基本上能够制造出让D区分不了的数据,G本质上成为真正物品的“主伪造者”,而D则不知所措,“无法区分这两种分布”。
在实践中,Goodfellow所展示的是G将能够对原始数据集执行一种无监督学习形式,找到某种方式以非常低维的方式表示该数据。
正如Yann LeCun所说:“无监督学习是真正人工智能的‘蛋糕’。绝大多数人类和动物的学习方式是非监督学习。如果智能是个蛋糕,非监督学习才是蛋糕主体,监督学习只能说是蛋糕上的糖霜奶油,而强化学习只是蛋糕上点缀的樱桃。现在我们知道如何制作“糖霜奶油”和上面的“樱桃”, 但并不知道如何制作蛋糕主体。我们必须先解决关于非监督学习的问题,才能开始考虑如何做出一个真正的AI。这还仅仅是我们所知的难题之一。更何况那些我们未知的难题呢?”
这种强大的技术似乎需要一吨的代码才能实现?NoNoNo,使用PyTorch,我们实际上可以在50行代码中创建一个非常简单的GAN。下面就是技术细节了,感兴趣的同学继续往下读。
我们实际上只有5个部分需要考虑:
R:原始的真实数据集
I:作为随机噪声放入到Generator里
G:尝试复制/模仿原始数据集的Generator(生成器)
D:试图将R的输出分开的Discriminator(鉴别器)
Loop:实际的“训练”循环,我们需要做的就是两部分:教G来欺骗D和教D辨别G.
1.)R:在我们的例子中,我们假设我们的R就是 Bell Curve。Bell Curve这个函数采用均值和标准差,并返回一个函数,该函数从高斯分布中提供随机产生n个数。在我们的示例代码中,我们将设置平均值(mu)为4.0,标准差(sigma)为1.25。
2.)I:生成器的输入也是随机的,但是为了使我们的代码稍微复杂一点,让我们使用uniform distribution(均匀分布)而不是normal distribution(正态分布)。这意味着我们的模型G不能简单地将输入移位/缩放到复制R,而是必须以非线性方式重塑数据。
3.)G:生成器是标准的feedforward graph - 两个hidden layers,三个linear mapping。我们正在使用tanh激活函数。G将从I获得均匀分布的数据样本,并以某种方式模拟来自R的正态分布样本,并且本身不会看到R(防止抄袭)。
4.)D:discriminator的代码与G的代码非常相似;前向传播包含两个隐藏层。这里的激活函数是sigmoid, 它将从R或G获取样本,并将输出0到1之间的单个标量,辨别为“假”与“真”。
5.)最后,训练循环在两种模式之间交替:首先在真实数据和假数据上训练D,校准label(将其视为警察学院);然后训练G用错误的label愚弄D,这就像是善恶之间的斗争。
即使你之前没有见过PyTorch,你也可能知道上图代码的结构。在第一个(绿色)部分,我们通过D同时训练两种类型的数据,并对D的猜测和实际的标签应用一个可微的标准。我们用'backward()'来计算导数(如果你还不明白为什么需要用导数优化机器学习算法,欢迎报名我们机器学习课程),然后用d_optimizer.step()来更新的D里面的参数。G被使用但并不在这里训练。
然后在最后(红色)部分,我们对G做同样的事情 - 注意我们也通过D运行G的输出(我们实际上是给伪造者一个侦探,辅助锻炼伪装者的伪装能力)但是我们没有在这一步优化或更改D。我们不希望侦探D学习错误的标签。因此,我们只调用g_optimizer.step()。
上面就是所有的过程。当然还有一些其他代码,但GAN主要就是这5个部分。
通过这大约50行的Pytorch代码,你已经搞定GAN的大致思路了,对比与那复杂的paper,我们是不是把一个很复杂的model很清晰地讲给了你~
今天的机器学习小课堂就到这里了,如果你对机器学习感兴趣,快来pick TFT机器学习小组课程。我们的硅谷专业研究团队,用最简单,最清晰的语言带你搞懂机器学习。
文章链接: https://medium.com/@devnag/generative-adversarial-networks-gans-in-50-lines-of-code-pytorch-e81b79659e3f
继续阅读
阅读原文