YOLOX中decode 特征点解码过程可视化
该代码是特征宽高为20*20,batch_size=4,num_classes = 20进行解码可视化的过程。
import numpy as np
import matplotlib.pyplot as pltdef decode_for_vision(output):bs, hw = np.shape(output)[0], np.shape(output)[1:3]# hw[0] * hw[1] ------- 20,20output = np.reshape(output, [bs, hw[0] * hw[1], -1])#print(output)#output ------(4, 400, 23)grid_x, grid_y = np.meshgrid(np.arange(hw[1]), np.arange(hw[0]))#print(grid_x)grid = np.reshape(np.stack((grid_x, grid_y), 2), (1, -1, 2))#grid ---------(1, 400, 2)#print(grid)box_xy = (output[..., :2] + grid)#box_xy.shape (4, 400, 2)#output[..., :2] (4, 400, 2)#grid (1, 400, 2)box_wh = np.exp(output[..., 2:4])#output[..., 2:4].shape (4, 400, 2)#box_wh (4, 400, 2)fig = plt.figure()ax = fig.add_subplot(121)plt.ylim(-2.22, hw[0] + 2.22)plt.xlim(-2.22, hw[1] + 2.22)plt.scatter(grid_x, grid_y)plt.scatter(0, 0, c='black')plt.scatter(1, 0, c='black')plt.scatter(2, 0, c='black')plt.scatter(box_xy[0, 0, 0], box_xy[0, 0, 1], c='r')plt.scatter(box_xy[0, 1, 0], box_xy[0, 1, 1], c='g')plt.scatter(box_xy[0, 2, 0], box_xy[0, 2, 1], c='b')plt.gca().invert_yaxis()pre_left = box_xy[..., 0] - box_wh[..., 0] / 2pre_top = box_xy[..., 1] - box_wh[..., 1] / 2rect1 = plt.Rectangle([pre_left[0, 0], pre_top[0, 0]], box_wh[0, 0, 0], box_wh[0, 0, 1], color="r", fill=False)rect2 = plt.Rectangle([pre_left[0, 1], pre_top[0, 1]], box_wh[0, 1, 0], box_wh[0, 1, 1], color="r", fill=False)rect3 = plt.Rectangle([pre_left[0, 2], pre_top[0, 2]], box_wh[0, 2, 0], box_wh[0, 2, 1], color="r", fill=False)ax.add_patch(rect1)ax.add_patch(rect2)ax.add_patch(rect3)plt.show()if __name__ == '__main__':batch_size = 4num_classes = 20feat = np.concatenate([np.random.uniform(-1, 1, [batch_size, 20, 20, 1]),np.random.uniform(1, 3, [batch_size, 20, 20, 2]),np.random.uniform(0, 1, [batch_size, 20, 20, num_classes])],axis=-1)# print(feat.shape)# s= np.random.uniform(-1, 1, [batch_size, 20, 20, 2])# s1 = np.random.uniform(1, 3, [batch_size, 20, 20, 2])# s2 = np.random.uniform(0, 1, [batch_size, 20, 20, num_classes])# print(s2.shape)decode_for_vision(feat)#grid_x
# [[ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19]
# [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19]
# [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19]
# [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19]
# [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19]
# [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19]
# [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19]
# [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19]
# [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19]
# [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19]
# [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19]
# [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19]
# [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19]
# [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19]
# [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19]
# [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19]
# [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19]
# [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19]
# [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19]
# [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19]]#grid_y
# [[ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19]
# [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19]
# [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19]
# [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19]
# [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19]
# [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19]
# [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19]
# [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19]
# [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19]
# [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19]
# [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19]
# [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19]
# [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19]
# [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19]
# [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19]
# [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19]
# [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19]
# [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19]
# [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19]
# [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19]]
如下图所示
1. box_xy = (output[…, :2] + grid)
output 是一个形状为 (batch_size, height * width, 23) 的数组,其中 23 是通道的数量。每个位置包含了23个值,这些值通常包括:
- 预测的边界框的坐标(中心点的 x 和 y 坐标)
- 预测的边界框的宽度和高度
- 每个类的置信度分数
在 output[…, :2] 中,output 的形状为 (batch_size, height * width, 23),… 表示选取所有的前面的维度,而 :2 表示选择最后一维的前两个值。这意味着我们在提取预测的边界框中心点的 x 和 y 坐标(通常是第一个和第二个值)。
2. grid = np.reshape(np.stack((grid_x, grid_y), 2), (1, -1, 2))
在这个部分,grid_x 和 grid_y 是通过 np.meshgrid 创建的,它们的形状为 (height, width),表示网格中的每个位置的 x 和 y 坐标。
np.stack((grid_x, grid_y), 2)
np.stack 函数将 grid_x 和 grid_y 在新的维度(这里是第2个维度)上堆叠起来,因此生成的数组形状为 (height, width, 2),其中 2 表示堆叠的两个数组(grid_x 和 grid_y)。
np.reshape(…, (1, -1, 2))
np.reshape 函数将堆叠后的数组重新调整形状为 (1, -1, 2),具体如下:
- 1 表示批量维度
- -1 表示自动计算这一维度的大小,使总元素数保持不变
- 2 表示每个位置的两个坐标(x 和 y)
这意味着我们将原始形状为 (height, width, 2) 的数组变换为形状为 (1, height * width, 2) 的数组。这是为了方便后续操作,使网格坐标与输出的形状匹配。
3. box_wh = np.exp(output[…, 2:4])
在这个代码中,output 的形状是 (batch_size, height * width, 23),其中 23 是每个预测位置上的特征数。特征的具体内容通常包括:
- 预测的边界框的中心坐标 x 和 y(2 个值)。
- 预测的边界框的宽度和高度(2 个值)。
- 每个类的置信度分数(剩余的 19 个值,如果有 20 个类)。
因此,output[…, 2:4] 的意思是提取预测的边界框的宽度和高度。output[…, 2:4] 返回的是形状为 (batch_size, height * width, 2) 的数组,其中 2 代表宽度和高度两个值。
为什么使用 np.exp
模型在预测边界框的宽度和高度时,通常会预测其对数值。这是因为宽度和高度的值范围很大,直接预测这些值可能会使模型训练变得困难。因此,模型实际预测的是宽度和高度的对数值,这样可以将其转换回原始值:
box_wh = np.exp(output[..., 2:4])
具体过程
1.模型预测:
- 模型预测的是边界框宽度和高度的对数值。
2.指数转换: - 使用 np.exp 将对数值转换回实际的宽度和高度。
3.得到实际宽度和高度: - 转换后的值代表预测的边界框的实际宽度和高度。
output[…, :2] 提取预测的边界框中心点的坐标。
output[…, 2:4] 提取预测的边界框的宽度和高度的对数值。
np.exp(output[…, 2:4]) 将对数值转换为实际的宽度和高度。
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
详细解释:
- 导入必要的库
import numpy as np
import matplotlib.pyplot as plt
#导入 NumPy 用于数值计算,导入 Matplotlib 用于可视化。
- 解码并可视化检测器输出
def decode_for_vision(output):bs, hw = np.shape(output)[0], np.shape(output)[1:3]output = np.reshape(output, [bs, hw[0] * hw[1], -1])grid_x, grid_y = np.meshgrid(np.arange(hw[1]), np.arange(hw[0]))grid = np.reshape(np.stack((grid_x, grid_y), 2), (1, -1, 2))box_xy = (output[..., :2] + grid)box_wh = np.exp(output[..., 2:4])
- 获取批量大小 (bs) 和网格尺寸 (hw)。
- 将 output 重塑为 (batch_size, height * width, -1) 的形状。
- 使用 np.meshgrid 创建网格坐标 (grid_x, grid_y)。
- 堆叠并重塑网格坐标以匹配 output 的形状。
- 计算预测的边界框中心点 (box_xy) 和宽高 (box_wh)。
- 可视化网格和边界框
fig = plt.figure()ax = fig.add_subplot(121)plt.ylim(-2.22, hw[0] + 2.22)plt.xlim(-2.22, hw[1] + 2.22)plt.scatter(grid_x, grid_y)plt.scatter(0, 0, c='black')plt.scatter(1, 0, c='black')plt.scatter(2, 0, c='black')plt.scatter(box_xy[0, 0, 0], box_xy[0, 0, 1], c='r')plt.scatter(box_xy[0, 1, 0], box_xy[0, 1, 1], c='g')plt.scatter(box_xy[0, 2, 0], box_xy[0, 2, 1], c='b')plt.gca().invert_yaxis()
- 创建图像并设置坐标轴范围。
- 绘制网格坐标点。
- 绘制三个边界框中心点。
- 可视化边界框矩形
pre_left = box_xy[..., 0] - box_wh[..., 0] / 2pre_top = box_xy[..., 1] - box_wh[..., 1] / 2rect1 = plt.Rectangle([pre_left[0, 0], pre_top[0, 0]], box_wh[0, 0, 0], box_wh[0, 0, 1], color="r", fill=False)rect2 = plt.Rectangle([pre_left[0, 1], pre_top[0, 1]], box_wh[0, 1, 0], box_wh[0, 1, 1], color="r", fill=False)rect3 = plt.Rectangle([pre_left[0, 2], pre_top[0, 2]], box_wh[0, 2, 0], box_wh[0, 2, 1], color="r", fill=False)ax.add_patch(rect1)ax.add_patch(rect2)ax.add_patch(rect3)plt.show()
- 计算边界框左上角的坐标 (pre_left, pre_top)。
- 创建矩形边界框并添加到图像中。
- 显示图像。
- 主程序入口
if __name__ == '__main__':batch_size = 4num_classes = 20feat = np.concatenate([np.random.uniform(-1, 1, [batch_size, 20, 20, 1]),np.random.uniform(1, 3, [batch_size, 20, 20, 2]),np.random.uniform(0, 1, [batch_size, 20, 20, num_classes])],axis=-1)decode_for_vision(feat)
- 设置批量大小和类别数量。
- 生成随机伪数据来模拟检测器输出。
- 调用 decode_for_vision 函数解码并可视化这些数据。
总结
这个代码的主要目的是解码检测器输出数据并可视化网格和边界框,帮助理解检测器预测的边界框位置和大小。