您的位置:首页 > 科技 > IT业 > 太原市疫情最新消息今天_短链接生成器app_网络技术培训_汉中网站seo

太原市疫情最新消息今天_短链接生成器app_网络技术培训_汉中网站seo

2025/3/17 23:38:59 来源:https://blog.csdn.net/max500600/article/details/146302731  浏览:    关键词:太原市疫情最新消息今天_短链接生成器app_网络技术培训_汉中网站seo
太原市疫情最新消息今天_短链接生成器app_网络技术培训_汉中网站seo

以下是一个使用 PyTorch 实现 Conditional DCGAN(条件深度卷积生成对抗网络)进行图像到图像转换的示例代码。该代码包含训练和可视化部分,假设输入为图片和 4 个工艺参数,根据这些输入生成相应的图片。

1. 导入必要的库

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import numpy as np
import matplotlib.pyplot as plt# 检查是否有可用的 GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

2. 定义数据集类

class ImagePairDataset(Dataset):def __init__(self, image_pairs, params):self.image_pairs = image_pairsself.params = paramsdef __len__(self):return len(self.image_pairs)def __getitem__(self, idx):input_image, target_image = self.image_pairs[idx]param = self.params[idx]return input_image, target_image, param

3. 定义生成器和判别器

# 生成器
class Generator(nn.Module):def __init__(self, z_dim=4, img_channels=3):super(Generator, self).__init__()self.gen = nn.Sequential(# 输入: [batch_size, z_dim + 4, 1, 1]self._block(z_dim + 4, 1024, 4, 1, 0),  # [batch_size, 1024, 4, 4]self._block(1024, 512, 4, 2, 1),  # [batch_size, 512, 8, 8]self._block(512, 256, 4, 2, 1),  # [batch_size, 256, 16, 16]self._block(256, 128, 4, 2, 1),  # [batch_size, 128, 32, 32]nn.ConvTranspose2d(128, img_channels, kernel_size=4, stride=2, padding=1),nn.Tanh())def _block(self, in_channels, out_channels, kernel_size, stride, padding):return nn.Sequential(nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(True))def forward(self, z, params):params = params.view(params.size(0), 4, 1, 1)x = torch.cat([z, params], dim=1)return self.gen(x)# 判别器
class Discriminator(nn.Module):def __init__(self, img_channels=3):super(Discriminator, self).__init__()self.disc = nn.Sequential(# 输入: [batch_size, img_channels + 4, 64, 64]nn.Conv2d(img_channels + 4, 64, kernel_size=4, stride=2, padding=1),nn.LeakyReLU(0.2),self._block(64, 128, 4, 2, 1),  # [batch_size, 128, 16, 16]self._block(128, 256, 4, 2, 1),  # [batch_size, 256, 8, 8]self._block(256, 512, 4, 2, 1),  # [batch_size, 512, 4, 4]nn.Conv2d(512, 1, kernel_size=4, stride=2, padding=0),nn.Sigmoid())def _block(self, in_channels, out_channels, kernel_size, stride, padding):return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),nn.BatchNorm2d(out_channels),nn.LeakyReLU(0.2))def forward(self, img, params):params = params.view(params.size(0), 4, 1, 1).repeat(1, 1, img.size(2), img.size(3))x = torch.cat([img, params], dim=1)return self.disc(x)

4. 训练代码

def train_conditional_dcgan(image_pairs, params, batch_size=32, epochs=10, lr=0.0002, z_dim=4):dataset = ImagePairDataset(image_pairs, params)dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)gen = Generator(z_dim).to(device)disc = Discriminator().to(device)criterion = nn.BCELoss()opt_gen = optim.Adam(gen.parameters(), lr=lr, betas=(0.5, 0.999))opt_disc = optim.Adam(disc.parameters(), lr=lr, betas=(0.5, 0.999))for epoch in range(epochs):for i, (input_images, target_images, param) in enumerate(dataloader):input_images = input_images.to(device)target_images = target_images.to(device)param = param.to(device)# 训练判别器opt_disc.zero_grad()real_labels = torch.ones((target_images.size(0), 1, 1, 1)).to(device)fake_labels = torch.zeros((target_images.size(0), 1, 1, 1)).to(device)# 计算判别器对真实图像的损失real_output = disc(target_images, param)d_real_loss = criterion(real_output, real_labels)# 生成假图像z = torch.randn(target_images.size(0), z_dim, 1, 1).to(device)fake_images = gen(z, param)# 计算判别器对假图像的损失fake_output = disc(fake_images.detach(), param)d_fake_loss = criterion(fake_output, fake_labels)# 总判别器损失d_loss = d_real_loss + d_fake_lossd_loss.backward()opt_disc.step()# 训练生成器opt_gen.zero_grad()output = disc(fake_images, param)g_loss = criterion(output, real_labels)g_loss.backward()opt_gen.step()print(f'Epoch [{epoch+1}/{epochs}] D_loss: {d_loss.item():.4f} G_loss: {g_loss.item():.4f}')return gen

5. 可视化代码

def visualize_generated_images(gen, input_images, params, z_dim=4):input_images = input_images.to(device)params = params.to(device)z = torch.randn(input_images.size(0), z_dim, 1, 1).to(device)fake_images = gen(z, params).cpu().detach()fig, axes = plt.subplots(1, input_images.size(0), figsize=(15, 3))for i in range(input_images.size(0)):img = fake_images[i].permute(1, 2, 0).numpy()img = (img + 1) / 2  # 从 [-1, 1] 转换到 [0, 1]axes[i].imshow(img)axes[i].axis('off')plt.show()

6. 示例使用

# 假设 image_pairs 是一个包含图像对的列表,params 是一个包含 4 个工艺参数的列表
image_pairs = []  # 这里需要替换为实际的图像对数据
params = []  # 这里需要替换为实际的工艺参数数据# 训练模型
gen = train_conditional_dcgan(image_pairs, params)# 可视化生成的图像
test_input_images, test_target_images, test_params = image_pairs[:5], image_pairs[:5], params[:5]
test_input_images = torch.stack([torch.tensor(img) for img in test_input_images]).float()
test_params = torch.tensor(test_params).float()
visualize_generated_images(gen, test_input_images, test_params)

代码说明

  1. 数据集类ImagePairDataset 用于加载图像对和工艺参数。
  2. 生成器和判别器GeneratorDiscriminator 分别定义了生成器和判别器的网络结构。
  3. 训练代码train_conditional_dcgan 函数用于训练 Conditional DCGAN 模型。
  4. 可视化代码visualize_generated_images 函数用于可视化生成的图像。
  5. 示例使用:最后部分展示了如何使用上述函数进行训练和可视化。

请注意,你需要将 image_pairsparams 替换为实际的数据集。此外,代码中的超参数(如 batch_sizeepochslr 等)可以根据实际情况进行调整。

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com