Merge branch 'ml-explore:main' into adding-Muon-optimizer

This commit is contained in:
Gökdeniz Gülmez
2025-03-10 17:10:50 +01:00
committed by GitHub
3 changed files with 7 additions and 9 deletions

View File

@@ -1,7 +1,6 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2024 Apple Inc.
import argparse import argparse
from time import time
import mlx.core as mx import mlx.core as mx
import torch import torch

View File

@@ -81,7 +81,7 @@ class Module(dict):
""" """
return self return self
def _extra_repr(self): def _extra_repr(self) -> str:
return "" return ""
def __repr__(self): def __repr__(self):
@@ -210,7 +210,7 @@ class Module(dict):
mx.save_safetensors(file, params_dict) mx.save_safetensors(file, params_dict)
else: else:
raise ValueError( raise ValueError(
"Unsupported file extension. Use '.npz' or '.safetensors'." f"Unsupported file extension for {file}. Use '.npz' or '.safetensors'."
) )
@staticmethod @staticmethod
@@ -598,9 +598,7 @@ class Module(dict):
parameters to the new dtype. parameters to the new dtype.
""" """
if predicate is None: if predicate is None:
predicate = lambda _: True
def predicate(_):
return True
self.apply(lambda x: x.astype(dtype) if predicate(x.dtype) else x) self.apply(lambda x: x.astype(dtype) if predicate(x.dtype) else x)

View File

@@ -1,7 +1,7 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023 Apple Inc.
import math import math
from typing import Literal, Optional from typing import Literal, Optional, get_args
import mlx.core as mx import mlx.core as mx
@@ -9,14 +9,15 @@ Reduction = Literal["none", "mean", "sum"]
def _reduce(loss: mx.array, reduction: Reduction = "none"): def _reduce(loss: mx.array, reduction: Reduction = "none"):
if reduction not in get_args(Reduction):
raise ValueError(f"Invalid reduction. Must be one of {get_args(Reduction)}.")
if reduction == "mean": if reduction == "mean":
return mx.mean(loss) return mx.mean(loss)
elif reduction == "sum": elif reduction == "sum":
return mx.sum(loss) return mx.sum(loss)
elif reduction == "none": elif reduction == "none":
return loss return loss
else:
raise ValueError("Invalid reduction. Must be 'none', 'mean', or 'sum'.")
def cross_entropy( def cross_entropy(