mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
G.ndim >= 2 to assert G.ndim == 2
This commit is contained in:
@@ -894,7 +894,7 @@ class Muon(Optimizer):
|
|||||||
state["v"] = mx.zeros_like(parameter)
|
state["v"] = mx.zeros_like(parameter)
|
||||||
|
|
||||||
def _zeropower_via_newtonschulz5(self, G, steps: int):
|
def _zeropower_via_newtonschulz5(self, G, steps: int):
|
||||||
assert G.ndim >= 2
|
assert G.ndim == 2, f"Expected a 2D matrix for Newton-Schulz iteration, got shape {G.shape} instead."
|
||||||
a, b, c = (3.4445, -4.7750, 2.0315)
|
a, b, c = (3.4445, -4.7750, 2.0315)
|
||||||
X = G.astype(G.dtype)
|
X = G.astype(G.dtype)
|
||||||
transpose_needed = G.shape[-2] > G.shape[-1]
|
transpose_needed = G.shape[-2] > G.shape[-1]
|
||||||
|
|||||||
Reference in New Issue
Block a user