import argparse
import os
import numpy as npimport torchvision.transforms as transforms
from torchvision.utils import save_imagefrom torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variableimport torch.nn as nn
import torchos.makedirs("images", exist_ok=True)parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--n_classes", type=int, default=10, help="number of classes for dataset")
parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=400, help="interval between image sampling")
opt = parser.parse_args()
print(opt)img_shape = (opt.channels, opt.img_size, opt.img_size)cuda = True if torch.cuda.is_available() else Falseclass Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.label_emb = nn.Embedding(opt.n_classes, opt.n_classes)def block(in_feat, out_feat, normalize=True):layers = [nn.Linear(in_feat, out_feat)]if normalize:layers.append(nn.BatchNorm1d(out_feat, 0.8))layers.append(nn.LeakyReLU(0.2, inplace=True))return layersself.model = nn.Sequential(*block(opt.latent_dim + opt.n_classes, 128, normalize=False),*block(128, 256),*block(256, 512),*block(512, 1024),nn.Linear(1024, int(np.prod(img_shape))),nn.Tanh())def forward(self, noise, labels):# Concatenate label embedding and image to produce inputgen_input = torch.cat((self.label_emb(labels), noise), -1) #cat(64*10, 64*100)->(64,110)img = self.model(gen_input)img = img.view(img.size(0), *img_shape)return imgclass Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.label_embedding = nn.Embedding(opt.n_classes, opt.n_classes)self.model = nn.Sequential(nn.Linear(opt.n_classes + int(np.prod(img_shape)), 512), #784+10nn.LeakyReLU(0.2, inplace=True),nn.Linear(512, 512),nn.Dropout(0.4),nn.LeakyReLU(0.2, inplace=True),nn.Linear(512, 512),nn.Dropout(0.4),nn.LeakyReLU(0.2, inplace=True),nn.Linear(512, 1),nn.Sigmoid(),)def forward(self, img, labels):# Concatenate label embedding and image to produce inputd_in = torch.cat((img.view(img.size(0), -1), self.label_embedding(labels)), -1) validity = self.model(d_in)return validity# Loss functions
adversarial_loss = torch.nn.MSELoss()# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()if cuda:generator.cuda()discriminator.cuda()adversarial_loss.cuda()# Configure data loader
os.makedirs("../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(datasets.MNIST("../data/mnist",train=True,download=True,transform=transforms.Compose([transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]),),batch_size=opt.batch_size,shuffle=True,
)# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensordef sample_image(n_row, batches_done):"""Saves a grid of generated digits ranging from 0 to n_classes"""# Sample noisez = Variable(FloatTensor(np.random.normal(0, 1, (n_row ** 2, opt.latent_dim))))# Get labels ranging from 0 to n_classes for n rowslabels = np.array([num for _ in range(n_row) for num in range(n_row)])labels = Variable(LongTensor(labels))gen_imgs = generator(z, labels)save_image(gen_imgs.data, "images/%d.png" % batches_done, nrow=n_row, normalize=True)# ----------
# Training
# ----------for epoch in range(opt.n_epochs):for i, (imgs, labels) in enumerate(dataloader):batch_size = imgs.shape[0]# Adversarial ground truthsvalid = Variable(FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False)fake = Variable(FloatTensor(batch_size, 1).fill_(0.0), requires_grad=False)# Configure inputreal_imgs = Variable(imgs.type(FloatTensor))labels = Variable(labels.type(LongTensor))# -----------------# Train Generator# -----------------optimizer_G.zero_grad()# Sample noise and labels as generator inputz = Variable(FloatTensor(np.random.normal(0, 1, (batch_size, opt.latent_dim))))gen_labels = Variable(LongTensor(np.random.randint(0, opt.n_classes, batch_size)))#0到9,生成64个 # Generate a batch of imagesgen_imgs = generator(z, gen_labels)# Loss measures generator's ability to fool the discriminatorvalidity = discriminator(gen_imgs, gen_labels)g_loss = adversarial_loss(validity, valid)g_loss.backward()optimizer_G.step()# ---------------------# Train Discriminator# ---------------------optimizer_D.zero_grad()# Loss for real imagesvalidity_real = discriminator(real_imgs, labels)d_real_loss = adversarial_loss(validity_real, valid)# Loss for fake imagesvalidity_fake = discriminator(gen_imgs.detach(), gen_labels)d_fake_loss = adversarial_loss(validity_fake, fake)# Total discriminator lossd_loss = (d_real_loss + d_fake_loss) / 2d_loss.backward()optimizer_D.step()print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"% (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item()))batches_done = epoch * len(dataloader) + iif batches_done % opt.sample_interval == 0:sample_image(n_row=10, batches_done=batches_done)
- nn.Embedding(),将离散数据转换成连续向量,比如0,1,2…离散的数字,用下面的向量表示。
torch.tensor([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9], [1.0, 1.1, 1.2], [1.3, 1.4, 1.5]])
nn.Embedding(num_embeddings=10, embedding_dim=3) 本质上就是随机的创建了一个 10 行 3 列的矩阵,可以理解为一个查找表(lookup table),形状是 (10, 3)。可以将0,1,2…9转化成10个向量。可以通过embedding.weight
查看这些向量,同时还可以通过索引找到第几行:如下图,这会返回 embedding.weight
矩阵中的第 2 行和第 5 行。
input_ids = torch.tensor([2, 5]) # 选择第 2 和第 5 号索引
output = embedding(input_ids)
print(output)
第一个参数num_embeddings意思是有多少个类别,可能后期要用batch_size个向量,不过这些向量永远是这10类,第二个参数embedding_dim就是一个向量有几个数(维度)。
2. 生成器的输入是两个:
z = Variable(FloatTensor(np.random.normal(0, 1, (batch_size, opt.latent_dim))))
gen_labels = Variable(LongTensor(np.random.randint(0, opt.n_classes, batch_size)))
z.shape = (64, 100) (随机噪声,64 个 100 维向量)
gen_labels.shape = (64,) (64 个类别索引,每个值在 0~9 之间)shape of z = torch.Size([64, 100]) shape of real_imgs = torch.Size([64, 1, 28, 28]) z = tensor([[-0.2021, -0.6528, -0.6111, ..., -0.5988, 0.0187, 0.8311],[ 0.2402, 1.1745, 0.4431, ..., -0.1055, -0.1356, -0.5389],[-0.8425, -1.3124, 0.9545, ..., 0.8020, -0.1754, -0.5615],...,[ 0.2027, -0.8791, -0.9138, ..., 1.0122, -1.0658, 1.1842],[ 0.5115, -0.1609, 0.0903, ..., 1.3818, 1.7254, 0.6183],[ 1.4386, 0.0568, -0.8814, ..., 0.8862, 0.3396, 0.8465]],device='cuda:0') gen_labels = tensor([3, 9, 8, 8, 3, 6, 0, 9, 4, 2, 2, 2, 5, 5, 0, 8, 1, 0, 9, 3, 8, 1, 7, 7,8, 5, 8, 8, 4, 2, 5, 5, 0, 1, 3, 2, 3, 3, 5, 2, 9, 7, 7, 9, 3, 4, 9, 2,3, 3, 2, 2, 8, 3, 7, 5, 8, 3, 0, 2, 1, 4, 8, 1], device='cuda:0')
接着:
self.label_emb = nn.Embedding(opt.n_classes, opt.n_classes)
这里 nn.Embedding(10, 10) 是一个查找表:
- 输入:类别索引 labels.shape = (64,)
- 输出:类别的嵌入向量 shape = (64, 10)
gen_input = torch.cat((self.label_emb(labels), noise), -1)#-1:最后1维的方向拼接
64是bitch_size,也就是说,来了64个原材料,每一个原材料都和一个向量进行拼接,向量本是随机的,本没有意义,但是他是索引3对应的向量,这个向量就为生成3提供了暗示,暗示生成器要生成3。
给个例子:(这里的label_emb用独热表示,便于观看)
self.label_emb(labels) = [[0, 0, 0, 1, 0, 0, 0, 0, 0, 0], # 类别 3[0, 0, 0, 0, 0, 0, 0, 1, 0, 0], # 类别 7...
]z = [[ 0.1, -0.5, ..., 0.3], # 第 1 张图片的噪声向量[-0.2, 0.7, ..., -0.1], # 第 2 张图片的噪声向量...
]gen_input = torch.cat((self.label_emb(labels), z), -1) = [[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0.1, -0.5, ..., 0.3], # (10 + 100 = 110)[0, 0, 0, 0, 0, 0, 0, 1, 0, 0, -0.2, 0.7, ..., -0.1], # (10 + 100 = 110)...
] # 形状: (64, 110)
接着就是和GAN一样输入生成器中:
img = self.model(gen_input)
- 同理,判别器,其中注意,这两句是判别器的输入,labels的形状是和上一个一样是(64,)
real_imgs = Variable(imgs.type(FloatTensor))#(64, 1, 28, 28)labels = Variable(labels.type(LongTensor))#(64,)d_in = torch.cat((img.view(img.size(0), -1), self.label_embedding(labels)), -1) #(64, 784),(64, 10)-> (64, 794)
CGAN相较原版GAN效果有了显著提升: