您的位置:首页 > 文旅 > 旅游 > 邢台快照优化_上海发布官网app下载_平台推广策划方案_太原网站建设方案优化

邢台快照优化_上海发布官网app下载_平台推广策划方案_太原网站建设方案优化

2025/3/16 3:36:38 来源:https://blog.csdn.net/m0_61864577/article/details/146232687  浏览:    关键词:邢台快照优化_上海发布官网app下载_平台推广策划方案_太原网站建设方案优化
邢台快照优化_上海发布官网app下载_平台推广策划方案_太原网站建设方案优化

比较二次训练过程中所有算子的误差,定位存在一致性问题的pytorch算子

    • 一.背景
    • 二.技术方案
      • 1.核心思路
      • 2.关键技术点
    • 三.代码

一.背景

在分布式训练场景中,观察到以下现象:

  1. 相同超参配置下,多次训练的Loss曲线存在显著差异(波动幅度>5%)
  2. 模型收敛稳定性受训练节点数影响,节点越多差异越明显
  3. 梯度检查(Gradient Check)未发现异常,初步排除模型结构问题

二.技术方案

1.核心思路

通过算子级数值一致性验证,定位导致多机训练结果不一致的PyTorch原生算子。关键技术路径:

  1. 算子拦截:利用__torch_dispatch__机制捕获所有ATen算子调用
  2. 双模校验
    • 基准模式:首次运行保存各算子输入/输出的统计特征
    • 验证模式:后续运行实时校验数值一致性
  3. 差异定位:当检测到统计特征偏离时,打印完整的调用栈信息

2.关键技术点

模块实现方案
算子拦截继承TorchDispatchMode重写调度逻辑
特征提取计算张量均值(排除形状/类型等非数值因素)
差异检测使用torch.allclose进行容差对比(默认atol=1e-4)
结果持久化按rank序列化存储基准数据到磁盘
黑名单机制过滤empty_like等非计算类算子,减少误报

三.代码

import torch
from torch.utils._python_dispatch import TorchDispatchMode
from dataclasses import dataclass
from typing import Any
from datetime import datetime
import numpy as np
import torch.nn as nn
import time
import os
import pickle
import inspect
import traceback@dataclass
class _ProfilerState:cls: Anyobject: Any = Nonedef is_tensor(val):return isinstance(val, (torch.Tensor, nn.Parameter))def do_compare(rank, name, tensor):timestamp = time.time()seconds = int(timestamp)millis = int((timestamp - seconds) * 1000)now=f"{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(seconds))}.{millis:03d}"has_nan = torch.isnan(tensor).any()if has_nan:print(f"has_nan {now} {rank} {name} {tensor.shape} {tensor.dtype} {tensor.device}")return Falsecurrent_mean = torch.mean(tensor.float()).cpu().item()cache_file = f"/logs/rank_{rank}.pkl"os.makedirs(os.path.dirname(cache_file), exist_ok=True)stored_data = {}if os.path.exists(cache_file):try:with open(cache_file, 'rb') as f:stored_data = pickle.load(f)except:print("load failed:",cache_file)traceback.print_exc()if name in stored_data:if torch.allclose(torch.tensor(current_mean),torch.tensor(stored_data[name]),atol=1e-4):return Trueprint("------------------------------In-----------------------------------")print(f"MisMatch {now} {rank} {name} {tensor.shape} {tensor.dtype} {tensor.device} {current_mean} {stored_data[name]} min:{torch.min(tensor)} max:{torch.max(tensor)}")return Falseelse:stored_data[name] = current_meanwith open(cache_file, 'wb') as f:pickle.dump(stored_data, f)return Trueindex_counter = 0
def compare_tensor(name, tensor):global index_counterindex_counter += 1rank=torch.distributed.get_rank()if is_tensor(tensor):if not do_compare(rank,f"{name}-{index_counter}",tensor):return Falseelif isinstance(tensor, (tuple, list)):for idx, t in enumerate(tensor):if is_tensor(t):if not do_compare(rank,f"{name}-{index_counter}-{idx}",t):return Falsereturn Trueclass TorchDumpDispatchMode(TorchDispatchMode):def __init__(self,parent):super().__init__()self.parent=parentdef is_allow_dump(self,name):black_list=["empty","like","zero","detach","has","view","copy","arange","fill","ones","lift_fresh","alias","scalar_tensor","clone","stack","slice","source","barrier","select","random","unsqueeze","expand","normal_"]for i in black_list:if name.find(i)>=0:return Falsereturn Truedef __torch_dispatch__(self, func, types, args=(), kwargs=None):func_packet = func._overloadpacketop_name=f"{func}"enable_dump= self.is_allow_dump(op_name)if kwargs is None:kwargs = {}if enable_dump:torch.cuda.synchronize()if not compare_tensor(f"{op_name}-[input]", args):stack = inspect.stack()i=0for frame_info in reversed(stack):msg=f"{i}:{frame_info.filename}:{frame_info.lineno}"print(msg)i+=1print("------------------------------Out-----------------------------------")ret= func(*args, **kwargs)if enable_dump:torch.cuda.synchronize()if not compare_tensor(f"{op_name}-[output0]", ret):stack = inspect.stack()i=0for frame_info in reversed(stack):msg=f"{i} {frame_info.filename}:{frame_info.lineno}"print(msg)i+=1print("------------------------------Out0-----------------------------------")if not compare_tensor(f"{op_name}-[output1]", args):stack = inspect.stack()i=0for frame_info in reversed(stack):msg=f"{i} {frame_info.filename}:{frame_info.lineno}"print(msg)i+=1print("------------------------------Out1-----------------------------------")return retclass TorchDebugDumper:_CURRENT_Dumper = Nonedef __init__(self):self.p= _ProfilerState(TorchDumpDispatchMode)def __enter__(self):assert TorchDebugDumper._CURRENT_Dumper is NoneTorchDebugDumper._CURRENT_Dumper = selfif self.p.object is None:o = self.p.cls(self)o.__enter__()self.p.object = oelse:self.p.object.step()return selfdef __exit__(self, exc_type, exc_val, exc_tb):TorchDebugDumper._CURRENT_Dumper = Noneif self.p.object is not None:self.p.object.__exit__(exc_type, exc_val, exc_tb)del self.p.objectdef main():pretrain(train_valid_test_datasets_provider,model_provider,forward_step,extra_args_provider=llama_argument_handler,args_defaults={"tokenizer_type": "GPT2BPETokenizer"},)
if __name__ == "__main__":with TorchDebugDumper():main()

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com