博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
GAN——生成手写数字
阅读量:5253 次
发布时间:2019-06-14

本文共 5824 字,大约阅读时间需要 19 分钟。

 《Generative Adversarial Nets》是 GAN 系列的鼻祖。在这里通过 PyTorch 实现 GAN ,并且用于手写数字生成。

摘要: 我们提出了一个新的框架,通过对抗处理来评估生成模型。其中,我们同时训练两个 model :一个是生成模型 G,用于获取数据分布;另一个是判别模型 D,用来预测样本来自训练数据而不是生成模型 G 的概率。G 的训练过程是最大化 D 犯错的概率。这个框架对应于一个极小极大的二人游戏。在任意函数 G 和 D 的空间中,存在着一个唯一的解,G 恢复训练数据的分布而 D 一直等于1/2. 在 G 和 D 都由多层感知器定义的情况下,整个系统可以通过反向传播进行训练。  

 

import timeimport numpy as npimport torchimport torch.nn.functional as Ffrom torchvision import datasetsfrom torchvision import transformsimport torch.nn as nnfrom torch.utils.data import DataLoaderif torch.cuda.is_available():    torch.backends.cudnn.deterministic = True
要导入的包

 

########################### SETTINGS########################## Devicedevice = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")# Hyperparametersrandom_seed = 123generator_learning_rate = 0.001discriminator_learning_rate = 0.001num_epochs = 100batch_size = 128LATENT_DIM = 100IMG_SHAPE = (1, 28, 28)IMG_SIZE = 1for x in IMG_SHAPE:    IMG_SIZE *= x
设置超参数

 

########################### MNIST DATASET#########################train_dataset = datasets.MNIST(root='../data',                                train=True,                                transform=transforms.ToTensor(),                               download=True)test_dataset = datasets.MNIST(root='../data',                               train=False,                               transform=transforms.ToTensor())train_loader = DataLoader(dataset=train_dataset,                           batch_size=batch_size,                           shuffle=True)test_loader = DataLoader(dataset=test_dataset,                          batch_size=batch_size,                          shuffle=False)# Checking the datasetfor images, labels in train_loader:      print('Image batch dimensions:', images.shape)    print('Image label dimensions:', labels.shape)    break# 输出# Image batch dimensions: torch.Size([128, 1, 28, 28])# Image label dimensions: torch.Size([128])
加载MNIST数据集

 

################################ MODEL##############################class GAN(torch.nn.Module):        def __init__(self):        super(GAN, self).__init__()                self.generator = nn.Sequential(            nn.Linear(LATENT_DIM, 128),            nn.LeakyReLU(inplace=True),            nn.Dropout(p=0.5),            nn.Linear(128, IMG_SIZE),            nn.Tanh()        )                self.discriminator = nn.Sequential(            nn.Linear(IMG_SIZE, 128),            nn.LeakyReLU(inplace=True),            nn.Dropout(p=0.5),            nn.Linear(128, 1),            nn.Sigmoid()        )        def generator_forward(self, z):        img = self.generator(z)        return img        def discriminator_forward(self, img):        pred = model.discriminator(img)        return pred.view(-1)
GAN—Model

 

start_time = time.time()discr_costs = []gener_costs = []for epoch in range(num_epochs):    model = model.train()    for batch_idx, (features, targets) in enumerate(train_loader):                features = (features - 0.5) * 2.        features = features.view(-1, IMG_SIZE).to(device)        targets = targets.to(device)                # Adversarial ground truths        valid = torch.ones(targets.size(0)).float().to(device)        fake = torch.zeros(targets.size(0)).float().to(device)                ### FORWARD AND BACK PROP                # ---------------------        # Train Generator        # ---------------------                # make new images        z = torch.zeros((targets.size(0), LATENT_DIM)).uniform_(-1.0, 1.0).to(device)                # generate a batch of images        generated_features = model.generator_forward(z)                # Loss measures generators's ability to fool the discriminator        discr_pred = model.discriminator_forward(generated_features)               gener_loss = F.binary_cross_entropy(discr_pred, valid)                optim_gener.zero_grad()        gener_loss.backward()        optim_gener.step()                        # ---------------------        # Train Discriminator        # ---------------------                # Measure discriminator's ability to classify real from samples        discr_pred_real = model.discriminator_forward(features.view(-1, IMG_SIZE))        real_loss = F.binary_cross_entropy(discr_pred_real, valid)                discr_pred_fake = model.discriminator_forward(generated_features.detach())        fake_loss = F.binary_cross_entropy(discr_pred_fake, fake)                discr_loss = 0.5 * (real_loss + fake_loss)                optim_discr.zero_grad()        discr_loss.backward()        optim_discr.step()                discr_costs.append(discr_loss)        gener_costs.append(gener_loss)                ### LOGGING        if not batch_idx % 100:            print('Epoch: %03d/%03d | Batch %03d/%03d | Gen/Dis Loss: %.4f/%.4f'                 %(epoch+1, num_epochs, batch_idx, len(train_loader), gener_loss, discr_loss))            print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))        print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))
网络训练

 

画出 generator loss 和 discriminator loss 的变化图:

plt.plot(range(len(gener_costs)), gener_costs, label='generator loss')plt.plot(range(len(discr_costs)), discr_costs, label='discriminator loss')plt.legend()plt.savefig('./loss.jpg')plt.show()

利用以上训练的 Generator 生成一些仿手写数字图片:

########################### VISUALIZATION#########################model.eval()# Make new imagesz = torch.zeros((5, LATENT_DIM)).uniform_(-1.0, 1.0).to(device)generated_features = model.generator_forward(z)imgs = generated_features.view(-1, 28, 28)fig, axes = plt.subplots(nrows=1, ncols=5, figsize=(20, 2.5))for i, ax in enumerate(axes):    axes[i].imshow(imgs[i].detach().numpy(), cmap='binary')

再生成几次:

可以发现,以上生成的数字图片有些很清晰,但有些很模糊,不易辨认,但是结果已经让人很兴奋了~~

后续可以对GAN进行改进,从而生成质量更高的图片。

 

 

Reference

  [1]

  [2]  

 

 

转载于:https://www.cnblogs.com/xxxxxxxxx/p/11326956.html

你可能感兴趣的文章
ef codefirst VS里修改数据表结构后更新到数据库
查看>>
boost 同步定时器
查看>>
[ROS] Chinese MOOC || Chapter-4.4 Action
查看>>
简单的数据库操作
查看>>
iOS-解决iOS8及以上设置applicationIconBadgeNumber报错的问题
查看>>
亡灵序曲-The Dawn
查看>>
Redmine
查看>>
帧的最小长度 CSMA/CD
查看>>
xib文件加载后设置frame无效问题
查看>>
编程算法 - 左旋转字符串 代码(C)
查看>>
IOS解析XML
查看>>
Python3多线程爬取meizitu的图片
查看>>
树状数组及其他特别简单的扩展
查看>>
zookeeper适用场景:分布式锁实现
查看>>
110104_LC-Display(液晶显示屏)
查看>>
httpd_Vhosts文件的配置
查看>>
php学习笔记
查看>>
普通求素数和线性筛素数
查看>>
PHP截取中英文混合字符
查看>>
【洛谷P1816 忠诚】线段树
查看>>