Merge branch 'ml-explore:main' into adding-Muon-optimizer

This commit is contained in:
Gökdeniz Gülmez
2025-03-05 10:05:10 +01:00
committed by GitHub
39 changed files with 989 additions and 280 deletions

View File

@@ -14,6 +14,8 @@ import time
from collections import Counter
from dataclasses import dataclass
from pathlib import Path
from queue import Empty as QueueEmpty
from queue import Queue
from select import select
from subprocess import PIPE, Popen, run
from typing import Optional
@@ -185,46 +187,54 @@ def parse_hostlist(parser, hostlist, repeats):
def make_monitor_script(rank, hostfile, cwd, env, command, verbose):
# Imports that are used throughout
script = ""
script += "import os\n"
script += "import sys\n"
script += "import tempfile\n"
script += "from pathlib import Path\n"
# Write the PID to a file so we can kill the process if needed
script += "pidfile=$(mktemp)\n"
script += "echo $$ >$pidfile\n"
script += "echo $pidfile\n"
script += "_, pidfile = tempfile.mkstemp() \n"
script += "open(pidfile, 'w').write(str(os.getpid()))\n"
script += "print(pidfile, flush=True)\n"
# Change the working directory if one was requested. Otherwise attempt to
# change to change to the current one but don't fail if it wasn't possible.
# change to the current one but don't fail if it wasn't possible.
d = cwd or os.getcwd()
script += f"if [ -d {shlex.quote(d)} ]; then\n"
script += f" cd {shlex.quote(d)}\n"
script += f"if Path({repr(d)}).exists():\n"
script += f" os.chdir({repr(d)})\n"
if cwd is not None:
script += "else\n"
script += f" echo Failed to change directory to {shlex.quote(d)} 1>&2\n"
script += f" exit 1\n"
script += "fi\n"
script += "else:\n"
script += (
f" print('Failed to change directory to', {repr(d)}, file=sys.stderr)\n"
)
script += f" sys.exit(1)\n"
# Add the environment variables that were given to us
script += "env = dict(os.environ)\n"
for e in env:
key, *value = e.split("=", maxsplit=1)
value = shlex.quote(value[0]) if len(value) > 0 else ""
if not all(c.isalnum() or c == "_" for c in key):
log_warning(f"'{e}' is an invalid environment variable so it is ignored")
continue
script += f"export {key}={value}\n"
script += f"env[{repr(key)}] = {repr(value)}\n"
# Add the environment variables to enable the ring distributed backend
if hostfile != "":
script += "tmpfile=$(mktemp)\n"
script += f"echo {shlex.quote(hostfile)} >$tmpfile\n"
script += "_, hostfile = tempfile.mkstemp()\n"
script += "with open(hostfile, 'w') as f:\n"
script += f" f.write({repr(hostfile)})\n"
if verbose:
script += "export MLX_RING_VERBOSE=1\n"
script += "export MLX_HOSTFILE=$tmpfile\n"
script += f"export MLX_RANK={rank}\n"
script += "env['MLX_RING_VERBOSE'] = '1'\n"
script += "env['MLX_HOSTFILE'] = hostfile\n"
script += f"env['MLX_RANK'] = '{rank}'\n"
script += "\n"
# Replace the process with the script
script += shlex.join(["exec", *command])
script += "\n"
script += f"command = [{','.join(map(repr, command))}]\n"
script += "os.execve(command[0], command, env)\n"
return script
@@ -233,28 +243,37 @@ def launch_ring(parser, hosts, args, command):
stop = False
exit_codes = [None] * len(hosts)
def node_thread(rank, host, hostfile):
def node_thread(rank, host, hostfile, input_queue):
is_local = host == "127.0.0.1"
script = make_monitor_script(
rank, hostfile, args.cwd, args.env, command, args.verbose
)
script_b64 = base64.b64encode(script.encode()).decode()
cmd = f'echo "{script_b64}" | base64 -d | /bin/bash'
cmd = f'{sys.executable} -c "import base64; exec(base64.b64decode(\\"{script_b64}\\"));"'
if not is_local:
cmd = f"ssh {host} '{cmd}'"
p = Popen(
cmd,
shell=True,
stdin=PIPE,
stdout=PIPE,
stderr=PIPE,
)
os.set_blocking(p.stdout.fileno(), False)
os.set_blocking(p.stderr.fileno(), False)
os.set_blocking(p.stdin.fileno(), False)
# Repeat the stdout and stderr to the local machine
to_read = [p.stdout.fileno(), p.stderr.fileno()]
to_write = [p.stdin.fileno()]
pidfile = ""
stdin_buffer = b""
while p.poll() is None:
rlist, _, _ = select([p.stdout.fileno(), p.stderr.fileno()], [], [], 1.0)
try:
stdin_buffer += input_queue.get_nowait()
except QueueEmpty:
pass
rlist, wlist, _ = select(to_read, to_write, [], 1.0)
for fd in rlist:
is_stdout = fd == p.stdout.fileno()
outfile = sys.stdout if is_stdout else sys.stderr
@@ -266,6 +285,11 @@ def launch_ring(parser, hosts, args, command):
msg = msg[0] if msg else ""
outfile.write(msg)
outfile.flush()
for fd in wlist:
if len(stdin_buffer) > 0:
n = os.write(fd, stdin_buffer)
stdin_buffer = stdin_buffer[n:]
if stop:
p.terminate()
break
@@ -310,16 +334,25 @@ def launch_ring(parser, hosts, args, command):
log(args.verbose, "Running", shlex.join(command))
input_queues = []
threads = []
for i, h in enumerate(hosts):
if i + 1 == len(hosts):
time.sleep(1.0)
t = threading.Thread(target=node_thread, args=(i, h.ssh_hostname, hostfile))
input_queues.append(Queue())
t = threading.Thread(
target=node_thread, args=(i, h.ssh_hostname, hostfile, input_queues[-1])
)
t.start()
threads.append(t)
os.set_blocking(sys.stdin.fileno(), False)
while not stop:
time.sleep(1.0)
rlist, _, _ = select([sys.stdin.fileno()], [], [], 1.0)
for fd in rlist:
stdin_buffer = os.read(fd, 8192)
for q in input_queues:
q.put(stdin_buffer)
if any(t.is_alive() for t in threads):
for i, t in enumerate(threads):
if not t.is_alive():
@@ -730,6 +763,8 @@ def main():
if len(rest) == 0:
parser.error("No script is provided")
if rest[0] == "--":
rest.pop(0)
# Try to extract a list of hosts and corresponding ips
if args.hostfile is not None:

View File

@@ -5,7 +5,7 @@ from typing import Callable, List, Optional, Tuple, Union
import mlx.core as mx
from mlx.nn import Module
from mlx.utils import tree_map, tree_reduce
from mlx.utils import tree_flatten, tree_map, tree_merge, tree_reduce, tree_unflatten
class Optimizer:
@@ -154,6 +154,79 @@ class Optimizer:
self.state[name] = param
class MultiOptimizer(Optimizer):
"""Wraps a list of optimizers with corresponding weight predicates/filters
to make it easy to use different optimizers for different weights.
The predicates take the full "path" of the weight and the weight itself and
return True if it should be considered for this optimizer. The last
optimizer in the list is a fallback optimizer and no predicate should be
given for it.
Args:
optimizers (list[Optimizer]): A list of optimizers to delegate to
filters (list[Callable[[str, array], bool]): A list of predicates that
should be one less than the provided optimizers.
"""
def __init__(self, optimizers, filters: list = []):
super().__init__()
self._state = {}
if len(filters) != len(optimizers) - 1:
raise ValueError(
f"Given {len(filters)} filters but {len(optimizers)-1} needed."
)
self.optimizers = optimizers
self.filters = filters + [lambda *args, **kwargs: True]
def _split_dictionary(self, gradients: dict):
if len(self.optimizers) == 1:
return [gradients]
parts = [[] for _ in range(len(self.optimizers))]
flat_gradients = tree_flatten(gradients)
for k, g in flat_gradients:
for i, fn in enumerate(self.filters):
if fn(k, g):
parts[i].append((k, g))
break
return [tree_unflatten(p) for p in parts]
def init(self, parameters: dict):
for o, p in zip(self.optimizers, self._split_dictionary(parameters)):
o.init(p)
def apply_gradients(self, gradients: dict, parameters: dict):
tree = {}
for o, g in zip(self.optimizers, self._split_dictionary(gradients)):
tree = tree_merge(tree, o.apply_gradients(g, parameters))
return tree
@property
def state(self):
return {"states": [o.state for o in self.optimizers]}
@state.setter
def state(self, state: dict):
if "states" not in state or len(state["states"]) != len(self.optimizers):
raise ValueError("Invalid state provided")
for o, s in zip(self.optimizers, state["states"]):
o.state = s
@property
def learning_rate(self):
return self.optimizers[0].learning_rate
@learning_rate.setter
def learning_rate(self, learning_rate: Union[float, mx.array]):
for o in self.optimizers:
o.learning_rate = learning_rate
class SGD(Optimizer):
r"""The stochastic gradient descent optimizer.

View File

@@ -1,5 +1,6 @@
# Copyright © 2023 Apple Inc.
from collections import defaultdict
from itertools import zip_longest
from typing import Any, Callable, List, Optional, Tuple
@@ -244,3 +245,46 @@ def tree_reduce(fn, tree, initializer=None, is_leaf=None):
return tree if accumulator is None else fn(accumulator, tree)
return accumulator
def tree_merge(tree_a, tree_b, merge_fn=None):
"""Merge two Python trees in one containing the values of both. It can be
thought of as a deep dict.update method.
Args:
tree_a (Any): The first Python tree.
tree_b (Any): The second Python tree.
merge_fn (callable, optional): A function to merge leaves.
Returns:
The Python tree containing the values of both ``tree_a`` and
``tree_b``.
"""
if isinstance(tree_a, (dict, list, tuple)) and len(tree_a) == 0:
tree_a = None
if isinstance(tree_b, (dict, list, tuple)) and len(tree_b) == 0:
tree_b = None
if tree_a is None and tree_b is not None:
return tree_b
if tree_a is not None and tree_b is None:
return tree_a
if isinstance(tree_a, (list, tuple)) and isinstance(tree_b, (list, tuple)):
TreeType = type(tree_a)
return TreeType(
tree_merge(a, b, merge_fn) for a, b in zip_longest(tree_a, tree_b)
)
elif isinstance(tree_a, dict) and isinstance(tree_b, dict):
return {
k: tree_merge(tree_a.get(k, None), tree_b.get(k, None), merge_fn)
for k in set(tree_a.keys()) | set(tree_b.keys())
}
else:
if merge_fn is None:
raise ValueError(
(
"Trees contain elements at the same locations but no merge "
"function was provided"
)
)
return merge_fn(tree_a, tree_b)

View File

@@ -25,12 +25,12 @@ void init_fast(nb::module_& parent_module) {
"rms_norm",
&mx::fast::rms_norm,
"x"_a,
"weight"_a,
"weight"_a.none(),
"eps"_a,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def rms_norm(x: array, weight: array, eps: float, *, stream: Union[None, Stream, Device] = None) -> array"),
"def rms_norm(x: array, weight: Optional[array], eps: float, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Root Mean Square normalization (RMS norm).
@@ -38,9 +38,9 @@ void init_fast(nb::module_& parent_module) {
Args:
x (array): Input array.
weight (array): A multiplicative weight to scale the result by.
weight (array, optional): A multiplicative weight to scale the result by.
The ``weight`` should be one-dimensional with the same size
as the last axis of ``x``.
as the last axis of ``x``. If set to ``None`` then no scaling happens.
eps (float): A small additive constant for numerical stability.
Returns:

View File

@@ -92,6 +92,7 @@ void init_linalg(nb::module_& parent_module) {
===== ============================ ==========================
None Frobenius norm 2-norm
'fro' Frobenius norm --
'nuc' nuclear norm --
inf max(sum(abs(x), axis=1)) max(abs(x))
-inf min(sum(abs(x), axis=1)) min(abs(x))
0 -- sum(x != 0)
@@ -102,9 +103,6 @@ void init_linalg(nb::module_& parent_module) {
other -- sum(abs(x)**ord)**(1./ord)
===== ============================ ==========================
.. warning::
Nuclear norm and norms based on singular values are not yet implemented.
The Frobenius norm is given by [1]_:
:math:`||A||_F = [\sum_{i,j} abs(a_{i,j})^2]^{1/2}`
@@ -206,15 +204,22 @@ void init_linalg(nb::module_& parent_module) {
)pbdoc");
m.def(
"svd",
[](const mx::array& a, mx::StreamOrDevice s /* = {} */) {
const auto result = mx::linalg::svd(a, s);
return nb::make_tuple(result.at(0), result.at(1), result.at(2));
[](const mx::array& a,
bool compute_uv /* = true */,
mx::StreamOrDevice s /* = {} */) -> nb::object {
const auto result = mx::linalg::svd(a, compute_uv, s);
if (result.size() == 1) {
return nb::cast(result.at(0));
} else {
return nb::make_tuple(result.at(0), result.at(1), result.at(2));
}
},
"a"_a,
"compute_uv"_a = true,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def svd(a: array, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array, array]"),
"def svd(a: array, compute_uv: bool = True, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array, array]"),
R"pbdoc(
The Singular Value Decomposition (SVD) of the input matrix.
@@ -224,12 +229,15 @@ void init_linalg(nb::module_& parent_module) {
Args:
a (array): Input array.
compute_uv (bool, optional): If ``True``, return the ``U``, ``S``, and ``Vt`` components.
If ``False``, return only the ``S`` array. Default: ``True``.
stream (Stream, optional): Stream or device. Defaults to ``None``
in which case the default stream of the default device is used.
Returns:
tuple(array, array, array): The ``U``, ``S``, and ``Vt`` matrices, such that
``A = U @ diag(S) @ Vt``
Union[tuple(array, ...), array]:
If compute_uv is ``True`` returns the ``U``, ``S``, and ``Vt`` matrices, such that
``A = U @ diag(S) @ Vt``. If compute_uv is ``False`` returns singular values array ``S``.
)pbdoc");
m.def(
"inv",

View File

@@ -298,6 +298,9 @@ class TestFast(mlx_tests.MLXTestCase):
rx = rms_norm(x, weight, eps)
rx_fast = mx.fast.rms_norm(x, weight, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
rx = rms_norm(x, mx.ones_like(weight), eps)
rx_fast = mx.fast.rms_norm(x, None, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
for eps in epss:
dtype, _, dims = defaults
@@ -306,6 +309,9 @@ class TestFast(mlx_tests.MLXTestCase):
rx = rms_norm(x, weight, eps)
rx_fast = mx.fast.rms_norm(x, weight, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
rx = rms_norm(x, mx.ones_like(weight), eps)
rx_fast = mx.fast.rms_norm(x, None, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
for dims in dimss:
dtype, eps, _ = defaults
@@ -314,6 +320,9 @@ class TestFast(mlx_tests.MLXTestCase):
rx = rms_norm(x, weight, eps)
rx_fast = mx.fast.rms_norm(x, weight, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
rx = rms_norm(x, mx.ones_like(weight), eps)
rx_fast = mx.fast.rms_norm(x, None, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
# Test > 4096
dims, dtype, eps = 4099, mx.float32, 1e-5
@@ -333,6 +342,8 @@ class TestFast(mlx_tests.MLXTestCase):
eps = 1e-5
f1 = lambda x, w, y: (rms_norm(x, w, eps) * y).sum()
f2 = lambda x, w, y: (mx.fast.rms_norm(x, w, eps) * y).sum()
f3 = lambda x, y: (rms_norm(x, mx.ones((x.shape[-1],)), eps) * y).sum()
f4 = lambda x, y: (mx.fast.rms_norm(x, None, eps) * y).sum()
x = mx.random.uniform(shape=(8, 100, D))
w = mx.random.uniform(shape=(D,))
@@ -341,6 +352,9 @@ class TestFast(mlx_tests.MLXTestCase):
gx2, gw2 = mx.grad(f2, argnums=(0, 1))(x, w, y)
self.assertLess(mx.abs(gx1 - gx2).max(), 1e-5)
self.assertLess(mx.abs(gw1 - gw2).max() / mx.abs(gw1).mean(), 1e-5)
gx1 = mx.grad(f3, argnums=(0,))(x, y)
gx2 = mx.grad(f4, argnums=(0,))(x, y)
self.assertLess(mx.abs(gx1 - gx2).max(), 1e-5)
D = 8192
x = mx.random.uniform(shape=(2, 2, D))
@@ -350,6 +364,9 @@ class TestFast(mlx_tests.MLXTestCase):
gx2, gw2 = mx.grad(f2, argnums=(0, 1))(x, w, y)
self.assertLess(mx.abs(gx1 - gx2).max(), 1e-5)
self.assertLess(mx.abs(gw1 - gw2).max() / mx.abs(gw1).mean(), 1e-5)
gx1 = mx.grad(f3, argnums=(0,))(x, y)
gx2 = mx.grad(f4, argnums=(0,))(x, y)
self.assertLess(mx.abs(gx1 - gx2).max(), 1e-5)
def gf(f):
def inner(x, w, y):

View File

@@ -262,6 +262,61 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
)
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
def test_fast_sdpa_few_query(self):
D = 64
L = 43
Lq = 4
Nq = 8
Nkv = 1
scale = 1.0
mx.random.seed(0)
q = 5e-1 * mx.random.normal(shape=(1, Lq, Nq, D))
q = q.swapaxes(1, 2)
k = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))
v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))
masks = [
mx.array(True),
mx.array([True] * (L - 10) + [False] * 10),
mx.random.uniform(shape=(Nq, 1, L)) > 0.2,
mx.random.uniform(shape=(L, 1, Nq)).T > 0.2,
]
for m in masks:
ref = mlx_primitives_sdpa(q, k, v, scale, mask=m)
out = mx.fast.scaled_dot_product_attention(
q,
k,
v,
scale=scale,
mask=m,
)
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
return
L = 4096
scale = 1.0
mx.random.seed(0)
q = 5e-1 * mx.random.normal(shape=(1, Nq, Lq, D))
k = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))
v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))
masks = [
mx.array(True),
mx.array([True] * (L - 10) + [False] * 10),
mx.random.uniform(shape=(Nq, 1, L)) > 0.2,
mx.random.uniform(shape=(L, 1, Nq)).T > 0.2,
]
for m in masks:
ref = mlx_primitives_sdpa(q, k, v, scale, mask=m)
out = mx.fast.scaled_dot_product_attention(
q,
k,
v,
scale=scale,
mask=m,
)
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
@unittest.skip("Different head and value dims is not enabled")
def test_fast_sdpa_vector_value_dims(self):
D = 192

View File

@@ -12,11 +12,11 @@ import numpy as np
class TestLinalg(mlx_tests.MLXTestCase):
def test_norm(self):
vector_ords = [None, 0.5, 0, 1, 2, 3, -1, float("inf"), -float("inf")]
matrix_ords = [None, "fro", -1, 1, float("inf"), -float("inf")]
matrix_ords = [None, "fro", "nuc", -1, 1, -2, 2, float("inf"), -float("inf")]
for shape in [(3,), (2, 3), (2, 3, 3)]:
x_mx = mx.arange(1, math.prod(shape) + 1).reshape(shape)
x_np = np.arange(1, math.prod(shape) + 1).reshape(shape)
x_mx = mx.arange(1, math.prod(shape) + 1, dtype=mx.float32).reshape(shape)
x_np = np.arange(1, math.prod(shape) + 1, dtype=np.float32).reshape(shape)
# Test when at least one axis is provided
for num_axes in range(1, len(shape)):
if num_axes == 1:
@@ -26,11 +26,14 @@ class TestLinalg(mlx_tests.MLXTestCase):
for axis in itertools.combinations(range(len(shape)), num_axes):
for keepdims in [True, False]:
for o in ords:
stream = (
mx.cpu if o in ["nuc", -2, 2] else mx.default_device()
)
out_np = np.linalg.norm(
x_np, ord=o, axis=axis, keepdims=keepdims
)
out_mx = mx.linalg.norm(
x_mx, ord=o, axis=axis, keepdims=keepdims
x_mx, ord=o, axis=axis, keepdims=keepdims, stream=stream
)
with self.subTest(
shape=shape, ord=o, axis=axis, keepdims=keepdims
@@ -133,20 +136,38 @@ class TestLinalg(mlx_tests.MLXTestCase):
def test_svd_decomposition(self):
A = mx.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], dtype=mx.float32)
U, S, Vt = mx.linalg.svd(A, stream=mx.cpu)
U, S, Vt = mx.linalg.svd(A, compute_uv=True, stream=mx.cpu)
self.assertTrue(
mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, A, rtol=1e-5, atol=1e-7)
)
S = mx.linalg.svd(A, compute_uv=False, stream=mx.cpu)
self.assertTrue(
mx.allclose(
mx.linalg.norm(S), mx.linalg.norm(A, ord="fro"), rtol=1e-5, atol=1e-7
)
)
# Multiple matrices
B = A + 10.0
AB = mx.stack([A, B])
Us, Ss, Vts = mx.linalg.svd(AB, stream=mx.cpu)
Us, Ss, Vts = mx.linalg.svd(AB, compute_uv=True, stream=mx.cpu)
for M, U, S, Vt in zip([A, B], Us, Ss, Vts):
self.assertTrue(
mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, M, rtol=1e-5, atol=1e-7)
)
Ss = mx.linalg.svd(AB, compute_uv=False, stream=mx.cpu)
for M, S in zip([A, B], Ss):
self.assertTrue(
mx.allclose(
mx.linalg.norm(S),
mx.linalg.norm(M, ord="fro"),
rtol=1e-5,
atol=1e-7,
)
)
def test_inverse(self):
A = mx.array([[1, 2, 3], [6, -5, 4], [-9, 8, 7]], dtype=mx.float32)
A_inv = mx.linalg.inv(A, stream=mx.cpu)

View File

@@ -1894,6 +1894,22 @@ class TestOps(mlx_tests.MLXTestCase):
expected = mx.repeat(expected[:, None], 2, axis=1)
self.assertTrue(mx.array_equal(expected, out))
# Test donation
def fn(its):
x = mx.ones((32,))
for _ in range(its):
x = mx.cumsum(x)
return x
mx.synchronize(mx.default_stream(mx.default_device()))
mx.eval(fn(2))
mx.synchronize(mx.default_stream(mx.default_device()))
mem2 = mx.metal.get_peak_memory()
mx.eval(fn(4))
mx.synchronize(mx.default_stream(mx.default_device()))
mem4 = mx.metal.get_peak_memory()
self.assertEqual(mem2, mem4)
def test_squeeze_expand(self):
a = mx.zeros((2, 1, 2, 1))
self.assertEqual(mx.squeeze(a).shape, (2, 2))
@@ -2846,6 +2862,11 @@ class TestOps(mlx_tests.MLXTestCase):
b[::2] = 0
self.assertTrue(mx.array_equal(b, mx.array([0, 3, 0, 1])))
def test_slice_with_negative_stride(self):
a = mx.random.uniform(shape=(128, 4))
out = a[::-1]
self.assertTrue(mx.array_equal(out[-1, :], a[0, :]))
if __name__ == "__main__":
unittest.main()

View File

@@ -39,6 +39,7 @@ def tree_equal(fn, *args):
optimizers_dict = get_all_optimizers()
del optimizers_dict["MultiOptimizer"]
class TestOptimizers(mlx_tests.MLXTestCase):
@@ -500,6 +501,30 @@ class TestSchedulers(unittest.TestCase):
grads = model.trainable_parameters()
optimizer.update(model, grads)
def test_multi_optimizer(self):
class Model(nn.Module):
def __init__(self):
super().__init__()
self.l1 = nn.Linear(2, 2)
self.drop = nn.Dropout(p=0.5)
self.l2 = nn.Linear(2, 2)
self.vals = [nn.Linear(2, 2), nn.ReLU(), nn.ReLU()]
model = Model()
optimizer = opt.MultiOptimizer(
[opt.Adam(learning_rate=0.001), opt.SGD(learning_rate=0.1)],
[lambda name, weight: weight.ndim > 1],
)
optimizer.init(model.trainable_parameters())
self.assertEqual(len(optimizer.state["states"]), 2)
adam_states = tree_flatten(optimizer.state["states"][0])
sgd_states = tree_flatten(optimizer.state["states"][1])
self.assertEqual((len(sgd_states) - 2) * 2, len(adam_states) - 2)
self.assertFalse(any("bias" in k for k, v in adam_states))
self.assertFalse(any("weight" in k for k, v in sgd_states))
if __name__ == "__main__":
unittest.main()

View File

@@ -3,6 +3,7 @@
import unittest
import mlx.core as mx
import mlx.nn as nn
import mlx.utils
import mlx_tests
@@ -22,6 +23,29 @@ class TestTreeUtils(mlx_tests.MLXTestCase):
self.assertEqual(list(zip(*flat_tree))[1], vals)
self.assertEqual(mlx.utils.tree_unflatten(flat_tree), tree)
def test_merge(self):
t1 = {"a": 0}
t2 = {"b": 1}
t = mlx.utils.tree_merge(t1, t2)
self.assertEqual({"a": 0, "b": 1}, t)
with self.assertRaises(ValueError):
mlx.utils.tree_merge(t1, t1)
with self.assertRaises(ValueError):
mlx.utils.tree_merge(t, t1)
mod1 = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
mod2 = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
mod = nn.Sequential(mod1, mod2)
params1 = {"layers": [mod1.parameters()]}
params2 = {"layers": [None, mod2.parameters()]}
params = mlx.utils.tree_merge(params1, params2)
for (k1, v1), (k2, v2) in zip(
mlx.utils.tree_flatten(params), mlx.utils.tree_flatten(mod.parameters())
):
self.assertEqual(k1, k2)
self.assertTrue(mx.array_equal(v1, v2))
if __name__ == "__main__":
unittest.main()

View File

@@ -316,35 +316,59 @@ class TestVmap(mlx_tests.MLXTestCase):
def test_vmap_svd(self):
a = mx.random.uniform(shape=(3, 4, 2))
cpu_svd = lambda x: mx.linalg.svd(x, stream=mx.cpu)
cpu_svd_full = lambda x: mx.linalg.svd(x, compute_uv=True, stream=mx.cpu)
cpu_svd_singular = lambda x: mx.linalg.svd(x, compute_uv=False, stream=mx.cpu)
# Vmap over the first axis (this is already supported natively by the primitive).
Us, Ss, Vts = mx.vmap(cpu_svd, in_axes=(0,))(a)
Us, Ss, Vts = mx.vmap(cpu_svd_full, in_axes=(0,))(a)
self.assertEqual(Us.shape, (a.shape[0], a.shape[1], a.shape[1]))
self.assertEqual(Ss.shape, (a.shape[0], a.shape[2]))
self.assertEqual(Vts.shape, (a.shape[0], a.shape[2], a.shape[2]))
Sv = mx.vmap(cpu_svd_singular, in_axes=(0,))(a)
self.assertEqual(Sv.shape, (a.shape[0], a.shape[2]))
for i in range(a.shape[0]):
M = a[i]
U, S, Vt = Us[i], Ss[i], Vts[i]
self.assertTrue(
mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, M, rtol=1e-5, atol=1e-7)
)
self.assertTrue(
mx.allclose(
mx.linalg.norm(Sv[i]),
mx.linalg.norm(M, ord="fro"),
rtol=1e-5,
atol=1e-7,
)
)
# Vmap over the second axis.
Us, Ss, Vts = mx.vmap(cpu_svd, in_axes=(1,))(a)
Us, Ss, Vts = mx.vmap(cpu_svd_full, in_axes=(1,))(a)
self.assertEqual(Us.shape, (a.shape[1], a.shape[0], a.shape[0]))
self.assertEqual(Ss.shape, (a.shape[1], a.shape[2]))
self.assertEqual(Vts.shape, (a.shape[1], a.shape[2], a.shape[2]))
Sv = mx.vmap(cpu_svd_singular, in_axes=(1,))(a)
self.assertEqual(Sv.shape, (a.shape[1], a.shape[2]))
for i in range(a.shape[1]):
M = a[:, i, :]
U, S, Vt = Us[i], Ss[i], Vts[i]
self.assertTrue(
mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, M, rtol=1e-5, atol=1e-7)
)
self.assertTrue(
mx.allclose(
mx.linalg.norm(Sv[i]),
mx.linalg.norm(M, ord="fro"),
rtol=1e-5,
atol=1e-7,
)
)
def test_vmap_inverse(self):
mx.random.seed(42)
a = mx.random.uniform(shape=(3, 4, 4))
cpu_inv = lambda x: mx.linalg.inv(x, stream=mx.cpu)