AlphaFold3的函数gdt、gdt_ts以及gdt_ha实现了 Global Distance Test (GDT) 评分计算,用于衡量蛋白质结构预测的准确性。GDT 评分衡量的是 预测结构(p1) 和 真实结构(p2) 之间的相似度,主要用于蛋白质结构比较。
源代码:
def gdt(p1, p2, mask, cutoffs):"""Calculate the Global Distance Test (GDT) score for protein structures.Args:p1 (torch.Tensor): Coordinates of the first structure [..., N, 3].p2 (torch.Tensor): Coordinates of the second structure [..., N, 3].mask (torch.Tensor): Mask for valid residues [..., N].cutoffs (list): List of distance cutoffs for GDT calculation.Returns:torch.Tensor: GDT score [...]."""# Ensure inputs are floatp1 = p1.float()p2 = p2.float()mask = mask.float()# Calculate number of valid residues per batchn = torch.sum(mask, dim=-1)# Calculate pairwise distancesdistances = torch.sqrt(torch.sum((p1 - p2)**2, dim=-1))scores = []for c in cutoffs:# Calculate score for each cutoff, accounting for the maskscore = torch.sum((distances <= c).float() * mask, dim=-1) / (n + 1e-8)scores.append(score)# Stack scores and average across cutoffsscores = torch.stack(scores, dim=-1)return torch.mean(scores, dim=-1)def gdt_ts(p1, p2, mask):"""Calculate the Global Distance Test Total Score (GDT_TS).Args:p1 (torch.Tensor): Coordinates of the first structure [..., N, 3].p2 (torch.Tensor): Coordinates of the second structure [..., N, 3].mask (torch.Tensor): Mask for valid residues [..., N].Returns:torch.Tensor: GDT_TS score [...]."""return gdt(p1, p2, mask, [1., 2., 4., 8.])def gdt_ha(p1, p2, mask):"""Calculate the Global Distance Test High Accuracy (GDT_HA) score.Args:p1 (torch.Tensor): Coo