您的位置:首页 > 房产 > 家装 > 利用clip模型实现text2draw

利用clip模型实现text2draw

2024/12/23 16:54:04 来源:https://blog.csdn.net/baidu_39211350/article/details/141718372  浏览:    关键词:利用clip模型实现text2draw

参考论文

实践

有数据增强的代码

import math
import collections
import CLIP_.clip as clip
import torch
import torch.nn as nn
from torchvision import models, transforms
import numpy as np
import webp
from PIL import Image
import skimage
import torchvision
import pydiffvg
import os
import torch.nn.functional as Fclass GeometrymatchLoss(torch.nn.Module):def __init__(self, device, reference_images_path):super(GeometrymatchLoss, self).__init__()self.device = deviceself.model, clip_preprocess = clip.load('ViT-B/32', self.device, jit=False)self.model.eval()self.preprocess = transforms.Compose([clip_preprocess.transforms[0], clip_preprocess.transforms[-1]])  # clip normalisationself.reference_images_feature = self.reference_images_feature(reference_images_path)self.reference_images_feature =self.reference_images_feature/ self.reference_images_feature.norm(dim=-1, keepdim=True)self.text = clip.tokenize([ "A picture of triangle"]).to(device)self.text_features = self.model.encode_text(self.text)# self.text_features = self.text_features / self.text_features.norm(dim=-1, keepdim=True)print("text_features.requires_grad:",self.text_features.requires_grad)self.text_features=self.text_features.detach()self.shape_groups=[pydiffvg.ShapeGroup(shape_ids=torch.tensor([0]), fill_color=torch.tensor([0.0, 0.0, 0.0, 1.0]),stroke_color=torch.tensor([0.0, 0.0, 0.0, 1.0]))]# Image Augmentation Transformationself.augment_trans = transforms.Compose([transforms.RandomPerspective(fill=1, p=1, distortion_scale=0.5),transforms.RandomResizedCrop(224, scale=(0.7, 0.9)),])def forward(self, t,canvas_width, canvas_height,shapes):scene_args = pydiffvg.RenderFunction.serialize_scene(canvas_width, canvas_height, shapes, self.shape_groups)# 渲染图像render = pydiffvg.RenderFunction.applytarget = render(canvas_width, canvas_height, 2, 2, 0, None, *scene_args)if target.shape[-1] == 4:target = self.compose_image_with_white_background(target)if t%100==0:pydiffvg.imwrite(target.cpu(), f'learn/log_augs/output_{t}.png', gamma=2.2)# targets_ = self.preprocess(target.permute(2, 0, 1).unsqueeze(0)).to(self.device)img = target.unsqueeze(0)img = img.permute(0, 3, 1, 2)loss = 0NUM_AUGS = 4img_augs = []for n in range(NUM_AUGS):img_augs.append(self.augment_trans(img))im_batch = torch.cat(img_augs)image_features = self.model.encode_image(im_batch)# logit_scale = self.model.logit_scale.exp()for n in range(NUM_AUGS):loss -= torch.cosine_similarity(self.text_features, image_features[n:n + 1], dim=1)return lossdef compose_image_with_white_background(self, img: torch.tensor) -> torch.tensor:if img.shape[-1] == 3:  # return img if it is already rgbreturn img# Compose img with white backgroundalpha = img[:, :, 3:4]img = alpha * img[:, :, :3] + (1 - alpha) * torch.ones(img.shape[0], img.shape[1], 3, device=self.device)return imgdef read_png_image_from_path(self, path_to_png_image: str) -> torch.tensor:numpy_image = skimage.io.imread(path_to_png_image)normalized_tensor_image = torch.from_numpy(numpy_image).to(torch.float32) / 255.0resizer = torchvision.transforms.Resize((224, 224))resized_image = resizer(normalized_tensor_image.permute(2, 0, 1)).permute(1, 2, 0)return resized_imagedef reference_images_feature(self, reference_images_path):reference_images_num = len(os.listdir(reference_images_path))reference_images_feature = []for i in range(reference_images_num):i_reference_image = self.read_png_image_from_path(os.path.join(reference_images_path, str(i) + ".png"))if i_reference_image.shape[-1] == 4:i_reference_image = self.compose_image_with_white_background(i_reference_image)# targets_ = self.preprocess(i_reference_image.permute(2, 0, 1).unsqueeze(0)).to(self.device)i_reference_image_features = self.model.encode_image(i_reference_image.permute(2, 0, 1).unsqueeze(0).to(self.device)).detach()reference_images_feature.append(i_reference_image_features)return torch.cat(reference_images_feature)def read_png_image_from_path(path_to_png_image: str) -> torch.tensor:if path_to_png_image.endswith('.webp'):numpy_image = np.array(webp.load_image(path_to_png_image))else:numpy_image = skimage.io.imread(path_to_png_image)normalized_tensor_image = torch.from_numpy(numpy_image).to(torch.float32) / 255.0resizer = torchvision.transforms.Resize((224, 224))resized_image = resizer(normalized_tensor_image.permute(2, 0, 1)).permute(1, 2, 0)return resized_imageif __name__ == '__main__':torch.autograd.set_detect_anomaly(True)from tqdm import tqdmdef get_bezier_circle(radius: float = 80,segments: int = 4,bias: np.array = np.asarray([100., 100.])):deg = torch.arange(0, segments * 3 + 1) * 2 * np.pi / (segments * 3 + 1)points = torch.stack((torch.cos(deg), torch.sin(deg))).Tpoints = points * radius + torch.tensor(bias).unsqueeze(dim=0)points = points.type(torch.FloatTensor).contiguous()return pointsdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")matchLoss = GeometrymatchLoss(device, "reference_images/")# print(matchLoss.reference_images_feature.shape)# img1 = read_png_image_from_path('learn/output.png')canvas_width, canvas_height = 224, 224num_segments=4points1 = get_bezier_circle()path = pydiffvg.Path(num_control_points=torch.tensor(num_segments * [2] + [0],dtype=torch.int32), points=points1, stroke_width=torch.tensor(2.0),is_closed=True)shapes=[path]path.points.requires_grad = Trueprint(id(path.points))print(id(points1))points_vars = []points_vars.append(path.points)points_optim = torch.optim.Adam(points_vars, lr=1)pbar = tqdm(range(100000))print(points1)for t in pbar:# print(t)points_optim.zero_grad()# print("match_loss:", match_loss)match_loss = matchLoss(t,224, 224, shapes)match_loss.backward()# print(path.points.grad)points_optim.step()pbar.set_postfix({"match_loss": f"{match_loss.item()}"})# print(points_vars[0])pass

迭代1000轮次后生成的结果
在这里插入图片描述

没有图像增强

import math
import collections
import CLIP_.clip as clip
import torch
import torch.nn as nn
from torchvision import models, transforms
import numpy as np
import webp
from PIL import Image
import skimage
import torchvision
import pydiffvg
import os
import torch.nn.functional as Fclass GeometrymatchLoss(torch.nn.Module):def __init__(self, device, reference_images_path):super(GeometrymatchLoss, self).__init__()self.device = deviceself.model, clip_preprocess = clip.load('ViT-B/32', self.device, jit=False)self.model.eval()self.preprocess = transforms.Compose([clip_preprocess.transforms[0], clip_preprocess.transforms[-1]])  # clip normalisation# self.preprocess = transforms.Compose([clip_preprocess.transforms[-1]])  # clip normalisationself.reference_images_feature = self.reference_images_feature(reference_images_path)self.reference_images_feature =self.reference_images_feature/ self.reference_images_feature.norm(dim=-1, keepdim=True)self.text = clip.tokenize([ "A picture of triangle"]).to(device)# self.text = clip.tokenize(["A picture of rectangle", "A picture of triangle", "A picture of circle", "A picture of pentagon","A picture of five-pointed star"]).to(device)self.text_features = self.model.encode_text(self.text)self.text_features = self.text_features / self.text_features.norm(dim=-1, keepdim=True)print("text_features.requires_grad:",self.text_features.requires_grad)self.text_features=self.text_features.detach()self.shape_groups=[pydiffvg.ShapeGroup(shape_ids=torch.tensor([0]), fill_color=torch.tensor([0.0, 0.0, 0.0, 1.0]),stroke_color=torch.tensor([0.0, 0.0, 0.0, 1.0]))]# Image Augmentation Transformationself.augment_trans = transforms.Compose([transforms.RandomPerspective(fill=1, p=1, distortion_scale=0.5),transforms.RandomResizedCrop(224, scale=(0.7, 0.9)),])def forward(self, t,canvas_width, canvas_height,shapes):scene_args = pydiffvg.RenderFunction.serialize_scene(canvas_width, canvas_height, shapes, self.shape_groups)# 渲染图像render = pydiffvg.RenderFunction.applytarget = render(canvas_width, canvas_height, 2, 2, 0, None, *scene_args)if target.shape[-1] == 4:target = self.compose_image_with_white_background(target)if t%100==0:pydiffvg.imwrite(target.cpu(), f'learn/log/output_{t}.png', gamma=2.2)# targets_ = self.preprocess(target.permute(2, 0, 1).unsqueeze(0)).to(self.device)img = target.unsqueeze(0)img = img.permute(0, 3, 1, 2)loss = 0NUM_AUGS = 4img_augs = []for n in range(NUM_AUGS):img_augs.append(self.augment_trans(img))im_batch = torch.cat(img_augs)image_features = self.model.encode_image(img)self.targets_features: torch.tensor=image_features[0]self.targets_features = self.targets_features / self.targets_features.norm(dim=-1, keepdim=True)loss -= torch.cosine_similarity(self.text_features, self.targets_features, dim=1)return lossdef compose_image_with_white_background(self, img: torch.tensor) -> torch.tensor:if img.shape[-1] == 3:  # return img if it is already rgbreturn img# Compose img with white backgroundalpha = img[:, :, 3:4]img = alpha * img[:, :, :3] + (1 - alpha) * torch.ones(img.shape[0], img.shape[1], 3, device=self.device)return imgdef read_png_image_from_path(self, path_to_png_image: str) -> torch.tensor:numpy_image = skimage.io.imread(path_to_png_image)normalized_tensor_image = torch.from_numpy(numpy_image).to(torch.float32) / 255.0resizer = torchvision.transforms.Resize((224, 224))resized_image = resizer(normalized_tensor_image.permute(2, 0, 1)).permute(1, 2, 0)return resized_imagedef reference_images_feature(self, reference_images_path):reference_images_num = len(os.listdir(reference_images_path))reference_images_feature = []for i in range(reference_images_num):i_reference_image = self.read_png_image_from_path(os.path.join(reference_images_path, str(i) + ".png"))if i_reference_image.shape[-1] == 4:i_reference_image = self.compose_image_with_white_background(i_reference_image)# targets_ = self.preprocess(i_reference_image.permute(2, 0, 1).unsqueeze(0)).to(self.device)i_reference_image_features = self.model.encode_image(i_reference_image.permute(2, 0, 1).unsqueeze(0).to(self.device)).detach()reference_images_feature.append(i_reference_image_features)return torch.cat(reference_images_feature)def read_png_image_from_path(path_to_png_image: str) -> torch.tensor:if path_to_png_image.endswith('.webp'):numpy_image = np.array(webp.load_image(path_to_png_image))else:numpy_image = skimage.io.imread(path_to_png_image)normalized_tensor_image = torch.from_numpy(numpy_image).to(torch.float32) / 255.0resizer = torchvision.transforms.Resize((224, 224))resized_image = resizer(normalized_tensor_image.permute(2, 0, 1)).permute(1, 2, 0)return resized_imageif __name__ == '__main__':torch.autograd.set_detect_anomaly(True)from tqdm import tqdmdef get_bezier_circle(radius: float = 80,segments: int = 4,bias: np.array = np.asarray([100., 100.])):deg = torch.arange(0, segments * 3 + 1) * 2 * np.pi / (segments * 3 + 1)points = torch.stack((torch.cos(deg), torch.sin(deg))).Tpoints = points * radius + torch.tensor(bias).unsqueeze(dim=0)points = points.type(torch.FloatTensor).contiguous()return pointsdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")matchLoss = GeometrymatchLoss(device, "reference_images/")# print(matchLoss.reference_images_feature.shape)# img1 = read_png_image_from_path('learn/output.png')canvas_width, canvas_height = 224, 224num_segments=4points1 = get_bezier_circle()path = pydiffvg.Path(num_control_points=torch.tensor(num_segments * [2] + [0],dtype=torch.int32), points=points1, stroke_width=torch.tensor(2.0),is_closed=True)shapes=[path]path.points.requires_grad = Trueprint(id(path.points))print(id(points1))points_vars = []points_vars.append(path.points)points_optim = torch.optim.Adam(points_vars, lr=1)pbar = tqdm(range(100000))print(points1)for t in pbar:# print(t)points_optim.zero_grad()# print("match_loss:", match_loss)match_loss = matchLoss(t,224, 224, shapes)match_loss.backward()# print(path.points.grad)points_optim.step()pbar.set_postfix({"match_loss": f"{match_loss.item()}"})# print(points_vars[0])pass

迭代1000轮次后生成的结果
在这里插入图片描述
迭代2000轮次后生成的结果
在这里插入图片描述
迭代4000轮次后生成的结果
在这里插入图片描述
迭代8000轮次后生成的结果
在这里插入图片描述

无图像增强效果不好的原因分析

论文CLIPDraw: Exploring Text-to-Drawing Synthesisthrough Language-Image Encoders解释

在这里插入图片描述

论文StyleCLIPDraw: Coupling Content and Style in Text-to-Drawing Translation解释

在这里插入图片描述

个人理解

因为有很多图片可以和一个文本相匹配,对于我们人来说这些图片有一个根本和文本不相关,如果进行图像增强大概率会得到局部最优值。在计算损失函数之前对图片先进行增强,透过透视等变换,相关的图片不论如何变换和文本的相似度基本不会降低,而不相关的图像变换完之后一般会让相似度降低,这样就可以防止不相关图片对实验结果的影响。

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com