第一步:LYT-Net介绍
本文介绍了LYT-Net,即轻量级YUV Transformer 网络,作为一种新的低光图像增强方法。所提出的架构与传统的基于Retinex的模型不同,它利用YUV颜色空间对亮度(Y)和色度(U和V)的自然分离,简化了在图像中分离光和颜色信息的复杂任务。通过利用 Transformer 捕捉长距离依赖关系的优势,LYT-Net在保持降低模型复杂性的同时,确保了对图像的全面上下文理解。通过采用一种新颖的混合损失函数,LYT-Net在低光图像增强数据集上取得了最先进的结果,同时其体积比其他方法小得多。
LYT-Net采用了YUV色彩空间,这对LLIE来说尤其有利,因为它能将亮度(Y)和色度(U和V)明确分离。通过使用这个色彩空间,作者可以专门针对能在低光条件下提高图像可见性和细节的增强,而不会对颜色信息产生不利影响。由于人眼对亮度的变化更为敏感,因此专注于Y通道可以带来更自然、感知上更吸引人的增强效果。
作者的工作主要贡献可以概括为:
LYT-Net,一个轻量级模型,采用YUV颜色空间进行针对性增强。它在去噪后的亮度层和色度层上使用多头自注意力机制,旨在在处理过程的最后阶段实现更好的融合。
设计了一个混合损失函数,它在模型的高效训练中扮演了关键角色,并对模型的增强能力有显著贡献。
通过定量和定性的实验,LYT-Net在LOL数据集上与现有技术水平(SOTA)方法相比,已显示出强大的性能。
第二步:LYT-Net网络结构
作者展示了LYT-Net的整体架构。如图所示,该模型主要包括一个主要的YUV分解部分,以将色度与亮度分离,之后是几层及可分离的块,如多头自注意力(MHSA)块、多阶段挤压与激活融合(MSEF)块和通道去噪(CWD)块。作者采用双路径方法,将色度和亮度视为独立实体,以帮助模型更好地理解在光照调整和损坏恢复之间的差异。
该模型以RGB格式处理输入图像并将其转换为YUV。每个通道都通过一系列卷积层、池化操作以及MHSA机制单独增强。亮度通道经过卷积和池化提取特征,之后通过MHSA模块进行增强。色度通道和通过CWD块处理以降低噪声同时保留细节。增强后的色度通道被重新组合并通过MSEF块处理。最终,色度与亮度被连接起来,并通过最后一组卷积层生成输出,得到高质量的增强图像。
第三步:模型代码展示
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as initclass LayerNormalization(nn.Module):def __init__(self, dim):super(LayerNormalization, self).__init__()self.norm = nn.LayerNorm(dim)def forward(self, x):# Rearrange the tensor for LayerNorm (B, C, H, W) to (B, H, W, C)x = x.permute(0, 2, 3, 1)x = self.norm(x)# Rearrange back to (B, C, H, W)return x.permute(0, 3, 1, 2)class SEBlock(nn.Module):def __init__(self, input_channels, reduction_ratio=16):super(SEBlock, self).__init__()self.pool = nn.AdaptiveAvgPool2d(1)self.fc1 = nn.Linear(input_channels, input_channels // reduction_ratio)self.fc2 = nn.Linear(input_channels // reduction_ratio, input_channels)self._init_weights()def forward(self, x):batch_size, num_channels, _, _ = x.size()y = self.pool(x).reshape(batch_size, num_channels)y = F.relu(self.fc1(y))y = torch.tanh(self.fc2(y))y = y.reshape(batch_size, num_channels, 1, 1)return x * ydef _init_weights(self):init.kaiming_uniform_(self.fc1.weight, a=0, mode='fan_in', nonlinearity='relu')init.kaiming_uniform_(self.fc2.weight, a=0, mode='fan_in', nonlinearity='relu')init.constant_(self.fc1.bias, 0)init.constant_(self.fc2.bias, 0)class MSEFBlock(nn.Module):def __init__(self, filters):super(MSEFBlock, self).__init__()self.layer_norm = LayerNormalization(filters)self.depthwise_conv = nn.Conv2d(filters, filters, kernel_size=3, padding=1, groups=filters)self.se_attn = SEBlock(filters)self._init_weights()def forward(self, x):x_norm = self.layer_norm(x)x1 = self.depthwise_conv(x_norm)x2 = self.se_attn(x_norm)x_fused = x1 * x2x_out = x_fused + xreturn x_outdef _init_weights(self):init.kaiming_uniform_(self.depthwise_conv.weight, a=0, mode='fan_in', nonlinearity='relu')init.constant_(self.depthwise_conv.bias, 0)class MultiHeadSelfAttention(nn.Module):def __init__(self, embed_size, num_heads):super(MultiHeadSelfAttention, self).__init__()self.embed_size = embed_sizeself.num_heads = num_headsassert embed_size % num_heads == 0self.head_dim = embed_size // num_headsself.query_dense = nn.Linear(embed_size, embed_size)self.key_dense = nn.Linear(embed_size, embed_size)self.value_dense = nn.Linear(embed_size, embed_size)self.combine_heads = nn.Linear(embed_size, embed_size)self._init_weights()def split_heads(self, x, batch_size):x = x.reshape(batch_size, -1, self.num_heads, self.head_dim)return x.permute(0, 2, 1, 3)def forward(self, x):batch_size, _, height, width = x.size()x = x.reshape(batch_size, height * width, -1)query = self.split_heads(self.query_dense(x), batch_size)key = self.split_heads(self.key_dense(x), batch_size)value = self.split_heads(self.value_dense(x), batch_size)attention_weights = F.softmax(torch.matmul(query, key.transpose(-2, -1)) / (self.head_dim ** 0.5), dim=-1)attention = torch.matmul(attention_weights, value)attention = attention.permute(0, 2, 1, 3).contiguous().reshape(batch_size, -1, self.embed_size)output = self.combine_heads(attention)return output.reshape(batch_size, height, width, self.embed_size).permute(0, 3, 1, 2)def _init_weights(self):init.xavier_uniform_(self.query_dense.weight)init.xavier_uniform_(self.key_dense.weight)init.xavier_uniform_(self.value_dense.weight)init.xavier_uniform_(self.combine_heads.weight)init.constant_(self.query_dense.bias, 0)init.constant_(self.key_dense.bias, 0)init.constant_(self.value_dense.bias, 0)init.constant_(self.combine_heads.bias, 0)class Denoiser(nn.Module):def __init__(self, num_filters, kernel_size=3, activation='relu'):super(Denoiser, self).__init__()self.conv1 = nn.Conv2d(1, num_filters, kernel_size=kernel_size, padding=1)self.conv2 = nn.Conv2d(num_filters, num_filters, kernel_size=kernel_size, stride=2, padding=1)self.conv3 = nn.Conv2d(num_filters, num_filters, kernel_size=kernel_size, stride=2, padding=1)self.conv4 = nn.Conv2d(num_filters, num_filters, kernel_size=kernel_size, stride=2, padding=1)self.bottleneck = MultiHeadSelfAttention(embed_size=num_filters, num_heads=4)self.up2 = nn.Upsample(scale_factor=2, mode='nearest')self.up3 = nn.Upsample(scale_factor=2, mode='nearest')self.up4 = nn.Upsample(scale_factor=2, mode='nearest')self.output_layer = nn.Conv2d(1, 1, kernel_size=kernel_size, padding=1)self.res_layer = nn.Conv2d(num_filters, 1, kernel_size=kernel_size, padding=1)self.activation = getattr(F, activation)self._init_weights()def forward(self, x):x1 = self.activation(self.conv1(x))x2 = self.activation(self.conv2(x1))x3 = self.activation(self.conv3(x2))x4 = self.activation(self.conv4(x3))x = self.bottleneck(x4)x = self.up4(x)x = self.up3(x + x3)x = self.up2(x + x2)x = x + x1x = self.res_layer(x)return torch.tanh(self.output_layer(x + x))def _init_weights(self):for layer in [self.conv1, self.conv2, self.conv3, self.conv4, self.output_layer, self.res_layer]:init.kaiming_uniform_(layer.weight, a=0, mode='fan_in', nonlinearity='relu')if layer.bias is not None:init.constant_(layer.bias, 0)class LYT(nn.Module):def __init__(self, filters=32):super(LYT, self).__init__()self.process_y = self._create_processing_layers(filters)self.process_cb = self._create_processing_layers(filters)self.process_cr = self._create_processing_layers(filters)self.denoiser_cb = Denoiser(filters // 2)self.denoiser_cr = Denoiser(filters // 2)self.lum_pool = nn.MaxPool2d(8)self.lum_mhsa = MultiHeadSelfAttention(embed_size=filters, num_heads=4)self.lum_up = nn.Upsample(scale_factor=8, mode='nearest')self.lum_conv = nn.Conv2d(filters, filters, kernel_size=1, padding=0)self.ref_conv = nn.Conv2d(filters * 2, filters, kernel_size=1, padding=0)self.msef = MSEFBlock(filters)self.recombine = nn.Conv2d(filters * 2, filters, kernel_size=3, padding=1)self.final_adjustments = nn.Conv2d(filters, 3, kernel_size=3, padding=1)self._init_weights()def _create_processing_layers(self, filters):return nn.Sequential(nn.Conv2d(1, filters, kernel_size=3, padding=1),nn.ReLU(inplace=True))def _rgb_to_ycbcr(self, image):r, g, b = image[:, 0, :, :], image[:, 1, :, :], image[:, 2, :, :]y = 0.299 * r + 0.587 * g + 0.114 * bu = -0.14713 * r - 0.28886 * g + 0.436 * b + 0.5v = 0.615 * r - 0.51499 * g - 0.10001 * b + 0.5yuv = torch.stack((y, u, v), dim=1)return yuvdef forward(self, inputs):ycbcr = self._rgb_to_ycbcr(inputs)y, cb, cr = torch.split(ycbcr, 1, dim=1)cb = self.denoiser_cb(cb) + cbcr = self.denoiser_cr(cr) + cry_processed = self.process_y(y)cb_processed = self.process_cb(cb)cr_processed = self.process_cr(cr)ref = torch.cat([cb_processed, cr_processed], dim=1)lum = y_processedlum_1 = self.lum_pool(lum)lum_1 = self.lum_mhsa(lum_1)lum_1 = self.lum_up(lum_1)lum = lum + lum_1ref = self.ref_conv(ref)shortcut = refref = ref + 0.2 * self.lum_conv(lum)ref = self.msef(ref)ref = ref + shortcutrecombined = self.recombine(torch.cat([ref, lum], dim=1))output = self.final_adjustments(recombined)return torch.sigmoid(output)def _init_weights(self):for module in self.children():if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):init.kaiming_uniform_(module.weight, a=0, mode='fan_in', nonlinearity='relu')if module.bias is not None:init.constant_(module.bias, 0)
第四步:运行
第五步:整个工程的内容
项目完整文件下载请见演示与介绍视频的简介处给出:➷➷➷
PyTorch框架——基于深度学习LYT-Net神经网络AI低光图像增强系统源码_哔哩哔哩_bilibili