Rotation
类(rigid_utils 模块
)是 AlphaFold3 中用于 3D旋转 的核心组件,支持两种旋转表示: 1️⃣ 旋转矩阵 (3x3)
2️⃣ 四元数 (quaternion, 4元向量)
👉 设计目标:
-
允许灵活选择 旋转矩阵 或 四元数
-
封装了常用的 旋转操作(组合、逆旋转、应用到点上等)
-
像
torch.Tensor
一样,支持索引、拼接、广播等操作
源代码:
class Rotation:"""A 3D rotation. Depending on how the object is initialized, therotation is represented by either a rotation matrix or aquaternion, though both formats are made available by helper functions.To simplify gradient computation, the underlying format of therotation cannot be changed in-place. Like Rigid, the class is designedto mimic the behavior of a torch Tensor, almost as if each Rotationobject were a tensor of rotations, in one format or another."""def __init__(self,rot_mats: Optional[torch.Tensor] = None,quats: Optional[torch.Tensor] = None,normalize_quats: bool = True,):"""Args:rot_mats:A [*, 3, 3] rotation matrix tensor. Mutually exclusive withquatsquats:A [*, 4] quaternion. Mutually exclusive with rot_mats. Ifnormalize_quats is not True, must be a unit quaternionnormalize_quats:If quats is specified, whether to normalize quats"""if((rot_mats is None and quats is None) or (rot_mats is not None and quats is not None)):raise ValueError("Exactly one input argument must be specified")if((rot_mats is not None and rot_mats.shape[-2:] != (3, 3)) or (quats is not None and quats.shape[-1] != 4)):raise ValueError("Incorrectly shaped rotation matrix or quaternion")# Force full-precisionif(quats is not None):quats = quats.to(dtype=torch.float32)if(rot_mats is not None):rot_mats = rot_mats.to(dtype=torch.float32)if(quats is not None and normalize_quats):quats = quats / torch.linalg.norm(quats, dim=-1, keepdim=True)self._rot_mats = rot_matsself._quats = quats@staticmethoddef identity(shape,dtype: Optional[torch.dtype] = None,device: Optional[torch.device] = None,requires_grad: bool = True,fmt: str = "quat",) -> Rotation:"""Returns an identity Rotation.Args:shape:The "shape" of the resulting Rotation object. See documentationfor the shape propertydtype:The torch dtype for the rotationdevice:The torch device for the new rotationrequires_grad:Whether the underlying tensors in the new rotation objectshould require gradient computationfmt:One of "quat" or "rot_mat". Determines the underlying formatof the new object's rotation Returns:A new identity rotation"""if(fmt == "rot_mat"):rot_mats = identity_rot_mats(shape, dtype, device, requires_grad,)return Rotation(rot_mats=rot_mats, quats=None)elif(fmt == "quat"):quats = identity_quats(shape, dtype, device, requires_grad)return Rotation(rot_mats=None, quats=quats, normalize_quats=False)else:raise ValueError(f"Invalid format: f{fmt}")# Magic methodsdef __getitem__(self, index: Any) -> Rotation:"""Allows torch-style indexing over the virtual shape of the rotationobject. See documentation for the shape property.Args:index:A torch index. E.g. (1, 3, 2), or (slice(None,))Returns:The indexed rotation"""if type(index) != tuple:index = (index,)if(self._rot_mats is not None):rot_mats = self._rot_mats[index + (slice(None), slice(None))]return Rotation(rot_mats=rot_mats)elif(self._quats is not None):quats = self._quats[index + (slice(None),)]return Rotation(quats=quats, normalize_quats=False)else:raise ValueError("Both rotations are None")def __mul__(self,right: torch.Tensor,) -> Rotation:"""Pointwise left multiplication of the rotation with a tensor. Can beused to e.g. mask the Ro