import torch
import torchvision
import torchvision.transforms as T
import matplotlib.pyplot as plt
from PIL import Image
import numpy as npdef decodeSegMap(image, nc=21):label_colors = np.array([(0,0,0),(128,0,0),(0,128,0),(128,128,0),(0,0,128),(128,0,128),(0,128,128),(128,128,128),(64,0,0),(192,0,0),(64,128,0),(192,128,0),(64,0,128),(192,0,128),(64,128,128),(192,128,128),(0,64,0),(128,64,0),(0,192,0),(128,192,0),(0,64,128)])r = np.zeros_like(image).astype(np.uint8)g = np.zeros_like(image).astype(np.uint8)b = np.zeros_like(image).astype(np.uint8)for l in range(0, nc):idx = image == lr[idx] = label_colors[l][0]g[idx] = label_colors[l][1]b[idx] = label_colors[l][2]return np.stack([r,g,b], axis=2)# 获取模型
model = torchvision.models.segmentation.deeplabv3_resnet50(pretrained=True)
# model.load_state_dict(torch.load('./deeplabv3_resnet50_coco-cd0a2569.pth'))
model = model.eval()# 预处理
transform = T.Compose([T.Resize(256), T.CenterCrop(224), T.ToTensor(),T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])# 加载图片
img_pth = './dog.2.jpg'
img = Image.open(img_pth)
plt.imshow(img)
plt.axis('off')
plt.show()img = transform(img).unsqueeze(0)
# 显示用transform转换后的图片
img_transform = np.transpose(img.detach().numpy()[0], (1, 2, 0))
plt.imshow(img_transform)
plt.show()output = model(img)
print(f"输出结果的形状:{output['out'].shape}")
output = torch.argmax(output['out'].squeeze(), dim=0).detach().cpu().numpy()
result_class = set(list(output.flat))
print(result_class)rgb = decodeSegMap(output)
img = Image.fromarray(rgb)
plt.axis('off')
plt.imshow(img)
plt.show()
原图:
结果图: