Merge branch 'ml-explore:main' into adding-Muon-optimizer

This commit is contained in:
Gökdeniz Gülmez
2025-03-10 17:10:50 +01:00
committed by GitHub
3 changed files with 7 additions and 9 deletions

View File

@@ -81,7 +81,7 @@ class Module(dict):
"""
return self
def _extra_repr(self):
def _extra_repr(self) -> str:
return ""
def __repr__(self):
@@ -210,7 +210,7 @@ class Module(dict):
mx.save_safetensors(file, params_dict)
else:
raise ValueError(
"Unsupported file extension. Use '.npz' or '.safetensors'."
f"Unsupported file extension for {file}. Use '.npz' or '.safetensors'."
)
@staticmethod
@@ -598,9 +598,7 @@ class Module(dict):
parameters to the new dtype.
"""
if predicate is None:
def predicate(_):
return True
predicate = lambda _: True
self.apply(lambda x: x.astype(dtype) if predicate(x.dtype) else x)

View File

@@ -1,7 +1,7 @@
# Copyright © 2023 Apple Inc.
import math
from typing import Literal, Optional
from typing import Literal, Optional, get_args
import mlx.core as mx
@@ -9,14 +9,15 @@ Reduction = Literal["none", "mean", "sum"]
def _reduce(loss: mx.array, reduction: Reduction = "none"):
if reduction not in get_args(Reduction):
raise ValueError(f"Invalid reduction. Must be one of {get_args(Reduction)}.")
if reduction == "mean":
return mx.mean(loss)
elif reduction == "sum":
return mx.sum(loss)
elif reduction == "none":
return loss
else:
raise ValueError("Invalid reduction. Must be 'none', 'mean', or 'sum'.")
def cross_entropy(