mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Merge branch 'ml-explore:main' into adding-Muon-optimizer
This commit is contained in:
@@ -270,9 +270,11 @@ def launch_ring(parser, hosts, args, command):
|
||||
|
||||
# Repeat the stdout and stderr to the local machine
|
||||
to_read = [p.stdout.fileno(), p.stderr.fileno()]
|
||||
to_write = [p.stdin.fileno()]
|
||||
to_write = [p.stdin.fileno(), sys.stdout.fileno(), sys.stderr.fileno()]
|
||||
pidfile = ""
|
||||
stdin_buffer = b""
|
||||
stdout_buffer = b""
|
||||
stderr_buffer = b""
|
||||
while p.poll() is None:
|
||||
try:
|
||||
stdin_buffer += input_queue.get_nowait()
|
||||
@@ -280,8 +282,6 @@ def launch_ring(parser, hosts, args, command):
|
||||
pass
|
||||
rlist, wlist, _ = select(to_read, to_write, [], 1.0)
|
||||
for fd in rlist:
|
||||
is_stdout = fd == p.stdout.fileno()
|
||||
outfile = sys.stdout if is_stdout else sys.stderr
|
||||
msg = os.read(fd, 8192).decode(errors="ignore")
|
||||
|
||||
# Fetch the PID file first if we haven't already
|
||||
@@ -289,12 +289,21 @@ def launch_ring(parser, hosts, args, command):
|
||||
pidfile, *msg = msg.split("\n", maxsplit=1)
|
||||
msg = msg[0] if msg else ""
|
||||
|
||||
outfile.write(msg)
|
||||
outfile.flush()
|
||||
is_stdout = fd == p.stdout.fileno()
|
||||
if is_stdout:
|
||||
stdout_buffer += msg.encode()
|
||||
else:
|
||||
stderr_buffer += msg.encode()
|
||||
for fd in wlist:
|
||||
if len(stdin_buffer) > 0:
|
||||
if fd == p.stdin.fileno() and len(stdin_buffer) > 0:
|
||||
n = os.write(fd, stdin_buffer)
|
||||
stdin_buffer = stdin_buffer[n:]
|
||||
elif fd == sys.stdout.fileno() and len(stdout_buffer) > 0:
|
||||
n = os.write(fd, stdout_buffer)
|
||||
stdout_buffer = stdout_buffer[n:]
|
||||
elif fd == sys.stderr.fileno() and len(stderr_buffer) > 0:
|
||||
n = os.write(fd, stderr_buffer)
|
||||
stderr_buffer = stderr_buffer[n:]
|
||||
if stop:
|
||||
p.terminate()
|
||||
break
|
||||
|
||||
@@ -53,11 +53,7 @@ class CMakeBuild(build_ext):
|
||||
# Set CMAKE_BUILD_PARALLEL_LEVEL to control the parallel build level
|
||||
# across all generators.
|
||||
if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ:
|
||||
# self.parallel is a Python 3 only way to set parallel jobs by hand
|
||||
# using -j in the build_ext call, not supported by pip or PyPA-build.
|
||||
if hasattr(self, "parallel") and self.parallel:
|
||||
# CMake 3.12+ only.
|
||||
build_args += [f"-j{self.parallel}"]
|
||||
build_args += [f"-j{os.cpu_count()}"]
|
||||
|
||||
build_temp = Path(self.build_temp) / ext.name
|
||||
if not build_temp.exists():
|
||||
|
||||
@@ -546,7 +546,7 @@ class GELU(Module):
|
||||
|
||||
See :func:`gelu`, :func:`gelu_approx` and :func:`gelu_fast_approx` for the
|
||||
functional equivalents and information regarding error bounds.
|
||||
|
||||
|
||||
|
||||
Args:
|
||||
approx ('none' | 'precise' | 'fast'): Which approximation to gelu to use if any.
|
||||
@@ -554,20 +554,19 @@ class GELU(Module):
|
||||
|
||||
def __init__(self, approx="none"):
|
||||
super().__init__()
|
||||
|
||||
if approx == "none":
|
||||
self._act = gelu
|
||||
elif approx == "precise" or approx == "tanh":
|
||||
self._act = gelu_approx
|
||||
elif approx == "fast":
|
||||
self._act = gelu_fast_approx
|
||||
else:
|
||||
self._approx = approx
|
||||
allowed = ["none", "precise", "tanh", "fast"]
|
||||
if approx not in allowed:
|
||||
raise ValueError(
|
||||
f"The approximation should be in ['none', 'precise', 'tanh', 'fast'] but '{approx}' was given"
|
||||
f"The approximation should be in {allowed} but '{approx}' was given"
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
return self._act(x)
|
||||
if self._approx == "none":
|
||||
return gelu(x)
|
||||
elif self._approx in ["precise", "tanh"]:
|
||||
return gelu_approx(x)
|
||||
return gelu_fast_approx(x)
|
||||
|
||||
|
||||
@_make_activation_module(tanh)
|
||||
|
||||
@@ -114,6 +114,12 @@ class Module(dict):
|
||||
super(Module, self).__setattr__(key, val)
|
||||
self.pop(key, None)
|
||||
|
||||
def __delattr__(self, name):
|
||||
if (val := self.get(name, None)) is not None:
|
||||
del self[name]
|
||||
else:
|
||||
super().__delattr__(name)
|
||||
|
||||
def load_weights(
|
||||
self,
|
||||
file_or_weights: Union[str, List[Tuple[str, mx.array]]],
|
||||
@@ -174,11 +180,15 @@ class Module(dict):
|
||||
new_weights = dict(weights)
|
||||
curr_weights = dict(tree_flatten(self.parameters()))
|
||||
if extras := (new_weights.keys() - curr_weights.keys()):
|
||||
extras = " ".join(extras)
|
||||
raise ValueError(f"Received parameters not in model: {extras}.")
|
||||
num_extra = len(extras)
|
||||
extras = ",\n".join(sorted(extras))
|
||||
raise ValueError(
|
||||
f"Received {num_extra} parameters not in model: \n{extras}."
|
||||
)
|
||||
if missing := (curr_weights.keys() - new_weights.keys()):
|
||||
missing = " ".join(missing)
|
||||
raise ValueError(f"Missing parameters: {missing}.")
|
||||
num_missing = len(missing)
|
||||
missing = ",\n".join(sorted(missing))
|
||||
raise ValueError(f"Missing {num_missing} parameters: \n{missing}.")
|
||||
for k, v in curr_weights.items():
|
||||
v_new = new_weights[k]
|
||||
if not isinstance(v_new, mx.array):
|
||||
@@ -193,7 +203,7 @@ class Module(dict):
|
||||
)
|
||||
|
||||
if len(weights) != 0:
|
||||
self.update(tree_unflatten(weights))
|
||||
self.update(tree_unflatten(weights), strict=False)
|
||||
return self
|
||||
|
||||
def save_weights(self, file: str):
|
||||
@@ -291,7 +301,7 @@ class Module(dict):
|
||||
|
||||
return self.filter_and_map(self.valid_child_filter, is_leaf_fn=_is_leaf_module)
|
||||
|
||||
def update(self, parameters: dict) -> Module:
|
||||
def update(self, parameters: dict, strict: bool = True) -> Module:
|
||||
"""Replace the parameters of this Module with the provided ones in the
|
||||
dict of dicts and lists.
|
||||
|
||||
@@ -305,7 +315,9 @@ class Module(dict):
|
||||
|
||||
Args:
|
||||
parameters (dict): A complete or partial dictionary of the modules
|
||||
parameters.
|
||||
parameters.
|
||||
strict (bool): If ``True`` checks that ``parameters`` is a
|
||||
subset of the module's parameters. Default: ``True``.
|
||||
Returns:
|
||||
The module instance after updating the parameters.
|
||||
"""
|
||||
@@ -317,21 +329,29 @@ class Module(dict):
|
||||
current_value = dst[k]
|
||||
new_value = parameters[k]
|
||||
if isinstance(current_value, mx.array):
|
||||
if strict and not isinstance(new_value, mx.array):
|
||||
raise ValueError(
|
||||
f"Received invalid type: {type(new_value).__name__}."
|
||||
)
|
||||
dst[k] = new_value
|
||||
elif isinstance(current_value, Module):
|
||||
current_value.update(new_value)
|
||||
elif isinstance(current_value, (dict, list)):
|
||||
else:
|
||||
apply(current_value, new_value)
|
||||
elif strict:
|
||||
raise ValueError(f'Module does not have parameter named "{k}".')
|
||||
elif isinstance(parameters, list):
|
||||
for i in range(len(parameters)):
|
||||
current_value = dst[i]
|
||||
new_value = parameters[i]
|
||||
if isinstance(current_value, mx.array):
|
||||
if strict and not isinstance(new_value, mx.array):
|
||||
raise ValueError(
|
||||
f"Received invalid type: {type(new_value).__name__}."
|
||||
)
|
||||
dst[i] = new_value
|
||||
elif isinstance(current_value, Module):
|
||||
current_value.update(new_value)
|
||||
elif isinstance(current_value, (dict, list)):
|
||||
else:
|
||||
apply(current_value, new_value)
|
||||
elif strict:
|
||||
raise ValueError(f"Received invalid type: {type(parameters).__name__}.")
|
||||
|
||||
apply(self, parameters)
|
||||
return self
|
||||
@@ -359,7 +379,7 @@ class Module(dict):
|
||||
self.update(self.filter_and_map(filter_fn, map_fn))
|
||||
return self
|
||||
|
||||
def update_modules(self, modules: dict) -> Module:
|
||||
def update_modules(self, modules: dict, strict: bool = True) -> Module:
|
||||
"""Replace the child modules of this :class:`Module` instance with the
|
||||
provided ones in the dict of dicts and lists.
|
||||
|
||||
@@ -368,12 +388,14 @@ class Module(dict):
|
||||
programmatically swapping layers.
|
||||
|
||||
The passed in parameters dictionary need not be a full dictionary
|
||||
similar to :meth:`parameters`. Only the provided locations will be
|
||||
similar to :meth:`modules`. Only the provided locations will be
|
||||
updated.
|
||||
|
||||
Args:
|
||||
modules (dict): A complete or partial dictionary of the modules
|
||||
modules (dict): A complete or partial dictionary of the module's
|
||||
submodules.
|
||||
strict (bool): If ``True`` checks that ``modules`` is a
|
||||
subset of the child modules of this instance. Default: ``True``.
|
||||
Returns:
|
||||
The module instance after updating the submodules.
|
||||
"""
|
||||
@@ -388,14 +410,28 @@ class Module(dict):
|
||||
dst[k] = new_value
|
||||
elif isinstance(current_value, (dict, list)):
|
||||
apply(current_value, new_value)
|
||||
elif strict and new_value != {}:
|
||||
raise ValueError(
|
||||
f"Received invalid type: {type(new_value).__name__}."
|
||||
)
|
||||
elif strict:
|
||||
raise ValueError(
|
||||
f'Module does not have sub-module named "{k}".'
|
||||
)
|
||||
elif isinstance(modules, list):
|
||||
for i in range(len(dst)):
|
||||
for i in range(len(modules)):
|
||||
current_value = dst[i]
|
||||
new_value = modules[i]
|
||||
if self.is_module(current_value) and self.is_module(new_value):
|
||||
dst[i] = new_value
|
||||
elif isinstance(current_value, (dict, list)):
|
||||
apply(current_value, new_value)
|
||||
elif strict and new_value != {}:
|
||||
raise ValueError(
|
||||
f"Received invalid type: {type(new_value).__name__}."
|
||||
)
|
||||
elif strict:
|
||||
raise ValueError(f"Received invalid type: {type(modules).__name__}.")
|
||||
|
||||
apply(self, modules)
|
||||
return self
|
||||
|
||||
@@ -25,6 +25,8 @@ class ConvTranspose1d(Module):
|
||||
padding (int, optional): How many positions to 0-pad the input with.
|
||||
Default: ``0``.
|
||||
dilation (int, optional): The dilation of the convolution.
|
||||
output_padding(int, optional): Additional size added to one side of the
|
||||
output shape. Default: ``0``.
|
||||
bias (bool, optional): If ``True`` add a learnable bias to the output.
|
||||
Default: ``True``
|
||||
"""
|
||||
@@ -37,6 +39,7 @@ class ConvTranspose1d(Module):
|
||||
stride: int = 1,
|
||||
padding: int = 0,
|
||||
dilation: int = 1,
|
||||
output_padding: int = 0,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -53,18 +56,25 @@ class ConvTranspose1d(Module):
|
||||
self.padding = padding
|
||||
self.dilation = dilation
|
||||
self.stride = stride
|
||||
self.output_padding = output_padding
|
||||
|
||||
def _extra_repr(self):
|
||||
return (
|
||||
f"{self.weight.shape[-1]}, {self.weight.shape[0]}, "
|
||||
f"kernel_size={self.weight.shape[1]}, stride={self.stride}, "
|
||||
f"padding={self.padding}, dilation={self.dilation}, "
|
||||
f"output_padding={self.output_padding}, "
|
||||
f"bias={'bias' in self}"
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
y = mx.conv_transpose1d(
|
||||
x, self.weight, self.stride, self.padding, self.dilation
|
||||
x,
|
||||
self.weight,
|
||||
self.stride,
|
||||
self.padding,
|
||||
self.dilation,
|
||||
self.output_padding,
|
||||
)
|
||||
if "bias" in self:
|
||||
y = y + self.bias
|
||||
@@ -90,6 +100,8 @@ class ConvTranspose2d(Module):
|
||||
padding (int or tuple, optional): How many positions to 0-pad
|
||||
the input with. Default: ``0``.
|
||||
dilation (int or tuple, optional): The dilation of the convolution.
|
||||
output_padding(int or tuple, optional): Additional size added to one
|
||||
side of the output shape. Default: ``0``.
|
||||
bias (bool, optional): If ``True`` add a learnable bias to the
|
||||
output. Default: ``True``
|
||||
"""
|
||||
@@ -102,13 +114,14 @@ class ConvTranspose2d(Module):
|
||||
stride: Union[int, tuple] = 1,
|
||||
padding: Union[int, tuple] = 0,
|
||||
dilation: Union[int, tuple] = 1,
|
||||
output_padding: Union[int, tuple] = 0,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
kernel_size, stride, padding = map(
|
||||
kernel_size, stride, padding, output_padding = map(
|
||||
lambda x: (x, x) if isinstance(x, int) else x,
|
||||
(kernel_size, stride, padding),
|
||||
(kernel_size, stride, padding, output_padding),
|
||||
)
|
||||
scale = math.sqrt(1 / (in_channels * kernel_size[0] * kernel_size[1]))
|
||||
self.weight = mx.random.uniform(
|
||||
@@ -122,18 +135,25 @@ class ConvTranspose2d(Module):
|
||||
self.padding = padding
|
||||
self.stride = stride
|
||||
self.dilation = dilation
|
||||
self.output_padding = output_padding
|
||||
|
||||
def _extra_repr(self):
|
||||
return (
|
||||
f"{self.weight.shape[-1]}, {self.weight.shape[0]}, "
|
||||
f"kernel_size={self.weight.shape[1:2]}, stride={self.stride}, "
|
||||
f"padding={self.padding}, dilation={self.dilation}, "
|
||||
f"output_padding={self.output_padding}, "
|
||||
f"bias={'bias' in self}"
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
y = mx.conv_transpose2d(
|
||||
x, self.weight, self.stride, self.padding, self.dilation
|
||||
x,
|
||||
self.weight,
|
||||
self.stride,
|
||||
self.padding,
|
||||
self.dilation,
|
||||
self.output_padding,
|
||||
)
|
||||
if "bias" in self:
|
||||
y = y + self.bias
|
||||
@@ -160,6 +180,8 @@ class ConvTranspose3d(Module):
|
||||
padding (int or tuple, optional): How many positions to 0-pad
|
||||
the input with. Default: ``0``.
|
||||
dilation (int or tuple, optional): The dilation of the convolution.
|
||||
output_padding(int or tuple, optional): Additional size added to one
|
||||
side of the output shape. Default: ``0``.
|
||||
bias (bool, optional): If ``True`` add a learnable bias to the
|
||||
output. Default: ``True``
|
||||
"""
|
||||
@@ -172,13 +194,14 @@ class ConvTranspose3d(Module):
|
||||
stride: Union[int, tuple] = 1,
|
||||
padding: Union[int, tuple] = 0,
|
||||
dilation: Union[int, tuple] = 1,
|
||||
output_padding: Union[int, tuple] = 0,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
kernel_size, stride, padding = map(
|
||||
kernel_size, stride, padding, output_padding = map(
|
||||
lambda x: (x, x, x) if isinstance(x, int) else x,
|
||||
(kernel_size, stride, padding),
|
||||
(kernel_size, stride, padding, output_padding),
|
||||
)
|
||||
scale = math.sqrt(
|
||||
1 / (in_channels * kernel_size[0] * kernel_size[1] * kernel_size[2])
|
||||
@@ -194,18 +217,25 @@ class ConvTranspose3d(Module):
|
||||
self.padding = padding
|
||||
self.stride = stride
|
||||
self.dilation = dilation
|
||||
self.output_padding = output_padding
|
||||
|
||||
def _extra_repr(self):
|
||||
return (
|
||||
f"{self.weight.shape[-1]}, {self.weight.shape[0]}, "
|
||||
f"kernel_size={self.weight.shape[1:3]}, stride={self.stride}, "
|
||||
f"padding={self.padding}, dilation={self.dilation}, "
|
||||
f"output_padding={self.output_padding}, "
|
||||
f"bias={'bias' in self}"
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
y = mx.conv_transpose3d(
|
||||
x, self.weight, self.stride, self.padding, self.dilation
|
||||
x,
|
||||
self.weight,
|
||||
self.stride,
|
||||
self.padding,
|
||||
self.dilation,
|
||||
self.output_padding,
|
||||
)
|
||||
if "bias" in self:
|
||||
y = y + self.bias
|
||||
|
||||
@@ -193,12 +193,6 @@ class QuantizedLinear(Module):
|
||||
# Freeze this model's parameters
|
||||
self.freeze()
|
||||
|
||||
def unfreeze(self, *args, **kwargs):
|
||||
"""Wrap unfreeze so that we unfreeze any layers we might contain but
|
||||
our parameters will remain frozen."""
|
||||
super().unfreeze(*args, **kwargs)
|
||||
self.freeze(recurse=False)
|
||||
|
||||
def _extra_repr(self):
|
||||
out_dims, in_dims = self.weight.shape
|
||||
in_dims *= 32 // self.bits
|
||||
|
||||
@@ -25,7 +25,16 @@ def _scaled_indices(N, scale, align_corners, dim, ndims):
|
||||
|
||||
|
||||
def _nearest_indices(N, scale, dim, ndims):
|
||||
return _scaled_indices(N, scale, True, dim, ndims).astype(mx.uint32)
|
||||
M = int(scale * N)
|
||||
indices = mx.arange(M, dtype=mx.float32)
|
||||
if M > N:
|
||||
indices = (indices + 0.5) * (N / M) - 0.5
|
||||
indices = indices.round()
|
||||
else:
|
||||
indices = indices * (N / M)
|
||||
shape = [1] * ndims
|
||||
shape[dim] = -1
|
||||
return indices.astype(mx.uint32).reshape(shape)
|
||||
|
||||
|
||||
def _linear_indices(N, scale, align_corners, dim, ndims):
|
||||
|
||||
@@ -526,8 +526,10 @@ class Adam(Optimizer):
|
||||
state["v"] = v
|
||||
|
||||
if bias_correction:
|
||||
numerator = lr / (1 - b1**step) * m
|
||||
denominator = mx.sqrt(v) / mx.sqrt(1 - b2**step) + eps
|
||||
c1 = (lr / (1 - b1**step)).astype(gradient.dtype)
|
||||
c2 = mx.rsqrt(1 - b2**step).astype(gradient.dtype)
|
||||
numerator = c1 * m
|
||||
denominator = mx.sqrt(v) * c2 + eps
|
||||
return parameter - numerator / denominator
|
||||
else:
|
||||
return parameter - lr * m / (mx.sqrt(v) + eps)
|
||||
|
||||
Reference in New Issue
Block a user