优化器_2
此 MNIST 示例使用了这些优化器。
这个文件定义了 Adam 的通用基类及其扩展。由于可重用性,基类有助于以最少的代码实现其他优化器。
我们还为 L2 权重衰减定义了一个特殊的类,这样我们就不必在每个优化器中实现它,并且可以在不更改优化器的情况下轻松扩展到其他权重衰减,例如 L1。
以下是关于 PyTorch 优化器的一些概念:
PyTorch 优化器将参数分组到名为组的集合中。每个组可以有自己的超参数,例如学习率。
在大多数情况下,只有一组。这是你使用初始化优化器的时候,
在初始化优化器时,可以定义多个参数组:
在这里,我们传递一个组列表。每个组都是一个字典,其参数位于键 “params” 下。您也可以指定任何超参数。如果未定义 hyper 参数,它们将默认为优化程序级别的默认值。
您可以使用访问(甚至更改)这些组及其超参数 。我遇到的大多数学习率计划实现都访问了这个并更改了 “lr”。
Optimizer 在字典中维护每个参数(张量)的状态(字典) 。这是优化器维护指数平均值之类的东西的地方。
62from typing import Dict, Tuple, Any
63
64import torch
65from torch import nn
66from torch.optim.optimizer import Optimizer
69class GenericAdaptiveOptimizer(Optimizer):
74 def __init__(self, params, defaults: Dict[str, Any], lr: float, betas: Tuple[float, float], eps: float):
检查超参数
86 if not 0.0 <= lr:
87 raise ValueError(f"Invalid learning rate: {lr}")
88 if not 0.0 <= eps:
89 raise ValueError(f"Invalid epsilon value: {eps}")
90 if not 0.0 <= betas[0] < 1.0:
91 raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
92 if not 0.0 <= betas[1] < 1.0:
93 raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
将超参数添加到默认值
96 defaults.update(dict(lr=lr, betas=betas, eps=eps))
初始化 PyTorch 优化器。这将使用默认的超参数创建参数组
99 super().__init__(params, defaults)
这应该被代码覆盖,以便初始 化参数 。 是所 属的参数组字典。
101 def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):
108 pass
这应该被重写并对 张量采取优化步骤,其中 是该参数的梯度, 是该参数的优化器状态字典, 也是参数组字典 所属的。
110 def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.Tensor):
119 pass
我们创建了一个模板方法,它可以完成每个基于 Adam 的优化器所需要的常用内容。
121 @torch.no_grad()
122 def step(self, closure=None):
133 loss = None
134 if closure is not None:
135 with torch.enable_grad():
136 loss = closure()
遍历参数组
139 for group in self.param_groups:
遍历参数组中的参数
141 for param in group['params']:
如果参数没有渐变,则跳过
143 if param.grad is None:
144 continue
获取梯度张量
146 grad = param.grad.data
我们不处理稀疏渐变
148 if grad.is_sparse:
149 raise RuntimeError('GenericAdaptiveOptimizer does not support sparse gradients,'
150 ' please consider SparseAdam instead')
获取参数的状态
153 state = self.state[param]
如果状态未初始化,则初始化状态
156 if len(state) == 0:
157 self.init_state(state, group, param)
对参数采取优化步骤
160 self.step_param(state, group, grad, param)
返回从闭包计算得出的损失
163 return loss
166class WeightDecay:
- 是衰减系数
- 是一个标志,指示是将权重衰减添加到梯度还是直接从参数中衰减。如果添加到渐变中,它将通过普通的优化器更新。
- 此标志指示权重衰减系数是否为绝对值。当直接对参数执行衰减时,这适用。如果此值为假,则实际衰减为
- 。
171 def __init__(self, weight_decay: float = 0., weight_decouple: bool = True, absolute: bool = False):
检查超参数
184 if not 0.0 <= weight_decay:
185 raise ValueError(f"Invalid weight_decay value: {weight_decay}")
186
187 self.absolute = absolute
188 self.weight_decouple = weight_decouple
189 self.weight_decay = weight_decay
返回参数组的默认值
191 def defaults(self):
195 return dict(weight_decay=self.weight_decay)
197 def __call__(self, param: torch.nn.Parameter, grad: torch.Tensor, group: Dict[str, any]):
如果我们直接对参数进行衰减
203 if self.weight_decouple:
如果权重衰减系数为绝对值
205 if self.absolute:
206 param.data.mul_(1.0 - group['weight_decay'])
否则,
208 else:
209 param.data.mul_(1.0 - group['lr'] * group['weight_decay'])
返回未修改的渐变
211 return grad
212 else:
213 if group['weight_decay'] != 0:
将权重衰减添加到渐变并返回修改后的渐变
215 return grad.add(param.data, alpha=group['weight_decay'])
216 else:
217 return grad