awni's commit files

This commit is contained in:
Awni Hannun
2023-11-29 10:30:41 -08:00
parent e411fcae68
commit 8ca7f9e8e9
130 changed files with 30159 additions and 0 deletions

37
python/README.md Normal file
View File

@@ -0,0 +1,37 @@
### Packaging for PyPI
Install `build` and `twine`:
```
pip install --user --upgrade build
pip install --user --upgrade twine
```
Generate the source distribution and wheel:
```
python -m build
```
*Warning* use a test server first
#### Test Upload
Upload to test server:
```
python -m twine upload --repository testpypi dist/*
```
Install from test server and check that it works:
```
python -m pip install --index-url https://test.pypi.org/simple/ --no-deps mlx
```
#### Upload
```
python -m twine upload dist/*
```

View File

@@ -0,0 +1,18 @@
import array
import reprlib
class FixedRepr(reprlib.Repr):
"""Only route python array instances to repr_array."""
def repr_array(self, x, maxlevel):
if isinstance(x, array.array):
return super().repr_array(x, maxlevel)
else:
return self.repr_instance(x, maxlevel)
# We need to monkey-patch reprlib so that we can use the debugger without
# renaming the array to something else
fixed_repr = FixedRepr()
reprlib.repr = fixed_repr.repr

94
python/mlx/extension.py Normal file
View File

@@ -0,0 +1,94 @@
import os
import re
import subprocess
import sys
from pathlib import Path
from setuptools import Extension, setup, find_namespace_packages
from setuptools.command.build_ext import build_ext
import mlx
_MLX_PATH = str(mlx.__path__[0])
# A CMakeExtension needs a sourcedir instead of a file list.
class CMakeExtension(Extension):
def __init__(self, name: str, sourcedir: str = "") -> None:
super().__init__(name, sources=[])
self.sourcedir = os.fspath(Path(sourcedir).resolve())
class CMakeBuild(build_ext):
def build_extension(self, ext: CMakeExtension) -> None:
# Must be in this form due to bug in .resolve() only fixed in Python 3.10+
ext_fullpath = Path.cwd() / self.get_ext_fullpath(ext.name) # type: ignore[no-untyped-call]
extdir = ext_fullpath.parent.resolve()
debug = int(os.environ.get("DEBUG", 0)) if self.debug is None else self.debug
cfg = "Debug" if debug else "Release"
# CMake lets you override the generator - we need to check this.
# Can be set with Conda-Build, for example.
cmake_generator = os.environ.get("CMAKE_GENERATOR", "")
# Set Python_EXECUTABLE instead if you use PYBIND11_FINDPYTHON
# EXAMPLE_VERSION_INFO shows you how to pass a value into the C++ code
# from Python.
cmake_args = [
f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}{os.sep}",
f"-DCMAKE_BUILD_TYPE={cfg}",
"-DBUILD_SHARED_LIBS=ON",
]
build_args = []
# Adding CMake arguments set as environment variable
# (needed e.g. to build for ARM OSx on conda-forge)
if "CMAKE_ARGS" in os.environ:
cmake_args += [item for item in os.environ["CMAKE_ARGS"].split(" ") if item]
if sys.platform.startswith("darwin"):
# Cross-compile support for macOS - respect ARCHFLAGS if set
archs = re.findall(r"-arch (\S+)", os.environ.get("ARCHFLAGS", ""))
if archs:
cmake_args += ["-DCMAKE_OSX_ARCHITECTURES={}".format(";".join(archs))]
# Set CMAKE_BUILD_PARALLEL_LEVEL to control the parallel build level
# across all generators.
if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ:
# self.parallel is a Python 3 only way to set parallel jobs by hand
# using -j in the build_ext call, not supported by pip or PyPA-build.
if hasattr(self, "parallel") and self.parallel:
# CMake 3.12+ only.
build_args += [f"-j{self.parallel}"]
build_temp = Path(self.build_temp) / ext.name
if not build_temp.exists():
build_temp.mkdir(parents=True)
# Make sure cmake can find MLX
os.environ["MLX_DIR"] = _MLX_PATH
subprocess.run(
["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True
)
subprocess.run(
["cmake", "--build", ".", *build_args], cwd=build_temp, check=True
)
def run(self):
super().run()
# Based on https://github.com/pypa/setuptools/blob/main/setuptools/command/build_ext.py#L102
if self.inplace:
for ext in self.extensions:
if isinstance(ext, CMakeExtension):
# Resolve inplace package dir
build_py = self.get_finalized_command("build_py")
inplace_file, regular_file = self._get_inplace_equivalent(
build_py, ext
)
inplace_dir = str(Path(inplace_file).parent.resolve())
regular_dir = str(Path(regular_file).parent.resolve())
self.copy_tree(regular_dir, inplace_dir)

View File

@@ -0,0 +1,401 @@
import textwrap
from typing import Any, Callable, List, Union, Optional
import mlx.core as mx
from mlx.utils import tree_flatten, tree_unflatten
class Module(dict):
"""Base class for building neural networks with MLX.
All the layers provided in :mod:`mlx.nn.layers` subclass this class and
your models should do the same.
A ``Module`` can contain other ``Module`` instances or :class:`mlx.core.array`
instances in arbitrary nesting of python lists or dicts. The ``Module``
then allows recursively extracting all the :class:`mlx.core.array` instances
using :meth:`mlx.nn.Module.parameters`.
In addition, the ``Module`` has the concept of trainable and non trainable
parameters (called "frozen"). When using :func:`mlx.nn.value_and_grad`
the gradients are returned only with respect to the trainable parameters.
All arrays in a module are trainable unless they are added in the "frozen"
set by calling :meth:`freeze`.
.. code-block:: python
import mlx.core as mx
import mlx.nn as nn
class MyMLP(nn.Module):
def __init__(self, in_dims: int, out_dims: int, hidden_dims: int = 16):
super().__init__()
self.in_proj = nn.Linear(in_dims, hidden_dims)
self.out_proj = nn.Linear(hidden_dims, out_dims)
def __call__(self, x):
x = self.in_proj(x)
x = mx.maximum(x, 0)
return self.out_proj(x)
model = MyMLP(2, 1)
# All the model parameters are created but since MLX is lazy by
# default, they are not evaluated yet. Calling `mx.eval` actually
# allocates memory and initializes the parameters.
mx.eval(model.parameters())
# Setting a parameter to a new value is as simply as accessing that
# parameter and assigning a new array to it.
model.in_proj.weight = model.in_proj.weight * 2
mx.eval(model.parameters())
"""
def __init__(self):
"""Should be called by the subclasses of ``Module``."""
self._no_grad = set()
self._training = True
@property
def training(self):
return self._training
def _extra_repr(self):
return ""
def __repr__(self):
children = tree_flatten(self.children(), is_leaf=self.is_module)
value = f"{type(self).__name__}({self._extra_repr()}"
for k, v in children:
value += "\n"
value += textwrap.indent(f"({k}): {repr(v)}", prefix=" ")
if children:
value += "\n"
value += ")"
return value
def __getattr__(self, key: str):
if key in self:
return self[key]
else:
raise AttributeError(f"{type(self)!r} has no attribute {key!r}")
def __setattr__(self, key: str, val: Any):
self[key] = val
def load_weights(self, file: str):
"""
Load and update the model's weights from a `.npz` file.
"""
self.update(tree_unflatten(list(mx.load(file).items())))
def save_weights(self, file: str):
"""
Save the model's weights to a `.npz` file.
"""
mx.savez(file, **dict(tree_flatten(self.parameters())))
@staticmethod
def is_module(value):
return isinstance(value, Module)
@staticmethod
def valid_child_filter(module, key, value):
return isinstance(value, (dict, list))
@staticmethod
def valid_parameter_filter(module, key, value):
return isinstance(value, (dict, list, mx.array)) and not key.startswith("_")
@staticmethod
def trainable_parameter_filter(module, key, value):
return (
Module.valid_parameter_filter(module, key, value)
and key not in module._no_grad
)
def filter_and_map(
self,
filter_fn: Callable[["mlx.nn.Module", str, Any], bool],
map_fn: Optional[Callable] = None,
is_leaf_fn: Optional[Callable[["mlx.nn.Module", str, Any], bool]] = None,
):
"""Recursively filter the contents of the module using ``filter_fn``,
namely only select keys and values where ``filter_fn`` returns true.
This is used to implement :meth:`parameters` and :meth:`trainable_parameters`
but it can also be used to extract any subset of the module's parameters.
Args:
filter_fn (Callable): Given a value, the key in which it is found
and the containing module, decide whether to keep the value or
drop it.
map_fn (Callable, optional): Optionally transform the value before
returning it.
is_leaf_fn (Callable, optional): Given a value, the key in which it
is found and the containing module decide if it is a leaf.
Returns:
A dictionary containing the contents of the module recursively filtered
"""
map_fn = map_fn or (lambda x: x)
is_leaf_fn = is_leaf_fn or (
lambda m, k, v: not isinstance(v, (Module, dict, list))
)
def unwrap(vk, v):
if is_leaf_fn(self, vk, v):
return map_fn(v)
if isinstance(v, Module):
return v.filter_and_map(filter_fn, map_fn, is_leaf_fn)
if isinstance(v, dict):
nd = {}
for k, v in v.items():
tk = f"{vk}.{k}"
nd[k] = unwrap(tk, v) if filter_fn(self, tk, v) else {}
return nd
if isinstance(v, list):
nl = []
for i, vi in enumerate(v):
tk = f"{vk}.{i}"
nl.append(unwrap(tk, vi) if filter_fn(self, tk, vi) else {})
return nl
raise RuntimeError("Unexpected leaf found while traversing the module")
return {k: unwrap(k, v) for k, v in self.items() if filter_fn(self, k, v)}
def parameters(self):
"""Recursively return all the :class:`mlx.core.array` members of this Module
as a dict of dicts and lists."""
return self.filter_and_map(self.valid_parameter_filter)
def trainable_parameters(self):
"""Recursively return all the non frozen :class:`mlx.core.array` members of
this Module as a dict of dicts and lists."""
return self.filter_and_map(self.trainable_parameter_filter)
def children(self):
"""Return the direct descendants of this Module instance."""
return self.filter_and_map(
self.valid_child_filter, is_leaf_fn=lambda m, k, v: isinstance(v, Module)
)
def leaf_modules(self):
"""Return the submodules that do not contain other modules."""
def _is_leaf_module(m, k, v):
return isinstance(v, Module) and len(tree_flatten(v.children())) == 0
return self.filter_and_map(self.valid_child_filter, is_leaf_fn=_is_leaf_module)
def update(self, parameters: dict):
"""Replace the parameters of this Module with the provided ones in the
dict of dicts and lists.
Commonly used by the optimizer to change the model to the updated
(optimized) parameters. Also used by the :meth:`mlx.nn.value_and_grad` to set the
tracers in the model in order to compute gradients.
The passed in parameters dictionary need not be a full dictionary
similar to :meth:`parameters`. Only the provided locations will be
updated.
Args:
parameters (dict): A complete or partial dictionary of the modules
parameters.
"""
def apply(dst, parameters):
if isinstance(parameters, dict):
for k in parameters:
if k in dst:
current_value = dst[k]
new_value = parameters[k]
if isinstance(current_value, mx.array):
dst[k] = new_value
elif isinstance(current_value, Module):
current_value.update(new_value)
elif isinstance(current_value, (dict, list)):
apply(current_value, new_value)
elif isinstance(parameters, list):
for i in range(len(dst)):
current_value = dst[i]
new_value = parameters[i]
if isinstance(current_value, mx.array):
dst[i] = new_value
elif isinstance(current_value, Module):
current_value.update(new_value)
elif isinstance(current_value, (dict, list)):
apply(current_value, new_value)
apply(self, parameters)
def apply(
self,
map_fn: Callable[[mx.array], mx.array],
filter_fn: Optional[Callable[["mlx.nn.Module", str, Any], bool]] = None,
):
"""Map all the parameters using the provided ``map_fn`` and immediately
update the module with the mapped parameters.
For instance running ``model.apply(lambda x: x.astype(mx.float16))``
casts all parameters to 16 bit floats.
Args:
map_fn (Callable): Maps an array to another array
filter_fn (Callable, optional): Filter to select which arrays to
map (default: :meth:`Module.valid_parameter_filter`).
"""
filter_fn = filter_fn or Module.valid_parameter_filter
self.update(self.filter_and_map(filter_fn, map_fn))
def apply_to_modules(self, apply_fn: Callable[[str, "mlx.nn.Module"], Any]):
"""Apply a function to all the modules in this instance (including this
instance).
Args:
apply_fn (Callable): The function to apply to the modules.
"""
module_stack = [("", self)]
while module_stack:
prefix, mod = module_stack.pop()
apply_fn(prefix, mod)
prefix = "." + prefix if prefix else ""
module_stack.extend(
tree_flatten(mod.children(), prefix=prefix, is_leaf=self.is_module)
)
def modules(self):
"""Return a list with all the modules in this instance.
Returns:
A list of :class:`mlx.nn.Module` instances.
"""
modulelist = []
self.apply_to_modules(lambda k, m: modulelist.append(m))
return modulelist
def named_modules(self):
"""Return a list with all the modules in this instance and their name
with dot notation.
Returns:
A list of tuples (str, :class:`mlx.nn.Module`).
"""
modulelist = []
self.apply_to_modules(lambda k, m: modulelist.append((k, m)))
return modulelist
def _validate_keys(self, keys, strict):
keys = keys if isinstance(keys, list) else [keys]
if strict:
for k in keys:
if k not in self:
raise KeyError(f"Module doesn't contain member {k}.")
return keys
def freeze(
self,
*,
recurse: bool = True,
keys: Optional[Union[str, List[str]]] = None,
strict: bool = False,
):
"""Freeze the Module's parameters or some of them. Freezing a parameter means not
computing gradients for it.
This function is idempotent ie freezing a frozen model is a noop.
For instance to only train the attention parameters from a transformer:
model = ...
model.freeze()
model.apply_to_modules(lambda k, v: v.unfreeze() if k.endswith("attention") else None)
Args:
recurse (bool, optional): If True then freeze the parameters of the
submodules as well (default: True).
keys (str or list[str], optional): If provided then only these
parameters will be frozen otherwise all the parameters of a
module. For instance freeze all biases by calling
``module.freeze(keys="bias")``.
strict (bool, optional): If set to True validate that the passed keys exist
(default: False).
"""
def _freeze_impl(_, m):
local_keys = keys
if local_keys is None:
local_keys = tree_flatten(
m.filter_and_map(
lambda m, k, v: (not isinstance(v, Module))
and m.valid_parameter_filter(m, k, v)
)
)
local_keys = [k for (k, v) in local_keys]
local_keys = m._validate_keys(local_keys, strict)
m._no_grad.update(local_keys)
if recurse:
self.apply_to_modules(_freeze_impl)
else:
_freeze_impl("", self)
def unfreeze(
self,
*,
recurse: bool = True,
keys: Optional[Union[str, List[str]]] = None,
strict: bool = False,
):
"""Unfreeze the Module's parameters or some of them.
This function is idempotent ie unfreezing a model that is not frozen is
a noop.
For instance to only train the biases one can do:
model = ...
model.freeze()
model.unfreeze(keys="bias")
Args:
recurse (bool, optional): If True then unfreeze the parameters of the
submodules as well (default: True).
keys (str or list[str], optional): If provided then only these
parameters will be unfrozen otherwise all the parameters of a
module. For instance unfreeze all biases by calling
``module.unfreeze(keys="bias")``.
strict (bool, optional): If set to True validate that the passed keys exist
(default: False).
"""
def _unfreeze_impl(_, m):
if keys is None:
m._no_grad.clear()
else:
local_keys = m._validate_keys(keys, strict)
m._no_grad.difference_update(local_keys)
if recurse:
self.apply_to_modules(_unfreeze_impl)
else:
_unfreeze_impl("", self)
def train(self, mode: bool = True):
def _set_train(_, m):
m._training = mode
self.apply_to_modules(_set_train)
def eval(self):
self.train(False)

View File

@@ -0,0 +1,22 @@
from mlx.nn.layers.base import Module
class Sequential(Module):
"""A layer that calls the passed callables in order.
We can pass either modules or plain callables to the Sequential module. If
our functions have learnable parameters they should be implemented as
``nn.Module`` instances.
Args:
modules (tuple of Callables): The modules to call in order
"""
def __init__(self, *modules):
super().__init__()
self.layers = list(modules)
def __call__(self, x):
for m in self.layers:
x = m(x)
return x

View File

@@ -0,0 +1,33 @@
import mlx.core as mx
from mlx.nn.layers.base import Module
class Dropout(Module):
"""Randomly zero a portion of the elements during training.
The remaining elements are multiplied with :math:`\frac{1}{1-p}` where
:math:`p` is the probability of zeroing an element. This is done so the
expected value of a given element will remain the same.
Args:
p (float): The probability to zero an element
"""
def __init__(self, p: float = 0.5):
super().__init__()
if p < 0 or p >= 1:
raise ValueError("The dropout probability should be in [0, 1)")
self._p_1 = 1 - p
def _extra_repr(self):
return f"p={1-self._p_1}"
def __call__(self, x):
if self._p_1 == 1 or not self.training:
return x
mask = mx.random.bernoulli(self._p_1, x.shape)
return (1 / self._p_1) * mask.astype(x.dtype) * x

View File

@@ -0,0 +1,178 @@
import mlx.core as mx
from mlx.nn.layers.base import Module
class LayerNorm(Module):
r"""Applies layer normalization [1] on the inputs.
Computes
.. math::
y = \frac{x - E[x]}{\sqrt{Var[x]} + \epsilon} \gamma + \beta,
where :math:`\gamma` and :math:`\beta` are learned per feature dimension
parameters initialized at 1 and 0 respectively.
[1]: https://arxiv.org/abs/1607.06450
Args:
dims (int): The feature dimension of the input to normalize over
eps (float): A small additive constant for numerical stability
affine (bool): If True learn an affine transform to apply after the
normalization
"""
def __init__(self, dims: int, eps: float = 1e-5, affine: bool = True):
super().__init__()
if affine:
self.bias = mx.zeros((dims,))
self.weight = mx.ones((dims,))
self.eps = eps
self.dims = dims
def _extra_repr(self):
return f"{self.dims}, eps={self.eps}, affine={'weight' in self}"
def __call__(self, x):
means = mx.mean(x, axis=-1, keepdims=True)
var = mx.var(x, axis=-1, keepdims=True)
x = (x - means) * mx.rsqrt(var + self.eps)
return (self.weight * x + self.bias) if "weight" in self else x
class RMSNorm(Module):
r"""Applies Root Mean Square normalization [1] to the inputs.
Computes
.. math::
y = \frac{x}{\sqrt{E[x^2] + \epsilon}} \gamma
where :math:`\gamma` is a learned per feature dimension parameter initialized at
1.
[1]: https://arxiv.org/abs/1910.07467
Args:
dims (int): The feature dimension of the input to normalize over
eps (float): A small additive constant for numerical stability
"""
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def _extra_repr(self):
return f"{self.weight.shape[0]}, eps={self.eps}"
def __call__(self, x):
# S is 1/sqrt(N) where N is the size of the features of x and is used
# to compute a numerically more stable RMS of x by multiplying with S
# first and summing.
#
# This way we prefer underflow over overflow which is controlled with
# the parameter epsilon anyway.
S = 1 / x.shape[-1] ** 0.5
n = (x * S).square().sum(axis=-1, keepdims=True)
n = mx.rsqrt(n + self.eps)
return self.weight * x * n
class GroupNorm(Module):
r"""Applies Group Normalization [1] to the inputs.
Computes the same normalization as layer norm, namely
.. math::
y = \frac{x - E[x]}{\sqrt{Var[x]} + \epsilon} \gamma + \beta,
where :math:`\gamma` and :math:`\beta` are learned per feature dimension
parameters initialized at 1 and 0 respectively. However, the mean and
variance are computed over the spatial dimensions and each group of
features. In particular, the input is split into num_groups accross the
feature dimension.
The feature dimension is assumed to be the last dimension and the dimensions
that precede it (except the first) are considered the spatial dimensions.
[1]: https://arxiv.org/abs/1803.08494
Args:
num_groups (int): Number of groups to separate the features into
dims (int): The feature dimensions of the input to normalize over
eps (float): A small additive constant for numerical stability
affine (bool): If True learn an affine transform to apply after the
normalization.
pytorch_compatible (bool): If True perform the group normalization in
the same order/grouping as PyTorch.
"""
def __init__(
self,
num_groups: int,
dims: int,
eps: float = 1e-5,
affine: bool = True,
pytorch_compatible: bool = False,
):
super().__init__()
if affine:
self.bias = mx.zeros((dims,))
self.weight = mx.ones((dims,))
self.num_groups = num_groups
self.dims = dims
self.eps = eps
self.pytorch_compatible = pytorch_compatible
def _extra_repr(self):
return (
f"{self.num_groups}, {self.dims}, eps={self.eps}, "
f"affine={'weight' in self}, pytorch_compatible={self.pytorch_compatible}"
)
def _pytorch_compatible_group_norm(self, x):
num_groups = self.num_groups
batch, *rest, dims = x.shape
# Split into groups
x = x.reshape(batch, -1, num_groups, dims // num_groups)
x = x.transpose(0, 1, 3, 2).reshape(batch, -1, num_groups)
# Normalize
means = mx.mean(x, axis=1, keepdims=True)
var = mx.var(x, axis=1, keepdims=True)
x = (x - means) * mx.rsqrt(var + self.eps)
x = x.reshape(batch, -1, dims // num_groups, num_groups)
x = x.transpose(0, 1, 3, 2).reshape(batch, *rest, dims)
return x
def _group_norm(self, x):
num_groups = self.num_groups
batch, *rest, dims = x.shape
# Split into groups
x = x.reshape(batch, -1, num_groups)
# Normalize
means = mx.mean(x, axis=1, keepdims=True)
var = mx.var(x, axis=1, keepdims=True)
x = (x - means) * mx.rsqrt(var + self.eps)
x = x.reshape(batch, *rest, dims)
return x
def __call__(self, x):
group_norm = (
self._pytorch_compatible_group_norm
if self.pytorch_compatible
else self._group_norm
)
x = group_norm(x)
return (self.weight * x + self.bias) if "weight" in self else x

View File

@@ -0,0 +1,142 @@
import math
from typing import Optional
import mlx.core as mx
from mlx.nn.layers.base import Module
class RoPE(Module):
"""Implements the rotary positional encoding [1].
The traditional implementation rotates consecutive pairs of elements in the
feature dimension while the default implementation rotates pairs with
stride half the feature dimensions for efficiency.
[1]: https://arxiv.org/abs/2104.09864
Args:
dims (int): The feature dimensions to be rotated. If the input feature
is larger than dims then the rest is left unchanged.
traditional (bool): If set to True choose the traditional
implementation which is slightly less efficient.
"""
def __init__(self, dims: int, traditional: bool = False):
super().__init__()
self.dims = dims
self.traditional = traditional
def _extra_repr(self):
return f"{self.dims}, traditional={self.traditional}"
def _compute_rope(self, costheta, sintheta, x):
x1 = x[..., : self.dims // 2]
x2 = x[..., self.dims // 2 : self.dims]
rx1 = x1 * costheta - x2 * sintheta
rx2 = x1 * sintheta + x2 * costheta
if self.dims < x.shape[-1]:
rx = mx.concatenate([rx1, rx2, x[..., self.dims :]], axis=-1)
else:
rx = mx.concatenate([rx1, rx2], axis=-1)
return rx
def _compute_traditional_rope(self, costheta, sintheta, x):
x1 = x[..., ::2]
x2 = x[..., 1::2]
rx1 = x1 * costheta - x2 * sintheta
rx2 = x1 * sintheta + x2 * costheta
if self.dims < x.shape[-1]:
raise NotImplementedError(
"RoPE doesn't implement partial traditional application"
)
rx = mx.concatenate([rx1[..., None], rx2[..., None]], axis=-1)
return rx
def __call__(self, x, offset: int = 0):
shape = x.shape
x = mx.reshape(x, (-1, shape[-2], shape[-1]))
N = x.shape[1] + offset
costheta, sintheta = RoPE.create_cos_sin_theta(
N, self.dims, offset=offset, dtype=x.dtype
)
rope = (
self._compute_traditional_rope if self.traditional else self._compute_rope
)
rx = rope(costheta, sintheta, x)
return mx.reshape(rx, shape)
@staticmethod
def create_cos_sin_theta(
N: int, D: int, offset: int = 0, base: float = 10000, dtype=mx.float32
):
D = D // 2
positions = mx.arange(offset, N, dtype=dtype)
freqs = mx.exp(-mx.arange(0.0, D, dtype=dtype) * (math.log(base) / D))
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1))
costheta = mx.cos(theta)
sintheta = mx.sin(theta)
return costheta, sintheta
class SinusoidalPositionalEncoding(Module):
"""Implements sinusoidal positional encoding similar to [1].
[1]: https://arxiv.org/abs/1706.03762
Args:
dims (int): The dimensionality of the resulting positional embeddings.
min_freq (float): The minimum frequency expected (default: 0.0001)
max_freq (float): The maximum frequency expected (default: 1)
scale (float): Scale the embeddings by that number (default: sqrt(dims//2))
cos_first (bool): If set to True embed using ``[cos(x); sin(x)]``
instead of the other way around (default: False)
full_turns (bool): If set to True multiply the frequencies
with ``2 pi`` (default: False)
"""
def __init__(
self,
dims: int,
min_freq: float = 0.0001,
max_freq: float = 1,
scale: Optional[float] = None,
cos_first: bool = False,
full_turns: bool = False,
):
super().__init__()
one_zero = 1 - mx.arange(0, dims // 2) / (dims // 2 - 1)
min_freq = math.log(min_freq)
max_freq = math.log(max_freq)
# Start with underscore so it is not included in the parameters
self._sigmas = mx.exp(one_zero * (max_freq - min_freq) + min_freq)
if full_turns:
self._sigmas = self._sigmas * (2 * math.pi)
# Save some constants that define the implementation
self.scale = scale or (2 / dims) ** 0.5
self.cos_first = cos_first
def __call__(self, x):
y = x[..., None] * self._sigmas
cosy = mx.cos(y)
siny = mx.sin(y)
if self.cos_first:
y = mx.concatenate([cosy, siny], axis=-1)
else:
y = mx.concatenate([siny, cosy], axis=-1)
if self.scale != 1:
y = y * self.scale
return y

View File

@@ -0,0 +1,136 @@
import math
from typing import Optional
import mlx.core as mx
from mlx.nn.layers.base import Module
from mlx.nn.layers.linear import Linear
from mlx.nn.layers.normalization import LayerNorm
class MultiHeadAttention(Module):
"""Implements the scaled dot product attention with multiple heads.
Given inputs for queries, keys and values the ``MultiHeadAttention`` produces
new values by aggregating information from the input values according to
the similarities of the input queries and keys.
All inputs as well as the output are lineary projected without biases.
MultiHeadAttention also expects an additive attention mask that should be
broadcastable with (batch, num_heads, # queries, # keys). The mask should
have ``-inf`` or very negative numbers to the positions that should *not* be
attended to.
Args:
dims (int): The model dimensions. If no other dims are provided then
dims is used for queries, keys, values and the output.
num_heads (int): How many attention heads to use
query_input_dims (int, optional): The input dimensions of the queries (default: dims).
key_input_dims (int, optional): The input dimensions of the keys (default: dims).
value_input_dims (int, optional): The input dimensions of the values (default: key_input_dims).
value_dims (int, optional): The dimensions of the values after the projection (default: dims).
value_output_dims (int, optional): The dimensions the new values will be projected to (default: dims).
"""
def __init__(
self,
dims: int,
num_heads: int,
query_input_dims: Optional[int] = None,
key_input_dims: Optional[int] = None,
value_input_dims: Optional[int] = None,
value_dims: Optional[int] = None,
value_output_dims: Optional[int] = None,
):
super().__init__()
if (dims % num_heads) != 0:
raise ValueError(
f"The input feature dimensions should be divisble by the number of heads ({dims} % {num_heads}) != 0"
)
query_input_dims = query_input_dims or dims
key_input_dims = key_input_dims or dims
value_input_dims = value_input_dims or key_input_dims
value_dims = value_dims or dims
value_output_dims = value_output_dims or dims
self.num_heads = num_heads
self.query_proj = Linear(query_input_dims, dims, False)
self.key_proj = Linear(key_input_dims, dims, False)
self.value_proj = Linear(value_input_dims, value_dims, False)
self.out_proj = Linear(value_dims, value_output_dims, False)
def __call__(self, queries, keys, values, mask=None):
queries = self.query_proj(queries)
keys = self.key_proj(keys)
values = self.value_proj(values)
num_heads = self.num_heads
B, L, D = queries.shape
_, S, _ = keys.shape
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1)
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
# Dimensions are [batch x num heads x sequence x hidden dim]
scale = math.sqrt(1 / queries.shape[-1])
scores = (queries * scale) @ keys
if mask is not None:
scores = scores + mask.astype(scores.dtype)
scores = mx.softmax(scores, axis=-1)
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.out_proj(values_hat)
@staticmethod
def create_additive_causal_mask(N: int, dtype: mx.Dtype = mx.float32):
indices = mx.arange(N)
mask = indices[:, None] < indices[None]
# usually inf but 1e9 is as good and softmax(full(1e9)) != nan
# TODO: Should replace this with finfo(dtype).min
mask = mask.astype(dtype) * -1e9
return mask
class TransformerEncoderLayer(Module):
def __init__(self, dims: int, num_heads: int, mlp_dims: Optional[int] = None):
super().__init__()
mlp_dims = mlp_dims or dims * 4
self.attention = MultiHeadAttention(dims, num_heads)
self.ln1 = LayerNorm(dims)
self.ln2 = LayerNorm(dims)
self.linear1 = Linear(dims, mlp_dims)
self.linear2 = Linear(mlp_dims, dims)
def __call__(self, x, mask):
y = self.ln1(x)
y = self.attention(y, y, y, mask)
x = x + y
y = self.ln2(x)
y = self.linear1(y)
y = mx.maximum(y, 0)
y = self.linear2(y)
x = x + y
return x
class TransformerEncoder(Module):
def __init__(
self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None
):
super().__init__()
self.layers = [
TransformerEncoderLayer(dims, num_heads, mlp_dims)
for i in range(num_layers)
]
self.ln = LayerNorm(dims)
def __call__(self, x, mask):
for l in self.layers:
x = l(x, mask)
x = self.ln(x)
return x

31
python/mlx/nn/utils.py Normal file
View File

@@ -0,0 +1,31 @@
from typing import Callable
import mlx.core as mx
def value_and_grad(model: "mlx.nn.Module", fn: Callable):
"""Transform the passed function ``fn`` to a function that computes the
gradients of ``fn`` wrt the model's trainable parameters and also its
value.
Args:
model (mlx.nn.Module): The model whose trainable parameters to compute
gradients for
fn (Callable): The scalar function to compute gradients for
Returns:
A callable that returns the value of ``fn`` and the gradients wrt the
trainable parameters of ``model``
"""
def inner_fn(params, *args, **kwargs):
model.update(params)
return fn(*args, **kwargs)
value_grad_fn = mx.value_and_grad(inner_fn)
def wrapped_value_grad_fn(*args, **kwargs):
value, grad = value_grad_fn(model.trainable_parameters(), *args, **kwargs)
return value, grad
return wrapped_value_grad_fn

152
python/mlx/optimizers.py Normal file
View File

@@ -0,0 +1,152 @@
import math
from typing import List
import mlx.core as mx
from mlx.utils import tree_map
class OptimizerState(dict):
"""The optimizer state implements a recursively defined
:class:`collections.defaultdict`, namely a missing key in an optimizer
state is an :class:`OptimizerState`.
.. note::
:meth:`OptimizerState.get` in contrast to a normal dictionary also sets
the key to the ``default`` value if the ``key`` was not present in the
dictionary.
"""
def __getitem__(self, key):
if key not in self:
self[key] = OptimizerState()
return super().__getitem__(key)
def get(self, key, default):
"""If ``key`` doesn't exist set its value to ``default`` and then return it."""
if key not in self:
self[key] = default
return super().__getitem__(key)
class Optimizer:
"""The base class for all optimizers. It allows us to implement an
optimizer on a per-parameter basis and apply it to a parameter tree.
Attributes:
state (OptimizerState): It holds the optimizer's state dictionary.
"""
def __init__(self):
self.state = OptimizerState()
def update(self, model: "mlx.nn.Module", gradients: dict):
"""Apply the gradients to the parameters of the model and update the
model with the new parameters.
Args:
model (mlx.nn.Module): An mlx module to be updated.
gradients (dict): A Python tree of gradients, most likely computed
via :func:`mlx.nn.value_and_grad`.
"""
model.update(self.apply_gradients(gradients, model))
def apply_gradients(self, gradients: dict, model: dict):
"""Apply the gradients to the parameters and return the updated parameters.
Can be used to update a model via
``model.update(opt.apply_gradients(grads, model))`` which is precisely
how :meth:`Optimizer.update` is implemented.
Args:
gradients (dict): A Python tree of gradients.
model (dict): A Python tree of parameters. It can be a superset of
the gradients. In that case the returned python tree
will be of the same structure as the gradients.
"""
return tree_map(self.apply_single, gradients, model, self.state)
def apply_single(
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
):
"""To be extended by the children classes to implement each optimizer's
update."""
raise NotImplementedError()
class SGD(Optimizer):
r"""Stochastic gradient descent optimizer.
Updates a parameter :math:`w` with a gradient :math:`g` as follows
.. math::
v_{t+1} &= \mu v_t + (1 - \mu) g_t \\
w_{t+1} &= w_t - \lambda v_{t+1}
Args:
learning_rate (float): The learning :math:`\lambda` for the update
momentum (float): The momentum strength :math:`\mu`
"""
def __init__(self, learning_rate: float, momentum: float = 0.0):
super().__init__()
self.learning_rate = learning_rate
self.momentum = momentum
def apply_single(
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
):
"""Performs the SGD parameter update and stores :math:`v` in the
optimizer state."""
if self.momentum <= 0:
return parameter - self.learning_rate * gradient
v = state.get("v", mx.zeros_like(gradient))
v = self.momentum * v + (1 - self.momentum) * gradient
state["v"] = v
return parameter - self.learning_rate * v
class Adam(Optimizer):
r"""Implementation of the Adam optimizer [1].
Our Adam implementation follows the original paper and omits the bias
correction in the first and second moment estimates. In detail,
.. math::
m_{t+1} &= \beta_1 m_t + (1 - \beta_1) g_t \\
v_{t+1} &= \beta_2 v_t + (1 - \beta_2) g_t^2 \\
w_{t+1} &= w_t - \lambda \frac{m_{t+1}}{\sqrt{v_{t+1} + \epsilon}}
[1]: Kingma, D.P. and Ba, J., 2015. Adam: A method for stochastic
optimization. ICLR 2015.
"""
def __init__(
self, learning_rate: float, betas: List[float] = [0.9, 0.999], eps: float = 1e-8
):
super().__init__()
self.learning_rate = learning_rate
self.betas = betas
self.eps = eps
def apply_single(
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
):
"""Performs the Adam parameter update and stores :math:`v` and
:math:`m` in the optimizer state."""
lr = self.learning_rate
b1, b2 = self.betas
eps = self.eps
m = state.get("m", gradient)
v = state.get("v", mx.square(gradient))
m = b1 * m + (1 - b1) * gradient
v = b2 * v + (1 - b2) * mx.square(gradient)
state["m"] = m
state["v"] = v
return parameter - lr * m / (mx.sqrt(v) + eps)

32
python/src/CMakeLists.txt Normal file
View File

@@ -0,0 +1,32 @@
pybind11_add_module(
core
${CMAKE_CURRENT_SOURCE_DIR}/mlx.cpp
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
${CMAKE_CURRENT_SOURCE_DIR}/stream.cpp
${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp
${CMAKE_CURRENT_SOURCE_DIR}/random.cpp
)
if (NOT MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY)
set(MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY ${CMAKE_LIBRARY_OUTPUT_DIRECTORY})
endif()
set_target_properties(
core
PROPERTIES
LIBRARY_OUTPUT_DIRECTORY
${MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY}
)
target_link_libraries(core PRIVATE mlx)
target_compile_definitions(core PRIVATE _VERSION_=${MLX_VERSION})
if(BUILD_SHARED_LIBS)
target_link_options(core PRIVATE -Wl,-rpath,@loader_path/lib)
endif()

468
python/src/fft.cpp Normal file
View File

@@ -0,0 +1,468 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "python/src/utils.h"
#include "mlx/fft.h"
#include "mlx/ops.h"
namespace py = pybind11;
using namespace py::literals;
using namespace mlx::core;
void init_fft(py::module_& parent_module) {
auto m = parent_module.def_submodule(
"fft", "mlx.core.fft: Fast Fourier Transforms.");
m.def(
"fft",
[](const array& a,
const std::optional<int>& n,
int axis,
StreamOrDevice s) {
if (n.has_value()) {
return fft::fft(a, n.value(), axis, s);
} else {
return fft::fft(a, axis, s);
}
},
"a"_a,
"n"_a = none,
"axis"_a = -1,
"stream"_a = none,
R"pbdoc(
One dimensional discrete Fourier Transform.
Args:
a (array): The input array.
n (int, optional): Size of the transformed axis. The
corresponding axis in the input is truncated or padded with
zeros to match ``n``. The default value is ``a.shape[axis]``.
axis (int, optional): Axis along which to perform the FFT. The
default is ``-1``.
Returns:
array: The DFT of the input along the given axis.
)pbdoc");
m.def(
"ifft",
[](const array& a,
const std::optional<int>& n,
int axis,
StreamOrDevice s) {
if (n.has_value()) {
return fft::ifft(a, n.value(), axis, s);
} else {
return fft::ifft(a, axis, s);
}
},
"a"_a,
"n"_a = none,
"axis"_a = -1,
"stream"_a = none,
R"pbdoc(
One dimensional inverse discrete Fourier Transform.
Args:
a (array): The input array.
n (int, optional): Size of the transformed axis. The
corresponding axis in the input is truncated or padded with
zeros to match ``n``. The default value is ``a.shape[axis]``.
axis (int, optional): Axis along which to perform the FFT. The
default is ``-1``.
Returns:
array: The inverse DFT of the input along the given axis.
)pbdoc");
m.def(
"fft2",
[](const array& a,
const std::optional<std::vector<int>>& n,
const std::optional<std::vector<int>>& axes,
StreamOrDevice s) {
if (axes.has_value() && n.has_value()) {
return fft::fftn(a, n.value(), axes.value(), s);
} else if (axes.has_value()) {
return fft::fftn(a, axes.value(), s);
} else if (n.has_value()) {
std::vector<int> axes_(n.value().size());
std::iota(axes_.begin(), axes_.end(), -n.value().size());
return fft::fftn(a, n.value(), axes_, s);
} else {
return fft::fftn(a, s);
}
},
"a"_a,
"s"_a = none,
"axes"_a = std::vector<int>{-2, -1},
"stream"_a = none,
R"pbdoc(
Two dimensional discrete Fourier Transform.
Args:
a (array): The input array.
s (list(int), optional): Sizes of the transformed axes. The
corresponding axes in the input are truncated or padded with
zeros to match the sizes in ``s``. The default value is the
sizes of ``a`` along ``axes``.
axes (list(int), optional): Axes along which to perform the FFT.
The default is ``[-2, -1]``.
Returns:
array: The DFT of the input along the given axes.
)pbdoc");
m.def(
"ifft2",
[](const array& a,
const std::optional<std::vector<int>>& n,
const std::optional<std::vector<int>>& axes,
StreamOrDevice s) {
if (axes.has_value() && n.has_value()) {
return fft::ifftn(a, n.value(), axes.value(), s);
} else if (axes.has_value()) {
return fft::ifftn(a, axes.value(), s);
} else if (n.has_value()) {
std::vector<int> axes_(n.value().size());
std::iota(axes_.begin(), axes_.end(), -n.value().size());
return fft::ifftn(a, n.value(), axes_, s);
} else {
return fft::ifftn(a, s);
}
},
"a"_a,
"s"_a = none,
"axes"_a = std::vector<int>{-2, -1},
"stream"_a = none,
R"pbdoc(
Two dimensional inverse discrete Fourier Transform.
Args:
a (array): The input array.
s (list(int), optional): Sizes of the transformed axes. The
corresponding axes in the input are truncated or padded with
zeros to match the sizes in ``s``. The default value is the
sizes of ``a`` along ``axes``.
axes (list(int), optional): Axes along which to perform the FFT.
The default is ``[-2, -1]``.
Returns:
array: The inverse DFT of the input along the given axes.
)pbdoc");
m.def(
"fftn",
[](const array& a,
const std::optional<std::vector<int>>& n,
const std::optional<std::vector<int>>& axes,
StreamOrDevice s) {
if (axes.has_value() && n.has_value()) {
return fft::fftn(a, n.value(), axes.value(), s);
} else if (axes.has_value()) {
return fft::fftn(a, axes.value(), s);
} else if (n.has_value()) {
std::vector<int> axes_(n.value().size());
std::iota(axes_.begin(), axes_.end(), -n.value().size());
return fft::fftn(a, n.value(), axes_, s);
} else {
return fft::fftn(a, s);
}
},
"a"_a,
"s"_a = none,
"axes"_a = none,
"stream"_a = none,
R"pbdoc(
n-dimensional discrete Fourier Transform.
Args:
a (array): The input array.
s (list(int), optional): Sizes of the transformed axes. The
corresponding axes in the input are truncated or padded with
zeros to match the sizes in ``s``. The default value is the
sizes of ``a`` along ``axes``.
axes (list(int), optional): Axes along which to perform the FFT.
The default is ``None`` in which case the FFT is over the last
``len(s)`` axes are or all axes if ``s`` is also ``None``.
Returns:
array: The DFT of the input along the given axes.
)pbdoc");
m.def(
"ifftn",
[](const array& a,
const std::optional<std::vector<int>>& n,
const std::optional<std::vector<int>>& axes,
StreamOrDevice s) {
if (axes.has_value() && n.has_value()) {
return fft::ifftn(a, n.value(), axes.value(), s);
} else if (axes.has_value()) {
return fft::ifftn(a, axes.value(), s);
} else if (n.has_value()) {
std::vector<int> axes_(n.value().size());
std::iota(axes_.begin(), axes_.end(), -n.value().size());
return fft::ifftn(a, n.value(), axes_, s);
} else {
return fft::ifftn(a, s);
}
},
"a"_a,
"s"_a = none,
"axes"_a = none,
"stream"_a = none,
R"pbdoc(
n-dimensional inverse discrete Fourier Transform.
Args:
a (array): The input array.
s (list(int), optional): Sizes of the transformed axes. The
corresponding axes in the input are truncated or padded with
zeros to match the sizes in ``s``. The default value is the
sizes of ``a`` along ``axes``.
axes (list(int), optional): Axes along which to perform the FFT.
The default is ``None`` in which case the FFT is over the last
``len(s)`` axes or all axes if ``s`` is also ``None``.
Returns:
array: The inverse DFT of the input along the given axes.
)pbdoc");
m.def(
"rfft",
[](const array& a,
const std::optional<int>& n,
int axis,
StreamOrDevice s) {
if (n.has_value()) {
return fft::rfft(a, n.value(), axis, s);
} else {
return fft::rfft(a, axis, s);
}
},
"a"_a,
"n"_a = none,
"axis"_a = -1,
"stream"_a = none,
R"pbdoc(
One dimensional discrete Fourier Transform on a real input.
The output has the same shape as the input except along ``axis`` in
which case it has size ``n // 2 + 1``.
Args:
a (array): The input array. If the array is complex it will be silently
cast to a real type.
n (int, optional): Size of the transformed axis. The
corresponding axis in the input is truncated or padded with
zeros to match ``n``. The default value is ``a.shape[axis]``.
axis (int, optional): Axis along which to perform the FFT. The
default is ``-1``.
Returns:
array: The DFT of the input along the given axis. The output
data type will be complex.
)pbdoc");
m.def(
"irfft",
[](const array& a,
const std::optional<int>& n,
int axis,
StreamOrDevice s) {
if (n.has_value()) {
return fft::irfft(a, n.value(), axis, s);
} else {
return fft::irfft(a, axis, s);
}
},
"a"_a,
"n"_a = none,
"axis"_a = -1,
"stream"_a = none,
R"pbdoc(
The inverse of :func:`rfft`.
The output has the same shape as the input except along ``axis`` in
which case it has size ``n``.
Args:
a (array): The input array.
n (int, optional): Size of the transformed axis. The
corresponding axis in the input is truncated or padded with
zeros to match ``n // 2 + 1``. The default value is
``a.shape[axis] // 2 + 1``.
axis (int, optional): Axis along which to perform the FFT. The
default is ``-1``.
Returns:
array: The real array containing the inverse of :func:`rfft`.
)pbdoc");
m.def(
"rfft2",
[](const array& a,
const std::optional<std::vector<int>>& n,
const std::optional<std::vector<int>>& axes,
StreamOrDevice s) {
if (axes.has_value() && n.has_value()) {
return fft::rfftn(a, n.value(), axes.value(), s);
} else if (axes.has_value()) {
return fft::rfftn(a, axes.value(), s);
} else if (n.has_value()) {
std::vector<int> axes_(n.value().size());
std::iota(axes_.begin(), axes_.end(), -n.value().size());
return fft::rfftn(a, n.value(), axes_, s);
} else {
return fft::rfftn(a, s);
}
},
"a"_a,
"s"_a = none,
"axes"_a = std::vector<int>{-2, -1},
"stream"_a = none,
R"pbdoc(
Two dimensional real discrete Fourier Transform.
The output has the same shape as the input except along the dimensions in
``axes`` in which case it has sizes from ``s``. The last axis in ``axes`` is
treated as the real axis and will have size ``s[-1] // 2 + 1``.
Args:
a (array): The input array. If the array is complex it will be silently
cast to a real type.
s (list(int), optional): Sizes of the transformed axes. The
corresponding axes in the input are truncated or padded with
zeros to match the sizes in ``s``. The default value is the
sizes of ``a`` along ``axes``.
axes (list(int), optional): Axes along which to perform the FFT.
The default is ``[-2, -1]``.
Returns:
array: The real DFT of the input along the given axes. The output
data type will be complex.
)pbdoc");
m.def(
"irfft2",
[](const array& a,
const std::optional<std::vector<int>>& n,
const std::optional<std::vector<int>>& axes,
StreamOrDevice s) {
if (axes.has_value() && n.has_value()) {
return fft::irfftn(a, n.value(), axes.value(), s);
} else if (axes.has_value()) {
return fft::irfftn(a, axes.value(), s);
} else if (n.has_value()) {
std::vector<int> axes_(n.value().size());
std::iota(axes_.begin(), axes_.end(), -n.value().size());
return fft::irfftn(a, n.value(), axes_, s);
} else {
return fft::irfftn(a, s);
}
},
"a"_a,
"s"_a = none,
"axes"_a = std::vector<int>{-2, -1},
"stream"_a = none,
R"pbdoc(
The inverse of :func:`rfft2`.
Note the input is generally complex. The dimensions of the input
specified in ``axes`` are padded or truncated to match the sizes
from ``s``. The last axis in ``axes`` is treated as the real axis
and will have size ``s[-1] // 2 + 1``.
Args:
a (array): The input array.
s (list(int), optional): Sizes of the transformed axes. The
corresponding axes in the input are truncated or padded with
zeros to match the sizes in ``s`` except for the last axis
which has size ``s[-1] // 2 + 1``. The default value is the
sizes of ``a`` along ``axes``.
axes (list(int), optional): Axes along which to perform the FFT.
The default is ``[-2, -1]``.
Returns:
array: The real array containing the inverse of :func:`rfft2`.
)pbdoc");
m.def(
"rfftn",
[](const array& a,
const std::optional<std::vector<int>>& n,
const std::optional<std::vector<int>>& axes,
StreamOrDevice s) {
if (axes.has_value() && n.has_value()) {
return fft::rfftn(a, n.value(), axes.value(), s);
} else if (axes.has_value()) {
return fft::rfftn(a, axes.value(), s);
} else if (n.has_value()) {
std::vector<int> axes_(n.value().size());
std::iota(axes_.begin(), axes_.end(), -n.value().size());
return fft::rfftn(a, n.value(), axes_, s);
} else {
return fft::rfftn(a, s);
}
},
"a"_a,
"s"_a = none,
"axes"_a = none,
"stream"_a = none,
R"pbdoc(
n-dimensional real discrete Fourier Transform.
The output has the same shape as the input except along the dimensions in
``axes`` in which case it has sizes from ``s``. The last axis in ``axes`` is
treated as the real axis and will have size ``s[-1] // 2 + 1``.
Args:
a (array): The input array. If the array is complex it will be silently
cast to a real type.
s (list(int), optional): Sizes of the transformed axes. The
corresponding axes in the input are truncated or padded with
zeros to match the sizes in ``s``. The default value is the
sizes of ``a`` along ``axes``.
axes (list(int), optional): Axes along which to perform the FFT.
The default is ``None`` in which case the FFT is over the last
``len(s)`` axes or all axes if ``s`` is also ``None``.
Returns:
array: The real DFT of the input along the given axes. The output
)pbdoc");
m.def(
"irfftn",
[](const array& a,
const std::optional<std::vector<int>>& n,
const std::optional<std::vector<int>>& axes,
StreamOrDevice s) {
if (axes.has_value() && n.has_value()) {
return fft::irfftn(a, n.value(), axes.value(), s);
} else if (axes.has_value()) {
return fft::irfftn(a, axes.value(), s);
} else if (n.has_value()) {
std::vector<int> axes_(n.value().size());
std::iota(axes_.begin(), axes_.end(), -n.value().size());
return fft::irfftn(a, n.value(), axes_, s);
} else {
return fft::irfftn(a, s);
}
},
"a"_a,
"s"_a = none,
"axes"_a = none,
"stream"_a = none,
R"pbdoc(
The inverse of :func:`rfftn`.
Note the input is generally complex. The dimensions of the input
specified in ``axes`` are padded or truncated to match the sizes
from ``s``. The last axis in ``axes`` is treated as the real axis
and will have size ``s[-1] // 2 + 1``.
Args:
a (array): The input array.
s (list(int), optional): Sizes of the transformed axes. The
corresponding axes in the input are truncated or padded with
zeros to match the sizes in ``s``. The default value is the
sizes of ``a`` along ``axes``.
axes (list(int), optional): Axes along which to perform the FFT.
The default is ``None`` in which case the FFT is over the last
``len(s)`` axes or all axes if ``s`` is also ``None``.
Returns:
array: The real array containing the inverse of :func:`rfftn`.
)pbdoc");
}

635
python/src/indexing.cpp Normal file
View File

@@ -0,0 +1,635 @@
#include <numeric>
#include <sstream>
#include "python/src/indexing.h"
#include "mlx/ops.h"
bool is_none_slice(const py::slice& in_slice) {
return (
py::getattr(in_slice, "start").is_none() &&
py::getattr(in_slice, "stop").is_none() &&
py::getattr(in_slice, "step").is_none());
}
int get_slice_int(py::object obj, int default_val) {
if (!obj.is_none()) {
if (!py::isinstance<py::int_>(obj)) {
throw std::invalid_argument("Slice indices must be integers or None.");
}
return py::cast<int>(py::cast<py::int_>(obj));
}
return default_val;
}
void get_slice_params(
int& starts,
int& ends,
int& strides,
const py::slice& in_slice,
int axis_size) {
// Following numpy's convention
// Assume n is the number of elements in the dimension being sliced.
// Then, if i is not given it defaults to 0 for k > 0 and n - 1 for
// k < 0 . If j is not given it defaults to n for k > 0 and -n-1 for
// k < 0 . If k is not given it defaults to 1
strides = get_slice_int(py::getattr(in_slice, "step"), 1);
starts = get_slice_int(
py::getattr(in_slice, "start"), strides < 0 ? axis_size - 1 : 0);
ends = get_slice_int(
py::getattr(in_slice, "stop"), strides < 0 ? -axis_size - 1 : axis_size);
// starts = (starts < 0) ? starts + axis_size : starts;
// ends = (ends < 0) ? ends + axis_size : ends;
}
array get_int_index(py::object idx, int axis_size) {
int idx_ = py::cast<int>(idx);
idx_ = (idx_ < 0) ? idx_ + axis_size : idx_;
return array(idx_, uint32);
}
bool is_valid_index_type(const py::object& obj) {
return py::isinstance<py::slice>(obj) || py::isinstance<py::int_>(obj) ||
py::isinstance<array>(obj) || obj.is_none() || py::ellipsis().is(obj);
}
array mlx_get_item_slice(const array& src, const py::slice& in_slice) {
// Check input and raise error if 0 dim for parity with np
if (src.ndim() == 0) {
throw std::invalid_argument(
"too many indices for array: array is 0-dimensional");
}
// Return a copy of the array if none slice is request
if (is_none_slice(in_slice)) {
return src;
}
std::vector<int> starts(src.ndim(), 0);
std::vector<int> ends = src.shape();
std::vector<int> strides(src.ndim(), 1);
// Check and update slice params
get_slice_params(starts[0], ends[0], strides[0], in_slice, ends[0]);
return slice(src, starts, ends, strides);
}
array mlx_get_item_array(const array& src, const array& indices) {
// Check input and raise error if 0 dim for parity with np
if (src.ndim() == 0) {
throw std::invalid_argument(
"too many indices for array: array is 0-dimensional");
}
if (indices.dtype() == bool_) {
throw std::invalid_argument("boolean indices are not yet supported");
}
// If only one input array is mentioned, we set axis=0 in take
// for parity with np
return take(src, indices, 0);
}
array mlx_get_item_int(const array& src, const py::int_& idx) {
// Check input and raise error if 0 dim for parity with np
if (src.ndim() == 0) {
throw std::invalid_argument(
"too many indices for array: array is 0-dimensional");
}
// If only one input idx is mentioned, we set axis=0 in take
// for parity with np
return take(src, get_int_index(idx, src.shape(0)), 0);
}
array mlx_gather_nd(
array src,
const std::vector<py::object>& indices,
bool gather_first,
int& max_dims) {
max_dims = 0;
std::vector<array> gather_indices;
std::vector<bool> is_slice(indices.size(), false);
int num_slices = 0;
// gather all the arrays
for (int i = 0; i < indices.size(); i++) {
auto& idx = indices[i];
if (py::isinstance<py::slice>(idx)) {
int start, end, stride;
get_slice_params(start, end, stride, idx, src.shape(i));
gather_indices.push_back(arange(start, end, stride, uint32));
num_slices++;
is_slice[i] = true;
} else if (py::isinstance<py::int_>(idx)) {
gather_indices.push_back(get_int_index(idx, src.shape(i)));
} else if (py::isinstance<array>(idx)) {
auto arr = py::cast<array>(idx);
max_dims = std::max(static_cast<int>(arr.ndim()), max_dims);
gather_indices.push_back(arr);
}
}
// reshape them so that the int/array indices are first
if (gather_first) {
int slice_index = 0;
for (int i = 0; i < gather_indices.size(); i++) {
if (is_slice[i]) {
std::vector<int> index_shape(max_dims + num_slices, 1);
index_shape[max_dims + slice_index] = gather_indices[i].shape(0);
gather_indices[i] = reshape(gather_indices[i], index_shape);
slice_index++;
} else {
std::vector<int> index_shape = gather_indices[i].shape();
index_shape.insert(index_shape.end(), num_slices, 1);
gather_indices[i] = reshape(gather_indices[i], index_shape);
}
}
} else {
// reshape them so that the int/array indices are last
for (int i = 0; i < gather_indices.size(); i++) {
if (i < num_slices) {
std::vector<int> index_shape(max_dims + num_slices, 1);
index_shape[i] = gather_indices[i].shape(0);
gather_indices[i] = reshape(gather_indices[i], index_shape);
}
}
}
// Do the gather
std::vector<int> axes(indices.size());
std::iota(axes.begin(), axes.end(), 0);
std::vector<int> slice_sizes = src.shape();
std::fill(slice_sizes.begin(), slice_sizes.begin() + indices.size(), 1);
src = gather(src, gather_indices, axes, slice_sizes);
// Squeeze the dims
std::vector<int> out_shape;
out_shape.insert(
out_shape.end(),
src.shape().begin(),
src.shape().begin() + max_dims + num_slices);
out_shape.insert(
out_shape.end(),
src.shape().begin() + max_dims + num_slices + indices.size(),
src.shape().end());
src = reshape(src, out_shape);
return src;
}
array mlx_get_item_nd(array src, const py::tuple& entries) {
// No indices make this a noop
if (entries.size() == 0) {
return src;
}
// The plan is as follows:
// 1. Replace the ellipsis with a series of slice(None)
// 2. Loop over the indices and calculate the gather indices
// 3. Calculate the remaining slices and reshapes
// Ellipsis handling
std::vector<py::object> indices;
{
int non_none_indices_before = 0;
int non_none_indices_after = 0;
std::vector<py::object> r_indices;
int i = 0;
for (; i < entries.size(); i++) {
auto idx = entries[i];
if (!is_valid_index_type(idx)) {
throw std::invalid_argument(
"Cannot index mlx array using the given type yet");
}
if (!py::ellipsis().is(idx)) {
indices.push_back(idx);
non_none_indices_before += !idx.is_none();
} else {
break;
}
}
for (int j = entries.size() - 1; j > i; j--) {
auto idx = entries[j];
if (!is_valid_index_type(idx)) {
throw std::invalid_argument(
"Cannot index mlx array using the given type yet");
}
if (py::ellipsis().is(idx)) {
throw std::invalid_argument(
"An index can only have a single ellipsis (...)");
}
r_indices.push_back(idx);
non_none_indices_after += !idx.is_none();
}
for (int axis = non_none_indices_before;
axis < src.ndim() - non_none_indices_after;
axis++) {
indices.push_back(py::slice(0, src.shape(axis), 1));
}
indices.insert(indices.end(), r_indices.rbegin(), r_indices.rend());
}
// Check for the number of indices passed
{
int cnt = src.ndim();
for (auto& idx : indices) {
if (!idx.is_none()) {
cnt--;
}
}
if (cnt < 0) {
std::ostringstream msg;
msg << "Too many indices for array with " << src.ndim() << "dimensions.";
throw std::invalid_argument(msg.str());
}
}
// Gather handling
//
// Check whether we have arrays or integer indices and delegate to gather_nd
// after removing the slices at the end and all Nones.
std::vector<py::object> remaining_indices;
bool have_array = false;
{
// First check whether the results of gather are going to be 1st or
// normally in between.
bool have_non_array = false;
bool gather_first = false;
for (auto& idx : indices) {
if (py::isinstance<array>(idx) || py::isinstance<py::int_>(idx)) {
if (have_array && have_non_array) {
gather_first = true;
break;
}
have_array = true;
} else {
have_non_array |= have_array;
}
}
if (have_array) {
int last_array;
// Then find the last array
for (last_array = indices.size() - 1; last_array >= 0; last_array--) {
auto& idx = indices[last_array];
if (py::isinstance<array>(idx) || py::isinstance<py::int_>(idx)) {
break;
}
}
std::vector<py::object> gather_indices;
for (int i = 0; i <= last_array; i++) {
auto& idx = indices[i];
if (!idx.is_none()) {
gather_indices.push_back(idx);
}
}
int max_dims;
src = mlx_gather_nd(src, gather_indices, gather_first, max_dims);
// Reassemble the indices for the slicing or reshaping if there are any
if (gather_first) {
for (int i = 0; i < max_dims; i++) {
remaining_indices.push_back(
py::slice(py::none(), py::none(), py::none()));
}
for (int i = 0; i < last_array; i++) {
auto& idx = indices[i];
if (idx.is_none()) {
remaining_indices.push_back(indices[i]);
} else if (py::isinstance<py::slice>(idx)) {
remaining_indices.push_back(
py::slice(py::none(), py::none(), py::none()));
}
}
for (int i = last_array + 1; i < indices.size(); i++) {
remaining_indices.push_back(indices[i]);
}
} else {
for (int i = 0; i < indices.size(); i++) {
auto& idx = indices[i];
if (py::isinstance<array>(idx) || py::isinstance<py::int_>(idx)) {
break;
} else if (idx.is_none()) {
remaining_indices.push_back(idx);
} else {
remaining_indices.push_back(
py::slice(py::none(), py::none(), py::none()));
}
}
for (int i = 0; i < max_dims; i++) {
remaining_indices.push_back(
py::slice(py::none(), py::none(), py::none()));
}
for (int i = last_array + 1; i < indices.size(); i++) {
remaining_indices.push_back(indices[i]);
}
}
}
}
if (have_array && remaining_indices.empty()) {
return src;
}
if (remaining_indices.empty()) {
remaining_indices = indices;
}
// Slice handling
{
std::vector<int> starts(src.ndim(), 0);
std::vector<int> ends = src.shape();
std::vector<int> strides(src.ndim(), 1);
int axis = 0;
for (auto& idx : remaining_indices) {
if (!idx.is_none()) {
get_slice_params(
starts[axis], ends[axis], strides[axis], idx, ends[axis]);
axis++;
}
}
src = slice(src, starts, ends, strides);
}
// Unsqueeze handling
if (remaining_indices.size() > src.ndim()) {
std::vector<int> out_shape;
int axis = 0;
for (auto& idx : remaining_indices) {
if (idx.is_none()) {
out_shape.push_back(1);
} else {
out_shape.push_back(src.shape(axis++));
}
}
src = reshape(src, out_shape);
}
return src;
}
array mlx_get_item(const array& src, const py::object& obj) {
if (py::isinstance<py::slice>(obj)) {
return mlx_get_item_slice(src, obj);
} else if (py::isinstance<array>(obj)) {
return mlx_get_item_array(src, py::cast<array>(obj));
} else if (py::isinstance<py::int_>(obj)) {
return mlx_get_item_int(src, obj);
} else if (py::isinstance<py::tuple>(obj)) {
return mlx_get_item_nd(src, obj);
} else if (obj.is_none()) {
std::vector<int> s(1, 1);
s.insert(s.end(), src.shape().begin(), src.shape().end());
return reshape(src, s);
}
throw std::invalid_argument("Cannot index mlx array using the given type.");
}
array mlx_set_item_int(
const array& src,
const py::int_& idx,
const array& update) {
if (src.ndim() == 0) {
throw std::invalid_argument(
"too many indices for array: array is 0-dimensional");
}
// Remove any leading singleton dimensions from the update
// and then broadcast update to shape of src[0, ...]
int s = 0;
for (; s < update.ndim() && update.shape(s) == 1; s++)
;
auto up_shape =
std::vector<int>(update.shape().begin() + s, update.shape().end());
auto shape = src.shape();
shape[0] = 1;
return scatter(
src,
get_int_index(idx, src.shape(0)),
broadcast_to(reshape(update, up_shape), shape),
0);
}
array mlx_set_item_array(
const array& src,
const array& indices,
const array& update) {
if (src.ndim() == 0) {
throw std::invalid_argument(
"too many indices for array: array is 0-dimensional");
}
// Remove any leading singleton dimensions from the update
int s = 0;
for (; s < update.ndim() && update.shape(s) == 1; s++)
;
auto up_shape =
std::vector<int>(update.shape().begin() + s, update.shape().end());
auto up = reshape(update, up_shape);
// The update shape must broadcast with indices.shape + [1] + src.shape[1:]
up_shape = indices.shape();
up_shape.insert(up_shape.end(), src.shape().begin() + 1, src.shape().end());
up = broadcast_to(up, up_shape);
up_shape.insert(up_shape.begin() + indices.ndim(), 1);
up = reshape(up, up_shape);
return scatter(src, indices, up, 0);
}
array mlx_set_item_slice(
const array& src,
const py::slice& in_slice,
const array& update) {
// Check input and raise error if 0 dim for parity with np
if (src.ndim() == 0) {
throw std::invalid_argument(
"too many indices for array: array is 0-dimensional");
}
// If none slice is requested broadcast the update
// to the src size and return it.
if (is_none_slice(in_slice)) {
int s = 0;
for (; s < update.ndim() && update.shape(s) == 1; s++)
;
auto up_shape =
std::vector<int>(update.shape().begin() + s, update.shape().end());
return broadcast_to(reshape(update, up_shape), src.shape());
}
int start = 0;
int end = src.shape(0);
int stride = 1;
// Check and update slice params
get_slice_params(start, end, stride, in_slice, end);
return mlx_set_item_array(src, arange(start, end, stride, uint32), update);
}
array mlx_set_item_nd(
const array& src,
const py::tuple& entries,
const array& update) {
std::vector<py::object> indices;
int non_none_indices = 0;
// Expand ellipses into a series of ':' slices
{
int non_none_indices_before = 0;
int non_none_indices_after = 0;
bool has_ellipsis = false;
int indices_before = 0;
for (int i = 0; i < entries.size(); ++i) {
auto idx = entries[i];
if (!is_valid_index_type(idx)) {
throw std::invalid_argument(
"Cannot index mlx array using the given type yet");
} else if (!py::ellipsis().is(idx)) {
if (!has_ellipsis) {
indices_before++;
non_none_indices_before += !idx.is_none();
} else {
non_none_indices_after += !idx.is_none();
}
indices.push_back(idx);
} else if (has_ellipsis) {
throw std::invalid_argument(
"An index can only have a single ellipsis (...)");
} else {
has_ellipsis = true;
}
}
if (has_ellipsis) {
for (int axis = non_none_indices_before;
axis < src.ndim() - non_none_indices_after;
axis++) {
indices.insert(
indices.begin() + indices_before, py::slice(0, src.shape(axis), 1));
}
non_none_indices = src.ndim();
} else {
non_none_indices = non_none_indices_before + non_none_indices_after;
}
}
if (non_none_indices > src.ndim()) {
std::ostringstream msg;
msg << "Too many indices for array with " << src.ndim() << "dimensions.";
throw std::invalid_argument(msg.str());
}
// Remove leading singletons dimensions from the update
int s = 0;
for (; s < update.ndim() && update.shape(s) == 1; s++) {
};
auto up_shape =
std::vector<int>(update.shape().begin() + s, update.shape().end());
auto up = reshape(update, up_shape);
// If no non-None indices return the broadcasted update
if (non_none_indices == 0) {
return broadcast_to(up, src.shape());
}
unsigned long max_dim = 0;
bool arrays_first = false;
int num_slices = 0;
int num_arrays = 0;
{
bool have_array = false;
bool have_non_array = false;
for (auto& idx : indices) {
if (py::isinstance<py::slice>(idx) || idx.is_none()) {
have_non_array = have_array;
num_slices++;
} else if (py::isinstance<array>(idx)) {
have_array = true;
if (have_array && have_non_array) {
arrays_first = true;
}
max_dim = std::max(py::cast<array>(idx).ndim(), max_dim);
num_arrays++;
}
}
}
std::vector<array> arr_indices;
int slice_num = 0;
int array_num = 0;
int ax = 0;
for (int i = 0; i < indices.size(); ++i) {
auto& pyidx = indices[i];
if (py::isinstance<py::slice>(pyidx)) {
int start, end, stride;
get_slice_params(start, end, stride, pyidx, src.shape(ax++));
auto idx = arange(start, end, stride, uint32);
std::vector<int> idx_shape(max_dim + num_slices, 1);
auto loc = slice_num + (arrays_first ? max_dim : 0);
slice_num++;
idx_shape[loc] = idx.size();
arr_indices.push_back(reshape(idx, idx_shape));
} else if (py::isinstance<py::int_>(pyidx)) {
arr_indices.push_back(get_int_index(pyidx, src.shape(ax++)));
} else if (pyidx.is_none()) {
slice_num++;
} else if (py::isinstance<array>(pyidx)) {
ax++;
auto idx = py::cast<array>(pyidx);
std::vector<int> idx_shape;
if (!arrays_first) {
idx_shape.insert(idx_shape.end(), slice_num, 1);
}
idx_shape.insert(idx_shape.end(), max_dim - idx.ndim(), 1);
idx_shape.insert(idx_shape.end(), idx.shape().begin(), idx.shape().end());
idx_shape.insert(
idx_shape.end(), num_slices - (arrays_first ? 0 : slice_num), 1);
arr_indices.push_back(reshape(idx, idx_shape));
if (!arrays_first && ++array_num == num_arrays) {
slice_num += max_dim;
}
} else {
throw std::invalid_argument(
"Cannot index mlx array using the given type yet");
}
}
arr_indices = broadcast_arrays(arr_indices);
up_shape = arr_indices[0].shape();
up_shape.insert(
up_shape.end(),
src.shape().begin() + non_none_indices,
src.shape().end());
up = broadcast_to(up, up_shape);
up_shape.insert(
up_shape.begin() + arr_indices[0].ndim(), non_none_indices, 1);
up = reshape(up, up_shape);
std::vector<int> axes(arr_indices.size(), 0);
std::iota(axes.begin(), axes.end(), 0);
return scatter(src, arr_indices, up, axes);
}
void mlx_set_item(array& src, const py::object& obj, const ScalarOrArray& v) {
auto vals = to_array(v, src.dtype());
auto impl = [&src, &obj, &vals]() {
if (py::isinstance<py::slice>(obj)) {
return mlx_set_item_slice(src, obj, vals);
} else if (py::isinstance<array>(obj)) {
return mlx_set_item_array(src, py::cast<array>(obj), vals);
} else if (py::isinstance<py::int_>(obj)) {
return mlx_set_item_int(src, obj, vals);
} else if (py::isinstance<py::tuple>(obj)) {
return mlx_set_item_nd(src, obj, vals);
} else if (obj.is_none()) {
return broadcast_to(vals, src.shape());
}
throw std::invalid_argument("Cannot index mlx array using the given type.");
};
auto out = impl();
src.overwrite_descriptor(out);
}

12
python/src/indexing.h Normal file
View File

@@ -0,0 +1,12 @@
#pragma once
#include <pybind11/pybind11.h>
#include "mlx/array.h"
#include "python/src/utils.h"
namespace py = pybind11;
using namespace mlx::core;
array mlx_get_item(const array& src, const py::object& obj);
void mlx_set_item(array& src, const py::object& obj, const ScalarOrArray& v);

290
python/src/load.cpp Normal file
View File

@@ -0,0 +1,290 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <cstring>
#include <fstream>
#include <stdexcept>
#include <string>
#include <string_view>
#include <unordered_map>
#include <vector>
#include <iostream>
#include "mlx/load.h"
#include "mlx/ops.h"
#include "mlx/utils.h"
#include "python/src/load.h"
#include "python/src/utils.h"
namespace py = pybind11;
using namespace py::literals;
using namespace mlx::core;
///////////////////////////////////////////////////////////////////////////////
// Helpers
///////////////////////////////////////////////////////////////////////////////
bool is_istream_object(const py::object& file) {
return py::hasattr(file, "read") && py::hasattr(file, "seek") &&
py::hasattr(file, "tell") && py::hasattr(file, "closed");
}
bool is_ostream_object(const py::object& file) {
return py::hasattr(file, "write") && py::hasattr(file, "seek") &&
py::hasattr(file, "tell") && py::hasattr(file, "closed");
}
bool is_zip_file(const py::module_& zipfile, const py::object& file) {
if (is_istream_object(file)) {
auto st_pos = file.attr("tell")();
bool r = (zipfile.attr("is_zipfile")(file)).cast<bool>();
file.attr("seek")(st_pos, 0);
return r;
}
return zipfile.attr("is_zipfile")(file).cast<bool>();
}
class ZipFileWrapper {
public:
ZipFileWrapper(
const py::module_& zipfile,
const py::object& file,
char mode = 'r',
int compression = 0)
: zipfile_module_(zipfile),
zipfile_object_(zipfile.attr("ZipFile")(
file,
"mode"_a = mode,
"compression"_a = compression,
"allowZip64"_a = true)),
files_list_(zipfile_object_.attr("namelist")()),
open_func_(zipfile_object_.attr("open")),
read_func_(zipfile_object_.attr("read")),
close_func_(zipfile_object_.attr("close")) {}
std::vector<std::string> namelist() const {
return files_list_.cast<std::vector<std::string>>();
}
py::object open(const std::string& key, char mode = 'r') {
// Following numpy :
// https://github.com/numpy/numpy/blob/db4f43983cb938f12c311e1f5b7165e270c393b4/numpy/lib/npyio.py#L742C36-L742C47
if (mode == 'w') {
return open_func_(key, "mode"_a = mode, "force_zip64"_a = true);
}
return open_func_(key, "mode"_a = mode);
}
private:
py::module_ zipfile_module_;
py::object zipfile_object_;
py::list files_list_;
py::object open_func_;
py::object read_func_;
py::object close_func_;
};
///////////////////////////////////////////////////////////////////////////////
// Loading
///////////////////////////////////////////////////////////////////////////////
class PyFileReader : public io::Reader {
public:
PyFileReader(py::object file)
: pyistream_(file),
readinto_func_(file.attr("readinto")),
seek_func_(file.attr("seek")),
tell_func_(file.attr("tell")) {}
bool is_open() const override {
return !pyistream_.attr("closed").cast<bool>();
}
bool good() const override {
return !pyistream_.is_none();
}
size_t tell() const override {
return tell_func_().cast<size_t>();
}
void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
override {
seek_func_(off, (int)way);
}
void read(char* data, size_t n) override {
py::object bytes_read =
readinto_func_(py::memoryview::from_buffer(data, {n}, {sizeof(char)}));
if (bytes_read.is_none() || py::cast<size_t>(bytes_read) < n) {
throw std::runtime_error("[load] Failed to read from python stream");
}
}
std::string label() const override {
return "python file object";
}
private:
py::object pyistream_;
py::object readinto_func_;
py::object seek_func_;
py::object tell_func_;
};
DictOrArray mlx_load_helper(py::object file, StreamOrDevice s) {
py::module_ zipfile = py::module_::import("zipfile");
// Assume .npz file if it is zipped
if (is_zip_file(zipfile, file)) {
// Output dictionary filename in zip -> loaded array
std::unordered_map<std::string, array> array_dict;
// Create python ZipFile object
ZipFileWrapper zipfile_object(zipfile, file);
for (const std::string& st : zipfile_object.namelist()) {
// Open zip file as a python file stream
py::object sub_file = zipfile_object.open(st);
// Create array from python fille stream
auto arr = load(std::make_shared<PyFileReader>(sub_file), s);
// Remove .npy from file if it is there
auto key = st;
if (st.length() > 4 && st.substr(st.length() - 4, 4) == ".npy")
key = st.substr(0, st.length() - 4);
// Add array to dict
array_dict.insert({key, arr});
}
// If we don't own the stream and it was passed to us, eval immediately
for (auto& [key, arr] : array_dict) {
arr.eval();
}
return {array_dict};
} else if (py::isinstance<py::str>(file)) { // Assume .npy file path string
return {load(py::cast<std::string>(file), s)};
} else if (is_istream_object(file)) {
// If we don't own the stream and it was passed to us, eval immediately
auto arr = load(std::make_shared<PyFileReader>(file), s);
arr.eval();
return {arr};
}
throw std::invalid_argument(
"[load] Input must be a file-like object, string, or pathlib.Path");
}
///////////////////////////////////////////////////////////////////////////////
// Saving
///////////////////////////////////////////////////////////////////////////////
class PyFileWriter : public io::Writer {
public:
PyFileWriter(py::object file)
: pyostream_(file),
write_func_(file.attr("write")),
seek_func_(file.attr("seek")),
tell_func_(file.attr("tell")) {}
bool is_open() const override {
return !pyostream_.attr("closed").cast<bool>();
}
bool good() const override {
return !pyostream_.is_none();
}
size_t tell() const override {
return tell_func_().cast<size_t>();
}
void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
override {
seek_func_(off, (int)way);
}
void write(const char* data, size_t n) override {
py::object bytes_written =
write_func_(py::memoryview::from_buffer(data, {n}, {sizeof(char)}));
if (bytes_written.is_none() || py::cast<size_t>(bytes_written) < n) {
throw std::runtime_error("[load] Failed to write to python stream");
}
}
std::string label() const override {
return "python file object";
}
private:
py::object pyostream_;
py::object write_func_;
py::object seek_func_;
py::object tell_func_;
};
void mlx_save_helper(py::object file, array a, bool retain_graph) {
if (py::isinstance<py::str>(file)) {
save(py::cast<std::string>(file), a, retain_graph);
return;
} else if (is_ostream_object(file)) {
save(std::make_shared<PyFileWriter>(file), a, retain_graph);
return;
}
throw std::invalid_argument(
"[save] Input must be a file-like object, string, or pathlib.Path");
}
void mlx_savez_helper(
py::object file_,
py::args args,
const py::kwargs& kwargs,
bool compressed) {
// Add .npz to the end of the filename if not already there
py::object file = file_;
if (py::isinstance<py::str>(file_)) {
std::string fname = file_.cast<std::string>();
// Add .npz to file name if it is not there
if (fname.length() < 4 || fname.substr(fname.length() - 4, 4) != ".npz")
fname += ".npz";
file = py::str(fname);
}
// Collect args and kwargs
auto arrays_dict = kwargs.cast<std::unordered_map<std::string, array>>();
auto arrays_list = args.cast<std::vector<array>>();
for (int i = 0; i < arrays_list.size(); i++) {
std::string arr_name = "arr_" + std::to_string(i);
if (arrays_dict.count(arr_name) > 0) {
throw std::invalid_argument(
"[savez] Cannot use un-named variables and keyword " + arr_name);
}
arrays_dict.insert({arr_name, arrays_list[i]});
}
// Create python ZipFile object depending on compression
py::module_ zipfile = py::module_::import("zipfile");
int compression = compressed ? zipfile.attr("ZIP_DEFLATED").cast<int>()
: zipfile.attr("ZIP_STORED").cast<int>();
char mode = 'w';
ZipFileWrapper zipfile_object(zipfile, file, mode, compression);
// Save each array
for (auto [k, a] : arrays_dict) {
std::string fname = k + ".npy";
auto py_ostream = zipfile_object.open(fname, 'w');
save(std::make_shared<PyFileWriter>(py_ostream), a);
}
return;
}

2422
python/src/ops.cpp Normal file

File diff suppressed because it is too large Load Diff

289
python/src/random.cpp Normal file
View File

@@ -0,0 +1,289 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "python/src/utils.h"
#include "mlx/ops.h"
#include "mlx/random.h"
namespace py = pybind11;
using namespace py::literals;
using namespace mlx::core;
using namespace mlx::core::random;
void init_random(py::module_& parent_module) {
auto m = parent_module.def_submodule(
"random",
"mlx.core.random: functionality related to random number generation");
m.def(
"seed",
&seed,
"seed"_a,
R"pbdoc(
Seed the global PRNG.
Args:
seed (int): Seed for the global PRNG.
)pbdoc");
m.def(
"key",
&key,
"seed"_a,
R"pbdoc(
Get a PRNG key from a seed.
Args:
seed (int): Seed for the PRNG.
Returns:
array: The PRNG key array.
)pbdoc");
m.def(
"split",
py::overload_cast<const array&, int, StreamOrDevice>(&random::split),
"key"_a,
"num"_a = 2,
"stream"_a = none,
R"pbdoc(
Split a PRNG key into sub keys.
Args:
key (array): Input key to split.
num (int, optional): Number of sub keys. Default is 2.
Returns:
array: The array of sub keys with ``num`` as its first dimension.
)pbdoc");
m.def(
"uniform",
[](const ScalarOrArray& low,
const ScalarOrArray& high,
const std::vector<int>& shape,
Dtype type,
const std::optional<array>& key,
StreamOrDevice s) {
return uniform(to_array(low), to_array(high), shape, type, key, s);
},
"low"_a = 0,
"high"_a = 1,
"shape"_a = std::vector<int>{},
"dtype"_a = float32,
"key"_a = none,
"stream"_a = none,
R"pbdoc(
Generate uniformly distributed random numbers.
The values are sampled uniformly in the half-open interval ``[low, high)``.
The lower and upper bound can be scalars or arrays and must be
broadcastable to ``shape``.
Args:
low (scalar or array, optional): Lower bound of the distribution. Default is ``0``.
high (scalar or array, optional): Upper bound of the distribution. Default is ``1``.
shape (list(int), optional): Shape of the output. Default is ``()``.
key (array, optional): A PRNG key. Default: None.
dtype (Dtype, optional): Type of the output. Default is ``float32``.
Returns:
array: The output array random values.
)pbdoc");
m.def(
"normal",
[](const std::vector<int>& shape,
Dtype type,
const std::optional<array>& key,
StreamOrDevice s) { return normal(shape, type, key, s); },
"shape"_a = std::vector<int>{},
"dtype"_a = float32,
"key"_a = none,
"stream"_a = none,
R"pbdoc(
Generate normally distributed random numbers.
Args:
shape (list(int), optional): Shape of the output. Default is ``()``.
dtype (Dtype, optional): Type of the output. Default is ``float32``.
key (array, optional): A PRNG key. Default: None.
Returns:
array: The output array of random values.
)pbdoc");
m.def(
"randint",
[](const ScalarOrArray& low,
const ScalarOrArray& high,
const std::vector<int>& shape,
Dtype type,
const std::optional<array>& key,
StreamOrDevice s) {
return randint(to_array(low), to_array(high), shape, type, key, s);
},
"low"_a,
"high"_a,
"shape"_a = std::vector<int>{},
"dtype"_a = int32,
"key"_a = none,
"stream"_a = none,
R"pbdoc(
Generate random integers from the given interval.
The values are sampled with equal probability from the integers in
half-open interval ``[low, high)``. The lower and upper bound can be
scalars or arrays and must be roadcastable to ``shape``.
Args:
low (scalar or array): Lower bound of the interval.
high (scalar or array): Upper bound of the interval.
shape (list(int), optional): Shape of the output. Defaults to ``()``.
dtype (Dtype, optional): Type of the output. Defaults to ``int32``.
key (array, optional): A PRNG key. Default: None.
Returns:
array: The array of random integers.
)pbdoc");
m.def(
"bernoulli",
[](const ScalarOrArray& p_,
const std::optional<std::vector<int>> shape,
const std::optional<array>& key,
StreamOrDevice s) {
auto p = to_array(p_);
if (shape.has_value()) {
return bernoulli(p, shape.value(), key, s);
} else {
return bernoulli(p, key, s);
}
},
"p"_a = 0.5,
"shape"_a = none,
"key"_a = none,
"stream"_a = none,
R"pbdoc(
Generate Bernoulli random values.
The values are sampled from the bernoulli distribution with parameter
``p``. The parameter ``p`` can be a :obj:`float` or :obj:`array` and
must be broadcastable to ``shape``.
Args:
p (float or array, optional): Parameter of the Bernoulli
distribution. Default is 0.5.
shape (list(int), optional): Shape of the output. The default
shape is ``p.shape``.
key (array, optional): A PRNG key. Default: None.
Returns:
array: The array of random integers.
)pbdoc");
m.def(
"truncated_normal",
[](const ScalarOrArray& lower_,
const ScalarOrArray& upper_,
const std::optional<std::vector<int>> shape_,
Dtype dtype,
const std::optional<array>& key,
StreamOrDevice s) {
auto lower = to_array(lower_);
auto upper = to_array(upper_);
if (shape_.has_value()) {
return truncated_normal(lower, upper, shape_.value(), dtype, key, s);
} else {
return truncated_normal(lower, upper, dtype, key, s);
}
},
"lower"_a,
"upper"_a,
"shape"_a = none,
"dtype"_a = float32,
"key"_a = none,
"stream"_a = none,
R"pbdoc(
Generate values from a truncated normal distribution.
The values are sampled from the truncated normal distribution
on the domain ``(lower, upper)``. The bounds ``lower`` and ``upper``
can be scalars or arrays and must be broadcastable to ``shape``.
Args:
lower (scalar or array): Lower bound of the domain.
upper (scalar or array): Upper bound of the domain.
shape (list(int), optional): The shape of the output.
Default is ``()``.
dtype (Dtype, optinoal): The data type of the output.
Default is ``float32``.
key (array, optional): A PRNG key. Default: None.
Returns:
array: The output array of random values.
)pbdoc");
m.def(
"gumbel",
&gumbel,
"shape"_a = std::vector<int>{},
"dtype"_a = float32,
"stream"_a = none,
"key"_a = none,
R"pbdoc(
Sample from the standard Gumbel distribution.
The values are sampled from a standard Gumbel distribution
which CDF ``exp(-exp(-x))``.
Args:
shape (list(int)): The shape of the output.
key (array, optional): A PRNG key. Default: None.
Returns:
array: The :class:`array` with shape ``shape`` and
distributed according to the Gumbel distribution
)pbdoc");
m.def(
"categorical",
[](const array& logits,
int axis,
const std::optional<std::vector<int>> shape,
const std::optional<int> num_samples,
const std::optional<array>& key,
StreamOrDevice s) {
if (shape.has_value() && num_samples.has_value()) {
throw std::invalid_argument(
"[categorical] At most one of shape or num_samples can be specified.");
} else if (shape.has_value()) {
return categorical(logits, axis, shape.value(), key, s);
} else if (num_samples.has_value()) {
return categorical(logits, axis, num_samples.value(), key, s);
} else {
return categorical(logits, axis, key, s);
}
},
"logits"_a,
"axis"_a = -1,
"shape"_a = none,
"num_samples"_a = none,
"key"_a = none,
"stream"_a = none,
R"pbdoc(
Sample from a categorical distribution.
The values are sampled from the categorical distribution specified by
the unnormalized values in ``logits``. Note, at most one of ``shape``
or ``num_samples`` can be specified. If both are ``None``, the output
has the same shape as ``logits`` with the ``axis`` dimension removed.
Args:
logits (array): The *unnormalized* categorical distribution(s).
axis (int, optional): The axis which specifies the distribution.
Default is ``-1``.
shape (list(int), optional): The shape of the output. This must
be broadcast compatable with ``logits.shape`` with the ``axis``
dimension removed. Default: ``None``
num_samples (int, optional): The number of samples to draw from each
of the categorical distributions in ``logits``. The output will have
``num_samples`` in the last dimension. Default: ``None``.
key (array, optional): A PRNG key. Default: None.
Returns:
array: The ``shape``-sized output array with type ``uint32``.
)pbdoc");
}

71
python/src/utils.h Normal file
View File

@@ -0,0 +1,71 @@
#pragma once
#include <numeric>
#include <variant>
#include <pybind11/complex.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "mlx/array.h"
namespace py = pybind11;
using namespace mlx::core;
using IntOrVec = std::variant<std::monostate, int, std::vector<int>>;
using ScalarOrArray =
std::variant<py::bool_, py::int_, py::float_, std::complex<float>, array>;
static constexpr std::monostate none{};
inline std::vector<int> get_reduce_axes(const IntOrVec& v, int dims) {
std::vector<int> axes;
if (std::holds_alternative<std::monostate>(v)) {
axes.resize(dims);
std::iota(axes.begin(), axes.end(), 0);
} else if (auto pv = std::get_if<int>(&v); pv) {
axes.push_back(*pv);
} else {
axes = std::get<std::vector<int>>(v);
}
return axes;
}
inline array to_array(
const ScalarOrArray& v,
std::optional<Dtype> dtype = std::nullopt) {
if (auto pv = std::get_if<py::bool_>(&v); pv) {
return array(py::cast<bool>(*pv), dtype.value_or(bool_));
} else if (auto pv = std::get_if<py::int_>(&v); pv) {
auto out_t = dtype.value_or(int32);
// bool_ is an exception and is always promoted
return array(py::cast<int>(*pv), (out_t == bool_) ? int32 : out_t);
} else if (auto pv = std::get_if<py::float_>(&v); pv) {
auto out_t = dtype.value_or(float32);
return array(
py::cast<float>(*pv), is_floating_point(out_t) ? out_t : float32);
} else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {
return array(static_cast<complex64_t>(*pv), complex64);
} else {
return std::get<array>(v);
}
}
inline std::pair<array, array> to_arrays(
const ScalarOrArray& a,
const ScalarOrArray& b) {
// Four cases:
// - If both a and b are arrays leave their types alone
// - If a is an array but b is not, treat b as a weak python type
// - If b is an array but a is not, treat a as a weak python type
// - If neither is an array convert to arrays but leave their types alone
if (auto pa = std::get_if<array>(&a); pa) {
if (auto pb = std::get_if<array>(&b); pb) {
return {*pa, *pb};
}
return {*pa, to_array(b, pa->dtype())};
} else if (auto pb = std::get_if<array>(&b); pb) {
return {to_array(a, pb->dtype()), *pb};
} else {
return {to_array(a), to_array(b)};
}
}

1041
python/tests/test_array.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,263 @@
import unittest
import mlx.core as mx
import mlx_tests
class TestAutograd(mlx_tests.MLXTestCase):
def test_jvp(self):
fun = lambda x: 2 * x
out, dout = mx.jvp(fun, [mx.array(1.0)], [mx.array(2.0)])
self.assertEqual(out[0].item(), 2.0)
self.assertEqual(dout[0].item(), 4.0)
fun = lambda x, y: x * y
_, out = mx.jvp(
fun, [mx.array(4.0), mx.array(2.0)], [mx.array(3.0), mx.array(2.0)]
)
self.assertEqual(out[0].item(), 4.0 * 2.0 + 2.0 * 3.0)
fun = lambda x, y, z: (x * y, y * z)
_, out = mx.jvp(
fun,
[mx.array(2.0), mx.array(4.0), mx.array(6.0)],
[mx.array(1.0), mx.array(3.0), mx.array(1.0)],
)
self.assertEqual(len(out), 2)
self.assertEqual(out[0].item(), 4.0 * 1.0 + 2.0 * 3.0)
self.assertEqual(out[1].item(), 4.0 * 1.0 + 6.0 * 3.0)
def test_vjp(self):
fun = lambda x: 2 * x
out, dout = mx.vjp(fun, [mx.array(1.0)], [mx.array(2.0)])
self.assertEqual(out[0].item(), 2.0)
self.assertEqual(dout[0].item(), 4.0)
fun = lambda x, y: x * y
_, dout = mx.vjp(fun, [mx.array(4.0), mx.array(2.0)], [mx.array(3.0)])
self.assertEqual(dout[0].item(), 6.0)
self.assertEqual(dout[1].item(), 12.0)
fun = lambda x, y, z: (x * y, y * z)
_, out = mx.vjp(
fun,
[mx.array(2.0), mx.array(4.0), mx.array(6.0)],
[mx.array(1.0), mx.array(3.0)],
)
self.assertEqual(len(out), 3)
self.assertEqual(out[0].item(), 4.0 * 1.0)
self.assertEqual(out[1].item(), 2.0 * 1.0 + 6.0 * 3.0)
self.assertEqual(out[2].item(), 4.0 * 3.0)
def test_grad(self):
fun = lambda x: x * x
value, dfdx = mx.value_and_grad(fun)(mx.array(0.5))
self.assertEqual(value.item(), 0.25)
self.assertEqual(dfdx.item(), 1.0)
dfdx = mx.grad(fun)(mx.array(0.5))
self.assertEqual(dfdx.item(), 1.0)
df2dx2 = mx.grad(mx.grad(fun))(mx.array(0.5))
self.assertEqual(df2dx2.item(), 2.0)
df3dx3 = mx.grad(mx.grad(mx.grad(fun)))(mx.array(0.5))
self.assertEqual(df3dx3.item(), 0.0)
fun = lambda x, y: x * y
x = mx.array(2.0)
y = mx.array(3.0)
dfdx = mx.grad(fun, argnums=0)(x, y)
self.assertEqual(dfdx.item(), 3.0)
dfdx = mx.grad(fun, argnums=1)(x, y)
self.assertEqual(dfdx.item(), 2.0)
# Pass non array args to functions works
fun = lambda x, y: x
value, dfdx = mx.value_and_grad(fun)(mx.array(2.0), "hello")
self.assertEqual(value.item(), 2.0)
self.assertEqual(dfdx.item(), 1.0)
dfdx = mx.grad(fun)(mx.array(2.0), "hello")
self.assertEqual(dfdx.item(), 1.0)
# Raises when function does not return array
fun = lambda x: "hello"
with self.assertRaises(ValueError):
mx.grad(fun)(mx.array(2.0))
# Raises for invalid argument number or argument type
fun = lambda x: x
with self.assertRaises(ValueError):
mx.grad(fun, argnums=2)(mx.array(2.0))
with self.assertRaises(ValueError):
mx.grad(fun, argnums=-2)(mx.array(2.0))
with self.assertRaises(ValueError):
mx.grad(fun)("hello")
# Raises when output is not a scalar array
fun = lambda x: mx.sum(x, keepdims=True)
with self.assertRaises(ValueError):
mx.grad(fun)(mx.ones((2, 2)))
def test_grad_trees(self):
fun = lambda x, y: x * y
value, dfdx = mx.value_and_grad(fun, (0, 1))(mx.array(0.5), mx.array(2.0))
self.assertEqual(value.item(), 1.0)
self.assertTrue(isinstance(dfdx, tuple))
self.assertEqual(dfdx[0].item(), 2.0)
self.assertEqual(dfdx[1].item(), 0.5)
fun = lambda x, y: x * y
value, dfdx = mx.value_and_grad(fun, 1)(mx.array(0.5), mx.array(2.0))
self.assertEqual(value.item(), 1.0)
self.assertEqual(dfdx.item(), 0.5)
fun = lambda p: p["x"] * p["y"]
value, dfdx = mx.value_and_grad(fun)({"x": mx.array(0.5), "y": mx.array(2.0)})
self.assertEqual(value.item(), 1.0)
self.assertEqual(dfdx["x"].item(), 2.0)
self.assertEqual(dfdx["y"].item(), 0.5)
fun = lambda p: p["x"] * p["y"]
with self.assertRaises(ValueError):
mx.value_and_grad(fun)({"x": 0.5, "y": mx.array(2.0)})
with self.assertRaises(ValueError):
mx.value_and_grad(fun, (0, 1))({"x": mx.array(0.5), "y": mx.array(2.0)})
fun = lambda p, b: mx.square(p[0]["foo"][2]) * b
value, dfdx = mx.value_and_grad(fun)(
[{"foo": [[], [], mx.array(2.0)]}], mx.array(0.5)
)
self.assertEqual(value.item(), 2.0)
self.assertEqual(dfdx[0]["foo"][2].item(), 2.0)
fun = lambda x: x
with self.assertRaises(TypeError):
mx.value_and_grad(fun, (None, None))
with self.assertRaises(ValueError):
mx.value_and_grad(fun, tuple())
def test_auxiliary_values(self):
def fun(x, y):
l = (x * y).sum()
extra = {"loss": l, "foo": y.square() + x.square(), "bar": [1, 2, 3, y, x]}
return l, extra
fun_value_grad = mx.value_and_grad(fun)
fun_grad = mx.grad(fun)
(loss, a), b = fun_value_grad(mx.ones((2, 2)), mx.ones((2, 2)))
self.assertEqual(a["loss"].item(), 4)
self.assertTrue(mx.array_equal(b, mx.ones((2, 2))))
self.assertTrue(mx.array_equal(a["foo"], 2 * mx.ones((2, 2))))
self.assertEqual(a["bar"][:3], [1, 2, 3])
self.assertTrue(mx.array_equal(a["bar"][3], mx.ones((2, 2))))
self.assertTrue(mx.array_equal(a["bar"][4], mx.ones((2, 2))))
with self.assertRaises(ValueError):
_ = fun_grad(mx.ones((2, 2)), mx.ones((2, 2)))
def test_grad_kwargs(self):
fun = lambda x, y: x * y
a, b = mx.array(0.5), mx.array(2.0)
dfdx = mx.grad(fun)
self.assertEqual(dfdx(a, b).item(), 2.0)
self.assertEqual(dfdx(a, y=b).item(), 2.0)
with self.assertRaises(ValueError):
dfdx(x=a, y=b).item()
dfdy = mx.grad(fun, argnums=[], argnames=["y"])
with self.assertRaises(ValueError):
dfdy(a, b)
grads = dfdy(a, y=b)
self.assertTrue(isinstance(grads, tuple))
self.assertTrue(grads[0] is None)
self.assertTrue(isinstance(grads[1], dict))
self.assertEqual(grads[1]["y"].item(), 0.5)
grads = dfdy(x=a, y=b)
self.assertEqual(grads[1]["y"].item(), 0.5)
self.assertEqual(len(grads[1]), 1)
dfdxy = mx.grad(fun, argnums=[0], argnames=["y"])
with self.assertRaises(ValueError):
dfdxy(a, b)
with self.assertRaises(ValueError):
dfdxy(x=a, y=b)
grads = dfdxy(a, y=b)
self.assertTrue(isinstance(grads, tuple))
self.assertEqual(grads[0].item(), 2.0)
self.assertTrue(isinstance(grads[1], dict))
self.assertEqual(grads[1]["y"].item(), 0.5)
fun = lambda x, y, z: x * y * z
dfdxyz = mx.grad(fun, argnums=[0, 1], argnames=["z"])
c = mx.array(4.0)
grads = dfdxyz(a, b, z=c)
self.assertTrue(isinstance(grads, tuple))
self.assertTrue(isinstance(grads[0], tuple))
self.assertEqual(grads[0][0].item(), 8.0)
self.assertEqual(grads[0][1].item(), 2.0)
self.assertTrue(isinstance(grads[1], dict))
self.assertEqual(grads[1]["z"].item(), 1.0)
fun = lambda x, y: x * y
dfdy = mx.grad(fun, argnames=["y"])
grads = dfdy(a, y=b)
self.assertTrue(isinstance(grads, tuple))
self.assertTrue(grads[0] is None)
self.assertTrue(isinstance(grads[1], dict))
self.assertEqual(grads[1]["y"].item(), 0.5)
def test_captured(self):
a = mx.array(5.0)
f = lambda x: a + x
g = lambda x: a + a
h = lambda x: x + x
dfdx = mx.grad(f)
self.assertEqual(dfdx(a).item(), 1.0)
dgdx = mx.grad(g)
self.assertEqual(dgdx(a).item(), 0.0)
dhdx = mx.grad(h)
self.assertEqual(dhdx(a).item(), 2.0)
d2fdx2 = mx.grad(dfdx)
self.assertEqual(d2fdx2(a).item(), 0.0)
d2gdx2 = mx.grad(dgdx)
self.assertEqual(d2gdx2(a).item(), 0.0)
d2hdx2 = mx.grad(dhdx)
self.assertEqual(d2hdx2(a).item(), 0.0)
def test_stop_gradient(self):
shape_in = (4, 4)
w_in = mx.ones(shape_in)
x_in = mx.ones(shape_in)
cotan = mx.ones(shape_in)
def h(w, x):
x1 = 2 * x
y = mx.stop_gradient(x1)
y1 = 3 * y
return w @ y1
vals, vjps = mx.vjp(h, [w_in, x_in], [cotan])
mx.eval(vjps)
self.assertTrue(mx.allclose(vjps[0], 24.0 * mx.ones(shape_in)))
self.assertTrue(mx.allclose(vjps[1], mx.zeros(shape_in)))
g = lambda x: h(w_in, x)
vals, vjps = mx.vjp(g, [x_in], [cotan])
mx.eval(vjps)
self.assertTrue(mx.allclose(vjps[0], mx.zeros(shape_in)))
if __name__ == "__main__":
unittest.main()

105
python/tests/test_device.py Normal file
View File

@@ -0,0 +1,105 @@
import unittest
import mlx.core as mx
import mlx_tests
# Don't inherit from MLXTestCase to avoid call to setUp
class TestDefaultDevice(unittest.TestCase):
def test_mlx_default_device(self):
device = mx.default_device()
if mx.metal.is_available():
self.assertEqual(device, mx.Device(mx.gpu))
self.assertEqual(str(device), "Device(gpu, 0)")
self.assertEqual(device, mx.gpu)
self.assertEqual(mx.gpu, device)
else:
self.assertEqual(device.type, mx.Device(mx.cpu))
with self.assertRaises(ValueError):
mx.set_default_device(mx.gpu)
class TestDevice(mlx_tests.MLXTestCase):
def test_device(self):
device = mx.default_device()
cpu = mx.Device(mx.cpu)
mx.set_default_device(cpu)
self.assertEqual(mx.default_device(), cpu)
self.assertEqual(str(cpu), "Device(cpu, 0)")
mx.set_default_device(mx.cpu)
self.assertEqual(mx.default_device(), mx.cpu)
self.assertEqual(cpu, mx.cpu)
self.assertEqual(mx.cpu, cpu)
# Restore device
mx.set_default_device(device)
def test_op_on_device(self):
x = mx.array(1.0)
y = mx.array(1.0)
a = mx.add(x, y, stream=None)
b = mx.add(x, y, stream=mx.default_device())
self.assertEqual(a.item(), b.item())
b = mx.add(x, y, stream=mx.cpu)
self.assertEqual(a.item(), b.item())
if mx.metal.is_available():
b = mx.add(x, y, stream=mx.gpu)
self.assertEqual(a.item(), b.item())
class TestStream(mlx_tests.MLXTestCase):
def test_stream(self):
s1 = mx.default_stream(mx.default_device())
self.assertEqual(s1.device, mx.default_device())
s2 = mx.new_stream(mx.default_device())
self.assertEqual(s2.device, mx.default_device())
self.assertNotEqual(s1, s2)
if mx.metal.is_available():
s_gpu = mx.default_stream(mx.gpu)
self.assertEqual(s_gpu.device, mx.gpu)
else:
with self.assertRaises(ValueError):
mx.default_stream(mx.gpu)
s_cpu = mx.default_stream(mx.cpu)
self.assertEqual(s_cpu.device, mx.cpu)
s_cpu = mx.new_stream(mx.cpu)
self.assertEqual(s_cpu.device, mx.cpu)
if mx.metal.is_available():
s_gpu = mx.new_stream(mx.gpu)
self.assertEqual(s_gpu.device, mx.gpu)
else:
with self.assertRaises(ValueError):
mx.new_stream(mx.gpu)
def test_op_on_stream(self):
x = mx.array(1.0)
y = mx.array(1.0)
a = mx.add(x, y, stream=mx.default_stream(mx.default_device()))
if mx.metal.is_available():
b = mx.add(x, y, stream=mx.default_stream(mx.gpu))
self.assertEqual(a.item(), b.item())
s_gpu = mx.new_stream(mx.gpu)
b = mx.add(x, y, stream=s_gpu)
self.assertEqual(a.item(), b.item())
b = mx.add(x, y, stream=mx.default_stream(mx.cpu))
self.assertEqual(a.item(), b.item())
s_cpu = mx.new_stream(mx.cpu)
b = mx.add(x, y, stream=s_cpu)
self.assertEqual(a.item(), b.item())
if __name__ == "__main__":
unittest.main()

34
python/tests/test_eval.py Normal file
View File

@@ -0,0 +1,34 @@
from functools import partial
import unittest
import mlx.core as mx
import mlx_tests
class TestEval(mlx_tests.MLXTestCase):
def test_eval(self):
arrs = [mx.ones((2, 2)) for _ in range(4)]
mx.eval(*arrs)
for x in arrs:
self.assertEqual(x.tolist(), [[1, 1], [1, 1]])
def test_retain_graph(self):
def fun(x, retain_graph):
y = 3 * x
mx.eval(y, retain_graph=retain_graph)
return 2 * y
dfun_dx_1 = mx.grad(partial(fun, retain_graph=False))
dfun_dx_2 = mx.grad(partial(fun, retain_graph=True))
with self.assertRaises(ValueError):
dfun_dx_1(mx.array(1.0))
y = dfun_dx_2(mx.array(1.0))
self.assertEqual(y.item(), 6.0)
if __name__ == "__main__":
unittest.main()

90
python/tests/test_fft.py Normal file
View File

@@ -0,0 +1,90 @@
import unittest
import itertools
import mlx.core as mx
import numpy as np
import mlx_tests
class TestFFT(mlx_tests.MLXTestCase):
def check_mx_np(self, op, a_np, axes, s):
with self.subTest(op=op, axes=axes, s=s):
op_np = getattr(np.fft, op)
op_mx = getattr(mx.fft, op)
out_np = op_np(a_np, s=s, axes=axes)
a_mx = mx.array(a_np)
out_mx = op_mx(a_mx, s=s, axes=axes)
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6))
def test_fft(self):
default = mx.default_device()
mx.set_default_device(mx.cpu)
def check_mx_np(op_mx, op_np, a_np, **kwargs):
out_np = op_np(a_np, **kwargs)
a_mx = mx.array(a_np)
out_mx = op_mx(a_mx, **kwargs)
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6))
r = np.random.rand(100).astype(np.float32)
i = np.random.rand(100).astype(np.float32)
a_np = r + 1j * i
check_mx_np(mx.fft.fft, np.fft.fft, a_np)
# Check with slicing and padding
r = np.random.rand(100).astype(np.float32)
i = np.random.rand(100).astype(np.float32)
a_np = r + 1j * i
check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=80)
check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=120)
# Check different axes
r = np.random.rand(100, 100).astype(np.float32)
i = np.random.rand(100, 100).astype(np.float32)
a_np = r + 1j * i
check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=0)
check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=1)
# Check real fft
a_np = np.random.rand(100).astype(np.float32)
check_mx_np(mx.fft.rfft, np.fft.rfft, a_np)
check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=80)
check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=120)
# Check real inverse
r = np.random.rand(100, 100).astype(np.float32)
i = np.random.rand(100, 100).astype(np.float32)
a_np = r + 1j * i
check_mx_np(mx.fft.ifft, np.fft.ifft, a_np)
check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=80)
check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=120)
check_mx_np(mx.fft.irfft, np.fft.irfft, a_np)
check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=80)
check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=120)
mx.set_default_device(default)
def test_fftn(self):
default = mx.default_device()
mx.set_default_device(mx.cpu)
r = np.random.randn(8, 8, 8).astype(np.float32)
i = np.random.randn(8, 8, 8).astype(np.float32)
a = r + 1j * i
axes = [None, (1, 2), (2, 1), (0, 2)]
shapes = [None, (10, 5), (5, 10)]
ops = ["fft2", "ifft2", "rfft2", "irfft2", "fftn", "ifftn", "rfftn", "irfftn"]
for op, ax, s in itertools.product(ops, axes, shapes):
x = a
if op in ["rfft2", "rfftn"]:
x = r
self.check_mx_np(op, x, axes=ax, s=s)
mx.set_default_device(default)
if __name__ == "__main__":
unittest.main()

1283
python/tests/test_ops.py Normal file

File diff suppressed because it is too large Load Diff

118
python/tests/test_reduce.py Normal file
View File

@@ -0,0 +1,118 @@
import unittest
from itertools import permutations, combinations
import mlx.core as mx
import numpy as np
import mlx_tests
class TestReduce(mlx_tests.MLXTestCase):
def test_axis_permutation_sums(self):
x_npy = np.random.randn(5, 5, 5, 5, 5).astype(np.float32)
x_mlx = mx.array(x_npy)
for t in permutations(range(5)):
with self.subTest(t=t):
y_npy = np.transpose(x_npy, t)
y_mlx = mx.transpose(x_mlx, t)
for n in range(1, 6):
for a in combinations(range(5), n):
with self.subTest(a=a):
z_npy = np.sum(y_npy, axis=a)
z_mlx = mx.sum(y_mlx, axis=a)
mx.eval(z_mlx)
self.assertTrue(
np.allclose(z_npy, np.array(z_mlx), atol=1e-4)
)
def test_expand_sums(self):
x_npy = np.random.randn(5, 1, 5, 1, 5, 1).astype(np.float32)
x_mlx = mx.array(x_npy)
for m in range(1, 4):
for ax in combinations([1, 3, 5], m):
shape = np.array([5, 1, 5, 1, 5, 1])
shape[list(ax)] = 5
shape = shape.tolist()
with self.subTest(shape=shape):
y_npy = np.broadcast_to(x_npy, shape)
y_mlx = mx.broadcast_to(x_mlx, shape)
for n in range(1, 7):
for a in combinations(range(6), n):
with self.subTest(a=a):
z_npy = np.sum(y_npy, axis=a) / 1000
z_mlx = mx.sum(y_mlx, axis=a) / 1000
mx.eval(z_mlx)
self.assertTrue(
np.allclose(z_npy, np.array(z_mlx), atol=1e-4)
)
def test_dtypes(self):
int_dtypes = [
"int8",
"int16",
"int32",
"uint8",
"uint16",
"uint32",
]
float_dtypes = ["float32"]
for dtype in int_dtypes + float_dtypes:
with self.subTest(dtype=dtype):
x = np.random.uniform(0, 2, size=(3, 3, 3)).astype(getattr(np, dtype))
y = mx.array(x)
for op in ("sum", "prod", "min", "max"):
with self.subTest(op=op):
np_op = getattr(np, op)
mlx_op = getattr(mx, op)
for axes in (None, 0, 1, 2, (0, 1), (0, 2), (1, 2), (0, 1, 2)):
with self.subTest(axes=axes):
if op in ("sum", "prod"):
r_np = np_op(
x, axis=axes, dtype=(getattr(np, dtype))
)
else:
r_np = np_op(x, axis=axes)
r_mlx = mlx_op(y, axis=axes)
mx.eval(r_mlx)
self.assertTrue(np.allclose(r_np, r_mlx, atol=1e-4))
def test_arg_reduce(self):
dtypes = [
"uint8",
"uint16",
"uint32",
"uint64",
"int8",
"int16",
"int32",
"int64",
"float16",
"float32",
]
for dtype in dtypes:
with self.subTest(dtype=dtype):
data = np.random.rand(10, 12, 13).astype(getattr(np, dtype))
x = mx.array(data)
for op in ["argmin", "argmax"]:
for axis in range(3):
for kd in [True, False]:
a = getattr(mx, op)(x, axis, kd)
b = getattr(np, op)(data, axis, keepdims=kd)
self.assertEqual(a.tolist(), b.tolist())
for op in ["argmin", "argmax"]:
a = getattr(mx, op)(x, keepdims=True)
b = getattr(np, op)(data, keepdims=True)
self.assertEqual(a.tolist(), b.tolist())
a = getattr(mx, op)(x)
b = getattr(np, op)(data)
self.assertEqual(a.item(), b)
if __name__ == "__main__":
unittest.main(failfast=True)