Add ValuError message for Adamax (#508)

* ValuError message added

* beta errors added

* some corrections and testing

* Learning rate limitation deleted
This commit is contained in:
Arda Orçun 2024-01-20 18:56:15 +03:00 committed by GitHub
parent b207c2c86b
commit 363d3add6d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -423,6 +423,8 @@ class Adamax(Adam):
self, learning_rate: float, betas: List[float] = [0.9, 0.999], eps: float = 1e-8
):
super().__init__(learning_rate, betas, eps)
if not 0.0 <= eps:
raise ValueError(f"Epsilon value should be >=0, {self.eps} was provided instead")
def apply_single(
self, gradient: mx.array, parameter: mx.array, state: OptimizerState