@staticmethod
PYL-R0201 78 )
79 super(Adahessian, self).__init__(params, defaults)
80
81 def get_trace(self, params: Params, grads: Grads) -> List[torch.Tensor]: 82 """Get an estimate of Hessian Trace.
83 This is done by computing the Hessian vector product with a random
84 vector v at the current gradient point, to estimate Hessian trace by
106 def _rms(self, tensor: torch.Tensor) -> float:
107 return tensor.norm(2) / (tensor.numel() ** 0.5)
108
109 def _approx_sq_grad(110 self,
111 exp_avg_sq_row: torch.Tensor,
112 exp_avg_sq_col: torch.Tensor,
103 use_first_moment = param_group["beta1"] is not None
104 return factored, use_first_moment
105
106 def _rms(self, tensor: torch.Tensor) -> float:107 return tensor.norm(2) / (tensor.numel() ** 0.5)
108
109 def _approx_sq_grad(
96 param_scale = max(param_group["eps2"][1], param_state["RMS"])
97 return param_scale * rel_step_sz
98
99 def _get_options(100 self, param_group: ParamGroup, param_shape: Tuple[int, ...]
101 ) -> Tuple[bool, bool]:
102 factored = len(param_shape) >= 2
82 )
83 super(Adafactor, self).__init__(params, defaults)
84
85 def _get_lr(self, param_group: ParamGroup, param_state: State) -> float: 86 rel_step_sz = param_group["lr"]
87 if param_group["relative_step"]:
88 min_step = (
The method doesn't use its bound instance. Decorate this method with @staticmethod
decorator, so that Python does not have to instantiate a bound method for every instance of this class thereby saving memory and computation. Read more about staticmethods here.