mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Merge branch 'ml-explore:main' into adding-Muon-optimizer
This commit is contained in:
@@ -1,7 +1,6 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
from time import time
|
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
@@ -81,7 +81,7 @@ class Module(dict):
|
|||||||
"""
|
"""
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def _extra_repr(self):
|
def _extra_repr(self) -> str:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
@@ -210,7 +210,7 @@ class Module(dict):
|
|||||||
mx.save_safetensors(file, params_dict)
|
mx.save_safetensors(file, params_dict)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Unsupported file extension. Use '.npz' or '.safetensors'."
|
f"Unsupported file extension for {file}. Use '.npz' or '.safetensors'."
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -598,9 +598,7 @@ class Module(dict):
|
|||||||
parameters to the new dtype.
|
parameters to the new dtype.
|
||||||
"""
|
"""
|
||||||
if predicate is None:
|
if predicate is None:
|
||||||
|
predicate = lambda _: True
|
||||||
def predicate(_):
|
|
||||||
return True
|
|
||||||
|
|
||||||
self.apply(lambda x: x.astype(dtype) if predicate(x.dtype) else x)
|
self.apply(lambda x: x.astype(dtype) if predicate(x.dtype) else x)
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Literal, Optional
|
from typing import Literal, Optional, get_args
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
|
||||||
@@ -9,14 +9,15 @@ Reduction = Literal["none", "mean", "sum"]
|
|||||||
|
|
||||||
|
|
||||||
def _reduce(loss: mx.array, reduction: Reduction = "none"):
|
def _reduce(loss: mx.array, reduction: Reduction = "none"):
|
||||||
|
if reduction not in get_args(Reduction):
|
||||||
|
raise ValueError(f"Invalid reduction. Must be one of {get_args(Reduction)}.")
|
||||||
|
|
||||||
if reduction == "mean":
|
if reduction == "mean":
|
||||||
return mx.mean(loss)
|
return mx.mean(loss)
|
||||||
elif reduction == "sum":
|
elif reduction == "sum":
|
||||||
return mx.sum(loss)
|
return mx.sum(loss)
|
||||||
elif reduction == "none":
|
elif reduction == "none":
|
||||||
return loss
|
return loss
|
||||||
else:
|
|
||||||
raise ValueError("Invalid reduction. Must be 'none', 'mean', or 'sum'.")
|
|
||||||
|
|
||||||
|
|
||||||
def cross_entropy(
|
def cross_entropy(
|
||||||
|
|||||||
Reference in New Issue
Block a user