mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Merge branch 'ml-explore:main' into adding-Muon-optimizer
This commit is contained in:
@@ -14,6 +14,8 @@ import time
|
||||
from collections import Counter
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from queue import Empty as QueueEmpty
|
||||
from queue import Queue
|
||||
from select import select
|
||||
from subprocess import PIPE, Popen, run
|
||||
from typing import Optional
|
||||
@@ -185,46 +187,54 @@ def parse_hostlist(parser, hostlist, repeats):
|
||||
|
||||
|
||||
def make_monitor_script(rank, hostfile, cwd, env, command, verbose):
|
||||
# Imports that are used throughout
|
||||
script = ""
|
||||
script += "import os\n"
|
||||
script += "import sys\n"
|
||||
script += "import tempfile\n"
|
||||
script += "from pathlib import Path\n"
|
||||
|
||||
# Write the PID to a file so we can kill the process if needed
|
||||
script += "pidfile=$(mktemp)\n"
|
||||
script += "echo $$ >$pidfile\n"
|
||||
script += "echo $pidfile\n"
|
||||
script += "_, pidfile = tempfile.mkstemp() \n"
|
||||
script += "open(pidfile, 'w').write(str(os.getpid()))\n"
|
||||
script += "print(pidfile, flush=True)\n"
|
||||
|
||||
# Change the working directory if one was requested. Otherwise attempt to
|
||||
# change to change to the current one but don't fail if it wasn't possible.
|
||||
# change to the current one but don't fail if it wasn't possible.
|
||||
d = cwd or os.getcwd()
|
||||
script += f"if [ -d {shlex.quote(d)} ]; then\n"
|
||||
script += f" cd {shlex.quote(d)}\n"
|
||||
script += f"if Path({repr(d)}).exists():\n"
|
||||
script += f" os.chdir({repr(d)})\n"
|
||||
if cwd is not None:
|
||||
script += "else\n"
|
||||
script += f" echo Failed to change directory to {shlex.quote(d)} 1>&2\n"
|
||||
script += f" exit 1\n"
|
||||
script += "fi\n"
|
||||
script += "else:\n"
|
||||
script += (
|
||||
f" print('Failed to change directory to', {repr(d)}, file=sys.stderr)\n"
|
||||
)
|
||||
script += f" sys.exit(1)\n"
|
||||
|
||||
# Add the environment variables that were given to us
|
||||
script += "env = dict(os.environ)\n"
|
||||
for e in env:
|
||||
key, *value = e.split("=", maxsplit=1)
|
||||
value = shlex.quote(value[0]) if len(value) > 0 else ""
|
||||
if not all(c.isalnum() or c == "_" for c in key):
|
||||
log_warning(f"'{e}' is an invalid environment variable so it is ignored")
|
||||
continue
|
||||
script += f"export {key}={value}\n"
|
||||
script += f"env[{repr(key)}] = {repr(value)}\n"
|
||||
|
||||
# Add the environment variables to enable the ring distributed backend
|
||||
if hostfile != "":
|
||||
script += "tmpfile=$(mktemp)\n"
|
||||
script += f"echo {shlex.quote(hostfile)} >$tmpfile\n"
|
||||
script += "_, hostfile = tempfile.mkstemp()\n"
|
||||
script += "with open(hostfile, 'w') as f:\n"
|
||||
script += f" f.write({repr(hostfile)})\n"
|
||||
if verbose:
|
||||
script += "export MLX_RING_VERBOSE=1\n"
|
||||
script += "export MLX_HOSTFILE=$tmpfile\n"
|
||||
script += f"export MLX_RANK={rank}\n"
|
||||
script += "env['MLX_RING_VERBOSE'] = '1'\n"
|
||||
script += "env['MLX_HOSTFILE'] = hostfile\n"
|
||||
script += f"env['MLX_RANK'] = '{rank}'\n"
|
||||
script += "\n"
|
||||
|
||||
# Replace the process with the script
|
||||
script += shlex.join(["exec", *command])
|
||||
script += "\n"
|
||||
script += f"command = [{','.join(map(repr, command))}]\n"
|
||||
script += "os.execve(command[0], command, env)\n"
|
||||
|
||||
return script
|
||||
|
||||
@@ -233,28 +243,37 @@ def launch_ring(parser, hosts, args, command):
|
||||
stop = False
|
||||
exit_codes = [None] * len(hosts)
|
||||
|
||||
def node_thread(rank, host, hostfile):
|
||||
def node_thread(rank, host, hostfile, input_queue):
|
||||
is_local = host == "127.0.0.1"
|
||||
script = make_monitor_script(
|
||||
rank, hostfile, args.cwd, args.env, command, args.verbose
|
||||
)
|
||||
script_b64 = base64.b64encode(script.encode()).decode()
|
||||
cmd = f'echo "{script_b64}" | base64 -d | /bin/bash'
|
||||
cmd = f'{sys.executable} -c "import base64; exec(base64.b64decode(\\"{script_b64}\\"));"'
|
||||
if not is_local:
|
||||
cmd = f"ssh {host} '{cmd}'"
|
||||
p = Popen(
|
||||
cmd,
|
||||
shell=True,
|
||||
stdin=PIPE,
|
||||
stdout=PIPE,
|
||||
stderr=PIPE,
|
||||
)
|
||||
os.set_blocking(p.stdout.fileno(), False)
|
||||
os.set_blocking(p.stderr.fileno(), False)
|
||||
os.set_blocking(p.stdin.fileno(), False)
|
||||
|
||||
# Repeat the stdout and stderr to the local machine
|
||||
to_read = [p.stdout.fileno(), p.stderr.fileno()]
|
||||
to_write = [p.stdin.fileno()]
|
||||
pidfile = ""
|
||||
stdin_buffer = b""
|
||||
while p.poll() is None:
|
||||
rlist, _, _ = select([p.stdout.fileno(), p.stderr.fileno()], [], [], 1.0)
|
||||
try:
|
||||
stdin_buffer += input_queue.get_nowait()
|
||||
except QueueEmpty:
|
||||
pass
|
||||
rlist, wlist, _ = select(to_read, to_write, [], 1.0)
|
||||
for fd in rlist:
|
||||
is_stdout = fd == p.stdout.fileno()
|
||||
outfile = sys.stdout if is_stdout else sys.stderr
|
||||
@@ -266,6 +285,11 @@ def launch_ring(parser, hosts, args, command):
|
||||
msg = msg[0] if msg else ""
|
||||
|
||||
outfile.write(msg)
|
||||
outfile.flush()
|
||||
for fd in wlist:
|
||||
if len(stdin_buffer) > 0:
|
||||
n = os.write(fd, stdin_buffer)
|
||||
stdin_buffer = stdin_buffer[n:]
|
||||
if stop:
|
||||
p.terminate()
|
||||
break
|
||||
@@ -310,16 +334,25 @@ def launch_ring(parser, hosts, args, command):
|
||||
|
||||
log(args.verbose, "Running", shlex.join(command))
|
||||
|
||||
input_queues = []
|
||||
threads = []
|
||||
for i, h in enumerate(hosts):
|
||||
if i + 1 == len(hosts):
|
||||
time.sleep(1.0)
|
||||
t = threading.Thread(target=node_thread, args=(i, h.ssh_hostname, hostfile))
|
||||
input_queues.append(Queue())
|
||||
t = threading.Thread(
|
||||
target=node_thread, args=(i, h.ssh_hostname, hostfile, input_queues[-1])
|
||||
)
|
||||
t.start()
|
||||
threads.append(t)
|
||||
|
||||
os.set_blocking(sys.stdin.fileno(), False)
|
||||
while not stop:
|
||||
time.sleep(1.0)
|
||||
rlist, _, _ = select([sys.stdin.fileno()], [], [], 1.0)
|
||||
for fd in rlist:
|
||||
stdin_buffer = os.read(fd, 8192)
|
||||
for q in input_queues:
|
||||
q.put(stdin_buffer)
|
||||
if any(t.is_alive() for t in threads):
|
||||
for i, t in enumerate(threads):
|
||||
if not t.is_alive():
|
||||
@@ -730,6 +763,8 @@ def main():
|
||||
|
||||
if len(rest) == 0:
|
||||
parser.error("No script is provided")
|
||||
if rest[0] == "--":
|
||||
rest.pop(0)
|
||||
|
||||
# Try to extract a list of hosts and corresponding ips
|
||||
if args.hostfile is not None:
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.nn import Module
|
||||
from mlx.utils import tree_map, tree_reduce
|
||||
from mlx.utils import tree_flatten, tree_map, tree_merge, tree_reduce, tree_unflatten
|
||||
|
||||
|
||||
class Optimizer:
|
||||
@@ -154,6 +154,79 @@ class Optimizer:
|
||||
self.state[name] = param
|
||||
|
||||
|
||||
class MultiOptimizer(Optimizer):
|
||||
"""Wraps a list of optimizers with corresponding weight predicates/filters
|
||||
to make it easy to use different optimizers for different weights.
|
||||
|
||||
The predicates take the full "path" of the weight and the weight itself and
|
||||
return True if it should be considered for this optimizer. The last
|
||||
optimizer in the list is a fallback optimizer and no predicate should be
|
||||
given for it.
|
||||
|
||||
Args:
|
||||
optimizers (list[Optimizer]): A list of optimizers to delegate to
|
||||
filters (list[Callable[[str, array], bool]): A list of predicates that
|
||||
should be one less than the provided optimizers.
|
||||
"""
|
||||
|
||||
def __init__(self, optimizers, filters: list = []):
|
||||
super().__init__()
|
||||
self._state = {}
|
||||
|
||||
if len(filters) != len(optimizers) - 1:
|
||||
raise ValueError(
|
||||
f"Given {len(filters)} filters but {len(optimizers)-1} needed."
|
||||
)
|
||||
|
||||
self.optimizers = optimizers
|
||||
self.filters = filters + [lambda *args, **kwargs: True]
|
||||
|
||||
def _split_dictionary(self, gradients: dict):
|
||||
if len(self.optimizers) == 1:
|
||||
return [gradients]
|
||||
|
||||
parts = [[] for _ in range(len(self.optimizers))]
|
||||
flat_gradients = tree_flatten(gradients)
|
||||
for k, g in flat_gradients:
|
||||
for i, fn in enumerate(self.filters):
|
||||
if fn(k, g):
|
||||
parts[i].append((k, g))
|
||||
break
|
||||
|
||||
return [tree_unflatten(p) for p in parts]
|
||||
|
||||
def init(self, parameters: dict):
|
||||
for o, p in zip(self.optimizers, self._split_dictionary(parameters)):
|
||||
o.init(p)
|
||||
|
||||
def apply_gradients(self, gradients: dict, parameters: dict):
|
||||
tree = {}
|
||||
for o, g in zip(self.optimizers, self._split_dictionary(gradients)):
|
||||
tree = tree_merge(tree, o.apply_gradients(g, parameters))
|
||||
return tree
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
return {"states": [o.state for o in self.optimizers]}
|
||||
|
||||
@state.setter
|
||||
def state(self, state: dict):
|
||||
if "states" not in state or len(state["states"]) != len(self.optimizers):
|
||||
raise ValueError("Invalid state provided")
|
||||
|
||||
for o, s in zip(self.optimizers, state["states"]):
|
||||
o.state = s
|
||||
|
||||
@property
|
||||
def learning_rate(self):
|
||||
return self.optimizers[0].learning_rate
|
||||
|
||||
@learning_rate.setter
|
||||
def learning_rate(self, learning_rate: Union[float, mx.array]):
|
||||
for o in self.optimizers:
|
||||
o.learning_rate = learning_rate
|
||||
|
||||
|
||||
class SGD(Optimizer):
|
||||
r"""The stochastic gradient descent optimizer.
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
from collections import defaultdict
|
||||
from itertools import zip_longest
|
||||
from typing import Any, Callable, List, Optional, Tuple
|
||||
|
||||
|
||||
@@ -244,3 +245,46 @@ def tree_reduce(fn, tree, initializer=None, is_leaf=None):
|
||||
return tree if accumulator is None else fn(accumulator, tree)
|
||||
|
||||
return accumulator
|
||||
|
||||
|
||||
def tree_merge(tree_a, tree_b, merge_fn=None):
|
||||
"""Merge two Python trees in one containing the values of both. It can be
|
||||
thought of as a deep dict.update method.
|
||||
|
||||
Args:
|
||||
tree_a (Any): The first Python tree.
|
||||
tree_b (Any): The second Python tree.
|
||||
merge_fn (callable, optional): A function to merge leaves.
|
||||
|
||||
Returns:
|
||||
The Python tree containing the values of both ``tree_a`` and
|
||||
``tree_b``.
|
||||
"""
|
||||
if isinstance(tree_a, (dict, list, tuple)) and len(tree_a) == 0:
|
||||
tree_a = None
|
||||
if isinstance(tree_b, (dict, list, tuple)) and len(tree_b) == 0:
|
||||
tree_b = None
|
||||
if tree_a is None and tree_b is not None:
|
||||
return tree_b
|
||||
if tree_a is not None and tree_b is None:
|
||||
return tree_a
|
||||
|
||||
if isinstance(tree_a, (list, tuple)) and isinstance(tree_b, (list, tuple)):
|
||||
TreeType = type(tree_a)
|
||||
return TreeType(
|
||||
tree_merge(a, b, merge_fn) for a, b in zip_longest(tree_a, tree_b)
|
||||
)
|
||||
elif isinstance(tree_a, dict) and isinstance(tree_b, dict):
|
||||
return {
|
||||
k: tree_merge(tree_a.get(k, None), tree_b.get(k, None), merge_fn)
|
||||
for k in set(tree_a.keys()) | set(tree_b.keys())
|
||||
}
|
||||
else:
|
||||
if merge_fn is None:
|
||||
raise ValueError(
|
||||
(
|
||||
"Trees contain elements at the same locations but no merge "
|
||||
"function was provided"
|
||||
)
|
||||
)
|
||||
return merge_fn(tree_a, tree_b)
|
||||
|
||||
@@ -25,12 +25,12 @@ void init_fast(nb::module_& parent_module) {
|
||||
"rms_norm",
|
||||
&mx::fast::rms_norm,
|
||||
"x"_a,
|
||||
"weight"_a,
|
||||
"weight"_a.none(),
|
||||
"eps"_a,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def rms_norm(x: array, weight: array, eps: float, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def rms_norm(x: array, weight: Optional[array], eps: float, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Root Mean Square normalization (RMS norm).
|
||||
|
||||
@@ -38,9 +38,9 @@ void init_fast(nb::module_& parent_module) {
|
||||
|
||||
Args:
|
||||
x (array): Input array.
|
||||
weight (array): A multiplicative weight to scale the result by.
|
||||
weight (array, optional): A multiplicative weight to scale the result by.
|
||||
The ``weight`` should be one-dimensional with the same size
|
||||
as the last axis of ``x``.
|
||||
as the last axis of ``x``. If set to ``None`` then no scaling happens.
|
||||
eps (float): A small additive constant for numerical stability.
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -92,6 +92,7 @@ void init_linalg(nb::module_& parent_module) {
|
||||
===== ============================ ==========================
|
||||
None Frobenius norm 2-norm
|
||||
'fro' Frobenius norm --
|
||||
'nuc' nuclear norm --
|
||||
inf max(sum(abs(x), axis=1)) max(abs(x))
|
||||
-inf min(sum(abs(x), axis=1)) min(abs(x))
|
||||
0 -- sum(x != 0)
|
||||
@@ -102,9 +103,6 @@ void init_linalg(nb::module_& parent_module) {
|
||||
other -- sum(abs(x)**ord)**(1./ord)
|
||||
===== ============================ ==========================
|
||||
|
||||
.. warning::
|
||||
Nuclear norm and norms based on singular values are not yet implemented.
|
||||
|
||||
The Frobenius norm is given by [1]_:
|
||||
|
||||
:math:`||A||_F = [\sum_{i,j} abs(a_{i,j})^2]^{1/2}`
|
||||
@@ -206,15 +204,22 @@ void init_linalg(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"svd",
|
||||
[](const mx::array& a, mx::StreamOrDevice s /* = {} */) {
|
||||
const auto result = mx::linalg::svd(a, s);
|
||||
return nb::make_tuple(result.at(0), result.at(1), result.at(2));
|
||||
[](const mx::array& a,
|
||||
bool compute_uv /* = true */,
|
||||
mx::StreamOrDevice s /* = {} */) -> nb::object {
|
||||
const auto result = mx::linalg::svd(a, compute_uv, s);
|
||||
if (result.size() == 1) {
|
||||
return nb::cast(result.at(0));
|
||||
} else {
|
||||
return nb::make_tuple(result.at(0), result.at(1), result.at(2));
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
"compute_uv"_a = true,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def svd(a: array, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array, array]"),
|
||||
"def svd(a: array, compute_uv: bool = True, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array, array]"),
|
||||
R"pbdoc(
|
||||
The Singular Value Decomposition (SVD) of the input matrix.
|
||||
|
||||
@@ -224,12 +229,15 @@ void init_linalg(nb::module_& parent_module) {
|
||||
|
||||
Args:
|
||||
a (array): Input array.
|
||||
compute_uv (bool, optional): If ``True``, return the ``U``, ``S``, and ``Vt`` components.
|
||||
If ``False``, return only the ``S`` array. Default: ``True``.
|
||||
stream (Stream, optional): Stream or device. Defaults to ``None``
|
||||
in which case the default stream of the default device is used.
|
||||
|
||||
Returns:
|
||||
tuple(array, array, array): The ``U``, ``S``, and ``Vt`` matrices, such that
|
||||
``A = U @ diag(S) @ Vt``
|
||||
Union[tuple(array, ...), array]:
|
||||
If compute_uv is ``True`` returns the ``U``, ``S``, and ``Vt`` matrices, such that
|
||||
``A = U @ diag(S) @ Vt``. If compute_uv is ``False`` returns singular values array ``S``.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"inv",
|
||||
|
||||
@@ -298,6 +298,9 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
rx = rms_norm(x, weight, eps)
|
||||
rx_fast = mx.fast.rms_norm(x, weight, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
rx = rms_norm(x, mx.ones_like(weight), eps)
|
||||
rx_fast = mx.fast.rms_norm(x, None, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
|
||||
for eps in epss:
|
||||
dtype, _, dims = defaults
|
||||
@@ -306,6 +309,9 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
rx = rms_norm(x, weight, eps)
|
||||
rx_fast = mx.fast.rms_norm(x, weight, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
rx = rms_norm(x, mx.ones_like(weight), eps)
|
||||
rx_fast = mx.fast.rms_norm(x, None, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
|
||||
for dims in dimss:
|
||||
dtype, eps, _ = defaults
|
||||
@@ -314,6 +320,9 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
rx = rms_norm(x, weight, eps)
|
||||
rx_fast = mx.fast.rms_norm(x, weight, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
rx = rms_norm(x, mx.ones_like(weight), eps)
|
||||
rx_fast = mx.fast.rms_norm(x, None, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
|
||||
# Test > 4096
|
||||
dims, dtype, eps = 4099, mx.float32, 1e-5
|
||||
@@ -333,6 +342,8 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
eps = 1e-5
|
||||
f1 = lambda x, w, y: (rms_norm(x, w, eps) * y).sum()
|
||||
f2 = lambda x, w, y: (mx.fast.rms_norm(x, w, eps) * y).sum()
|
||||
f3 = lambda x, y: (rms_norm(x, mx.ones((x.shape[-1],)), eps) * y).sum()
|
||||
f4 = lambda x, y: (mx.fast.rms_norm(x, None, eps) * y).sum()
|
||||
|
||||
x = mx.random.uniform(shape=(8, 100, D))
|
||||
w = mx.random.uniform(shape=(D,))
|
||||
@@ -341,6 +352,9 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
gx2, gw2 = mx.grad(f2, argnums=(0, 1))(x, w, y)
|
||||
self.assertLess(mx.abs(gx1 - gx2).max(), 1e-5)
|
||||
self.assertLess(mx.abs(gw1 - gw2).max() / mx.abs(gw1).mean(), 1e-5)
|
||||
gx1 = mx.grad(f3, argnums=(0,))(x, y)
|
||||
gx2 = mx.grad(f4, argnums=(0,))(x, y)
|
||||
self.assertLess(mx.abs(gx1 - gx2).max(), 1e-5)
|
||||
|
||||
D = 8192
|
||||
x = mx.random.uniform(shape=(2, 2, D))
|
||||
@@ -350,6 +364,9 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
gx2, gw2 = mx.grad(f2, argnums=(0, 1))(x, w, y)
|
||||
self.assertLess(mx.abs(gx1 - gx2).max(), 1e-5)
|
||||
self.assertLess(mx.abs(gw1 - gw2).max() / mx.abs(gw1).mean(), 1e-5)
|
||||
gx1 = mx.grad(f3, argnums=(0,))(x, y)
|
||||
gx2 = mx.grad(f4, argnums=(0,))(x, y)
|
||||
self.assertLess(mx.abs(gx1 - gx2).max(), 1e-5)
|
||||
|
||||
def gf(f):
|
||||
def inner(x, w, y):
|
||||
|
||||
@@ -262,6 +262,61 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
|
||||
)
|
||||
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
||||
|
||||
def test_fast_sdpa_few_query(self):
|
||||
D = 64
|
||||
L = 43
|
||||
Lq = 4
|
||||
Nq = 8
|
||||
Nkv = 1
|
||||
scale = 1.0
|
||||
mx.random.seed(0)
|
||||
q = 5e-1 * mx.random.normal(shape=(1, Lq, Nq, D))
|
||||
q = q.swapaxes(1, 2)
|
||||
k = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))
|
||||
v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))
|
||||
|
||||
masks = [
|
||||
mx.array(True),
|
||||
mx.array([True] * (L - 10) + [False] * 10),
|
||||
mx.random.uniform(shape=(Nq, 1, L)) > 0.2,
|
||||
mx.random.uniform(shape=(L, 1, Nq)).T > 0.2,
|
||||
]
|
||||
for m in masks:
|
||||
ref = mlx_primitives_sdpa(q, k, v, scale, mask=m)
|
||||
out = mx.fast.scaled_dot_product_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
scale=scale,
|
||||
mask=m,
|
||||
)
|
||||
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
||||
|
||||
return
|
||||
L = 4096
|
||||
scale = 1.0
|
||||
mx.random.seed(0)
|
||||
q = 5e-1 * mx.random.normal(shape=(1, Nq, Lq, D))
|
||||
k = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))
|
||||
v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))
|
||||
|
||||
masks = [
|
||||
mx.array(True),
|
||||
mx.array([True] * (L - 10) + [False] * 10),
|
||||
mx.random.uniform(shape=(Nq, 1, L)) > 0.2,
|
||||
mx.random.uniform(shape=(L, 1, Nq)).T > 0.2,
|
||||
]
|
||||
for m in masks:
|
||||
ref = mlx_primitives_sdpa(q, k, v, scale, mask=m)
|
||||
out = mx.fast.scaled_dot_product_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
scale=scale,
|
||||
mask=m,
|
||||
)
|
||||
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
||||
|
||||
@unittest.skip("Different head and value dims is not enabled")
|
||||
def test_fast_sdpa_vector_value_dims(self):
|
||||
D = 192
|
||||
|
||||
@@ -12,11 +12,11 @@ import numpy as np
|
||||
class TestLinalg(mlx_tests.MLXTestCase):
|
||||
def test_norm(self):
|
||||
vector_ords = [None, 0.5, 0, 1, 2, 3, -1, float("inf"), -float("inf")]
|
||||
matrix_ords = [None, "fro", -1, 1, float("inf"), -float("inf")]
|
||||
matrix_ords = [None, "fro", "nuc", -1, 1, -2, 2, float("inf"), -float("inf")]
|
||||
|
||||
for shape in [(3,), (2, 3), (2, 3, 3)]:
|
||||
x_mx = mx.arange(1, math.prod(shape) + 1).reshape(shape)
|
||||
x_np = np.arange(1, math.prod(shape) + 1).reshape(shape)
|
||||
x_mx = mx.arange(1, math.prod(shape) + 1, dtype=mx.float32).reshape(shape)
|
||||
x_np = np.arange(1, math.prod(shape) + 1, dtype=np.float32).reshape(shape)
|
||||
# Test when at least one axis is provided
|
||||
for num_axes in range(1, len(shape)):
|
||||
if num_axes == 1:
|
||||
@@ -26,11 +26,14 @@ class TestLinalg(mlx_tests.MLXTestCase):
|
||||
for axis in itertools.combinations(range(len(shape)), num_axes):
|
||||
for keepdims in [True, False]:
|
||||
for o in ords:
|
||||
stream = (
|
||||
mx.cpu if o in ["nuc", -2, 2] else mx.default_device()
|
||||
)
|
||||
out_np = np.linalg.norm(
|
||||
x_np, ord=o, axis=axis, keepdims=keepdims
|
||||
)
|
||||
out_mx = mx.linalg.norm(
|
||||
x_mx, ord=o, axis=axis, keepdims=keepdims
|
||||
x_mx, ord=o, axis=axis, keepdims=keepdims, stream=stream
|
||||
)
|
||||
with self.subTest(
|
||||
shape=shape, ord=o, axis=axis, keepdims=keepdims
|
||||
@@ -133,20 +136,38 @@ class TestLinalg(mlx_tests.MLXTestCase):
|
||||
|
||||
def test_svd_decomposition(self):
|
||||
A = mx.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], dtype=mx.float32)
|
||||
U, S, Vt = mx.linalg.svd(A, stream=mx.cpu)
|
||||
U, S, Vt = mx.linalg.svd(A, compute_uv=True, stream=mx.cpu)
|
||||
self.assertTrue(
|
||||
mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, A, rtol=1e-5, atol=1e-7)
|
||||
)
|
||||
|
||||
S = mx.linalg.svd(A, compute_uv=False, stream=mx.cpu)
|
||||
self.assertTrue(
|
||||
mx.allclose(
|
||||
mx.linalg.norm(S), mx.linalg.norm(A, ord="fro"), rtol=1e-5, atol=1e-7
|
||||
)
|
||||
)
|
||||
|
||||
# Multiple matrices
|
||||
B = A + 10.0
|
||||
AB = mx.stack([A, B])
|
||||
Us, Ss, Vts = mx.linalg.svd(AB, stream=mx.cpu)
|
||||
Us, Ss, Vts = mx.linalg.svd(AB, compute_uv=True, stream=mx.cpu)
|
||||
for M, U, S, Vt in zip([A, B], Us, Ss, Vts):
|
||||
self.assertTrue(
|
||||
mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, M, rtol=1e-5, atol=1e-7)
|
||||
)
|
||||
|
||||
Ss = mx.linalg.svd(AB, compute_uv=False, stream=mx.cpu)
|
||||
for M, S in zip([A, B], Ss):
|
||||
self.assertTrue(
|
||||
mx.allclose(
|
||||
mx.linalg.norm(S),
|
||||
mx.linalg.norm(M, ord="fro"),
|
||||
rtol=1e-5,
|
||||
atol=1e-7,
|
||||
)
|
||||
)
|
||||
|
||||
def test_inverse(self):
|
||||
A = mx.array([[1, 2, 3], [6, -5, 4], [-9, 8, 7]], dtype=mx.float32)
|
||||
A_inv = mx.linalg.inv(A, stream=mx.cpu)
|
||||
|
||||
@@ -1894,6 +1894,22 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
expected = mx.repeat(expected[:, None], 2, axis=1)
|
||||
self.assertTrue(mx.array_equal(expected, out))
|
||||
|
||||
# Test donation
|
||||
def fn(its):
|
||||
x = mx.ones((32,))
|
||||
for _ in range(its):
|
||||
x = mx.cumsum(x)
|
||||
return x
|
||||
|
||||
mx.synchronize(mx.default_stream(mx.default_device()))
|
||||
mx.eval(fn(2))
|
||||
mx.synchronize(mx.default_stream(mx.default_device()))
|
||||
mem2 = mx.metal.get_peak_memory()
|
||||
mx.eval(fn(4))
|
||||
mx.synchronize(mx.default_stream(mx.default_device()))
|
||||
mem4 = mx.metal.get_peak_memory()
|
||||
self.assertEqual(mem2, mem4)
|
||||
|
||||
def test_squeeze_expand(self):
|
||||
a = mx.zeros((2, 1, 2, 1))
|
||||
self.assertEqual(mx.squeeze(a).shape, (2, 2))
|
||||
@@ -2846,6 +2862,11 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
b[::2] = 0
|
||||
self.assertTrue(mx.array_equal(b, mx.array([0, 3, 0, 1])))
|
||||
|
||||
def test_slice_with_negative_stride(self):
|
||||
a = mx.random.uniform(shape=(128, 4))
|
||||
out = a[::-1]
|
||||
self.assertTrue(mx.array_equal(out[-1, :], a[0, :]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -39,6 +39,7 @@ def tree_equal(fn, *args):
|
||||
|
||||
|
||||
optimizers_dict = get_all_optimizers()
|
||||
del optimizers_dict["MultiOptimizer"]
|
||||
|
||||
|
||||
class TestOptimizers(mlx_tests.MLXTestCase):
|
||||
@@ -500,6 +501,30 @@ class TestSchedulers(unittest.TestCase):
|
||||
grads = model.trainable_parameters()
|
||||
optimizer.update(model, grads)
|
||||
|
||||
def test_multi_optimizer(self):
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.l1 = nn.Linear(2, 2)
|
||||
self.drop = nn.Dropout(p=0.5)
|
||||
self.l2 = nn.Linear(2, 2)
|
||||
self.vals = [nn.Linear(2, 2), nn.ReLU(), nn.ReLU()]
|
||||
|
||||
model = Model()
|
||||
optimizer = opt.MultiOptimizer(
|
||||
[opt.Adam(learning_rate=0.001), opt.SGD(learning_rate=0.1)],
|
||||
[lambda name, weight: weight.ndim > 1],
|
||||
)
|
||||
optimizer.init(model.trainable_parameters())
|
||||
|
||||
self.assertEqual(len(optimizer.state["states"]), 2)
|
||||
|
||||
adam_states = tree_flatten(optimizer.state["states"][0])
|
||||
sgd_states = tree_flatten(optimizer.state["states"][1])
|
||||
self.assertEqual((len(sgd_states) - 2) * 2, len(adam_states) - 2)
|
||||
self.assertFalse(any("bias" in k for k, v in adam_states))
|
||||
self.assertFalse(any("weight" in k for k, v in sgd_states))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.utils
|
||||
import mlx_tests
|
||||
|
||||
@@ -22,6 +23,29 @@ class TestTreeUtils(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(list(zip(*flat_tree))[1], vals)
|
||||
self.assertEqual(mlx.utils.tree_unflatten(flat_tree), tree)
|
||||
|
||||
def test_merge(self):
|
||||
t1 = {"a": 0}
|
||||
t2 = {"b": 1}
|
||||
t = mlx.utils.tree_merge(t1, t2)
|
||||
self.assertEqual({"a": 0, "b": 1}, t)
|
||||
with self.assertRaises(ValueError):
|
||||
mlx.utils.tree_merge(t1, t1)
|
||||
with self.assertRaises(ValueError):
|
||||
mlx.utils.tree_merge(t, t1)
|
||||
|
||||
mod1 = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
|
||||
mod2 = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
|
||||
mod = nn.Sequential(mod1, mod2)
|
||||
|
||||
params1 = {"layers": [mod1.parameters()]}
|
||||
params2 = {"layers": [None, mod2.parameters()]}
|
||||
params = mlx.utils.tree_merge(params1, params2)
|
||||
for (k1, v1), (k2, v2) in zip(
|
||||
mlx.utils.tree_flatten(params), mlx.utils.tree_flatten(mod.parameters())
|
||||
):
|
||||
self.assertEqual(k1, k2)
|
||||
self.assertTrue(mx.array_equal(v1, v2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -316,35 +316,59 @@ class TestVmap(mlx_tests.MLXTestCase):
|
||||
def test_vmap_svd(self):
|
||||
a = mx.random.uniform(shape=(3, 4, 2))
|
||||
|
||||
cpu_svd = lambda x: mx.linalg.svd(x, stream=mx.cpu)
|
||||
cpu_svd_full = lambda x: mx.linalg.svd(x, compute_uv=True, stream=mx.cpu)
|
||||
cpu_svd_singular = lambda x: mx.linalg.svd(x, compute_uv=False, stream=mx.cpu)
|
||||
|
||||
# Vmap over the first axis (this is already supported natively by the primitive).
|
||||
Us, Ss, Vts = mx.vmap(cpu_svd, in_axes=(0,))(a)
|
||||
Us, Ss, Vts = mx.vmap(cpu_svd_full, in_axes=(0,))(a)
|
||||
self.assertEqual(Us.shape, (a.shape[0], a.shape[1], a.shape[1]))
|
||||
self.assertEqual(Ss.shape, (a.shape[0], a.shape[2]))
|
||||
self.assertEqual(Vts.shape, (a.shape[0], a.shape[2], a.shape[2]))
|
||||
|
||||
Sv = mx.vmap(cpu_svd_singular, in_axes=(0,))(a)
|
||||
self.assertEqual(Sv.shape, (a.shape[0], a.shape[2]))
|
||||
|
||||
for i in range(a.shape[0]):
|
||||
M = a[i]
|
||||
U, S, Vt = Us[i], Ss[i], Vts[i]
|
||||
self.assertTrue(
|
||||
mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, M, rtol=1e-5, atol=1e-7)
|
||||
)
|
||||
self.assertTrue(
|
||||
mx.allclose(
|
||||
mx.linalg.norm(Sv[i]),
|
||||
mx.linalg.norm(M, ord="fro"),
|
||||
rtol=1e-5,
|
||||
atol=1e-7,
|
||||
)
|
||||
)
|
||||
|
||||
# Vmap over the second axis.
|
||||
Us, Ss, Vts = mx.vmap(cpu_svd, in_axes=(1,))(a)
|
||||
Us, Ss, Vts = mx.vmap(cpu_svd_full, in_axes=(1,))(a)
|
||||
self.assertEqual(Us.shape, (a.shape[1], a.shape[0], a.shape[0]))
|
||||
self.assertEqual(Ss.shape, (a.shape[1], a.shape[2]))
|
||||
self.assertEqual(Vts.shape, (a.shape[1], a.shape[2], a.shape[2]))
|
||||
|
||||
Sv = mx.vmap(cpu_svd_singular, in_axes=(1,))(a)
|
||||
self.assertEqual(Sv.shape, (a.shape[1], a.shape[2]))
|
||||
|
||||
for i in range(a.shape[1]):
|
||||
M = a[:, i, :]
|
||||
U, S, Vt = Us[i], Ss[i], Vts[i]
|
||||
self.assertTrue(
|
||||
mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, M, rtol=1e-5, atol=1e-7)
|
||||
)
|
||||
self.assertTrue(
|
||||
mx.allclose(
|
||||
mx.linalg.norm(Sv[i]),
|
||||
mx.linalg.norm(M, ord="fro"),
|
||||
rtol=1e-5,
|
||||
atol=1e-7,
|
||||
)
|
||||
)
|
||||
|
||||
def test_vmap_inverse(self):
|
||||
mx.random.seed(42)
|
||||
a = mx.random.uniform(shape=(3, 4, 4))
|
||||
|
||||
cpu_inv = lambda x: mx.linalg.inv(x, stream=mx.cpu)
|
||||
|
||||
Reference in New Issue
Block a user