1. 计算Ca距离的函数
def _dist(self, X, mask, eps=1E-6):mask_2D = torch.unsqueeze(mask,1) * torch.unsqueeze(mask,2)dX = torch.unsqueeze(X,1) - torch.unsqueeze(X,2)D = mask_2D * torch.sqrt(torch.sum(dX**2, 3) + eps)D_max, _ = torch.max(D, -1, keepdim=True)D_adjust = D + (1. - mask_2D) * D_maxsampled_top_k = self.top_kD_neighbors, E_idx = torch.topk(D_adjust, np.minimum(self.top_k, X.shape[1]), dim=-1, largest=False)return D_neighbors, E_idx
_dist
函数主要用于计算一批数据中每个残基的邻居距离以及相应的邻居索引。输入的 X
是每个残基的 Ca 原子坐标,维度为 (B, L, 3)
,mask
的维度为 (B, L)
,用于标记有效的残基。下面解释每个张量的维度变化过程:
输入
X
的维度是(B, L, 3)
,其中:B
是 batch size。L
是残基的数量。3
是 Ca 原子的三维坐标。
mask
的维度是(B, L)
,表示每个残基的有效性(0 或 1)。
计算步骤与维度变化
-
mask_2D = torch.unsqueeze(mask, 1) * torch.unsqueeze(mask, 2)
mask
通过unsqueeze
变成(B, 1, L)
和(B, L, 1)
两个张量。- 这两个张量相乘后得到
mask_2D
,维度为(B, L, L)