首先该类实现, 使用timm ==0.6.11 版本;
Exponential Moving Average (EMA) for models in PyTorch.
目的:它旨在维护模型状态字典的移动平均值,包括参数和缓冲区。该技术通常用于训练方案,其中权重的平滑版本对于最佳性能至关重要。
class ModelEma:""" Model Exponential Moving Average (DEPRECATED)Keep a moving average of everything in the model state_dict (parameters and buffers).This version is deprecated, it does not work with scripted models. Will be removed eventually.This is intended to allow functionality like smoothed version of the weights is necessary for some training schemes to perform Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that useRMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMAsmoothing of weights to match results. Pay attention to the decay constant you are usingrelative to your update count per epoch.To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory butdisable validation of the EMA weights. Validation will have to be done manually in a separateprocess, or after the training stops converging.This class is sensitive where it is initialized in the sequence of model init,GPU assignment and distributed training wrappers."""def __init__(self, model, decay=0.9999, device='', resume=''):# make a copy of the model for accumulating moving average a = deepcopy(a.eval()self.decay = decayself.device = device # perform ema on different device from model if setif a.to(device=a_has_module = a, 'module')if resume:self._load_checkpoint(resume)for p a.parameters():p.requires_grad_(False)def _load_checkpoint(self, checkpoint_path):checkpoint = torch.load(checkpoint_path, map_location='cpu')assert isinstance(checkpoint, dict)if 'state_dict_ema' in checkpoint:new_state_dict = OrderedDict()for k, v in checkpoint['state_dict_ema'].items():# ema model may have been wrapped by DataParallel, and need module a_has_module:name = 'module.' + k if not k.startswith('module') else kelse:name = knew_state_dict[name] = a.load_state_dict(new_state_dict)_logger.info("Loaded state_dict_ema")else:_logger.warning("Failed to find state_dict_ema, starting from loaded model weights")def update(self, model):# correct a mismatch in state dict keysneeds_module = hasattr(model, 'module') and a_has__grad():msd = model.state_dict()for k, ema_v a.state_dict().items():if needs_module:k = 'module.' + kmodel_v = msd[k].detach()if self.device:model_v = (device=self.device)py_(ema_v * self.decay + (1. - self.decay) * model_v)
Methods:方法:
__init__
:通过创建所提供模型的副本、设置衰减率和设备放置来初始化 EMA 模型。模型设置为评估模式,并且其梯度被禁用。
_load_checkpoint
:加载 EMA 模型的检查点。它处理由 DataParallel 包装器引起的状态字典命名约定中的潜在差异。
update
:
通过计算原始模型参数和当前 EMA 参数的加权平均值来更新 EMA 参数。
Features:特征:
import logging
from collections import OrderedDict
from copy import deepcopyimport torch
as nn_logger = Logger(__name__)class ModelEmaV2(nn.Module):""" Model Exponential Moving Average V2Keep a moving average of everything in the model state_dict (parameters and buffers).V2 of this module is simpler, it does not match params/buffers based on name but simplyiterates in order. It works with torchscript (JIT of full model).This is intended to allow functionality like smoothed version of the weights is necessary for some training schemes to perform Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that useRMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMAsmoothing of weights to match results. Pay attention to the decay constant you are usingrelative to your update count per epoch.To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory butdisable validation of the EMA weights. Validation will have to be done manually in a separateprocess, or after the training stops converging.This class is sensitive where it is initialized in the sequence of model init,GPU assignment and distributed training wrappers."""def __init__(self, model, decay=0.9999, device=None):super(ModelEmaV2, self).__init__()# make a copy of the model for accumulating moving average dule = deepcopy(dule.eval()self.decay = decayself.device = device # perform ema on different device from model if setif self.device is not (device=device)def _update(self, model, update_fn):_grad():for ema_v, model_v in dule.state_dict().values(), model.state_dict().values()):if self.device is not None:model_v = (device=self.device)py_(update_fn(ema_v, model_v))def update(self, model): # 使用衰减率更新 EMA 参数self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)def set(self, model): # 直接将 EMA 参数设置为与提供的模型参数相同。self._update(model, update_fn=lambda e, m: m)
EmaV2版本:与 ModelEma 类似,但实现更简单。它还维护模型状态字典的移动平均值,并设计为与 torchscript(完整模型的 JIT)配合使用。
Methods:方法:
__init__
:与 ModelEma 类似,但添加了对 super() 的调用来初始化 nn.Module 基类。
_update
:更新 EMA 参数的辅助函数,以自定义更新函数作为参数。
update
:使用衰减率更新 EMA 参数。
set
:直接将 EMA 参数设置为与提供的模型参数相同。
Features:特征:
设计复杂性: ModelEmaV2 更简单、更直接,避免了按名称匹配参数。
兼容性: ModelEmaV2 与 torchscript 兼容,与 ModelEma 不同。
.参数匹配: ModelEma 按名称匹配参数和缓冲区,而 ModelEmaV2 根据参数和顺序进行匹配。
版本控制和用例: ModelEma 已被弃用,并且对于较新的训练方案(尤其是需要脚本的训练方案)而言不太受欢迎。
这两个类本质上用于相同的目的,但采用不同的方法,使得 ModelEmaV2 更适合利用脚本的现代 PyTorch 工作流程。
与 ModelEma 相比,在训练过程中使用 ModelEmaV2 涉及的方法略有不同。以下是有关如何将 ModelEmaV2 合并到训练循环中的指南,以及有关衰减参数的作用和预训练权重的使用的说明。
要在训练过程中使用 ModelEma V2 ,您应该将其集成到现有的训练循环中。以下是有关如何执行此操作的分步指南:
由于v1版本被弃用, 所以这里介绍使用 V2 版本;
初始化:定义模型后,使用您的模型作为参数初始化 ModelEmaV2 。根据您的需求设置 decay 参数。
model = YourModel() # Replace with your model
ema = ModelEmaV2(model, decay=0.9999)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
(device)
(device)
这里需要注意到的是 ,需要在每个反向传播 更新之后,才回去更新EMA 模型;
for epoch in range(num_epochs):for batch in dataloader:inputs, targets = batchinputs, targets = (device), (_grad()outputs = model(inputs)loss = criterion(outputs, targets)loss.backward()optimizer.step()ema.update(model)
在获取EMA 更新的权重之后,
EMA 模型的参数权重, 真正使用他的地方是在 推理阶段, 即 training 之后的 evaluate 阶段;
dule.eval() # Set EMA model to evaluation mode
_grad():for batch in validation_dataloader:inputs, targets = batchinputs, targets = (device), (device)outputs = dule(inputs) # Use EMA model for predictions# Compute validation metrics
torch.save({'model_state_dict': model.state_dict(),'ema_state_dict': dule.state_dict(),# ... other states like optimizer, epoch, etc.
}, 'checkpoint.pth')
checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
dule.load_state_dict(checkpoint['ema_state_dict'])
# Load other states
ModelEmaV2 中的衰减参数起着至关重要的作用:
它确定移动平均线中当前模型参数相对于历史参数的权重。
衰减值的选择取决于您的训练动态和训练步骤总数。常见的做法是从高衰减开始,然后随着时间的推移逐渐减少。
较低的衰减值(远离 1):较低的衰减值导致 EMA 模型更加重视最近的模型更新。这使得 EMA 权重不太平滑,因为它们对模型参数的最新变化更加敏感。虽然这可以使 EMA 权重对数据的新趋势更加敏感,但也使它们更容易受到噪音和突然变化的影响。
总而言之,较高的衰减参数(接近 1)通过赋予历史数据更多权重来提高 EMA 模型权重的平滑度,从而导致权重更稳定但响应性较差。相反,较低的衰减值会降低平滑度,使权重对最近的变化更加敏感,但会牺牲稳定性。适当衰减值的选择取决于训练过程的具体要求和数据的性质。
使用 ModelEmaV2 时,在初始化 ModelEmaV2 之前将预训练的权重加载到原始模型中可能会很有帮助,特别是当您正在进行微调或有特定的起点时。
使用预先训练的权重:
使用 ModelEmaV2 时,在初始化 ModelEmaV2 之前将预训练的权重加载到原始模型中可能会很有帮助,特别是当您正在进行微调或有特定的起点时。
然后,EMA 模型将从这些权重的平滑版本开始,这可以导致更快的收敛和可能更好的最终性能,特别是在微调场景中。
但是,如果您从头开始训练,则使用没有预训练权重的模型初始化 ModelEmaV2 也可以。 EMA 模型将随着训练的进展进行调整。
总之, ModelEmaV2 用于维持模型权重的更平滑、更稳定的版本,这对于实现最佳性能至关重要,特别是在训练的后期阶段或微调场景中。衰减参数是控制应用平滑程度的关键。使用 ModelEmaV2 时,预训练权重可能很有用,但它们并不是绝对必要的,特别是在从头开始训练的场景中。
遇到的错误 RuntimeError: Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment
:
表示 ModelEmaV2 初始化中的 deepcopy 操作存在问题。当尝试在 PyTorch 中深度复制具有一定复杂性或特定类型的层或参数的模型时,通常会出现此问题。
检查不可复制的层或参数:PyTorch 模型中的某些自定义层或参数可能不支持深度复制。如果您的模型包含此类层,请考虑修改模型以仅使用深度复制兼容的层。
更新 PyTorch 版本:确保您使用的是最新版本的 PyTorch。有时,此类问题会在新版本中得到解决。
解决方法:自定义深度复制方法:此函数将手动将每个参数和缓冲区从原始模型复制到新模型。可以编写自定义函数来创建模型的副本,而不是使用 deepcopy 。即将原始的
__init__()
初始化过程中, dule 不使用 deepcopy()函数。
替换成如下方式拷贝:
def custom_deepcopy(model):model_copy = type(model)() # Create a new instance of the model's classmodel_copy.load_state_dict(model.state_dict()) # Copy parameters and buffersreturn a = ModelEmaV2(custom_deepcopy(self), decay=0.9999)
并且需要将原始 __init__()
初始化过程中, dule 不使用 deepcopy()函数,
def __init__(self, model, decay=0.9999, device=None):super(ModelEmaV2, self).__init__()# make a copy of the model for accumulating moving average of weights# dule = deepcopy(dule = dule.eval()self.decay = decayself.device = device # perform ema on different device from model if setif self.device is not (device=device)
本文发布于:2024-01-30 22:46:06,感谢您对本站的认可!
本文链接:https://www.4u4v.net/it/170662596923372.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
留言与评论(共有 0 条评论) |