Merge branch 'ml-explore:main' into adding-Muon-optimizer

This commit is contained in:
Gökdeniz Gülmez
2025-03-21 08:50:43 +01:00
committed by GitHub
84 changed files with 901 additions and 484 deletions

View File

@@ -373,7 +373,7 @@ def smooth_l1_loss(
f"targets shape {targets.shape}."
)
diff = predictions - targets
diff = mx.abs(predictions - targets)
loss = mx.where(
diff < beta, 0.5 * mx.square(diff) / beta, mx.abs(diff) - 0.5 * beta
)

View File

@@ -50,19 +50,19 @@ class Optimizer:
dict_keys(['step', 'learning_rate', 'weight', 'bias'])
"""
# Iniatilize the optimizer state to match the parameter state
# Initialize the optimizer state to match the parameter state
def update_state(params, state):
if isinstance(params, (list, tuple)):
state = list(state)
for i in range(len(state)):
state[i] = update_state(params[i], state[i])
if len(state) != len(params):
state.extend(tree_map(lambda x: {}, params[len(state) :]))
state.extend(tree_map(lambda _: {}, params[len(state) :]))
return type(params)(state)
elif isinstance(params, dict):
for k, v in params.items():
if k not in state:
state[k] = tree_map(lambda x: {}, v)
state[k] = tree_map(lambda _: {}, v)
else:
state[k] = update_state(v, state[k])
return state
@@ -79,6 +79,7 @@ class Optimizer:
Args:
parameter (mx.array): A single parameter that will be optimized.
state (dict): The optimizer's state.
"""
raise NotImplementedError()
@@ -148,10 +149,10 @@ class Optimizer:
"""
if isinstance(param, Callable):
self._schedulers[name] = param
param = param(self.step)
parameter = param(self.step)
else:
param = mx.array(param)
self.state[name] = param
parameter = mx.array(param)
self.state[name] = parameter
class MultiOptimizer(Optimizer):

View File

@@ -80,12 +80,12 @@ def cosine_decay(init: float, decay_steps: int, end: float = 0.0) -> Callable:
array(0.0999961, dtype=float32)
"""
def scheduler(step):
def schedule(step):
s = mx.minimum(step, decay_steps)
decay = 0.5 * (1.0 + mx.cos((math.pi / decay_steps) * s))
return end + decay * (init - end)
return scheduler
return schedule
def join_schedules(schedules: List[Callable], boundaries: List[int]) -> Callable:
@@ -99,9 +99,9 @@ def join_schedules(schedules: List[Callable], boundaries: List[int]) -> Callable
that indicates when to transition between schedules.
Example:
>>> warmup = optim.linear_schedule(0, 1e-1, steps=10)
>>> linear = optim.linear_schedule(0, 1e-1, steps=10)
>>> cosine = optim.cosine_decay(1e-1, 200)
>>> lr_schedule = optim.join_schedules([warmup, cosine], [10])
>>> lr_schedule = optim.join_schedules([linear, cosine], [10])
>>> optimizer = optim.Adam(learning_rate=lr_schedule)
>>> optimizer.learning_rate
array(0.0, dtype=float32)
@@ -139,8 +139,8 @@ def linear_schedule(init: float, end: float, steps: int) -> Callable:
Example:
>>> warmup = optim.linear_schedule(0, 1e-1, 100)
>>> optimizer = optim.Adam(learning_rate=warmup)
>>> lr_schedule = optim.linear_schedule(0, 1e-1, 100)
>>> optimizer = optim.Adam(learning_rate=lr_schedule)
>>> optimizer.learning_rate
array(0.0, dtype=float32)
>>> for _ in range(101): optimizer.update({}, {})
@@ -151,8 +151,8 @@ def linear_schedule(init: float, end: float, steps: int) -> Callable:
if steps < 1:
raise ValueError(f"steps must be greater than 0, but got {steps}.")
def step_fn(step):
def schedule(step):
step = mx.minimum(step, steps)
return step * ((end - init) / steps) + init
return step_fn
return schedule

View File

@@ -134,7 +134,7 @@ void init_fast(nb::module_& parent_module) {
"memory_efficient_threshold"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
"def scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
"def scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Union[None, str, array] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
A fast implementation of multi-head attention: ``O = softmax(Q @ K.T, dim=-1) @ V``.
@@ -164,11 +164,11 @@ void init_fast(nb::module_& parent_module) {
k (array): Keys with shape ``[B, N_kv, T_kv, D]``.
v (array): Values with shape ``[B, N_kv, T_kv, D]``.
scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``)
mask (array, optional): A boolean or additive mask to apply to the
query-key scores. The mask can have at most 4 dimensions and must
be broadcast-compatible with the shape ``[B, N, T_q, T_kv]``. If an
additive mask is given its type must promote to the promoted
type of ``q``, ``k``, and ``v``.
mask (Union[None, str, array], optional): A causal, boolean or additive
mask to apply to the query-key scores. The mask can have at most 4
dimensions and must be broadcast-compatible with the shape
``[B, N, T_q, T_kv]``. If an additive mask is given its type must
promote to the promoted type of ``q``, ``k``, and ``v``.
Returns:
array: The output array.
)pbdoc");

View File

@@ -57,23 +57,19 @@ void init_metal(nb::module_& m) {
"set_memory_limit",
&mx::metal::set_memory_limit,
"limit"_a,
nb::kw_only(),
"relaxed"_a = true,
R"pbdoc(
Set the memory limit.
Memory allocations will wait on scheduled tasks to complete if the limit
is exceeded. If there are no more scheduled tasks an error will be raised
if ``relaxed`` is ``False``. Otherwise memory will be allocated
(including the potential for swap) if ``relaxed`` is ``True``.
The memory limit is a guideline for the maximum amount of memory to use
during graph evaluation. If the memory limit is exceeded and there is no
more RAM (including swap when available) allocations will result in an
exception.
The memory limit defaults to 1.5 times the maximum recommended working set
size reported by the device.
When metal is available the memory limit defaults to 1.5 times the
maximum recommended working set size reported by the device.
Args:
limit (int): Memory limit in bytes.
relaxed (bool, optional): If `False`` an error is raised if the limit
is exceeded. Default: ``True``
Returns:
int: The previous memory limit in bytes.

View File

@@ -176,7 +176,17 @@ auto py_value_and_grad(
// Call the python function
py_value_out = fun(*tree[0], **tree[1]);
tree_fill(tree, arrays);
// Replace the tracers with the originals. Don't overwrite
// locations which were written to during the call to fun
int index = 0;
tree_visit_update(tree, [&](nb::handle node) {
auto replace_arr = nb::cast<mx::array>(node);
if (replace_arr.id() == a[index].id()) {
return nb::cast(arrays[index++]);
} else {
return nb::cast(replace_arr);
}
});
// Validate the return value of the python function
if (!nb::isinstance<mx::array>(py_value_out)) {

View File

@@ -746,6 +746,7 @@ class TestAutograd(mlx_tests.MLXTestCase):
mx.checkpoint,
]:
if mx.metal.is_available():
mx.synchronize(mx.default_stream(mx.default_device()))
mem_pre = mx.metal.get_active_memory()
else:
mem_pre = 0
@@ -790,6 +791,20 @@ class TestAutograd(mlx_tests.MLXTestCase):
mx.grad(fun)(arrs)
self.assertEqual(init_id, id(arrs[0]))
def test_grad_with_inplace_update(self):
def loss_fn(model):
model[1] = mx.array(2.0)
return model[0]
model = [
mx.array(0.0),
mx.array(1.0),
]
grad_fn = mx.grad(loss_fn)
grad_fn(model)
self.assertEqual(model[1].item(), 2.0)
if __name__ == "__main__":
unittest.main()

View File

@@ -1146,6 +1146,18 @@ class TestBlas(mlx_tests.MLXTestCase):
self.assertEqual(r.shape, t.shape)
self.assertTrue(mx.allclose(r, t, atol=1e-4).item())
def test_gemv_gemm_same_precision(self):
mx.random.seed(0)
N = 256
if mx.metal.is_available():
t = mx.bfloat16
a = mx.random.normal([1, N]).astype(t)
b = mx.concatenate([a, a], axis=0).astype(t)
c = mx.random.normal([N, 64]).astype(t)
out_gemv = a @ c
out_gemm = (b @ c)[0]
self.assertTrue(mx.allclose(out_gemv, out_gemm))
if __name__ == "__main__":
unittest.main()

View File

@@ -174,6 +174,29 @@ class TestEval(mlx_tests.MLXTestCase):
post = mx.metal.get_peak_memory()
self.assertEqual(pre, post)
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
def test_multistream_deadlock(self):
s1 = mx.default_stream(mx.gpu)
s2 = mx.new_stream(mx.gpu)
x = mx.array(1.0)
x = mx.abs(x, stream=s1)
for _ in range(1000):
x = mx.abs(x, stream=s2)
mx.eval(x)
s1 = mx.default_stream(mx.gpu)
s2 = mx.new_stream(mx.gpu)
old_limit = mx.metal.set_memory_limit(1000)
x = mx.ones((512, 512), stream=s2)
for _ in range(80):
x = mx.abs(x, stream=s1)
y = mx.abs(x, stream=s2)
z = mx.abs(y, stream=s2)
mx.eval(z)
mx.metal.set_memory_limit(old_limit)
if __name__ == "__main__":
unittest.main()

View File

@@ -6,6 +6,91 @@ import mlx_tests
import numpy as np
def mlx_ref_attn(q, k, v, scale=1.0, mask=None):
q_dtype = q.dtype
q = q * mx.array(scale, q_dtype)
n_q_heads = q.shape[-3]
n_kv_heads = k.shape[-3]
n_repeats = n_q_heads // n_kv_heads
B = q.shape[0]
L = q.shape[2]
kL = k.shape[2]
if n_repeats > 1:
q = mx.reshape(q, [B, n_kv_heads, n_repeats, L, -1])
k = mx.expand_dims(k, 2)
v = mx.expand_dims(v, 2)
scores = q @ mx.swapaxes(k, -1, -2)
if mask is not None:
if mask == "causal":
q_offset = max(0, kL - L)
q_indices = mx.arange(q_offset, q_offset + L)
k_indices = mx.arange(kL)
mask = q_indices[:, None] >= k_indices[None]
if n_repeats > 1 and mask.ndim >= 3:
if mask.shape[-3] == 1:
mask = mx.expand_dims(mask, -3)
else:
mask = mx.unflatten(mask, -3, (n_kv_heads, n_repeats))
if mask.dtype == mx.bool_:
scores = mx.where(mask, scores, -np.float32(np.inf))
else:
scores += mask
scores = mx.softmax(scores, axis=-1, precise=True)
out = scores @ v
if n_repeats > 1:
out = mx.reshape(out, [B, n_q_heads, L, -1])
return out
def do_attention(f, q, k, v, scale, mask=None, transpose=False):
if transpose:
q_t = mx.transpose(q, (0, 2, 1, 3))
k_t = mx.transpose(k, (0, 2, 1, 3))
v_t = mx.transpose(v, (0, 2, 1, 3))
o_t = f(q_t, k_t, v_t, scale=scale, mask=mask)
return mx.transpose(o_t, (0, 2, 1, 3))
else:
return f(q, k, v, scale=scale, mask=mask)
def prepare_inputs(B, qL, kL, D, qH, kH, mask, transpose, dtype):
np.random.seed(0)
np_dtype = getattr(np, dtype)
shape_q = (B, qL, qH, D) if transpose else (B, qH, qL, D)
shape_kv = (B, kL, kH, D) if transpose else (B, kH, kL, D)
scale = 1.0 / math.sqrt(D)
q_np = np.random.normal(0.0, 0.5, shape_q).astype(np_dtype)
k_np = np.random.normal(0.0, 0.5, shape_kv).astype(np_dtype)
v_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype)
q_mx = mx.array(q_np)
k_mx = mx.array(k_np)
v_mx = mx.array(v_np)
if mask is not None:
if mask == "additive":
mask_np = np.random.normal(0.0, 0.5, (B, qH, qL, kL)).astype(np_dtype)
mask = mx.array(mask_np)
elif mask == "bool":
mask_np = np.random.uniform(0.0, 1.0, (B, qH, qL, kL)) < 0.5
mask = mx.array(mask_np)
return q_mx, k_mx, v_mx, scale, mask
# SDPA for MHA (n_heads == n_kv_heads)
def mlx_primitives_sdpa(q, k, v, scale, mask=None):
p = (q * scale) @ k.transpose(0, 1, 3, 2)
@@ -365,5 +450,99 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
class TestSDPA(mlx_tests.MLXTestCase):
@property
def dtypes(self):
return ["float32", "float16"] if mx.metal.is_available() else ["float32"]
def test_sdpa(self):
if not mx.metal.is_available():
return
# fmt: off
shapes_64 = (
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
( 1, 128, 128, 64, 32, 32),
( 1, 64, 128, 64, 32, 32),
( 1, 65, 128, 64, 32, 8),
( 1, 64, 127, 64, 32, 8),
( 1, 65, 127, 64, 32, 8),
( 1, 127, 65, 64, 32, 8),
)
shapes_128 = (
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
( 1, 128, 128, 128, 32, 8),
( 1, 64, 128, 128, 32, 8),
( 1, 65, 127, 128, 32, 8),
( 1, 127, 65, 128, 32, 8),
)
# fmt: on
shapes = shapes_64 + shapes_128
masks = [None, "additive", "bool", "causal"]
transposes = (False, True)
for dtype in self.dtypes:
for t in transposes:
for mask_str in masks:
for B, qL, kL, D, qH, kH in shapes:
with self.subTest(
B=B,
qsl=qL,
ksl=kL,
head_dim=D,
n_q_heads=qH,
n_kv_heads=kH,
mask=mask_str,
transpose=t,
dtype=dtype,
):
np.random.seed(0)
q_mx, k_mx, v_mx, scale, mask = prepare_inputs(
B, qL, kL, D, qH, kH, mask_str, t, dtype
)
out_ref = do_attention(
mlx_ref_attn, q_mx, k_mx, v_mx, scale, mask, t
)
out_fst = do_attention(
mx.fast.scaled_dot_product_attention,
q_mx,
k_mx,
v_mx,
scale,
mask,
t,
)
atol = 2e-5 if dtype == "float32" else 3e-4
self.assertListEqual(
list(out_ref.shape), list(out_fst.shape)
)
diff = mx.abs(out_fst - out_ref) - atol * mx.abs(out_ref)
self.assertLessEqual(mx.max(diff).item(), atol)
def test_sdpa_broadcast_mask(self):
mask = mx.array(True)
D = 64
Nq = 4
Nkv = 1
scale = 1.0
L = 256
mx.random.seed(0)
q = 5e-1 * mx.random.normal(shape=(1, Nq, L, D))
k = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))
v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))
ref = mlx_primitives_sdpa(q, k, v, scale, mask=mask)
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
if __name__ == "__main__":
unittest.main(failfast=True)