• 18834196600
如何在MNIST上构建条件龙采科技生成式对抗网络(CGAN)?
作者:龙采科技 / 2016-08-14 11:55 / 浏览次数:

图:pixabay

本教程将介绍如何在MNIST图像上构建和训练条件生成式对抗网络(CGAN)。

GAN如何进行工作的

一般来说,生成式对抗模型是同时训练两个模型的:一个是学习从未知分布中输出假样本的生成器,而另一个是学习区分真假样本的鉴别器。

CGAN是GAN的条件变体,其中生成器被指示生成具有特定特征的真实样本,而不是来自完全分布的通用样本。这样的条件可以是与本教程中的图像相关联的标签或者是更为详细的标签,如下图示例所示:

展开全文

图片来源:Scott Reed

初始设置

运行本教程需要以下软件包:

完整的演示是由GitHub提供的以下两个脚本组成的:

CGAN_mnist_setup.R:准备数据并定义模型结构

CGAN_train.R:执行训练操作

准备数据

我们需要的MNIST数据集可在Kaggle上获得。一旦我们将train.csv下载到数据/文件夹后,我们就可以将其导入到R中去。

自定义迭代器在iterators.R中定义,并由CGAN_mnist_setup.R导入。

生成器

生成器是一个从2个输入中创建新样本(MNIST图像)的网络:

•噪声矢量

•定义对象条件的标签(要生成哪个数字)

噪声矢量为Generator模型提供了构建块,它将学习如何将噪声结构化为样本。mx.symbol.Deconvolution操作符用于将初始输入从1x1形状向上采样到28x28图像。

用于生成假样本的标签上的信息是由附加到随机噪声的标签索引的独热编码(one-hot encoding)来提供的。对于MNIST来说,0-9索引因此被转换为长度为10的二进制向量。更复杂的应用将需要的是嵌入而不是简单的单向编码来编码条件。

鉴别器

鉴别器尝试区分生成器产生的假样本和从MNIST训练数据中抽取的真实样本。

在条件式GAN中,与样品相关联的标签也被提供给鉴别器。而在此次的演示中,这些信息将作为一个独热的编码标签,以便传播从而匹配图像的尺寸(10 - >28x28x10)。

训练逻辑

鉴别器的训练过程是最为明显的:损耗就是一个简单的二进制TRUE / FALSE响应,而且损耗可以传播回CNN网络。因此它可以理解为一个简单的二进制分类问题。

生成器损耗来自鉴别器损耗反向传播到其产生的输出。通过将生成器标签伪装成真实样本进入到鉴别器中,鉴别器反向传播损耗为生成器提供了如何最佳地调整其参数,从而欺骗鉴别器相信假样本是真实的信息。

这需要将梯度反向传播到鉴别器的输入数据中(而在普通前馈网络中通常忽略该输入梯度)。

上述训练步骤在CGAN_train.R脚本中执行。

监督训练

在训练期间,相机包(imager package)可以方便进行假样本的视觉质量评估。

以下是在不同训练阶段获得的样本。

从噪音开始:

慢慢地得到下面这个结果——迭代200:

根据需要生成指定的数字图像——迭代2400:

推理

一旦模型被训练,可以通过用固定标签而不是训练期间使用的随机生成的图像馈送到生成器来产生所需数字的合成图像。

在这里我们会产生假的“9”:

CGAN方法的进一步细节可以在Generative Adversarial Text to Image Synthesis论文中找到。

作者:Jeremie Desgagne-Bouchard

来源:返回搜狐,查看更多

【龙采业务】网站建设、网站设计、服务器空间租售、网站维护、网站托管、网站优化、百度推广、自媒体营销、微信公众号
如有意向---联系我们
热门栏目
热门资讯
热门标签

网站建设 网站托管 成功案例 新闻动态 关于我们 联系我们 服务器空间 加盟合作 网站优化

备案号:晋ICP备19000634号 

公司地址:山西省太原市南中环街清控创新基地 咨询QQ:1715793209 手机:18834196600 电话:0351-2371270