mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Merge branch 'ml-explore:main' into adding-Muon-optimizer
This commit is contained in:
@@ -81,7 +81,7 @@ class Module(dict):
|
||||
"""
|
||||
return self
|
||||
|
||||
def _extra_repr(self):
|
||||
def _extra_repr(self) -> str:
|
||||
return ""
|
||||
|
||||
def __repr__(self):
|
||||
@@ -210,7 +210,7 @@ class Module(dict):
|
||||
mx.save_safetensors(file, params_dict)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported file extension. Use '.npz' or '.safetensors'."
|
||||
f"Unsupported file extension for {file}. Use '.npz' or '.safetensors'."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -598,9 +598,7 @@ class Module(dict):
|
||||
parameters to the new dtype.
|
||||
"""
|
||||
if predicate is None:
|
||||
|
||||
def predicate(_):
|
||||
return True
|
||||
predicate = lambda _: True
|
||||
|
||||
self.apply(lambda x: x.astype(dtype) if predicate(x.dtype) else x)
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import math
|
||||
from typing import Literal, Optional
|
||||
from typing import Literal, Optional, get_args
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
@@ -9,14 +9,15 @@ Reduction = Literal["none", "mean", "sum"]
|
||||
|
||||
|
||||
def _reduce(loss: mx.array, reduction: Reduction = "none"):
|
||||
if reduction not in get_args(Reduction):
|
||||
raise ValueError(f"Invalid reduction. Must be one of {get_args(Reduction)}.")
|
||||
|
||||
if reduction == "mean":
|
||||
return mx.mean(loss)
|
||||
elif reduction == "sum":
|
||||
return mx.sum(loss)
|
||||
elif reduction == "none":
|
||||
return loss
|
||||
else:
|
||||
raise ValueError("Invalid reduction. Must be 'none', 'mean', or 'sum'.")
|
||||
|
||||
|
||||
def cross_entropy(
|
||||
|
||||
Reference in New Issue
Block a user