Merge branch 'ml-explore:main' into adding-Muon-optimizer

This commit is contained in:
Gökdeniz Gülmez
2025-03-05 10:05:10 +01:00
committed by GitHub
39 changed files with 989 additions and 280 deletions

View File

@@ -14,6 +14,8 @@ import time
from collections import Counter
from dataclasses import dataclass
from pathlib import Path
from queue import Empty as QueueEmpty
from queue import Queue
from select import select
from subprocess import PIPE, Popen, run
from typing import Optional
@@ -185,46 +187,54 @@ def parse_hostlist(parser, hostlist, repeats):
def make_monitor_script(rank, hostfile, cwd, env, command, verbose):
# Imports that are used throughout
script = ""
script += "import os\n"
script += "import sys\n"
script += "import tempfile\n"
script += "from pathlib import Path\n"
# Write the PID to a file so we can kill the process if needed
script += "pidfile=$(mktemp)\n"
script += "echo $$ >$pidfile\n"
script += "echo $pidfile\n"
script += "_, pidfile = tempfile.mkstemp() \n"
script += "open(pidfile, 'w').write(str(os.getpid()))\n"
script += "print(pidfile, flush=True)\n"
# Change the working directory if one was requested. Otherwise attempt to
# change to change to the current one but don't fail if it wasn't possible.
# change to the current one but don't fail if it wasn't possible.
d = cwd or os.getcwd()
script += f"if [ -d {shlex.quote(d)} ]; then\n"
script += f" cd {shlex.quote(d)}\n"
script += f"if Path({repr(d)}).exists():\n"
script += f" os.chdir({repr(d)})\n"
if cwd is not None:
script += "else\n"
script += f" echo Failed to change directory to {shlex.quote(d)} 1>&2\n"
script += f" exit 1\n"
script += "fi\n"
script += "else:\n"
script += (
f" print('Failed to change directory to', {repr(d)}, file=sys.stderr)\n"
)
script += f" sys.exit(1)\n"
# Add the environment variables that were given to us
script += "env = dict(os.environ)\n"
for e in env:
key, *value = e.split("=", maxsplit=1)
value = shlex.quote(value[0]) if len(value) > 0 else ""
if not all(c.isalnum() or c == "_" for c in key):
log_warning(f"'{e}' is an invalid environment variable so it is ignored")
continue
script += f"export {key}={value}\n"
script += f"env[{repr(key)}] = {repr(value)}\n"
# Add the environment variables to enable the ring distributed backend
if hostfile != "":
script += "tmpfile=$(mktemp)\n"
script += f"echo {shlex.quote(hostfile)} >$tmpfile\n"
script += "_, hostfile = tempfile.mkstemp()\n"
script += "with open(hostfile, 'w') as f:\n"
script += f" f.write({repr(hostfile)})\n"
if verbose:
script += "export MLX_RING_VERBOSE=1\n"
script += "export MLX_HOSTFILE=$tmpfile\n"
script += f"export MLX_RANK={rank}\n"
script += "env['MLX_RING_VERBOSE'] = '1'\n"
script += "env['MLX_HOSTFILE'] = hostfile\n"
script += f"env['MLX_RANK'] = '{rank}'\n"
script += "\n"
# Replace the process with the script
script += shlex.join(["exec", *command])
script += "\n"
script += f"command = [{','.join(map(repr, command))}]\n"
script += "os.execve(command[0], command, env)\n"
return script
@@ -233,28 +243,37 @@ def launch_ring(parser, hosts, args, command):
stop = False
exit_codes = [None] * len(hosts)
def node_thread(rank, host, hostfile):
def node_thread(rank, host, hostfile, input_queue):
is_local = host == "127.0.0.1"
script = make_monitor_script(
rank, hostfile, args.cwd, args.env, command, args.verbose
)
script_b64 = base64.b64encode(script.encode()).decode()
cmd = f'echo "{script_b64}" | base64 -d | /bin/bash'
cmd = f'{sys.executable} -c "import base64; exec(base64.b64decode(\\"{script_b64}\\"));"'
if not is_local:
cmd = f"ssh {host} '{cmd}'"
p = Popen(
cmd,
shell=True,
stdin=PIPE,
stdout=PIPE,
stderr=PIPE,
)
os.set_blocking(p.stdout.fileno(), False)
os.set_blocking(p.stderr.fileno(), False)
os.set_blocking(p.stdin.fileno(), False)
# Repeat the stdout and stderr to the local machine
to_read = [p.stdout.fileno(), p.stderr.fileno()]
to_write = [p.stdin.fileno()]
pidfile = ""
stdin_buffer = b""
while p.poll() is None:
rlist, _, _ = select([p.stdout.fileno(), p.stderr.fileno()], [], [], 1.0)
try:
stdin_buffer += input_queue.get_nowait()
except QueueEmpty:
pass
rlist, wlist, _ = select(to_read, to_write, [], 1.0)
for fd in rlist:
is_stdout = fd == p.stdout.fileno()
outfile = sys.stdout if is_stdout else sys.stderr
@@ -266,6 +285,11 @@ def launch_ring(parser, hosts, args, command):
msg = msg[0] if msg else ""
outfile.write(msg)
outfile.flush()
for fd in wlist:
if len(stdin_buffer) > 0:
n = os.write(fd, stdin_buffer)
stdin_buffer = stdin_buffer[n:]
if stop:
p.terminate()
break
@@ -310,16 +334,25 @@ def launch_ring(parser, hosts, args, command):
log(args.verbose, "Running", shlex.join(command))
input_queues = []
threads = []
for i, h in enumerate(hosts):
if i + 1 == len(hosts):
time.sleep(1.0)
t = threading.Thread(target=node_thread, args=(i, h.ssh_hostname, hostfile))
input_queues.append(Queue())
t = threading.Thread(
target=node_thread, args=(i, h.ssh_hostname, hostfile, input_queues[-1])
)
t.start()
threads.append(t)
os.set_blocking(sys.stdin.fileno(), False)
while not stop:
time.sleep(1.0)
rlist, _, _ = select([sys.stdin.fileno()], [], [], 1.0)
for fd in rlist:
stdin_buffer = os.read(fd, 8192)
for q in input_queues:
q.put(stdin_buffer)
if any(t.is_alive() for t in threads):
for i, t in enumerate(threads):
if not t.is_alive():
@@ -730,6 +763,8 @@ def main():
if len(rest) == 0:
parser.error("No script is provided")
if rest[0] == "--":
rest.pop(0)
# Try to extract a list of hosts and corresponding ips
if args.hostfile is not None:

View File

@@ -5,7 +5,7 @@ from typing import Callable, List, Optional, Tuple, Union
import mlx.core as mx
from mlx.nn import Module
from mlx.utils import tree_map, tree_reduce
from mlx.utils import tree_flatten, tree_map, tree_merge, tree_reduce, tree_unflatten
class Optimizer:
@@ -154,6 +154,79 @@ class Optimizer:
self.state[name] = param
class MultiOptimizer(Optimizer):
"""Wraps a list of optimizers with corresponding weight predicates/filters
to make it easy to use different optimizers for different weights.
The predicates take the full "path" of the weight and the weight itself and
return True if it should be considered for this optimizer. The last
optimizer in the list is a fallback optimizer and no predicate should be
given for it.
Args:
optimizers (list[Optimizer]): A list of optimizers to delegate to
filters (list[Callable[[str, array], bool]): A list of predicates that
should be one less than the provided optimizers.
"""
def __init__(self, optimizers, filters: list = []):
super().__init__()
self._state = {}
if len(filters) != len(optimizers) - 1:
raise ValueError(
f"Given {len(filters)} filters but {len(optimizers)-1} needed."
)
self.optimizers = optimizers
self.filters = filters + [lambda *args, **kwargs: True]
def _split_dictionary(self, gradients: dict):
if len(self.optimizers) == 1:
return [gradients]
parts = [[] for _ in range(len(self.optimizers))]
flat_gradients = tree_flatten(gradients)
for k, g in flat_gradients:
for i, fn in enumerate(self.filters):
if fn(k, g):
parts[i].append((k, g))
break
return [tree_unflatten(p) for p in parts]
def init(self, parameters: dict):
for o, p in zip(self.optimizers, self._split_dictionary(parameters)):
o.init(p)
def apply_gradients(self, gradients: dict, parameters: dict):
tree = {}
for o, g in zip(self.optimizers, self._split_dictionary(gradients)):
tree = tree_merge(tree, o.apply_gradients(g, parameters))
return tree
@property
def state(self):
return {"states": [o.state for o in self.optimizers]}
@state.setter
def state(self, state: dict):
if "states" not in state or len(state["states"]) != len(self.optimizers):
raise ValueError("Invalid state provided")
for o, s in zip(self.optimizers, state["states"]):
o.state = s
@property
def learning_rate(self):
return self.optimizers[0].learning_rate
@learning_rate.setter
def learning_rate(self, learning_rate: Union[float, mx.array]):
for o in self.optimizers:
o.learning_rate = learning_rate
class SGD(Optimizer):
r"""The stochastic gradient descent optimizer.

View File

@@ -1,5 +1,6 @@
# Copyright © 2023 Apple Inc.
from collections import defaultdict
from itertools import zip_longest
from typing import Any, Callable, List, Optional, Tuple
@@ -244,3 +245,46 @@ def tree_reduce(fn, tree, initializer=None, is_leaf=None):
return tree if accumulator is None else fn(accumulator, tree)
return accumulator
def tree_merge(tree_a, tree_b, merge_fn=None):
"""Merge two Python trees in one containing the values of both. It can be
thought of as a deep dict.update method.
Args:
tree_a (Any): The first Python tree.
tree_b (Any): The second Python tree.
merge_fn (callable, optional): A function to merge leaves.
Returns:
The Python tree containing the values of both ``tree_a`` and
``tree_b``.
"""
if isinstance(tree_a, (dict, list, tuple)) and len(tree_a) == 0:
tree_a = None
if isinstance(tree_b, (dict, list, tuple)) and len(tree_b) == 0:
tree_b = None
if tree_a is None and tree_b is not None:
return tree_b
if tree_a is not None and tree_b is None:
return tree_a
if isinstance(tree_a, (list, tuple)) and isinstance(tree_b, (list, tuple)):
TreeType = type(tree_a)
return TreeType(
tree_merge(a, b, merge_fn) for a, b in zip_longest(tree_a, tree_b)
)
elif isinstance(tree_a, dict) and isinstance(tree_b, dict):
return {
k: tree_merge(tree_a.get(k, None), tree_b.get(k, None), merge_fn)
for k in set(tree_a.keys()) | set(tree_b.keys())
}
else:
if merge_fn is None:
raise ValueError(
(
"Trees contain elements at the same locations but no merge "
"function was provided"
)
)
return merge_fn(tree_a, tree_b)