mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-30 15:28:10 +08:00
awni's commit files
This commit is contained in:
37
python/README.md
Normal file
37
python/README.md
Normal file
@@ -0,0 +1,37 @@
|
||||
### Packaging for PyPI
|
||||
|
||||
Install `build` and `twine`:
|
||||
|
||||
```
|
||||
pip install --user --upgrade build
|
||||
pip install --user --upgrade twine
|
||||
```
|
||||
|
||||
Generate the source distribution and wheel:
|
||||
|
||||
```
|
||||
python -m build
|
||||
```
|
||||
|
||||
*Warning* use a test server first
|
||||
|
||||
#### Test Upload
|
||||
|
||||
Upload to test server:
|
||||
|
||||
```
|
||||
python -m twine upload --repository testpypi dist/*
|
||||
```
|
||||
|
||||
Install from test server and check that it works:
|
||||
|
||||
```
|
||||
python -m pip install --index-url https://test.pypi.org/simple/ --no-deps mlx
|
||||
```
|
||||
|
||||
#### Upload
|
||||
|
||||
```
|
||||
python -m twine upload dist/*
|
||||
```
|
||||
|
||||
18
python/mlx/_reprlib_fix.py
Normal file
18
python/mlx/_reprlib_fix.py
Normal file
@@ -0,0 +1,18 @@
|
||||
import array
|
||||
import reprlib
|
||||
|
||||
|
||||
class FixedRepr(reprlib.Repr):
|
||||
"""Only route python array instances to repr_array."""
|
||||
|
||||
def repr_array(self, x, maxlevel):
|
||||
if isinstance(x, array.array):
|
||||
return super().repr_array(x, maxlevel)
|
||||
else:
|
||||
return self.repr_instance(x, maxlevel)
|
||||
|
||||
|
||||
# We need to monkey-patch reprlib so that we can use the debugger without
|
||||
# renaming the array to something else
|
||||
fixed_repr = FixedRepr()
|
||||
reprlib.repr = fixed_repr.repr
|
||||
94
python/mlx/extension.py
Normal file
94
python/mlx/extension.py
Normal file
@@ -0,0 +1,94 @@
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from setuptools import Extension, setup, find_namespace_packages
|
||||
from setuptools.command.build_ext import build_ext
|
||||
|
||||
import mlx
|
||||
|
||||
_MLX_PATH = str(mlx.__path__[0])
|
||||
|
||||
|
||||
# A CMakeExtension needs a sourcedir instead of a file list.
|
||||
class CMakeExtension(Extension):
|
||||
def __init__(self, name: str, sourcedir: str = "") -> None:
|
||||
super().__init__(name, sources=[])
|
||||
self.sourcedir = os.fspath(Path(sourcedir).resolve())
|
||||
|
||||
|
||||
class CMakeBuild(build_ext):
|
||||
def build_extension(self, ext: CMakeExtension) -> None:
|
||||
# Must be in this form due to bug in .resolve() only fixed in Python 3.10+
|
||||
ext_fullpath = Path.cwd() / self.get_ext_fullpath(ext.name) # type: ignore[no-untyped-call]
|
||||
extdir = ext_fullpath.parent.resolve()
|
||||
|
||||
debug = int(os.environ.get("DEBUG", 0)) if self.debug is None else self.debug
|
||||
cfg = "Debug" if debug else "Release"
|
||||
|
||||
# CMake lets you override the generator - we need to check this.
|
||||
# Can be set with Conda-Build, for example.
|
||||
cmake_generator = os.environ.get("CMAKE_GENERATOR", "")
|
||||
|
||||
# Set Python_EXECUTABLE instead if you use PYBIND11_FINDPYTHON
|
||||
# EXAMPLE_VERSION_INFO shows you how to pass a value into the C++ code
|
||||
# from Python.
|
||||
cmake_args = [
|
||||
f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}{os.sep}",
|
||||
f"-DCMAKE_BUILD_TYPE={cfg}",
|
||||
"-DBUILD_SHARED_LIBS=ON",
|
||||
]
|
||||
build_args = []
|
||||
# Adding CMake arguments set as environment variable
|
||||
# (needed e.g. to build for ARM OSx on conda-forge)
|
||||
if "CMAKE_ARGS" in os.environ:
|
||||
cmake_args += [item for item in os.environ["CMAKE_ARGS"].split(" ") if item]
|
||||
|
||||
if sys.platform.startswith("darwin"):
|
||||
# Cross-compile support for macOS - respect ARCHFLAGS if set
|
||||
archs = re.findall(r"-arch (\S+)", os.environ.get("ARCHFLAGS", ""))
|
||||
if archs:
|
||||
cmake_args += ["-DCMAKE_OSX_ARCHITECTURES={}".format(";".join(archs))]
|
||||
|
||||
# Set CMAKE_BUILD_PARALLEL_LEVEL to control the parallel build level
|
||||
# across all generators.
|
||||
if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ:
|
||||
# self.parallel is a Python 3 only way to set parallel jobs by hand
|
||||
# using -j in the build_ext call, not supported by pip or PyPA-build.
|
||||
if hasattr(self, "parallel") and self.parallel:
|
||||
# CMake 3.12+ only.
|
||||
build_args += [f"-j{self.parallel}"]
|
||||
|
||||
build_temp = Path(self.build_temp) / ext.name
|
||||
if not build_temp.exists():
|
||||
build_temp.mkdir(parents=True)
|
||||
|
||||
# Make sure cmake can find MLX
|
||||
os.environ["MLX_DIR"] = _MLX_PATH
|
||||
|
||||
subprocess.run(
|
||||
["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True
|
||||
)
|
||||
subprocess.run(
|
||||
["cmake", "--build", ".", *build_args], cwd=build_temp, check=True
|
||||
)
|
||||
|
||||
def run(self):
|
||||
super().run()
|
||||
|
||||
# Based on https://github.com/pypa/setuptools/blob/main/setuptools/command/build_ext.py#L102
|
||||
if self.inplace:
|
||||
for ext in self.extensions:
|
||||
if isinstance(ext, CMakeExtension):
|
||||
# Resolve inplace package dir
|
||||
build_py = self.get_finalized_command("build_py")
|
||||
inplace_file, regular_file = self._get_inplace_equivalent(
|
||||
build_py, ext
|
||||
)
|
||||
|
||||
inplace_dir = str(Path(inplace_file).parent.resolve())
|
||||
regular_dir = str(Path(regular_file).parent.resolve())
|
||||
|
||||
self.copy_tree(regular_dir, inplace_dir)
|
||||
401
python/mlx/nn/layers/base.py
Normal file
401
python/mlx/nn/layers/base.py
Normal file
@@ -0,0 +1,401 @@
|
||||
import textwrap
|
||||
from typing import Any, Callable, List, Union, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.utils import tree_flatten, tree_unflatten
|
||||
|
||||
|
||||
class Module(dict):
|
||||
"""Base class for building neural networks with MLX.
|
||||
|
||||
All the layers provided in :mod:`mlx.nn.layers` subclass this class and
|
||||
your models should do the same.
|
||||
|
||||
A ``Module`` can contain other ``Module`` instances or :class:`mlx.core.array`
|
||||
instances in arbitrary nesting of python lists or dicts. The ``Module``
|
||||
then allows recursively extracting all the :class:`mlx.core.array` instances
|
||||
using :meth:`mlx.nn.Module.parameters`.
|
||||
|
||||
In addition, the ``Module`` has the concept of trainable and non trainable
|
||||
parameters (called "frozen"). When using :func:`mlx.nn.value_and_grad`
|
||||
the gradients are returned only with respect to the trainable parameters.
|
||||
All arrays in a module are trainable unless they are added in the "frozen"
|
||||
set by calling :meth:`freeze`.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
class MyMLP(nn.Module):
|
||||
def __init__(self, in_dims: int, out_dims: int, hidden_dims: int = 16):
|
||||
super().__init__()
|
||||
|
||||
self.in_proj = nn.Linear(in_dims, hidden_dims)
|
||||
self.out_proj = nn.Linear(hidden_dims, out_dims)
|
||||
|
||||
def __call__(self, x):
|
||||
x = self.in_proj(x)
|
||||
x = mx.maximum(x, 0)
|
||||
return self.out_proj(x)
|
||||
|
||||
model = MyMLP(2, 1)
|
||||
|
||||
# All the model parameters are created but since MLX is lazy by
|
||||
# default, they are not evaluated yet. Calling `mx.eval` actually
|
||||
# allocates memory and initializes the parameters.
|
||||
mx.eval(model.parameters())
|
||||
|
||||
# Setting a parameter to a new value is as simply as accessing that
|
||||
# parameter and assigning a new array to it.
|
||||
model.in_proj.weight = model.in_proj.weight * 2
|
||||
mx.eval(model.parameters())
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Should be called by the subclasses of ``Module``."""
|
||||
self._no_grad = set()
|
||||
self._training = True
|
||||
|
||||
@property
|
||||
def training(self):
|
||||
return self._training
|
||||
|
||||
def _extra_repr(self):
|
||||
return ""
|
||||
|
||||
def __repr__(self):
|
||||
children = tree_flatten(self.children(), is_leaf=self.is_module)
|
||||
value = f"{type(self).__name__}({self._extra_repr()}"
|
||||
for k, v in children:
|
||||
value += "\n"
|
||||
value += textwrap.indent(f"({k}): {repr(v)}", prefix=" ")
|
||||
if children:
|
||||
value += "\n"
|
||||
value += ")"
|
||||
|
||||
return value
|
||||
|
||||
def __getattr__(self, key: str):
|
||||
if key in self:
|
||||
return self[key]
|
||||
else:
|
||||
raise AttributeError(f"{type(self)!r} has no attribute {key!r}")
|
||||
|
||||
def __setattr__(self, key: str, val: Any):
|
||||
self[key] = val
|
||||
|
||||
def load_weights(self, file: str):
|
||||
"""
|
||||
Load and update the model's weights from a `.npz` file.
|
||||
"""
|
||||
self.update(tree_unflatten(list(mx.load(file).items())))
|
||||
|
||||
def save_weights(self, file: str):
|
||||
"""
|
||||
Save the model's weights to a `.npz` file.
|
||||
"""
|
||||
mx.savez(file, **dict(tree_flatten(self.parameters())))
|
||||
|
||||
@staticmethod
|
||||
def is_module(value):
|
||||
return isinstance(value, Module)
|
||||
|
||||
@staticmethod
|
||||
def valid_child_filter(module, key, value):
|
||||
return isinstance(value, (dict, list))
|
||||
|
||||
@staticmethod
|
||||
def valid_parameter_filter(module, key, value):
|
||||
return isinstance(value, (dict, list, mx.array)) and not key.startswith("_")
|
||||
|
||||
@staticmethod
|
||||
def trainable_parameter_filter(module, key, value):
|
||||
return (
|
||||
Module.valid_parameter_filter(module, key, value)
|
||||
and key not in module._no_grad
|
||||
)
|
||||
|
||||
def filter_and_map(
|
||||
self,
|
||||
filter_fn: Callable[["mlx.nn.Module", str, Any], bool],
|
||||
map_fn: Optional[Callable] = None,
|
||||
is_leaf_fn: Optional[Callable[["mlx.nn.Module", str, Any], bool]] = None,
|
||||
):
|
||||
"""Recursively filter the contents of the module using ``filter_fn``,
|
||||
namely only select keys and values where ``filter_fn`` returns true.
|
||||
|
||||
This is used to implement :meth:`parameters` and :meth:`trainable_parameters`
|
||||
but it can also be used to extract any subset of the module's parameters.
|
||||
|
||||
Args:
|
||||
filter_fn (Callable): Given a value, the key in which it is found
|
||||
and the containing module, decide whether to keep the value or
|
||||
drop it.
|
||||
map_fn (Callable, optional): Optionally transform the value before
|
||||
returning it.
|
||||
is_leaf_fn (Callable, optional): Given a value, the key in which it
|
||||
is found and the containing module decide if it is a leaf.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the contents of the module recursively filtered
|
||||
"""
|
||||
|
||||
map_fn = map_fn or (lambda x: x)
|
||||
is_leaf_fn = is_leaf_fn or (
|
||||
lambda m, k, v: not isinstance(v, (Module, dict, list))
|
||||
)
|
||||
|
||||
def unwrap(vk, v):
|
||||
if is_leaf_fn(self, vk, v):
|
||||
return map_fn(v)
|
||||
|
||||
if isinstance(v, Module):
|
||||
return v.filter_and_map(filter_fn, map_fn, is_leaf_fn)
|
||||
|
||||
if isinstance(v, dict):
|
||||
nd = {}
|
||||
for k, v in v.items():
|
||||
tk = f"{vk}.{k}"
|
||||
nd[k] = unwrap(tk, v) if filter_fn(self, tk, v) else {}
|
||||
return nd
|
||||
|
||||
if isinstance(v, list):
|
||||
nl = []
|
||||
for i, vi in enumerate(v):
|
||||
tk = f"{vk}.{i}"
|
||||
nl.append(unwrap(tk, vi) if filter_fn(self, tk, vi) else {})
|
||||
return nl
|
||||
|
||||
raise RuntimeError("Unexpected leaf found while traversing the module")
|
||||
|
||||
return {k: unwrap(k, v) for k, v in self.items() if filter_fn(self, k, v)}
|
||||
|
||||
def parameters(self):
|
||||
"""Recursively return all the :class:`mlx.core.array` members of this Module
|
||||
as a dict of dicts and lists."""
|
||||
return self.filter_and_map(self.valid_parameter_filter)
|
||||
|
||||
def trainable_parameters(self):
|
||||
"""Recursively return all the non frozen :class:`mlx.core.array` members of
|
||||
this Module as a dict of dicts and lists."""
|
||||
return self.filter_and_map(self.trainable_parameter_filter)
|
||||
|
||||
def children(self):
|
||||
"""Return the direct descendants of this Module instance."""
|
||||
return self.filter_and_map(
|
||||
self.valid_child_filter, is_leaf_fn=lambda m, k, v: isinstance(v, Module)
|
||||
)
|
||||
|
||||
def leaf_modules(self):
|
||||
"""Return the submodules that do not contain other modules."""
|
||||
|
||||
def _is_leaf_module(m, k, v):
|
||||
return isinstance(v, Module) and len(tree_flatten(v.children())) == 0
|
||||
|
||||
return self.filter_and_map(self.valid_child_filter, is_leaf_fn=_is_leaf_module)
|
||||
|
||||
def update(self, parameters: dict):
|
||||
"""Replace the parameters of this Module with the provided ones in the
|
||||
dict of dicts and lists.
|
||||
|
||||
Commonly used by the optimizer to change the model to the updated
|
||||
(optimized) parameters. Also used by the :meth:`mlx.nn.value_and_grad` to set the
|
||||
tracers in the model in order to compute gradients.
|
||||
|
||||
The passed in parameters dictionary need not be a full dictionary
|
||||
similar to :meth:`parameters`. Only the provided locations will be
|
||||
updated.
|
||||
|
||||
Args:
|
||||
parameters (dict): A complete or partial dictionary of the modules
|
||||
parameters.
|
||||
"""
|
||||
|
||||
def apply(dst, parameters):
|
||||
if isinstance(parameters, dict):
|
||||
for k in parameters:
|
||||
if k in dst:
|
||||
current_value = dst[k]
|
||||
new_value = parameters[k]
|
||||
if isinstance(current_value, mx.array):
|
||||
dst[k] = new_value
|
||||
elif isinstance(current_value, Module):
|
||||
current_value.update(new_value)
|
||||
elif isinstance(current_value, (dict, list)):
|
||||
apply(current_value, new_value)
|
||||
elif isinstance(parameters, list):
|
||||
for i in range(len(dst)):
|
||||
current_value = dst[i]
|
||||
new_value = parameters[i]
|
||||
if isinstance(current_value, mx.array):
|
||||
dst[i] = new_value
|
||||
elif isinstance(current_value, Module):
|
||||
current_value.update(new_value)
|
||||
elif isinstance(current_value, (dict, list)):
|
||||
apply(current_value, new_value)
|
||||
|
||||
apply(self, parameters)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
map_fn: Callable[[mx.array], mx.array],
|
||||
filter_fn: Optional[Callable[["mlx.nn.Module", str, Any], bool]] = None,
|
||||
):
|
||||
"""Map all the parameters using the provided ``map_fn`` and immediately
|
||||
update the module with the mapped parameters.
|
||||
|
||||
For instance running ``model.apply(lambda x: x.astype(mx.float16))``
|
||||
casts all parameters to 16 bit floats.
|
||||
|
||||
Args:
|
||||
map_fn (Callable): Maps an array to another array
|
||||
filter_fn (Callable, optional): Filter to select which arrays to
|
||||
map (default: :meth:`Module.valid_parameter_filter`).
|
||||
"""
|
||||
filter_fn = filter_fn or Module.valid_parameter_filter
|
||||
self.update(self.filter_and_map(filter_fn, map_fn))
|
||||
|
||||
def apply_to_modules(self, apply_fn: Callable[[str, "mlx.nn.Module"], Any]):
|
||||
"""Apply a function to all the modules in this instance (including this
|
||||
instance).
|
||||
|
||||
Args:
|
||||
apply_fn (Callable): The function to apply to the modules.
|
||||
"""
|
||||
module_stack = [("", self)]
|
||||
while module_stack:
|
||||
prefix, mod = module_stack.pop()
|
||||
apply_fn(prefix, mod)
|
||||
prefix = "." + prefix if prefix else ""
|
||||
module_stack.extend(
|
||||
tree_flatten(mod.children(), prefix=prefix, is_leaf=self.is_module)
|
||||
)
|
||||
|
||||
def modules(self):
|
||||
"""Return a list with all the modules in this instance.
|
||||
|
||||
Returns:
|
||||
A list of :class:`mlx.nn.Module` instances.
|
||||
"""
|
||||
modulelist = []
|
||||
self.apply_to_modules(lambda k, m: modulelist.append(m))
|
||||
return modulelist
|
||||
|
||||
def named_modules(self):
|
||||
"""Return a list with all the modules in this instance and their name
|
||||
with dot notation.
|
||||
|
||||
Returns:
|
||||
A list of tuples (str, :class:`mlx.nn.Module`).
|
||||
"""
|
||||
modulelist = []
|
||||
self.apply_to_modules(lambda k, m: modulelist.append((k, m)))
|
||||
return modulelist
|
||||
|
||||
def _validate_keys(self, keys, strict):
|
||||
keys = keys if isinstance(keys, list) else [keys]
|
||||
if strict:
|
||||
for k in keys:
|
||||
if k not in self:
|
||||
raise KeyError(f"Module doesn't contain member {k}.")
|
||||
return keys
|
||||
|
||||
def freeze(
|
||||
self,
|
||||
*,
|
||||
recurse: bool = True,
|
||||
keys: Optional[Union[str, List[str]]] = None,
|
||||
strict: bool = False,
|
||||
):
|
||||
"""Freeze the Module's parameters or some of them. Freezing a parameter means not
|
||||
computing gradients for it.
|
||||
|
||||
This function is idempotent ie freezing a frozen model is a noop.
|
||||
|
||||
For instance to only train the attention parameters from a transformer:
|
||||
|
||||
model = ...
|
||||
model.freeze()
|
||||
model.apply_to_modules(lambda k, v: v.unfreeze() if k.endswith("attention") else None)
|
||||
|
||||
Args:
|
||||
recurse (bool, optional): If True then freeze the parameters of the
|
||||
submodules as well (default: True).
|
||||
keys (str or list[str], optional): If provided then only these
|
||||
parameters will be frozen otherwise all the parameters of a
|
||||
module. For instance freeze all biases by calling
|
||||
``module.freeze(keys="bias")``.
|
||||
strict (bool, optional): If set to True validate that the passed keys exist
|
||||
(default: False).
|
||||
"""
|
||||
|
||||
def _freeze_impl(_, m):
|
||||
local_keys = keys
|
||||
if local_keys is None:
|
||||
local_keys = tree_flatten(
|
||||
m.filter_and_map(
|
||||
lambda m, k, v: (not isinstance(v, Module))
|
||||
and m.valid_parameter_filter(m, k, v)
|
||||
)
|
||||
)
|
||||
local_keys = [k for (k, v) in local_keys]
|
||||
|
||||
local_keys = m._validate_keys(local_keys, strict)
|
||||
m._no_grad.update(local_keys)
|
||||
|
||||
if recurse:
|
||||
self.apply_to_modules(_freeze_impl)
|
||||
else:
|
||||
_freeze_impl("", self)
|
||||
|
||||
def unfreeze(
|
||||
self,
|
||||
*,
|
||||
recurse: bool = True,
|
||||
keys: Optional[Union[str, List[str]]] = None,
|
||||
strict: bool = False,
|
||||
):
|
||||
"""Unfreeze the Module's parameters or some of them.
|
||||
|
||||
This function is idempotent ie unfreezing a model that is not frozen is
|
||||
a noop.
|
||||
|
||||
For instance to only train the biases one can do:
|
||||
|
||||
model = ...
|
||||
model.freeze()
|
||||
model.unfreeze(keys="bias")
|
||||
|
||||
Args:
|
||||
recurse (bool, optional): If True then unfreeze the parameters of the
|
||||
submodules as well (default: True).
|
||||
keys (str or list[str], optional): If provided then only these
|
||||
parameters will be unfrozen otherwise all the parameters of a
|
||||
module. For instance unfreeze all biases by calling
|
||||
``module.unfreeze(keys="bias")``.
|
||||
strict (bool, optional): If set to True validate that the passed keys exist
|
||||
(default: False).
|
||||
"""
|
||||
|
||||
def _unfreeze_impl(_, m):
|
||||
if keys is None:
|
||||
m._no_grad.clear()
|
||||
|
||||
else:
|
||||
local_keys = m._validate_keys(keys, strict)
|
||||
m._no_grad.difference_update(local_keys)
|
||||
|
||||
if recurse:
|
||||
self.apply_to_modules(_unfreeze_impl)
|
||||
else:
|
||||
_unfreeze_impl("", self)
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
def _set_train(_, m):
|
||||
m._training = mode
|
||||
|
||||
self.apply_to_modules(_set_train)
|
||||
|
||||
def eval(self):
|
||||
self.train(False)
|
||||
22
python/mlx/nn/layers/containers.py
Normal file
22
python/mlx/nn/layers/containers.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from mlx.nn.layers.base import Module
|
||||
|
||||
|
||||
class Sequential(Module):
|
||||
"""A layer that calls the passed callables in order.
|
||||
|
||||
We can pass either modules or plain callables to the Sequential module. If
|
||||
our functions have learnable parameters they should be implemented as
|
||||
``nn.Module`` instances.
|
||||
|
||||
Args:
|
||||
modules (tuple of Callables): The modules to call in order
|
||||
"""
|
||||
|
||||
def __init__(self, *modules):
|
||||
super().__init__()
|
||||
self.layers = list(modules)
|
||||
|
||||
def __call__(self, x):
|
||||
for m in self.layers:
|
||||
x = m(x)
|
||||
return x
|
||||
33
python/mlx/nn/layers/dropout.py
Normal file
33
python/mlx/nn/layers/dropout.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import mlx.core as mx
|
||||
from mlx.nn.layers.base import Module
|
||||
|
||||
|
||||
class Dropout(Module):
|
||||
"""Randomly zero a portion of the elements during training.
|
||||
|
||||
The remaining elements are multiplied with :math:`\frac{1}{1-p}` where
|
||||
:math:`p` is the probability of zeroing an element. This is done so the
|
||||
expected value of a given element will remain the same.
|
||||
|
||||
Args:
|
||||
p (float): The probability to zero an element
|
||||
"""
|
||||
|
||||
def __init__(self, p: float = 0.5):
|
||||
super().__init__()
|
||||
|
||||
if p < 0 or p >= 1:
|
||||
raise ValueError("The dropout probability should be in [0, 1)")
|
||||
|
||||
self._p_1 = 1 - p
|
||||
|
||||
def _extra_repr(self):
|
||||
return f"p={1-self._p_1}"
|
||||
|
||||
def __call__(self, x):
|
||||
if self._p_1 == 1 or not self.training:
|
||||
return x
|
||||
|
||||
mask = mx.random.bernoulli(self._p_1, x.shape)
|
||||
|
||||
return (1 / self._p_1) * mask.astype(x.dtype) * x
|
||||
178
python/mlx/nn/layers/normalization.py
Normal file
178
python/mlx/nn/layers/normalization.py
Normal file
@@ -0,0 +1,178 @@
|
||||
import mlx.core as mx
|
||||
from mlx.nn.layers.base import Module
|
||||
|
||||
|
||||
class LayerNorm(Module):
|
||||
r"""Applies layer normalization [1] on the inputs.
|
||||
|
||||
Computes
|
||||
|
||||
.. math::
|
||||
|
||||
y = \frac{x - E[x]}{\sqrt{Var[x]} + \epsilon} \gamma + \beta,
|
||||
|
||||
where :math:`\gamma` and :math:`\beta` are learned per feature dimension
|
||||
parameters initialized at 1 and 0 respectively.
|
||||
|
||||
[1]: https://arxiv.org/abs/1607.06450
|
||||
|
||||
Args:
|
||||
dims (int): The feature dimension of the input to normalize over
|
||||
eps (float): A small additive constant for numerical stability
|
||||
affine (bool): If True learn an affine transform to apply after the
|
||||
normalization
|
||||
"""
|
||||
|
||||
def __init__(self, dims: int, eps: float = 1e-5, affine: bool = True):
|
||||
super().__init__()
|
||||
if affine:
|
||||
self.bias = mx.zeros((dims,))
|
||||
self.weight = mx.ones((dims,))
|
||||
self.eps = eps
|
||||
self.dims = dims
|
||||
|
||||
def _extra_repr(self):
|
||||
return f"{self.dims}, eps={self.eps}, affine={'weight' in self}"
|
||||
|
||||
def __call__(self, x):
|
||||
means = mx.mean(x, axis=-1, keepdims=True)
|
||||
var = mx.var(x, axis=-1, keepdims=True)
|
||||
x = (x - means) * mx.rsqrt(var + self.eps)
|
||||
return (self.weight * x + self.bias) if "weight" in self else x
|
||||
|
||||
|
||||
class RMSNorm(Module):
|
||||
r"""Applies Root Mean Square normalization [1] to the inputs.
|
||||
|
||||
Computes
|
||||
|
||||
.. math::
|
||||
|
||||
y = \frac{x}{\sqrt{E[x^2] + \epsilon}} \gamma
|
||||
|
||||
where :math:`\gamma` is a learned per feature dimension parameter initialized at
|
||||
1.
|
||||
|
||||
[1]: https://arxiv.org/abs/1910.07467
|
||||
|
||||
Args:
|
||||
dims (int): The feature dimension of the input to normalize over
|
||||
eps (float): A small additive constant for numerical stability
|
||||
"""
|
||||
|
||||
def __init__(self, dims: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.weight = mx.ones((dims,))
|
||||
self.eps = eps
|
||||
|
||||
def _extra_repr(self):
|
||||
return f"{self.weight.shape[0]}, eps={self.eps}"
|
||||
|
||||
def __call__(self, x):
|
||||
# S is 1/sqrt(N) where N is the size of the features of x and is used
|
||||
# to compute a numerically more stable RMS of x by multiplying with S
|
||||
# first and summing.
|
||||
#
|
||||
# This way we prefer underflow over overflow which is controlled with
|
||||
# the parameter epsilon anyway.
|
||||
S = 1 / x.shape[-1] ** 0.5
|
||||
|
||||
n = (x * S).square().sum(axis=-1, keepdims=True)
|
||||
n = mx.rsqrt(n + self.eps)
|
||||
|
||||
return self.weight * x * n
|
||||
|
||||
|
||||
class GroupNorm(Module):
|
||||
r"""Applies Group Normalization [1] to the inputs.
|
||||
|
||||
Computes the same normalization as layer norm, namely
|
||||
|
||||
.. math::
|
||||
|
||||
y = \frac{x - E[x]}{\sqrt{Var[x]} + \epsilon} \gamma + \beta,
|
||||
|
||||
where :math:`\gamma` and :math:`\beta` are learned per feature dimension
|
||||
parameters initialized at 1 and 0 respectively. However, the mean and
|
||||
variance are computed over the spatial dimensions and each group of
|
||||
features. In particular, the input is split into num_groups accross the
|
||||
feature dimension.
|
||||
|
||||
The feature dimension is assumed to be the last dimension and the dimensions
|
||||
that precede it (except the first) are considered the spatial dimensions.
|
||||
|
||||
[1]: https://arxiv.org/abs/1803.08494
|
||||
|
||||
Args:
|
||||
num_groups (int): Number of groups to separate the features into
|
||||
dims (int): The feature dimensions of the input to normalize over
|
||||
eps (float): A small additive constant for numerical stability
|
||||
affine (bool): If True learn an affine transform to apply after the
|
||||
normalization.
|
||||
pytorch_compatible (bool): If True perform the group normalization in
|
||||
the same order/grouping as PyTorch.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_groups: int,
|
||||
dims: int,
|
||||
eps: float = 1e-5,
|
||||
affine: bool = True,
|
||||
pytorch_compatible: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
if affine:
|
||||
self.bias = mx.zeros((dims,))
|
||||
self.weight = mx.ones((dims,))
|
||||
self.num_groups = num_groups
|
||||
self.dims = dims
|
||||
self.eps = eps
|
||||
self.pytorch_compatible = pytorch_compatible
|
||||
|
||||
def _extra_repr(self):
|
||||
return (
|
||||
f"{self.num_groups}, {self.dims}, eps={self.eps}, "
|
||||
f"affine={'weight' in self}, pytorch_compatible={self.pytorch_compatible}"
|
||||
)
|
||||
|
||||
def _pytorch_compatible_group_norm(self, x):
|
||||
num_groups = self.num_groups
|
||||
batch, *rest, dims = x.shape
|
||||
|
||||
# Split into groups
|
||||
x = x.reshape(batch, -1, num_groups, dims // num_groups)
|
||||
x = x.transpose(0, 1, 3, 2).reshape(batch, -1, num_groups)
|
||||
|
||||
# Normalize
|
||||
means = mx.mean(x, axis=1, keepdims=True)
|
||||
var = mx.var(x, axis=1, keepdims=True)
|
||||
x = (x - means) * mx.rsqrt(var + self.eps)
|
||||
x = x.reshape(batch, -1, dims // num_groups, num_groups)
|
||||
x = x.transpose(0, 1, 3, 2).reshape(batch, *rest, dims)
|
||||
|
||||
return x
|
||||
|
||||
def _group_norm(self, x):
|
||||
num_groups = self.num_groups
|
||||
batch, *rest, dims = x.shape
|
||||
|
||||
# Split into groups
|
||||
x = x.reshape(batch, -1, num_groups)
|
||||
|
||||
# Normalize
|
||||
means = mx.mean(x, axis=1, keepdims=True)
|
||||
var = mx.var(x, axis=1, keepdims=True)
|
||||
x = (x - means) * mx.rsqrt(var + self.eps)
|
||||
x = x.reshape(batch, *rest, dims)
|
||||
|
||||
return x
|
||||
|
||||
def __call__(self, x):
|
||||
group_norm = (
|
||||
self._pytorch_compatible_group_norm
|
||||
if self.pytorch_compatible
|
||||
else self._group_norm
|
||||
)
|
||||
x = group_norm(x)
|
||||
return (self.weight * x + self.bias) if "weight" in self else x
|
||||
142
python/mlx/nn/layers/positional_encoding.py
Normal file
142
python/mlx/nn/layers/positional_encoding.py
Normal file
@@ -0,0 +1,142 @@
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.nn.layers.base import Module
|
||||
|
||||
|
||||
class RoPE(Module):
|
||||
"""Implements the rotary positional encoding [1].
|
||||
|
||||
The traditional implementation rotates consecutive pairs of elements in the
|
||||
feature dimension while the default implementation rotates pairs with
|
||||
stride half the feature dimensions for efficiency.
|
||||
|
||||
[1]: https://arxiv.org/abs/2104.09864
|
||||
|
||||
Args:
|
||||
dims (int): The feature dimensions to be rotated. If the input feature
|
||||
is larger than dims then the rest is left unchanged.
|
||||
traditional (bool): If set to True choose the traditional
|
||||
implementation which is slightly less efficient.
|
||||
"""
|
||||
|
||||
def __init__(self, dims: int, traditional: bool = False):
|
||||
super().__init__()
|
||||
self.dims = dims
|
||||
self.traditional = traditional
|
||||
|
||||
def _extra_repr(self):
|
||||
return f"{self.dims}, traditional={self.traditional}"
|
||||
|
||||
def _compute_rope(self, costheta, sintheta, x):
|
||||
x1 = x[..., : self.dims // 2]
|
||||
x2 = x[..., self.dims // 2 : self.dims]
|
||||
rx1 = x1 * costheta - x2 * sintheta
|
||||
rx2 = x1 * sintheta + x2 * costheta
|
||||
|
||||
if self.dims < x.shape[-1]:
|
||||
rx = mx.concatenate([rx1, rx2, x[..., self.dims :]], axis=-1)
|
||||
else:
|
||||
rx = mx.concatenate([rx1, rx2], axis=-1)
|
||||
|
||||
return rx
|
||||
|
||||
def _compute_traditional_rope(self, costheta, sintheta, x):
|
||||
x1 = x[..., ::2]
|
||||
x2 = x[..., 1::2]
|
||||
rx1 = x1 * costheta - x2 * sintheta
|
||||
rx2 = x1 * sintheta + x2 * costheta
|
||||
|
||||
if self.dims < x.shape[-1]:
|
||||
raise NotImplementedError(
|
||||
"RoPE doesn't implement partial traditional application"
|
||||
)
|
||||
|
||||
rx = mx.concatenate([rx1[..., None], rx2[..., None]], axis=-1)
|
||||
|
||||
return rx
|
||||
|
||||
def __call__(self, x, offset: int = 0):
|
||||
shape = x.shape
|
||||
x = mx.reshape(x, (-1, shape[-2], shape[-1]))
|
||||
N = x.shape[1] + offset
|
||||
costheta, sintheta = RoPE.create_cos_sin_theta(
|
||||
N, self.dims, offset=offset, dtype=x.dtype
|
||||
)
|
||||
|
||||
rope = (
|
||||
self._compute_traditional_rope if self.traditional else self._compute_rope
|
||||
)
|
||||
rx = rope(costheta, sintheta, x)
|
||||
|
||||
return mx.reshape(rx, shape)
|
||||
|
||||
@staticmethod
|
||||
def create_cos_sin_theta(
|
||||
N: int, D: int, offset: int = 0, base: float = 10000, dtype=mx.float32
|
||||
):
|
||||
D = D // 2
|
||||
positions = mx.arange(offset, N, dtype=dtype)
|
||||
freqs = mx.exp(-mx.arange(0.0, D, dtype=dtype) * (math.log(base) / D))
|
||||
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1))
|
||||
costheta = mx.cos(theta)
|
||||
sintheta = mx.sin(theta)
|
||||
|
||||
return costheta, sintheta
|
||||
|
||||
|
||||
class SinusoidalPositionalEncoding(Module):
|
||||
"""Implements sinusoidal positional encoding similar to [1].
|
||||
|
||||
[1]: https://arxiv.org/abs/1706.03762
|
||||
|
||||
Args:
|
||||
dims (int): The dimensionality of the resulting positional embeddings.
|
||||
min_freq (float): The minimum frequency expected (default: 0.0001)
|
||||
max_freq (float): The maximum frequency expected (default: 1)
|
||||
scale (float): Scale the embeddings by that number (default: sqrt(dims//2))
|
||||
cos_first (bool): If set to True embed using ``[cos(x); sin(x)]``
|
||||
instead of the other way around (default: False)
|
||||
full_turns (bool): If set to True multiply the frequencies
|
||||
with ``2 pi`` (default: False)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: int,
|
||||
min_freq: float = 0.0001,
|
||||
max_freq: float = 1,
|
||||
scale: Optional[float] = None,
|
||||
cos_first: bool = False,
|
||||
full_turns: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
one_zero = 1 - mx.arange(0, dims // 2) / (dims // 2 - 1)
|
||||
min_freq = math.log(min_freq)
|
||||
max_freq = math.log(max_freq)
|
||||
|
||||
# Start with underscore so it is not included in the parameters
|
||||
self._sigmas = mx.exp(one_zero * (max_freq - min_freq) + min_freq)
|
||||
if full_turns:
|
||||
self._sigmas = self._sigmas * (2 * math.pi)
|
||||
|
||||
# Save some constants that define the implementation
|
||||
self.scale = scale or (2 / dims) ** 0.5
|
||||
self.cos_first = cos_first
|
||||
|
||||
def __call__(self, x):
|
||||
y = x[..., None] * self._sigmas
|
||||
cosy = mx.cos(y)
|
||||
siny = mx.sin(y)
|
||||
|
||||
if self.cos_first:
|
||||
y = mx.concatenate([cosy, siny], axis=-1)
|
||||
else:
|
||||
y = mx.concatenate([siny, cosy], axis=-1)
|
||||
|
||||
if self.scale != 1:
|
||||
y = y * self.scale
|
||||
|
||||
return y
|
||||
136
python/mlx/nn/layers/transformer.py
Normal file
136
python/mlx/nn/layers/transformer.py
Normal file
@@ -0,0 +1,136 @@
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.nn.layers.base import Module
|
||||
from mlx.nn.layers.linear import Linear
|
||||
from mlx.nn.layers.normalization import LayerNorm
|
||||
|
||||
|
||||
class MultiHeadAttention(Module):
|
||||
"""Implements the scaled dot product attention with multiple heads.
|
||||
|
||||
Given inputs for queries, keys and values the ``MultiHeadAttention`` produces
|
||||
new values by aggregating information from the input values according to
|
||||
the similarities of the input queries and keys.
|
||||
|
||||
All inputs as well as the output are lineary projected without biases.
|
||||
|
||||
MultiHeadAttention also expects an additive attention mask that should be
|
||||
broadcastable with (batch, num_heads, # queries, # keys). The mask should
|
||||
have ``-inf`` or very negative numbers to the positions that should *not* be
|
||||
attended to.
|
||||
|
||||
Args:
|
||||
dims (int): The model dimensions. If no other dims are provided then
|
||||
dims is used for queries, keys, values and the output.
|
||||
num_heads (int): How many attention heads to use
|
||||
query_input_dims (int, optional): The input dimensions of the queries (default: dims).
|
||||
key_input_dims (int, optional): The input dimensions of the keys (default: dims).
|
||||
value_input_dims (int, optional): The input dimensions of the values (default: key_input_dims).
|
||||
value_dims (int, optional): The dimensions of the values after the projection (default: dims).
|
||||
value_output_dims (int, optional): The dimensions the new values will be projected to (default: dims).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: int,
|
||||
num_heads: int,
|
||||
query_input_dims: Optional[int] = None,
|
||||
key_input_dims: Optional[int] = None,
|
||||
value_input_dims: Optional[int] = None,
|
||||
value_dims: Optional[int] = None,
|
||||
value_output_dims: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if (dims % num_heads) != 0:
|
||||
raise ValueError(
|
||||
f"The input feature dimensions should be divisble by the number of heads ({dims} % {num_heads}) != 0"
|
||||
)
|
||||
|
||||
query_input_dims = query_input_dims or dims
|
||||
key_input_dims = key_input_dims or dims
|
||||
value_input_dims = value_input_dims or key_input_dims
|
||||
value_dims = value_dims or dims
|
||||
value_output_dims = value_output_dims or dims
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.query_proj = Linear(query_input_dims, dims, False)
|
||||
self.key_proj = Linear(key_input_dims, dims, False)
|
||||
self.value_proj = Linear(value_input_dims, value_dims, False)
|
||||
self.out_proj = Linear(value_dims, value_output_dims, False)
|
||||
|
||||
def __call__(self, queries, keys, values, mask=None):
|
||||
queries = self.query_proj(queries)
|
||||
keys = self.key_proj(keys)
|
||||
values = self.value_proj(values)
|
||||
|
||||
num_heads = self.num_heads
|
||||
B, L, D = queries.shape
|
||||
_, S, _ = keys.shape
|
||||
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1)
|
||||
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
# Dimensions are [batch x num heads x sequence x hidden dim]
|
||||
scale = math.sqrt(1 / queries.shape[-1])
|
||||
scores = (queries * scale) @ keys
|
||||
if mask is not None:
|
||||
scores = scores + mask.astype(scores.dtype)
|
||||
scores = mx.softmax(scores, axis=-1)
|
||||
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
|
||||
return self.out_proj(values_hat)
|
||||
|
||||
@staticmethod
|
||||
def create_additive_causal_mask(N: int, dtype: mx.Dtype = mx.float32):
|
||||
indices = mx.arange(N)
|
||||
mask = indices[:, None] < indices[None]
|
||||
# usually inf but 1e9 is as good and softmax(full(1e9)) != nan
|
||||
# TODO: Should replace this with finfo(dtype).min
|
||||
mask = mask.astype(dtype) * -1e9
|
||||
return mask
|
||||
|
||||
|
||||
class TransformerEncoderLayer(Module):
|
||||
def __init__(self, dims: int, num_heads: int, mlp_dims: Optional[int] = None):
|
||||
super().__init__()
|
||||
mlp_dims = mlp_dims or dims * 4
|
||||
self.attention = MultiHeadAttention(dims, num_heads)
|
||||
self.ln1 = LayerNorm(dims)
|
||||
self.ln2 = LayerNorm(dims)
|
||||
self.linear1 = Linear(dims, mlp_dims)
|
||||
self.linear2 = Linear(mlp_dims, dims)
|
||||
|
||||
def __call__(self, x, mask):
|
||||
y = self.ln1(x)
|
||||
y = self.attention(y, y, y, mask)
|
||||
x = x + y
|
||||
|
||||
y = self.ln2(x)
|
||||
y = self.linear1(y)
|
||||
y = mx.maximum(y, 0)
|
||||
y = self.linear2(y)
|
||||
x = x + y
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class TransformerEncoder(Module):
|
||||
def __init__(
|
||||
self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None
|
||||
):
|
||||
super().__init__()
|
||||
self.layers = [
|
||||
TransformerEncoderLayer(dims, num_heads, mlp_dims)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
self.ln = LayerNorm(dims)
|
||||
|
||||
def __call__(self, x, mask):
|
||||
for l in self.layers:
|
||||
x = l(x, mask)
|
||||
x = self.ln(x)
|
||||
|
||||
return x
|
||||
31
python/mlx/nn/utils.py
Normal file
31
python/mlx/nn/utils.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from typing import Callable
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
def value_and_grad(model: "mlx.nn.Module", fn: Callable):
|
||||
"""Transform the passed function ``fn`` to a function that computes the
|
||||
gradients of ``fn`` wrt the model's trainable parameters and also its
|
||||
value.
|
||||
|
||||
Args:
|
||||
model (mlx.nn.Module): The model whose trainable parameters to compute
|
||||
gradients for
|
||||
fn (Callable): The scalar function to compute gradients for
|
||||
|
||||
Returns:
|
||||
A callable that returns the value of ``fn`` and the gradients wrt the
|
||||
trainable parameters of ``model``
|
||||
"""
|
||||
|
||||
def inner_fn(params, *args, **kwargs):
|
||||
model.update(params)
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
value_grad_fn = mx.value_and_grad(inner_fn)
|
||||
|
||||
def wrapped_value_grad_fn(*args, **kwargs):
|
||||
value, grad = value_grad_fn(model.trainable_parameters(), *args, **kwargs)
|
||||
return value, grad
|
||||
|
||||
return wrapped_value_grad_fn
|
||||
152
python/mlx/optimizers.py
Normal file
152
python/mlx/optimizers.py
Normal file
@@ -0,0 +1,152 @@
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.utils import tree_map
|
||||
|
||||
|
||||
class OptimizerState(dict):
|
||||
"""The optimizer state implements a recursively defined
|
||||
:class:`collections.defaultdict`, namely a missing key in an optimizer
|
||||
state is an :class:`OptimizerState`.
|
||||
|
||||
.. note::
|
||||
:meth:`OptimizerState.get` in contrast to a normal dictionary also sets
|
||||
the key to the ``default`` value if the ``key`` was not present in the
|
||||
dictionary.
|
||||
"""
|
||||
|
||||
def __getitem__(self, key):
|
||||
if key not in self:
|
||||
self[key] = OptimizerState()
|
||||
return super().__getitem__(key)
|
||||
|
||||
def get(self, key, default):
|
||||
"""If ``key`` doesn't exist set its value to ``default`` and then return it."""
|
||||
if key not in self:
|
||||
self[key] = default
|
||||
return super().__getitem__(key)
|
||||
|
||||
|
||||
class Optimizer:
|
||||
"""The base class for all optimizers. It allows us to implement an
|
||||
optimizer on a per-parameter basis and apply it to a parameter tree.
|
||||
|
||||
Attributes:
|
||||
state (OptimizerState): It holds the optimizer's state dictionary.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.state = OptimizerState()
|
||||
|
||||
def update(self, model: "mlx.nn.Module", gradients: dict):
|
||||
"""Apply the gradients to the parameters of the model and update the
|
||||
model with the new parameters.
|
||||
|
||||
Args:
|
||||
model (mlx.nn.Module): An mlx module to be updated.
|
||||
gradients (dict): A Python tree of gradients, most likely computed
|
||||
via :func:`mlx.nn.value_and_grad`.
|
||||
"""
|
||||
model.update(self.apply_gradients(gradients, model))
|
||||
|
||||
def apply_gradients(self, gradients: dict, model: dict):
|
||||
"""Apply the gradients to the parameters and return the updated parameters.
|
||||
|
||||
Can be used to update a model via
|
||||
``model.update(opt.apply_gradients(grads, model))`` which is precisely
|
||||
how :meth:`Optimizer.update` is implemented.
|
||||
|
||||
Args:
|
||||
gradients (dict): A Python tree of gradients.
|
||||
model (dict): A Python tree of parameters. It can be a superset of
|
||||
the gradients. In that case the returned python tree
|
||||
will be of the same structure as the gradients.
|
||||
"""
|
||||
return tree_map(self.apply_single, gradients, model, self.state)
|
||||
|
||||
def apply_single(
|
||||
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
|
||||
):
|
||||
"""To be extended by the children classes to implement each optimizer's
|
||||
update."""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class SGD(Optimizer):
|
||||
r"""Stochastic gradient descent optimizer.
|
||||
|
||||
Updates a parameter :math:`w` with a gradient :math:`g` as follows
|
||||
|
||||
.. math::
|
||||
|
||||
v_{t+1} &= \mu v_t + (1 - \mu) g_t \\
|
||||
w_{t+1} &= w_t - \lambda v_{t+1}
|
||||
|
||||
Args:
|
||||
learning_rate (float): The learning :math:`\lambda` for the update
|
||||
momentum (float): The momentum strength :math:`\mu`
|
||||
"""
|
||||
|
||||
def __init__(self, learning_rate: float, momentum: float = 0.0):
|
||||
super().__init__()
|
||||
|
||||
self.learning_rate = learning_rate
|
||||
self.momentum = momentum
|
||||
|
||||
def apply_single(
|
||||
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
|
||||
):
|
||||
"""Performs the SGD parameter update and stores :math:`v` in the
|
||||
optimizer state."""
|
||||
if self.momentum <= 0:
|
||||
return parameter - self.learning_rate * gradient
|
||||
|
||||
v = state.get("v", mx.zeros_like(gradient))
|
||||
v = self.momentum * v + (1 - self.momentum) * gradient
|
||||
state["v"] = v
|
||||
return parameter - self.learning_rate * v
|
||||
|
||||
|
||||
class Adam(Optimizer):
|
||||
r"""Implementation of the Adam optimizer [1].
|
||||
|
||||
Our Adam implementation follows the original paper and omits the bias
|
||||
correction in the first and second moment estimates. In detail,
|
||||
|
||||
.. math::
|
||||
|
||||
m_{t+1} &= \beta_1 m_t + (1 - \beta_1) g_t \\
|
||||
v_{t+1} &= \beta_2 v_t + (1 - \beta_2) g_t^2 \\
|
||||
w_{t+1} &= w_t - \lambda \frac{m_{t+1}}{\sqrt{v_{t+1} + \epsilon}}
|
||||
|
||||
[1]: Kingma, D.P. and Ba, J., 2015. Adam: A method for stochastic
|
||||
optimization. ICLR 2015.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, learning_rate: float, betas: List[float] = [0.9, 0.999], eps: float = 1e-8
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.learning_rate = learning_rate
|
||||
self.betas = betas
|
||||
self.eps = eps
|
||||
|
||||
def apply_single(
|
||||
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
|
||||
):
|
||||
"""Performs the Adam parameter update and stores :math:`v` and
|
||||
:math:`m` in the optimizer state."""
|
||||
lr = self.learning_rate
|
||||
b1, b2 = self.betas
|
||||
eps = self.eps
|
||||
|
||||
m = state.get("m", gradient)
|
||||
v = state.get("v", mx.square(gradient))
|
||||
m = b1 * m + (1 - b1) * gradient
|
||||
v = b2 * v + (1 - b2) * mx.square(gradient)
|
||||
state["m"] = m
|
||||
state["v"] = v
|
||||
|
||||
return parameter - lr * m / (mx.sqrt(v) + eps)
|
||||
32
python/src/CMakeLists.txt
Normal file
32
python/src/CMakeLists.txt
Normal file
@@ -0,0 +1,32 @@
|
||||
pybind11_add_module(
|
||||
core
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/mlx.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/stream.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/random.cpp
|
||||
)
|
||||
|
||||
if (NOT MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY)
|
||||
set(MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY ${CMAKE_LIBRARY_OUTPUT_DIRECTORY})
|
||||
endif()
|
||||
|
||||
set_target_properties(
|
||||
core
|
||||
PROPERTIES
|
||||
LIBRARY_OUTPUT_DIRECTORY
|
||||
${MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY}
|
||||
)
|
||||
|
||||
target_link_libraries(core PRIVATE mlx)
|
||||
target_compile_definitions(core PRIVATE _VERSION_=${MLX_VERSION})
|
||||
|
||||
if(BUILD_SHARED_LIBS)
|
||||
target_link_options(core PRIVATE -Wl,-rpath,@loader_path/lib)
|
||||
endif()
|
||||
468
python/src/fft.cpp
Normal file
468
python/src/fft.cpp
Normal file
@@ -0,0 +1,468 @@
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#include "python/src/utils.h"
|
||||
|
||||
#include "mlx/fft.h"
|
||||
#include "mlx/ops.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace py::literals;
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
void init_fft(py::module_& parent_module) {
|
||||
auto m = parent_module.def_submodule(
|
||||
"fft", "mlx.core.fft: Fast Fourier Transforms.");
|
||||
m.def(
|
||||
"fft",
|
||||
[](const array& a,
|
||||
const std::optional<int>& n,
|
||||
int axis,
|
||||
StreamOrDevice s) {
|
||||
if (n.has_value()) {
|
||||
return fft::fft(a, n.value(), axis, s);
|
||||
} else {
|
||||
return fft::fft(a, axis, s);
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
"n"_a = none,
|
||||
"axis"_a = -1,
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
One dimensional discrete Fourier Transform.
|
||||
|
||||
Args:
|
||||
a (array): The input array.
|
||||
n (int, optional): Size of the transformed axis. The
|
||||
corresponding axis in the input is truncated or padded with
|
||||
zeros to match ``n``. The default value is ``a.shape[axis]``.
|
||||
axis (int, optional): Axis along which to perform the FFT. The
|
||||
default is ``-1``.
|
||||
|
||||
Returns:
|
||||
array: The DFT of the input along the given axis.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"ifft",
|
||||
[](const array& a,
|
||||
const std::optional<int>& n,
|
||||
int axis,
|
||||
StreamOrDevice s) {
|
||||
if (n.has_value()) {
|
||||
return fft::ifft(a, n.value(), axis, s);
|
||||
} else {
|
||||
return fft::ifft(a, axis, s);
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
"n"_a = none,
|
||||
"axis"_a = -1,
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
One dimensional inverse discrete Fourier Transform.
|
||||
|
||||
Args:
|
||||
a (array): The input array.
|
||||
n (int, optional): Size of the transformed axis. The
|
||||
corresponding axis in the input is truncated or padded with
|
||||
zeros to match ``n``. The default value is ``a.shape[axis]``.
|
||||
axis (int, optional): Axis along which to perform the FFT. The
|
||||
default is ``-1``.
|
||||
|
||||
Returns:
|
||||
array: The inverse DFT of the input along the given axis.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"fft2",
|
||||
[](const array& a,
|
||||
const std::optional<std::vector<int>>& n,
|
||||
const std::optional<std::vector<int>>& axes,
|
||||
StreamOrDevice s) {
|
||||
if (axes.has_value() && n.has_value()) {
|
||||
return fft::fftn(a, n.value(), axes.value(), s);
|
||||
} else if (axes.has_value()) {
|
||||
return fft::fftn(a, axes.value(), s);
|
||||
} else if (n.has_value()) {
|
||||
std::vector<int> axes_(n.value().size());
|
||||
std::iota(axes_.begin(), axes_.end(), -n.value().size());
|
||||
return fft::fftn(a, n.value(), axes_, s);
|
||||
} else {
|
||||
return fft::fftn(a, s);
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
"s"_a = none,
|
||||
"axes"_a = std::vector<int>{-2, -1},
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
Two dimensional discrete Fourier Transform.
|
||||
|
||||
Args:
|
||||
a (array): The input array.
|
||||
s (list(int), optional): Sizes of the transformed axes. The
|
||||
corresponding axes in the input are truncated or padded with
|
||||
zeros to match the sizes in ``s``. The default value is the
|
||||
sizes of ``a`` along ``axes``.
|
||||
axes (list(int), optional): Axes along which to perform the FFT.
|
||||
The default is ``[-2, -1]``.
|
||||
|
||||
Returns:
|
||||
array: The DFT of the input along the given axes.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"ifft2",
|
||||
[](const array& a,
|
||||
const std::optional<std::vector<int>>& n,
|
||||
const std::optional<std::vector<int>>& axes,
|
||||
StreamOrDevice s) {
|
||||
if (axes.has_value() && n.has_value()) {
|
||||
return fft::ifftn(a, n.value(), axes.value(), s);
|
||||
} else if (axes.has_value()) {
|
||||
return fft::ifftn(a, axes.value(), s);
|
||||
} else if (n.has_value()) {
|
||||
std::vector<int> axes_(n.value().size());
|
||||
std::iota(axes_.begin(), axes_.end(), -n.value().size());
|
||||
return fft::ifftn(a, n.value(), axes_, s);
|
||||
} else {
|
||||
return fft::ifftn(a, s);
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
"s"_a = none,
|
||||
"axes"_a = std::vector<int>{-2, -1},
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
Two dimensional inverse discrete Fourier Transform.
|
||||
|
||||
Args:
|
||||
a (array): The input array.
|
||||
s (list(int), optional): Sizes of the transformed axes. The
|
||||
corresponding axes in the input are truncated or padded with
|
||||
zeros to match the sizes in ``s``. The default value is the
|
||||
sizes of ``a`` along ``axes``.
|
||||
axes (list(int), optional): Axes along which to perform the FFT.
|
||||
The default is ``[-2, -1]``.
|
||||
|
||||
Returns:
|
||||
array: The inverse DFT of the input along the given axes.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"fftn",
|
||||
[](const array& a,
|
||||
const std::optional<std::vector<int>>& n,
|
||||
const std::optional<std::vector<int>>& axes,
|
||||
StreamOrDevice s) {
|
||||
if (axes.has_value() && n.has_value()) {
|
||||
return fft::fftn(a, n.value(), axes.value(), s);
|
||||
} else if (axes.has_value()) {
|
||||
return fft::fftn(a, axes.value(), s);
|
||||
} else if (n.has_value()) {
|
||||
std::vector<int> axes_(n.value().size());
|
||||
std::iota(axes_.begin(), axes_.end(), -n.value().size());
|
||||
return fft::fftn(a, n.value(), axes_, s);
|
||||
} else {
|
||||
return fft::fftn(a, s);
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
"s"_a = none,
|
||||
"axes"_a = none,
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
n-dimensional discrete Fourier Transform.
|
||||
|
||||
Args:
|
||||
a (array): The input array.
|
||||
s (list(int), optional): Sizes of the transformed axes. The
|
||||
corresponding axes in the input are truncated or padded with
|
||||
zeros to match the sizes in ``s``. The default value is the
|
||||
sizes of ``a`` along ``axes``.
|
||||
axes (list(int), optional): Axes along which to perform the FFT.
|
||||
The default is ``None`` in which case the FFT is over the last
|
||||
``len(s)`` axes are or all axes if ``s`` is also ``None``.
|
||||
|
||||
Returns:
|
||||
array: The DFT of the input along the given axes.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"ifftn",
|
||||
[](const array& a,
|
||||
const std::optional<std::vector<int>>& n,
|
||||
const std::optional<std::vector<int>>& axes,
|
||||
StreamOrDevice s) {
|
||||
if (axes.has_value() && n.has_value()) {
|
||||
return fft::ifftn(a, n.value(), axes.value(), s);
|
||||
} else if (axes.has_value()) {
|
||||
return fft::ifftn(a, axes.value(), s);
|
||||
} else if (n.has_value()) {
|
||||
std::vector<int> axes_(n.value().size());
|
||||
std::iota(axes_.begin(), axes_.end(), -n.value().size());
|
||||
return fft::ifftn(a, n.value(), axes_, s);
|
||||
} else {
|
||||
return fft::ifftn(a, s);
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
"s"_a = none,
|
||||
"axes"_a = none,
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
n-dimensional inverse discrete Fourier Transform.
|
||||
|
||||
Args:
|
||||
a (array): The input array.
|
||||
s (list(int), optional): Sizes of the transformed axes. The
|
||||
corresponding axes in the input are truncated or padded with
|
||||
zeros to match the sizes in ``s``. The default value is the
|
||||
sizes of ``a`` along ``axes``.
|
||||
axes (list(int), optional): Axes along which to perform the FFT.
|
||||
The default is ``None`` in which case the FFT is over the last
|
||||
``len(s)`` axes or all axes if ``s`` is also ``None``.
|
||||
|
||||
Returns:
|
||||
array: The inverse DFT of the input along the given axes.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"rfft",
|
||||
[](const array& a,
|
||||
const std::optional<int>& n,
|
||||
int axis,
|
||||
StreamOrDevice s) {
|
||||
if (n.has_value()) {
|
||||
return fft::rfft(a, n.value(), axis, s);
|
||||
} else {
|
||||
return fft::rfft(a, axis, s);
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
"n"_a = none,
|
||||
"axis"_a = -1,
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
One dimensional discrete Fourier Transform on a real input.
|
||||
|
||||
The output has the same shape as the input except along ``axis`` in
|
||||
which case it has size ``n // 2 + 1``.
|
||||
|
||||
Args:
|
||||
a (array): The input array. If the array is complex it will be silently
|
||||
cast to a real type.
|
||||
n (int, optional): Size of the transformed axis. The
|
||||
corresponding axis in the input is truncated or padded with
|
||||
zeros to match ``n``. The default value is ``a.shape[axis]``.
|
||||
axis (int, optional): Axis along which to perform the FFT. The
|
||||
default is ``-1``.
|
||||
|
||||
Returns:
|
||||
array: The DFT of the input along the given axis. The output
|
||||
data type will be complex.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"irfft",
|
||||
[](const array& a,
|
||||
const std::optional<int>& n,
|
||||
int axis,
|
||||
StreamOrDevice s) {
|
||||
if (n.has_value()) {
|
||||
return fft::irfft(a, n.value(), axis, s);
|
||||
} else {
|
||||
return fft::irfft(a, axis, s);
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
"n"_a = none,
|
||||
"axis"_a = -1,
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
The inverse of :func:`rfft`.
|
||||
|
||||
The output has the same shape as the input except along ``axis`` in
|
||||
which case it has size ``n``.
|
||||
|
||||
Args:
|
||||
a (array): The input array.
|
||||
n (int, optional): Size of the transformed axis. The
|
||||
corresponding axis in the input is truncated or padded with
|
||||
zeros to match ``n // 2 + 1``. The default value is
|
||||
``a.shape[axis] // 2 + 1``.
|
||||
axis (int, optional): Axis along which to perform the FFT. The
|
||||
default is ``-1``.
|
||||
|
||||
Returns:
|
||||
array: The real array containing the inverse of :func:`rfft`.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"rfft2",
|
||||
[](const array& a,
|
||||
const std::optional<std::vector<int>>& n,
|
||||
const std::optional<std::vector<int>>& axes,
|
||||
StreamOrDevice s) {
|
||||
if (axes.has_value() && n.has_value()) {
|
||||
return fft::rfftn(a, n.value(), axes.value(), s);
|
||||
} else if (axes.has_value()) {
|
||||
return fft::rfftn(a, axes.value(), s);
|
||||
} else if (n.has_value()) {
|
||||
std::vector<int> axes_(n.value().size());
|
||||
std::iota(axes_.begin(), axes_.end(), -n.value().size());
|
||||
return fft::rfftn(a, n.value(), axes_, s);
|
||||
} else {
|
||||
return fft::rfftn(a, s);
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
"s"_a = none,
|
||||
"axes"_a = std::vector<int>{-2, -1},
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
Two dimensional real discrete Fourier Transform.
|
||||
|
||||
The output has the same shape as the input except along the dimensions in
|
||||
``axes`` in which case it has sizes from ``s``. The last axis in ``axes`` is
|
||||
treated as the real axis and will have size ``s[-1] // 2 + 1``.
|
||||
|
||||
Args:
|
||||
a (array): The input array. If the array is complex it will be silently
|
||||
cast to a real type.
|
||||
s (list(int), optional): Sizes of the transformed axes. The
|
||||
corresponding axes in the input are truncated or padded with
|
||||
zeros to match the sizes in ``s``. The default value is the
|
||||
sizes of ``a`` along ``axes``.
|
||||
axes (list(int), optional): Axes along which to perform the FFT.
|
||||
The default is ``[-2, -1]``.
|
||||
|
||||
Returns:
|
||||
array: The real DFT of the input along the given axes. The output
|
||||
data type will be complex.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"irfft2",
|
||||
[](const array& a,
|
||||
const std::optional<std::vector<int>>& n,
|
||||
const std::optional<std::vector<int>>& axes,
|
||||
StreamOrDevice s) {
|
||||
if (axes.has_value() && n.has_value()) {
|
||||
return fft::irfftn(a, n.value(), axes.value(), s);
|
||||
} else if (axes.has_value()) {
|
||||
return fft::irfftn(a, axes.value(), s);
|
||||
} else if (n.has_value()) {
|
||||
std::vector<int> axes_(n.value().size());
|
||||
std::iota(axes_.begin(), axes_.end(), -n.value().size());
|
||||
return fft::irfftn(a, n.value(), axes_, s);
|
||||
} else {
|
||||
return fft::irfftn(a, s);
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
"s"_a = none,
|
||||
"axes"_a = std::vector<int>{-2, -1},
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
The inverse of :func:`rfft2`.
|
||||
|
||||
Note the input is generally complex. The dimensions of the input
|
||||
specified in ``axes`` are padded or truncated to match the sizes
|
||||
from ``s``. The last axis in ``axes`` is treated as the real axis
|
||||
and will have size ``s[-1] // 2 + 1``.
|
||||
|
||||
Args:
|
||||
a (array): The input array.
|
||||
s (list(int), optional): Sizes of the transformed axes. The
|
||||
corresponding axes in the input are truncated or padded with
|
||||
zeros to match the sizes in ``s`` except for the last axis
|
||||
which has size ``s[-1] // 2 + 1``. The default value is the
|
||||
sizes of ``a`` along ``axes``.
|
||||
axes (list(int), optional): Axes along which to perform the FFT.
|
||||
The default is ``[-2, -1]``.
|
||||
|
||||
Returns:
|
||||
array: The real array containing the inverse of :func:`rfft2`.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"rfftn",
|
||||
[](const array& a,
|
||||
const std::optional<std::vector<int>>& n,
|
||||
const std::optional<std::vector<int>>& axes,
|
||||
StreamOrDevice s) {
|
||||
if (axes.has_value() && n.has_value()) {
|
||||
return fft::rfftn(a, n.value(), axes.value(), s);
|
||||
} else if (axes.has_value()) {
|
||||
return fft::rfftn(a, axes.value(), s);
|
||||
} else if (n.has_value()) {
|
||||
std::vector<int> axes_(n.value().size());
|
||||
std::iota(axes_.begin(), axes_.end(), -n.value().size());
|
||||
return fft::rfftn(a, n.value(), axes_, s);
|
||||
} else {
|
||||
return fft::rfftn(a, s);
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
"s"_a = none,
|
||||
"axes"_a = none,
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
n-dimensional real discrete Fourier Transform.
|
||||
|
||||
The output has the same shape as the input except along the dimensions in
|
||||
``axes`` in which case it has sizes from ``s``. The last axis in ``axes`` is
|
||||
treated as the real axis and will have size ``s[-1] // 2 + 1``.
|
||||
|
||||
Args:
|
||||
a (array): The input array. If the array is complex it will be silently
|
||||
cast to a real type.
|
||||
s (list(int), optional): Sizes of the transformed axes. The
|
||||
corresponding axes in the input are truncated or padded with
|
||||
zeros to match the sizes in ``s``. The default value is the
|
||||
sizes of ``a`` along ``axes``.
|
||||
axes (list(int), optional): Axes along which to perform the FFT.
|
||||
The default is ``None`` in which case the FFT is over the last
|
||||
``len(s)`` axes or all axes if ``s`` is also ``None``.
|
||||
|
||||
Returns:
|
||||
array: The real DFT of the input along the given axes. The output
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"irfftn",
|
||||
[](const array& a,
|
||||
const std::optional<std::vector<int>>& n,
|
||||
const std::optional<std::vector<int>>& axes,
|
||||
StreamOrDevice s) {
|
||||
if (axes.has_value() && n.has_value()) {
|
||||
return fft::irfftn(a, n.value(), axes.value(), s);
|
||||
} else if (axes.has_value()) {
|
||||
return fft::irfftn(a, axes.value(), s);
|
||||
} else if (n.has_value()) {
|
||||
std::vector<int> axes_(n.value().size());
|
||||
std::iota(axes_.begin(), axes_.end(), -n.value().size());
|
||||
return fft::irfftn(a, n.value(), axes_, s);
|
||||
} else {
|
||||
return fft::irfftn(a, s);
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
"s"_a = none,
|
||||
"axes"_a = none,
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
The inverse of :func:`rfftn`.
|
||||
|
||||
Note the input is generally complex. The dimensions of the input
|
||||
specified in ``axes`` are padded or truncated to match the sizes
|
||||
from ``s``. The last axis in ``axes`` is treated as the real axis
|
||||
and will have size ``s[-1] // 2 + 1``.
|
||||
|
||||
Args:
|
||||
a (array): The input array.
|
||||
s (list(int), optional): Sizes of the transformed axes. The
|
||||
corresponding axes in the input are truncated or padded with
|
||||
zeros to match the sizes in ``s``. The default value is the
|
||||
sizes of ``a`` along ``axes``.
|
||||
axes (list(int), optional): Axes along which to perform the FFT.
|
||||
The default is ``None`` in which case the FFT is over the last
|
||||
``len(s)`` axes or all axes if ``s`` is also ``None``.
|
||||
|
||||
Returns:
|
||||
array: The real array containing the inverse of :func:`rfftn`.
|
||||
)pbdoc");
|
||||
}
|
||||
635
python/src/indexing.cpp
Normal file
635
python/src/indexing.cpp
Normal file
@@ -0,0 +1,635 @@
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
|
||||
#include "python/src/indexing.h"
|
||||
|
||||
#include "mlx/ops.h"
|
||||
|
||||
bool is_none_slice(const py::slice& in_slice) {
|
||||
return (
|
||||
py::getattr(in_slice, "start").is_none() &&
|
||||
py::getattr(in_slice, "stop").is_none() &&
|
||||
py::getattr(in_slice, "step").is_none());
|
||||
}
|
||||
|
||||
int get_slice_int(py::object obj, int default_val) {
|
||||
if (!obj.is_none()) {
|
||||
if (!py::isinstance<py::int_>(obj)) {
|
||||
throw std::invalid_argument("Slice indices must be integers or None.");
|
||||
}
|
||||
return py::cast<int>(py::cast<py::int_>(obj));
|
||||
}
|
||||
return default_val;
|
||||
}
|
||||
|
||||
void get_slice_params(
|
||||
int& starts,
|
||||
int& ends,
|
||||
int& strides,
|
||||
const py::slice& in_slice,
|
||||
int axis_size) {
|
||||
// Following numpy's convention
|
||||
// Assume n is the number of elements in the dimension being sliced.
|
||||
// Then, if i is not given it defaults to 0 for k > 0 and n - 1 for
|
||||
// k < 0 . If j is not given it defaults to n for k > 0 and -n-1 for
|
||||
// k < 0 . If k is not given it defaults to 1
|
||||
|
||||
strides = get_slice_int(py::getattr(in_slice, "step"), 1);
|
||||
starts = get_slice_int(
|
||||
py::getattr(in_slice, "start"), strides < 0 ? axis_size - 1 : 0);
|
||||
ends = get_slice_int(
|
||||
py::getattr(in_slice, "stop"), strides < 0 ? -axis_size - 1 : axis_size);
|
||||
|
||||
// starts = (starts < 0) ? starts + axis_size : starts;
|
||||
// ends = (ends < 0) ? ends + axis_size : ends;
|
||||
}
|
||||
|
||||
array get_int_index(py::object idx, int axis_size) {
|
||||
int idx_ = py::cast<int>(idx);
|
||||
idx_ = (idx_ < 0) ? idx_ + axis_size : idx_;
|
||||
|
||||
return array(idx_, uint32);
|
||||
}
|
||||
|
||||
bool is_valid_index_type(const py::object& obj) {
|
||||
return py::isinstance<py::slice>(obj) || py::isinstance<py::int_>(obj) ||
|
||||
py::isinstance<array>(obj) || obj.is_none() || py::ellipsis().is(obj);
|
||||
}
|
||||
|
||||
array mlx_get_item_slice(const array& src, const py::slice& in_slice) {
|
||||
// Check input and raise error if 0 dim for parity with np
|
||||
if (src.ndim() == 0) {
|
||||
throw std::invalid_argument(
|
||||
"too many indices for array: array is 0-dimensional");
|
||||
}
|
||||
|
||||
// Return a copy of the array if none slice is request
|
||||
if (is_none_slice(in_slice)) {
|
||||
return src;
|
||||
}
|
||||
|
||||
std::vector<int> starts(src.ndim(), 0);
|
||||
std::vector<int> ends = src.shape();
|
||||
std::vector<int> strides(src.ndim(), 1);
|
||||
|
||||
// Check and update slice params
|
||||
get_slice_params(starts[0], ends[0], strides[0], in_slice, ends[0]);
|
||||
return slice(src, starts, ends, strides);
|
||||
}
|
||||
|
||||
array mlx_get_item_array(const array& src, const array& indices) {
|
||||
// Check input and raise error if 0 dim for parity with np
|
||||
if (src.ndim() == 0) {
|
||||
throw std::invalid_argument(
|
||||
"too many indices for array: array is 0-dimensional");
|
||||
}
|
||||
|
||||
if (indices.dtype() == bool_) {
|
||||
throw std::invalid_argument("boolean indices are not yet supported");
|
||||
}
|
||||
|
||||
// If only one input array is mentioned, we set axis=0 in take
|
||||
// for parity with np
|
||||
return take(src, indices, 0);
|
||||
}
|
||||
|
||||
array mlx_get_item_int(const array& src, const py::int_& idx) {
|
||||
// Check input and raise error if 0 dim for parity with np
|
||||
if (src.ndim() == 0) {
|
||||
throw std::invalid_argument(
|
||||
"too many indices for array: array is 0-dimensional");
|
||||
}
|
||||
|
||||
// If only one input idx is mentioned, we set axis=0 in take
|
||||
// for parity with np
|
||||
return take(src, get_int_index(idx, src.shape(0)), 0);
|
||||
}
|
||||
|
||||
array mlx_gather_nd(
|
||||
array src,
|
||||
const std::vector<py::object>& indices,
|
||||
bool gather_first,
|
||||
int& max_dims) {
|
||||
max_dims = 0;
|
||||
std::vector<array> gather_indices;
|
||||
std::vector<bool> is_slice(indices.size(), false);
|
||||
int num_slices = 0;
|
||||
// gather all the arrays
|
||||
for (int i = 0; i < indices.size(); i++) {
|
||||
auto& idx = indices[i];
|
||||
|
||||
if (py::isinstance<py::slice>(idx)) {
|
||||
int start, end, stride;
|
||||
get_slice_params(start, end, stride, idx, src.shape(i));
|
||||
gather_indices.push_back(arange(start, end, stride, uint32));
|
||||
num_slices++;
|
||||
is_slice[i] = true;
|
||||
} else if (py::isinstance<py::int_>(idx)) {
|
||||
gather_indices.push_back(get_int_index(idx, src.shape(i)));
|
||||
} else if (py::isinstance<array>(idx)) {
|
||||
auto arr = py::cast<array>(idx);
|
||||
max_dims = std::max(static_cast<int>(arr.ndim()), max_dims);
|
||||
gather_indices.push_back(arr);
|
||||
}
|
||||
}
|
||||
|
||||
// reshape them so that the int/array indices are first
|
||||
if (gather_first) {
|
||||
int slice_index = 0;
|
||||
for (int i = 0; i < gather_indices.size(); i++) {
|
||||
if (is_slice[i]) {
|
||||
std::vector<int> index_shape(max_dims + num_slices, 1);
|
||||
index_shape[max_dims + slice_index] = gather_indices[i].shape(0);
|
||||
gather_indices[i] = reshape(gather_indices[i], index_shape);
|
||||
slice_index++;
|
||||
} else {
|
||||
std::vector<int> index_shape = gather_indices[i].shape();
|
||||
index_shape.insert(index_shape.end(), num_slices, 1);
|
||||
gather_indices[i] = reshape(gather_indices[i], index_shape);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// reshape them so that the int/array indices are last
|
||||
for (int i = 0; i < gather_indices.size(); i++) {
|
||||
if (i < num_slices) {
|
||||
std::vector<int> index_shape(max_dims + num_slices, 1);
|
||||
index_shape[i] = gather_indices[i].shape(0);
|
||||
gather_indices[i] = reshape(gather_indices[i], index_shape);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Do the gather
|
||||
std::vector<int> axes(indices.size());
|
||||
std::iota(axes.begin(), axes.end(), 0);
|
||||
std::vector<int> slice_sizes = src.shape();
|
||||
std::fill(slice_sizes.begin(), slice_sizes.begin() + indices.size(), 1);
|
||||
src = gather(src, gather_indices, axes, slice_sizes);
|
||||
|
||||
// Squeeze the dims
|
||||
std::vector<int> out_shape;
|
||||
out_shape.insert(
|
||||
out_shape.end(),
|
||||
src.shape().begin(),
|
||||
src.shape().begin() + max_dims + num_slices);
|
||||
out_shape.insert(
|
||||
out_shape.end(),
|
||||
src.shape().begin() + max_dims + num_slices + indices.size(),
|
||||
src.shape().end());
|
||||
src = reshape(src, out_shape);
|
||||
|
||||
return src;
|
||||
}
|
||||
|
||||
array mlx_get_item_nd(array src, const py::tuple& entries) {
|
||||
// No indices make this a noop
|
||||
if (entries.size() == 0) {
|
||||
return src;
|
||||
}
|
||||
|
||||
// The plan is as follows:
|
||||
// 1. Replace the ellipsis with a series of slice(None)
|
||||
// 2. Loop over the indices and calculate the gather indices
|
||||
// 3. Calculate the remaining slices and reshapes
|
||||
|
||||
// Ellipsis handling
|
||||
std::vector<py::object> indices;
|
||||
{
|
||||
int non_none_indices_before = 0;
|
||||
int non_none_indices_after = 0;
|
||||
std::vector<py::object> r_indices;
|
||||
int i = 0;
|
||||
for (; i < entries.size(); i++) {
|
||||
auto idx = entries[i];
|
||||
if (!is_valid_index_type(idx)) {
|
||||
throw std::invalid_argument(
|
||||
"Cannot index mlx array using the given type yet");
|
||||
}
|
||||
if (!py::ellipsis().is(idx)) {
|
||||
indices.push_back(idx);
|
||||
non_none_indices_before += !idx.is_none();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
for (int j = entries.size() - 1; j > i; j--) {
|
||||
auto idx = entries[j];
|
||||
if (!is_valid_index_type(idx)) {
|
||||
throw std::invalid_argument(
|
||||
"Cannot index mlx array using the given type yet");
|
||||
}
|
||||
if (py::ellipsis().is(idx)) {
|
||||
throw std::invalid_argument(
|
||||
"An index can only have a single ellipsis (...)");
|
||||
}
|
||||
r_indices.push_back(idx);
|
||||
non_none_indices_after += !idx.is_none();
|
||||
}
|
||||
for (int axis = non_none_indices_before;
|
||||
axis < src.ndim() - non_none_indices_after;
|
||||
axis++) {
|
||||
indices.push_back(py::slice(0, src.shape(axis), 1));
|
||||
}
|
||||
indices.insert(indices.end(), r_indices.rbegin(), r_indices.rend());
|
||||
}
|
||||
|
||||
// Check for the number of indices passed
|
||||
{
|
||||
int cnt = src.ndim();
|
||||
for (auto& idx : indices) {
|
||||
if (!idx.is_none()) {
|
||||
cnt--;
|
||||
}
|
||||
}
|
||||
if (cnt < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "Too many indices for array with " << src.ndim() << "dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
// Gather handling
|
||||
//
|
||||
// Check whether we have arrays or integer indices and delegate to gather_nd
|
||||
// after removing the slices at the end and all Nones.
|
||||
std::vector<py::object> remaining_indices;
|
||||
bool have_array = false;
|
||||
{
|
||||
// First check whether the results of gather are going to be 1st or
|
||||
// normally in between.
|
||||
bool have_non_array = false;
|
||||
bool gather_first = false;
|
||||
for (auto& idx : indices) {
|
||||
if (py::isinstance<array>(idx) || py::isinstance<py::int_>(idx)) {
|
||||
if (have_array && have_non_array) {
|
||||
gather_first = true;
|
||||
break;
|
||||
}
|
||||
have_array = true;
|
||||
} else {
|
||||
have_non_array |= have_array;
|
||||
}
|
||||
}
|
||||
|
||||
if (have_array) {
|
||||
int last_array;
|
||||
// Then find the last array
|
||||
for (last_array = indices.size() - 1; last_array >= 0; last_array--) {
|
||||
auto& idx = indices[last_array];
|
||||
if (py::isinstance<array>(idx) || py::isinstance<py::int_>(idx)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<py::object> gather_indices;
|
||||
for (int i = 0; i <= last_array; i++) {
|
||||
auto& idx = indices[i];
|
||||
if (!idx.is_none()) {
|
||||
gather_indices.push_back(idx);
|
||||
}
|
||||
}
|
||||
int max_dims;
|
||||
src = mlx_gather_nd(src, gather_indices, gather_first, max_dims);
|
||||
|
||||
// Reassemble the indices for the slicing or reshaping if there are any
|
||||
if (gather_first) {
|
||||
for (int i = 0; i < max_dims; i++) {
|
||||
remaining_indices.push_back(
|
||||
py::slice(py::none(), py::none(), py::none()));
|
||||
}
|
||||
for (int i = 0; i < last_array; i++) {
|
||||
auto& idx = indices[i];
|
||||
if (idx.is_none()) {
|
||||
remaining_indices.push_back(indices[i]);
|
||||
} else if (py::isinstance<py::slice>(idx)) {
|
||||
remaining_indices.push_back(
|
||||
py::slice(py::none(), py::none(), py::none()));
|
||||
}
|
||||
}
|
||||
for (int i = last_array + 1; i < indices.size(); i++) {
|
||||
remaining_indices.push_back(indices[i]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < indices.size(); i++) {
|
||||
auto& idx = indices[i];
|
||||
if (py::isinstance<array>(idx) || py::isinstance<py::int_>(idx)) {
|
||||
break;
|
||||
} else if (idx.is_none()) {
|
||||
remaining_indices.push_back(idx);
|
||||
} else {
|
||||
remaining_indices.push_back(
|
||||
py::slice(py::none(), py::none(), py::none()));
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < max_dims; i++) {
|
||||
remaining_indices.push_back(
|
||||
py::slice(py::none(), py::none(), py::none()));
|
||||
}
|
||||
for (int i = last_array + 1; i < indices.size(); i++) {
|
||||
remaining_indices.push_back(indices[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (have_array && remaining_indices.empty()) {
|
||||
return src;
|
||||
}
|
||||
if (remaining_indices.empty()) {
|
||||
remaining_indices = indices;
|
||||
}
|
||||
|
||||
// Slice handling
|
||||
{
|
||||
std::vector<int> starts(src.ndim(), 0);
|
||||
std::vector<int> ends = src.shape();
|
||||
std::vector<int> strides(src.ndim(), 1);
|
||||
int axis = 0;
|
||||
for (auto& idx : remaining_indices) {
|
||||
if (!idx.is_none()) {
|
||||
get_slice_params(
|
||||
starts[axis], ends[axis], strides[axis], idx, ends[axis]);
|
||||
axis++;
|
||||
}
|
||||
}
|
||||
src = slice(src, starts, ends, strides);
|
||||
}
|
||||
|
||||
// Unsqueeze handling
|
||||
if (remaining_indices.size() > src.ndim()) {
|
||||
std::vector<int> out_shape;
|
||||
int axis = 0;
|
||||
for (auto& idx : remaining_indices) {
|
||||
if (idx.is_none()) {
|
||||
out_shape.push_back(1);
|
||||
} else {
|
||||
out_shape.push_back(src.shape(axis++));
|
||||
}
|
||||
}
|
||||
src = reshape(src, out_shape);
|
||||
}
|
||||
|
||||
return src;
|
||||
}
|
||||
|
||||
array mlx_get_item(const array& src, const py::object& obj) {
|
||||
if (py::isinstance<py::slice>(obj)) {
|
||||
return mlx_get_item_slice(src, obj);
|
||||
} else if (py::isinstance<array>(obj)) {
|
||||
return mlx_get_item_array(src, py::cast<array>(obj));
|
||||
} else if (py::isinstance<py::int_>(obj)) {
|
||||
return mlx_get_item_int(src, obj);
|
||||
} else if (py::isinstance<py::tuple>(obj)) {
|
||||
return mlx_get_item_nd(src, obj);
|
||||
} else if (obj.is_none()) {
|
||||
std::vector<int> s(1, 1);
|
||||
s.insert(s.end(), src.shape().begin(), src.shape().end());
|
||||
return reshape(src, s);
|
||||
}
|
||||
throw std::invalid_argument("Cannot index mlx array using the given type.");
|
||||
}
|
||||
|
||||
array mlx_set_item_int(
|
||||
const array& src,
|
||||
const py::int_& idx,
|
||||
const array& update) {
|
||||
if (src.ndim() == 0) {
|
||||
throw std::invalid_argument(
|
||||
"too many indices for array: array is 0-dimensional");
|
||||
}
|
||||
|
||||
// Remove any leading singleton dimensions from the update
|
||||
// and then broadcast update to shape of src[0, ...]
|
||||
int s = 0;
|
||||
for (; s < update.ndim() && update.shape(s) == 1; s++)
|
||||
;
|
||||
auto up_shape =
|
||||
std::vector<int>(update.shape().begin() + s, update.shape().end());
|
||||
auto shape = src.shape();
|
||||
shape[0] = 1;
|
||||
return scatter(
|
||||
src,
|
||||
get_int_index(idx, src.shape(0)),
|
||||
broadcast_to(reshape(update, up_shape), shape),
|
||||
0);
|
||||
}
|
||||
|
||||
array mlx_set_item_array(
|
||||
const array& src,
|
||||
const array& indices,
|
||||
const array& update) {
|
||||
if (src.ndim() == 0) {
|
||||
throw std::invalid_argument(
|
||||
"too many indices for array: array is 0-dimensional");
|
||||
}
|
||||
|
||||
// Remove any leading singleton dimensions from the update
|
||||
int s = 0;
|
||||
for (; s < update.ndim() && update.shape(s) == 1; s++)
|
||||
;
|
||||
auto up_shape =
|
||||
std::vector<int>(update.shape().begin() + s, update.shape().end());
|
||||
auto up = reshape(update, up_shape);
|
||||
|
||||
// The update shape must broadcast with indices.shape + [1] + src.shape[1:]
|
||||
up_shape = indices.shape();
|
||||
up_shape.insert(up_shape.end(), src.shape().begin() + 1, src.shape().end());
|
||||
up = broadcast_to(up, up_shape);
|
||||
up_shape.insert(up_shape.begin() + indices.ndim(), 1);
|
||||
up = reshape(up, up_shape);
|
||||
|
||||
return scatter(src, indices, up, 0);
|
||||
}
|
||||
|
||||
array mlx_set_item_slice(
|
||||
const array& src,
|
||||
const py::slice& in_slice,
|
||||
const array& update) {
|
||||
// Check input and raise error if 0 dim for parity with np
|
||||
if (src.ndim() == 0) {
|
||||
throw std::invalid_argument(
|
||||
"too many indices for array: array is 0-dimensional");
|
||||
}
|
||||
|
||||
// If none slice is requested broadcast the update
|
||||
// to the src size and return it.
|
||||
if (is_none_slice(in_slice)) {
|
||||
int s = 0;
|
||||
for (; s < update.ndim() && update.shape(s) == 1; s++)
|
||||
;
|
||||
auto up_shape =
|
||||
std::vector<int>(update.shape().begin() + s, update.shape().end());
|
||||
return broadcast_to(reshape(update, up_shape), src.shape());
|
||||
}
|
||||
|
||||
int start = 0;
|
||||
int end = src.shape(0);
|
||||
int stride = 1;
|
||||
|
||||
// Check and update slice params
|
||||
get_slice_params(start, end, stride, in_slice, end);
|
||||
|
||||
return mlx_set_item_array(src, arange(start, end, stride, uint32), update);
|
||||
}
|
||||
|
||||
array mlx_set_item_nd(
|
||||
const array& src,
|
||||
const py::tuple& entries,
|
||||
const array& update) {
|
||||
std::vector<py::object> indices;
|
||||
int non_none_indices = 0;
|
||||
|
||||
// Expand ellipses into a series of ':' slices
|
||||
{
|
||||
int non_none_indices_before = 0;
|
||||
int non_none_indices_after = 0;
|
||||
bool has_ellipsis = false;
|
||||
int indices_before = 0;
|
||||
for (int i = 0; i < entries.size(); ++i) {
|
||||
auto idx = entries[i];
|
||||
if (!is_valid_index_type(idx)) {
|
||||
throw std::invalid_argument(
|
||||
"Cannot index mlx array using the given type yet");
|
||||
} else if (!py::ellipsis().is(idx)) {
|
||||
if (!has_ellipsis) {
|
||||
indices_before++;
|
||||
non_none_indices_before += !idx.is_none();
|
||||
} else {
|
||||
non_none_indices_after += !idx.is_none();
|
||||
}
|
||||
indices.push_back(idx);
|
||||
} else if (has_ellipsis) {
|
||||
throw std::invalid_argument(
|
||||
"An index can only have a single ellipsis (...)");
|
||||
} else {
|
||||
has_ellipsis = true;
|
||||
}
|
||||
}
|
||||
if (has_ellipsis) {
|
||||
for (int axis = non_none_indices_before;
|
||||
axis < src.ndim() - non_none_indices_after;
|
||||
axis++) {
|
||||
indices.insert(
|
||||
indices.begin() + indices_before, py::slice(0, src.shape(axis), 1));
|
||||
}
|
||||
non_none_indices = src.ndim();
|
||||
} else {
|
||||
non_none_indices = non_none_indices_before + non_none_indices_after;
|
||||
}
|
||||
}
|
||||
|
||||
if (non_none_indices > src.ndim()) {
|
||||
std::ostringstream msg;
|
||||
msg << "Too many indices for array with " << src.ndim() << "dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
// Remove leading singletons dimensions from the update
|
||||
int s = 0;
|
||||
for (; s < update.ndim() && update.shape(s) == 1; s++) {
|
||||
};
|
||||
auto up_shape =
|
||||
std::vector<int>(update.shape().begin() + s, update.shape().end());
|
||||
auto up = reshape(update, up_shape);
|
||||
|
||||
// If no non-None indices return the broadcasted update
|
||||
if (non_none_indices == 0) {
|
||||
return broadcast_to(up, src.shape());
|
||||
}
|
||||
|
||||
unsigned long max_dim = 0;
|
||||
bool arrays_first = false;
|
||||
int num_slices = 0;
|
||||
int num_arrays = 0;
|
||||
{
|
||||
bool have_array = false;
|
||||
bool have_non_array = false;
|
||||
for (auto& idx : indices) {
|
||||
if (py::isinstance<py::slice>(idx) || idx.is_none()) {
|
||||
have_non_array = have_array;
|
||||
num_slices++;
|
||||
} else if (py::isinstance<array>(idx)) {
|
||||
have_array = true;
|
||||
if (have_array && have_non_array) {
|
||||
arrays_first = true;
|
||||
}
|
||||
max_dim = std::max(py::cast<array>(idx).ndim(), max_dim);
|
||||
num_arrays++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<array> arr_indices;
|
||||
int slice_num = 0;
|
||||
int array_num = 0;
|
||||
int ax = 0;
|
||||
for (int i = 0; i < indices.size(); ++i) {
|
||||
auto& pyidx = indices[i];
|
||||
if (py::isinstance<py::slice>(pyidx)) {
|
||||
int start, end, stride;
|
||||
get_slice_params(start, end, stride, pyidx, src.shape(ax++));
|
||||
auto idx = arange(start, end, stride, uint32);
|
||||
std::vector<int> idx_shape(max_dim + num_slices, 1);
|
||||
auto loc = slice_num + (arrays_first ? max_dim : 0);
|
||||
slice_num++;
|
||||
idx_shape[loc] = idx.size();
|
||||
arr_indices.push_back(reshape(idx, idx_shape));
|
||||
} else if (py::isinstance<py::int_>(pyidx)) {
|
||||
arr_indices.push_back(get_int_index(pyidx, src.shape(ax++)));
|
||||
} else if (pyidx.is_none()) {
|
||||
slice_num++;
|
||||
} else if (py::isinstance<array>(pyidx)) {
|
||||
ax++;
|
||||
auto idx = py::cast<array>(pyidx);
|
||||
std::vector<int> idx_shape;
|
||||
if (!arrays_first) {
|
||||
idx_shape.insert(idx_shape.end(), slice_num, 1);
|
||||
}
|
||||
idx_shape.insert(idx_shape.end(), max_dim - idx.ndim(), 1);
|
||||
idx_shape.insert(idx_shape.end(), idx.shape().begin(), idx.shape().end());
|
||||
idx_shape.insert(
|
||||
idx_shape.end(), num_slices - (arrays_first ? 0 : slice_num), 1);
|
||||
arr_indices.push_back(reshape(idx, idx_shape));
|
||||
if (!arrays_first && ++array_num == num_arrays) {
|
||||
slice_num += max_dim;
|
||||
}
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"Cannot index mlx array using the given type yet");
|
||||
}
|
||||
}
|
||||
|
||||
arr_indices = broadcast_arrays(arr_indices);
|
||||
up_shape = arr_indices[0].shape();
|
||||
up_shape.insert(
|
||||
up_shape.end(),
|
||||
src.shape().begin() + non_none_indices,
|
||||
src.shape().end());
|
||||
up = broadcast_to(up, up_shape);
|
||||
up_shape.insert(
|
||||
up_shape.begin() + arr_indices[0].ndim(), non_none_indices, 1);
|
||||
up = reshape(up, up_shape);
|
||||
|
||||
std::vector<int> axes(arr_indices.size(), 0);
|
||||
std::iota(axes.begin(), axes.end(), 0);
|
||||
return scatter(src, arr_indices, up, axes);
|
||||
}
|
||||
|
||||
void mlx_set_item(array& src, const py::object& obj, const ScalarOrArray& v) {
|
||||
auto vals = to_array(v, src.dtype());
|
||||
auto impl = [&src, &obj, &vals]() {
|
||||
if (py::isinstance<py::slice>(obj)) {
|
||||
return mlx_set_item_slice(src, obj, vals);
|
||||
} else if (py::isinstance<array>(obj)) {
|
||||
return mlx_set_item_array(src, py::cast<array>(obj), vals);
|
||||
} else if (py::isinstance<py::int_>(obj)) {
|
||||
return mlx_set_item_int(src, obj, vals);
|
||||
} else if (py::isinstance<py::tuple>(obj)) {
|
||||
return mlx_set_item_nd(src, obj, vals);
|
||||
} else if (obj.is_none()) {
|
||||
return broadcast_to(vals, src.shape());
|
||||
}
|
||||
throw std::invalid_argument("Cannot index mlx array using the given type.");
|
||||
};
|
||||
auto out = impl();
|
||||
src.overwrite_descriptor(out);
|
||||
}
|
||||
12
python/src/indexing.h
Normal file
12
python/src/indexing.h
Normal file
@@ -0,0 +1,12 @@
|
||||
#pragma once
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "python/src/utils.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace mlx::core;
|
||||
|
||||
array mlx_get_item(const array& src, const py::object& obj);
|
||||
void mlx_set_item(array& src, const py::object& obj, const ScalarOrArray& v);
|
||||
290
python/src/load.cpp
Normal file
290
python/src/load.cpp
Normal file
@@ -0,0 +1,290 @@
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "mlx/load.h"
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/utils.h"
|
||||
#include "python/src/load.h"
|
||||
#include "python/src/utils.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace py::literals;
|
||||
using namespace mlx::core;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Helpers
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
bool is_istream_object(const py::object& file) {
|
||||
return py::hasattr(file, "read") && py::hasattr(file, "seek") &&
|
||||
py::hasattr(file, "tell") && py::hasattr(file, "closed");
|
||||
}
|
||||
|
||||
bool is_ostream_object(const py::object& file) {
|
||||
return py::hasattr(file, "write") && py::hasattr(file, "seek") &&
|
||||
py::hasattr(file, "tell") && py::hasattr(file, "closed");
|
||||
}
|
||||
|
||||
bool is_zip_file(const py::module_& zipfile, const py::object& file) {
|
||||
if (is_istream_object(file)) {
|
||||
auto st_pos = file.attr("tell")();
|
||||
bool r = (zipfile.attr("is_zipfile")(file)).cast<bool>();
|
||||
file.attr("seek")(st_pos, 0);
|
||||
return r;
|
||||
}
|
||||
return zipfile.attr("is_zipfile")(file).cast<bool>();
|
||||
}
|
||||
|
||||
class ZipFileWrapper {
|
||||
public:
|
||||
ZipFileWrapper(
|
||||
const py::module_& zipfile,
|
||||
const py::object& file,
|
||||
char mode = 'r',
|
||||
int compression = 0)
|
||||
: zipfile_module_(zipfile),
|
||||
zipfile_object_(zipfile.attr("ZipFile")(
|
||||
file,
|
||||
"mode"_a = mode,
|
||||
"compression"_a = compression,
|
||||
"allowZip64"_a = true)),
|
||||
files_list_(zipfile_object_.attr("namelist")()),
|
||||
open_func_(zipfile_object_.attr("open")),
|
||||
read_func_(zipfile_object_.attr("read")),
|
||||
close_func_(zipfile_object_.attr("close")) {}
|
||||
|
||||
std::vector<std::string> namelist() const {
|
||||
return files_list_.cast<std::vector<std::string>>();
|
||||
}
|
||||
|
||||
py::object open(const std::string& key, char mode = 'r') {
|
||||
// Following numpy :
|
||||
// https://github.com/numpy/numpy/blob/db4f43983cb938f12c311e1f5b7165e270c393b4/numpy/lib/npyio.py#L742C36-L742C47
|
||||
if (mode == 'w') {
|
||||
return open_func_(key, "mode"_a = mode, "force_zip64"_a = true);
|
||||
}
|
||||
return open_func_(key, "mode"_a = mode);
|
||||
}
|
||||
|
||||
private:
|
||||
py::module_ zipfile_module_;
|
||||
py::object zipfile_object_;
|
||||
py::list files_list_;
|
||||
py::object open_func_;
|
||||
py::object read_func_;
|
||||
py::object close_func_;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Loading
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
class PyFileReader : public io::Reader {
|
||||
public:
|
||||
PyFileReader(py::object file)
|
||||
: pyistream_(file),
|
||||
readinto_func_(file.attr("readinto")),
|
||||
seek_func_(file.attr("seek")),
|
||||
tell_func_(file.attr("tell")) {}
|
||||
|
||||
bool is_open() const override {
|
||||
return !pyistream_.attr("closed").cast<bool>();
|
||||
}
|
||||
|
||||
bool good() const override {
|
||||
return !pyistream_.is_none();
|
||||
}
|
||||
|
||||
size_t tell() const override {
|
||||
return tell_func_().cast<size_t>();
|
||||
}
|
||||
|
||||
void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
|
||||
override {
|
||||
seek_func_(off, (int)way);
|
||||
}
|
||||
|
||||
void read(char* data, size_t n) override {
|
||||
py::object bytes_read =
|
||||
readinto_func_(py::memoryview::from_buffer(data, {n}, {sizeof(char)}));
|
||||
if (bytes_read.is_none() || py::cast<size_t>(bytes_read) < n) {
|
||||
throw std::runtime_error("[load] Failed to read from python stream");
|
||||
}
|
||||
}
|
||||
|
||||
std::string label() const override {
|
||||
return "python file object";
|
||||
}
|
||||
|
||||
private:
|
||||
py::object pyistream_;
|
||||
py::object readinto_func_;
|
||||
py::object seek_func_;
|
||||
py::object tell_func_;
|
||||
};
|
||||
|
||||
DictOrArray mlx_load_helper(py::object file, StreamOrDevice s) {
|
||||
py::module_ zipfile = py::module_::import("zipfile");
|
||||
|
||||
// Assume .npz file if it is zipped
|
||||
if (is_zip_file(zipfile, file)) {
|
||||
// Output dictionary filename in zip -> loaded array
|
||||
std::unordered_map<std::string, array> array_dict;
|
||||
|
||||
// Create python ZipFile object
|
||||
ZipFileWrapper zipfile_object(zipfile, file);
|
||||
for (const std::string& st : zipfile_object.namelist()) {
|
||||
// Open zip file as a python file stream
|
||||
py::object sub_file = zipfile_object.open(st);
|
||||
|
||||
// Create array from python fille stream
|
||||
auto arr = load(std::make_shared<PyFileReader>(sub_file), s);
|
||||
|
||||
// Remove .npy from file if it is there
|
||||
auto key = st;
|
||||
if (st.length() > 4 && st.substr(st.length() - 4, 4) == ".npy")
|
||||
key = st.substr(0, st.length() - 4);
|
||||
|
||||
// Add array to dict
|
||||
array_dict.insert({key, arr});
|
||||
}
|
||||
|
||||
// If we don't own the stream and it was passed to us, eval immediately
|
||||
for (auto& [key, arr] : array_dict) {
|
||||
arr.eval();
|
||||
}
|
||||
|
||||
return {array_dict};
|
||||
} else if (py::isinstance<py::str>(file)) { // Assume .npy file path string
|
||||
return {load(py::cast<std::string>(file), s)};
|
||||
} else if (is_istream_object(file)) {
|
||||
// If we don't own the stream and it was passed to us, eval immediately
|
||||
auto arr = load(std::make_shared<PyFileReader>(file), s);
|
||||
arr.eval();
|
||||
return {arr};
|
||||
}
|
||||
|
||||
throw std::invalid_argument(
|
||||
"[load] Input must be a file-like object, string, or pathlib.Path");
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Saving
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
class PyFileWriter : public io::Writer {
|
||||
public:
|
||||
PyFileWriter(py::object file)
|
||||
: pyostream_(file),
|
||||
write_func_(file.attr("write")),
|
||||
seek_func_(file.attr("seek")),
|
||||
tell_func_(file.attr("tell")) {}
|
||||
|
||||
bool is_open() const override {
|
||||
return !pyostream_.attr("closed").cast<bool>();
|
||||
}
|
||||
|
||||
bool good() const override {
|
||||
return !pyostream_.is_none();
|
||||
}
|
||||
|
||||
size_t tell() const override {
|
||||
return tell_func_().cast<size_t>();
|
||||
}
|
||||
|
||||
void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
|
||||
override {
|
||||
seek_func_(off, (int)way);
|
||||
}
|
||||
|
||||
void write(const char* data, size_t n) override {
|
||||
py::object bytes_written =
|
||||
write_func_(py::memoryview::from_buffer(data, {n}, {sizeof(char)}));
|
||||
if (bytes_written.is_none() || py::cast<size_t>(bytes_written) < n) {
|
||||
throw std::runtime_error("[load] Failed to write to python stream");
|
||||
}
|
||||
}
|
||||
|
||||
std::string label() const override {
|
||||
return "python file object";
|
||||
}
|
||||
|
||||
private:
|
||||
py::object pyostream_;
|
||||
py::object write_func_;
|
||||
py::object seek_func_;
|
||||
py::object tell_func_;
|
||||
};
|
||||
|
||||
void mlx_save_helper(py::object file, array a, bool retain_graph) {
|
||||
if (py::isinstance<py::str>(file)) {
|
||||
save(py::cast<std::string>(file), a, retain_graph);
|
||||
return;
|
||||
} else if (is_ostream_object(file)) {
|
||||
save(std::make_shared<PyFileWriter>(file), a, retain_graph);
|
||||
return;
|
||||
}
|
||||
|
||||
throw std::invalid_argument(
|
||||
"[save] Input must be a file-like object, string, or pathlib.Path");
|
||||
}
|
||||
|
||||
void mlx_savez_helper(
|
||||
py::object file_,
|
||||
py::args args,
|
||||
const py::kwargs& kwargs,
|
||||
bool compressed) {
|
||||
// Add .npz to the end of the filename if not already there
|
||||
py::object file = file_;
|
||||
|
||||
if (py::isinstance<py::str>(file_)) {
|
||||
std::string fname = file_.cast<std::string>();
|
||||
|
||||
// Add .npz to file name if it is not there
|
||||
if (fname.length() < 4 || fname.substr(fname.length() - 4, 4) != ".npz")
|
||||
fname += ".npz";
|
||||
|
||||
file = py::str(fname);
|
||||
}
|
||||
|
||||
// Collect args and kwargs
|
||||
auto arrays_dict = kwargs.cast<std::unordered_map<std::string, array>>();
|
||||
auto arrays_list = args.cast<std::vector<array>>();
|
||||
|
||||
for (int i = 0; i < arrays_list.size(); i++) {
|
||||
std::string arr_name = "arr_" + std::to_string(i);
|
||||
|
||||
if (arrays_dict.count(arr_name) > 0) {
|
||||
throw std::invalid_argument(
|
||||
"[savez] Cannot use un-named variables and keyword " + arr_name);
|
||||
}
|
||||
|
||||
arrays_dict.insert({arr_name, arrays_list[i]});
|
||||
}
|
||||
|
||||
// Create python ZipFile object depending on compression
|
||||
py::module_ zipfile = py::module_::import("zipfile");
|
||||
int compression = compressed ? zipfile.attr("ZIP_DEFLATED").cast<int>()
|
||||
: zipfile.attr("ZIP_STORED").cast<int>();
|
||||
char mode = 'w';
|
||||
ZipFileWrapper zipfile_object(zipfile, file, mode, compression);
|
||||
|
||||
// Save each array
|
||||
for (auto [k, a] : arrays_dict) {
|
||||
std::string fname = k + ".npy";
|
||||
auto py_ostream = zipfile_object.open(fname, 'w');
|
||||
save(std::make_shared<PyFileWriter>(py_ostream), a);
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
2422
python/src/ops.cpp
Normal file
2422
python/src/ops.cpp
Normal file
File diff suppressed because it is too large
Load Diff
289
python/src/random.cpp
Normal file
289
python/src/random.cpp
Normal file
@@ -0,0 +1,289 @@
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#include "python/src/utils.h"
|
||||
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/random.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace py::literals;
|
||||
using namespace mlx::core;
|
||||
using namespace mlx::core::random;
|
||||
|
||||
void init_random(py::module_& parent_module) {
|
||||
auto m = parent_module.def_submodule(
|
||||
"random",
|
||||
"mlx.core.random: functionality related to random number generation");
|
||||
m.def(
|
||||
"seed",
|
||||
&seed,
|
||||
"seed"_a,
|
||||
R"pbdoc(
|
||||
Seed the global PRNG.
|
||||
|
||||
Args:
|
||||
seed (int): Seed for the global PRNG.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"key",
|
||||
&key,
|
||||
"seed"_a,
|
||||
R"pbdoc(
|
||||
Get a PRNG key from a seed.
|
||||
|
||||
Args:
|
||||
seed (int): Seed for the PRNG.
|
||||
|
||||
Returns:
|
||||
array: The PRNG key array.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"split",
|
||||
py::overload_cast<const array&, int, StreamOrDevice>(&random::split),
|
||||
"key"_a,
|
||||
"num"_a = 2,
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
Split a PRNG key into sub keys.
|
||||
|
||||
Args:
|
||||
key (array): Input key to split.
|
||||
num (int, optional): Number of sub keys. Default is 2.
|
||||
|
||||
Returns:
|
||||
array: The array of sub keys with ``num`` as its first dimension.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"uniform",
|
||||
[](const ScalarOrArray& low,
|
||||
const ScalarOrArray& high,
|
||||
const std::vector<int>& shape,
|
||||
Dtype type,
|
||||
const std::optional<array>& key,
|
||||
StreamOrDevice s) {
|
||||
return uniform(to_array(low), to_array(high), shape, type, key, s);
|
||||
},
|
||||
"low"_a = 0,
|
||||
"high"_a = 1,
|
||||
"shape"_a = std::vector<int>{},
|
||||
"dtype"_a = float32,
|
||||
"key"_a = none,
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
Generate uniformly distributed random numbers.
|
||||
|
||||
The values are sampled uniformly in the half-open interval ``[low, high)``.
|
||||
The lower and upper bound can be scalars or arrays and must be
|
||||
broadcastable to ``shape``.
|
||||
|
||||
Args:
|
||||
low (scalar or array, optional): Lower bound of the distribution. Default is ``0``.
|
||||
high (scalar or array, optional): Upper bound of the distribution. Default is ``1``.
|
||||
shape (list(int), optional): Shape of the output. Default is ``()``.
|
||||
key (array, optional): A PRNG key. Default: None.
|
||||
dtype (Dtype, optional): Type of the output. Default is ``float32``.
|
||||
|
||||
Returns:
|
||||
array: The output array random values.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"normal",
|
||||
[](const std::vector<int>& shape,
|
||||
Dtype type,
|
||||
const std::optional<array>& key,
|
||||
StreamOrDevice s) { return normal(shape, type, key, s); },
|
||||
|
||||
"shape"_a = std::vector<int>{},
|
||||
"dtype"_a = float32,
|
||||
"key"_a = none,
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
Generate normally distributed random numbers.
|
||||
|
||||
Args:
|
||||
shape (list(int), optional): Shape of the output. Default is ``()``.
|
||||
dtype (Dtype, optional): Type of the output. Default is ``float32``.
|
||||
key (array, optional): A PRNG key. Default: None.
|
||||
|
||||
Returns:
|
||||
array: The output array of random values.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"randint",
|
||||
[](const ScalarOrArray& low,
|
||||
const ScalarOrArray& high,
|
||||
const std::vector<int>& shape,
|
||||
Dtype type,
|
||||
const std::optional<array>& key,
|
||||
StreamOrDevice s) {
|
||||
return randint(to_array(low), to_array(high), shape, type, key, s);
|
||||
},
|
||||
"low"_a,
|
||||
"high"_a,
|
||||
"shape"_a = std::vector<int>{},
|
||||
"dtype"_a = int32,
|
||||
"key"_a = none,
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
Generate random integers from the given interval.
|
||||
|
||||
The values are sampled with equal probability from the integers in
|
||||
half-open interval ``[low, high)``. The lower and upper bound can be
|
||||
scalars or arrays and must be roadcastable to ``shape``.
|
||||
|
||||
Args:
|
||||
low (scalar or array): Lower bound of the interval.
|
||||
high (scalar or array): Upper bound of the interval.
|
||||
shape (list(int), optional): Shape of the output. Defaults to ``()``.
|
||||
dtype (Dtype, optional): Type of the output. Defaults to ``int32``.
|
||||
key (array, optional): A PRNG key. Default: None.
|
||||
|
||||
Returns:
|
||||
array: The array of random integers.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"bernoulli",
|
||||
[](const ScalarOrArray& p_,
|
||||
const std::optional<std::vector<int>> shape,
|
||||
const std::optional<array>& key,
|
||||
StreamOrDevice s) {
|
||||
auto p = to_array(p_);
|
||||
if (shape.has_value()) {
|
||||
return bernoulli(p, shape.value(), key, s);
|
||||
} else {
|
||||
return bernoulli(p, key, s);
|
||||
}
|
||||
},
|
||||
"p"_a = 0.5,
|
||||
"shape"_a = none,
|
||||
"key"_a = none,
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
Generate Bernoulli random values.
|
||||
|
||||
The values are sampled from the bernoulli distribution with parameter
|
||||
``p``. The parameter ``p`` can be a :obj:`float` or :obj:`array` and
|
||||
must be broadcastable to ``shape``.
|
||||
|
||||
Args:
|
||||
p (float or array, optional): Parameter of the Bernoulli
|
||||
distribution. Default is 0.5.
|
||||
shape (list(int), optional): Shape of the output. The default
|
||||
shape is ``p.shape``.
|
||||
key (array, optional): A PRNG key. Default: None.
|
||||
|
||||
Returns:
|
||||
array: The array of random integers.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"truncated_normal",
|
||||
[](const ScalarOrArray& lower_,
|
||||
const ScalarOrArray& upper_,
|
||||
const std::optional<std::vector<int>> shape_,
|
||||
Dtype dtype,
|
||||
const std::optional<array>& key,
|
||||
StreamOrDevice s) {
|
||||
auto lower = to_array(lower_);
|
||||
auto upper = to_array(upper_);
|
||||
if (shape_.has_value()) {
|
||||
return truncated_normal(lower, upper, shape_.value(), dtype, key, s);
|
||||
} else {
|
||||
return truncated_normal(lower, upper, dtype, key, s);
|
||||
}
|
||||
},
|
||||
"lower"_a,
|
||||
"upper"_a,
|
||||
"shape"_a = none,
|
||||
"dtype"_a = float32,
|
||||
"key"_a = none,
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
Generate values from a truncated normal distribution.
|
||||
|
||||
The values are sampled from the truncated normal distribution
|
||||
on the domain ``(lower, upper)``. The bounds ``lower`` and ``upper``
|
||||
can be scalars or arrays and must be broadcastable to ``shape``.
|
||||
|
||||
Args:
|
||||
lower (scalar or array): Lower bound of the domain.
|
||||
upper (scalar or array): Upper bound of the domain.
|
||||
shape (list(int), optional): The shape of the output.
|
||||
Default is ``()``.
|
||||
dtype (Dtype, optinoal): The data type of the output.
|
||||
Default is ``float32``.
|
||||
key (array, optional): A PRNG key. Default: None.
|
||||
|
||||
Returns:
|
||||
array: The output array of random values.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"gumbel",
|
||||
&gumbel,
|
||||
"shape"_a = std::vector<int>{},
|
||||
"dtype"_a = float32,
|
||||
"stream"_a = none,
|
||||
"key"_a = none,
|
||||
R"pbdoc(
|
||||
Sample from the standard Gumbel distribution.
|
||||
|
||||
The values are sampled from a standard Gumbel distribution
|
||||
which CDF ``exp(-exp(-x))``.
|
||||
|
||||
Args:
|
||||
shape (list(int)): The shape of the output.
|
||||
key (array, optional): A PRNG key. Default: None.
|
||||
|
||||
Returns:
|
||||
array: The :class:`array` with shape ``shape`` and
|
||||
distributed according to the Gumbel distribution
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"categorical",
|
||||
[](const array& logits,
|
||||
int axis,
|
||||
const std::optional<std::vector<int>> shape,
|
||||
const std::optional<int> num_samples,
|
||||
const std::optional<array>& key,
|
||||
StreamOrDevice s) {
|
||||
if (shape.has_value() && num_samples.has_value()) {
|
||||
throw std::invalid_argument(
|
||||
"[categorical] At most one of shape or num_samples can be specified.");
|
||||
} else if (shape.has_value()) {
|
||||
return categorical(logits, axis, shape.value(), key, s);
|
||||
} else if (num_samples.has_value()) {
|
||||
return categorical(logits, axis, num_samples.value(), key, s);
|
||||
} else {
|
||||
return categorical(logits, axis, key, s);
|
||||
}
|
||||
},
|
||||
"logits"_a,
|
||||
"axis"_a = -1,
|
||||
"shape"_a = none,
|
||||
"num_samples"_a = none,
|
||||
"key"_a = none,
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
Sample from a categorical distribution.
|
||||
|
||||
The values are sampled from the categorical distribution specified by
|
||||
the unnormalized values in ``logits``. Note, at most one of ``shape``
|
||||
or ``num_samples`` can be specified. If both are ``None``, the output
|
||||
has the same shape as ``logits`` with the ``axis`` dimension removed.
|
||||
|
||||
Args:
|
||||
logits (array): The *unnormalized* categorical distribution(s).
|
||||
axis (int, optional): The axis which specifies the distribution.
|
||||
Default is ``-1``.
|
||||
shape (list(int), optional): The shape of the output. This must
|
||||
be broadcast compatable with ``logits.shape`` with the ``axis``
|
||||
dimension removed. Default: ``None``
|
||||
num_samples (int, optional): The number of samples to draw from each
|
||||
of the categorical distributions in ``logits``. The output will have
|
||||
``num_samples`` in the last dimension. Default: ``None``.
|
||||
key (array, optional): A PRNG key. Default: None.
|
||||
|
||||
Returns:
|
||||
array: The ``shape``-sized output array with type ``uint32``.
|
||||
)pbdoc");
|
||||
}
|
||||
71
python/src/utils.h
Normal file
71
python/src/utils.h
Normal file
@@ -0,0 +1,71 @@
|
||||
#pragma once
|
||||
#include <numeric>
|
||||
#include <variant>
|
||||
|
||||
#include <pybind11/complex.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
using IntOrVec = std::variant<std::monostate, int, std::vector<int>>;
|
||||
using ScalarOrArray =
|
||||
std::variant<py::bool_, py::int_, py::float_, std::complex<float>, array>;
|
||||
static constexpr std::monostate none{};
|
||||
|
||||
inline std::vector<int> get_reduce_axes(const IntOrVec& v, int dims) {
|
||||
std::vector<int> axes;
|
||||
if (std::holds_alternative<std::monostate>(v)) {
|
||||
axes.resize(dims);
|
||||
std::iota(axes.begin(), axes.end(), 0);
|
||||
} else if (auto pv = std::get_if<int>(&v); pv) {
|
||||
axes.push_back(*pv);
|
||||
} else {
|
||||
axes = std::get<std::vector<int>>(v);
|
||||
}
|
||||
return axes;
|
||||
}
|
||||
|
||||
inline array to_array(
|
||||
const ScalarOrArray& v,
|
||||
std::optional<Dtype> dtype = std::nullopt) {
|
||||
if (auto pv = std::get_if<py::bool_>(&v); pv) {
|
||||
return array(py::cast<bool>(*pv), dtype.value_or(bool_));
|
||||
} else if (auto pv = std::get_if<py::int_>(&v); pv) {
|
||||
auto out_t = dtype.value_or(int32);
|
||||
// bool_ is an exception and is always promoted
|
||||
return array(py::cast<int>(*pv), (out_t == bool_) ? int32 : out_t);
|
||||
} else if (auto pv = std::get_if<py::float_>(&v); pv) {
|
||||
auto out_t = dtype.value_or(float32);
|
||||
return array(
|
||||
py::cast<float>(*pv), is_floating_point(out_t) ? out_t : float32);
|
||||
} else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {
|
||||
return array(static_cast<complex64_t>(*pv), complex64);
|
||||
} else {
|
||||
return std::get<array>(v);
|
||||
}
|
||||
}
|
||||
|
||||
inline std::pair<array, array> to_arrays(
|
||||
const ScalarOrArray& a,
|
||||
const ScalarOrArray& b) {
|
||||
// Four cases:
|
||||
// - If both a and b are arrays leave their types alone
|
||||
// - If a is an array but b is not, treat b as a weak python type
|
||||
// - If b is an array but a is not, treat a as a weak python type
|
||||
// - If neither is an array convert to arrays but leave their types alone
|
||||
if (auto pa = std::get_if<array>(&a); pa) {
|
||||
if (auto pb = std::get_if<array>(&b); pb) {
|
||||
return {*pa, *pb};
|
||||
}
|
||||
return {*pa, to_array(b, pa->dtype())};
|
||||
} else if (auto pb = std::get_if<array>(&b); pb) {
|
||||
return {to_array(a, pb->dtype()), *pb};
|
||||
} else {
|
||||
return {to_array(a), to_array(b)};
|
||||
}
|
||||
}
|
||||
1041
python/tests/test_array.py
Normal file
1041
python/tests/test_array.py
Normal file
File diff suppressed because it is too large
Load Diff
263
python/tests/test_autograd.py
Normal file
263
python/tests/test_autograd.py
Normal file
@@ -0,0 +1,263 @@
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
import mlx_tests
|
||||
|
||||
|
||||
class TestAutograd(mlx_tests.MLXTestCase):
|
||||
def test_jvp(self):
|
||||
fun = lambda x: 2 * x
|
||||
out, dout = mx.jvp(fun, [mx.array(1.0)], [mx.array(2.0)])
|
||||
self.assertEqual(out[0].item(), 2.0)
|
||||
self.assertEqual(dout[0].item(), 4.0)
|
||||
|
||||
fun = lambda x, y: x * y
|
||||
_, out = mx.jvp(
|
||||
fun, [mx.array(4.0), mx.array(2.0)], [mx.array(3.0), mx.array(2.0)]
|
||||
)
|
||||
self.assertEqual(out[0].item(), 4.0 * 2.0 + 2.0 * 3.0)
|
||||
|
||||
fun = lambda x, y, z: (x * y, y * z)
|
||||
_, out = mx.jvp(
|
||||
fun,
|
||||
[mx.array(2.0), mx.array(4.0), mx.array(6.0)],
|
||||
[mx.array(1.0), mx.array(3.0), mx.array(1.0)],
|
||||
)
|
||||
self.assertEqual(len(out), 2)
|
||||
self.assertEqual(out[0].item(), 4.0 * 1.0 + 2.0 * 3.0)
|
||||
self.assertEqual(out[1].item(), 4.0 * 1.0 + 6.0 * 3.0)
|
||||
|
||||
def test_vjp(self):
|
||||
fun = lambda x: 2 * x
|
||||
out, dout = mx.vjp(fun, [mx.array(1.0)], [mx.array(2.0)])
|
||||
self.assertEqual(out[0].item(), 2.0)
|
||||
self.assertEqual(dout[0].item(), 4.0)
|
||||
|
||||
fun = lambda x, y: x * y
|
||||
_, dout = mx.vjp(fun, [mx.array(4.0), mx.array(2.0)], [mx.array(3.0)])
|
||||
self.assertEqual(dout[0].item(), 6.0)
|
||||
self.assertEqual(dout[1].item(), 12.0)
|
||||
|
||||
fun = lambda x, y, z: (x * y, y * z)
|
||||
_, out = mx.vjp(
|
||||
fun,
|
||||
[mx.array(2.0), mx.array(4.0), mx.array(6.0)],
|
||||
[mx.array(1.0), mx.array(3.0)],
|
||||
)
|
||||
self.assertEqual(len(out), 3)
|
||||
self.assertEqual(out[0].item(), 4.0 * 1.0)
|
||||
self.assertEqual(out[1].item(), 2.0 * 1.0 + 6.0 * 3.0)
|
||||
self.assertEqual(out[2].item(), 4.0 * 3.0)
|
||||
|
||||
def test_grad(self):
|
||||
fun = lambda x: x * x
|
||||
|
||||
value, dfdx = mx.value_and_grad(fun)(mx.array(0.5))
|
||||
self.assertEqual(value.item(), 0.25)
|
||||
self.assertEqual(dfdx.item(), 1.0)
|
||||
|
||||
dfdx = mx.grad(fun)(mx.array(0.5))
|
||||
self.assertEqual(dfdx.item(), 1.0)
|
||||
|
||||
df2dx2 = mx.grad(mx.grad(fun))(mx.array(0.5))
|
||||
self.assertEqual(df2dx2.item(), 2.0)
|
||||
df3dx3 = mx.grad(mx.grad(mx.grad(fun)))(mx.array(0.5))
|
||||
self.assertEqual(df3dx3.item(), 0.0)
|
||||
|
||||
fun = lambda x, y: x * y
|
||||
x = mx.array(2.0)
|
||||
y = mx.array(3.0)
|
||||
dfdx = mx.grad(fun, argnums=0)(x, y)
|
||||
self.assertEqual(dfdx.item(), 3.0)
|
||||
dfdx = mx.grad(fun, argnums=1)(x, y)
|
||||
self.assertEqual(dfdx.item(), 2.0)
|
||||
|
||||
# Pass non array args to functions works
|
||||
fun = lambda x, y: x
|
||||
value, dfdx = mx.value_and_grad(fun)(mx.array(2.0), "hello")
|
||||
self.assertEqual(value.item(), 2.0)
|
||||
self.assertEqual(dfdx.item(), 1.0)
|
||||
|
||||
dfdx = mx.grad(fun)(mx.array(2.0), "hello")
|
||||
self.assertEqual(dfdx.item(), 1.0)
|
||||
|
||||
# Raises when function does not return array
|
||||
fun = lambda x: "hello"
|
||||
with self.assertRaises(ValueError):
|
||||
mx.grad(fun)(mx.array(2.0))
|
||||
|
||||
# Raises for invalid argument number or argument type
|
||||
fun = lambda x: x
|
||||
with self.assertRaises(ValueError):
|
||||
mx.grad(fun, argnums=2)(mx.array(2.0))
|
||||
with self.assertRaises(ValueError):
|
||||
mx.grad(fun, argnums=-2)(mx.array(2.0))
|
||||
with self.assertRaises(ValueError):
|
||||
mx.grad(fun)("hello")
|
||||
|
||||
# Raises when output is not a scalar array
|
||||
fun = lambda x: mx.sum(x, keepdims=True)
|
||||
with self.assertRaises(ValueError):
|
||||
mx.grad(fun)(mx.ones((2, 2)))
|
||||
|
||||
def test_grad_trees(self):
|
||||
fun = lambda x, y: x * y
|
||||
value, dfdx = mx.value_and_grad(fun, (0, 1))(mx.array(0.5), mx.array(2.0))
|
||||
self.assertEqual(value.item(), 1.0)
|
||||
self.assertTrue(isinstance(dfdx, tuple))
|
||||
self.assertEqual(dfdx[0].item(), 2.0)
|
||||
self.assertEqual(dfdx[1].item(), 0.5)
|
||||
|
||||
fun = lambda x, y: x * y
|
||||
value, dfdx = mx.value_and_grad(fun, 1)(mx.array(0.5), mx.array(2.0))
|
||||
self.assertEqual(value.item(), 1.0)
|
||||
self.assertEqual(dfdx.item(), 0.5)
|
||||
|
||||
fun = lambda p: p["x"] * p["y"]
|
||||
value, dfdx = mx.value_and_grad(fun)({"x": mx.array(0.5), "y": mx.array(2.0)})
|
||||
self.assertEqual(value.item(), 1.0)
|
||||
self.assertEqual(dfdx["x"].item(), 2.0)
|
||||
self.assertEqual(dfdx["y"].item(), 0.5)
|
||||
|
||||
fun = lambda p: p["x"] * p["y"]
|
||||
with self.assertRaises(ValueError):
|
||||
mx.value_and_grad(fun)({"x": 0.5, "y": mx.array(2.0)})
|
||||
with self.assertRaises(ValueError):
|
||||
mx.value_and_grad(fun, (0, 1))({"x": mx.array(0.5), "y": mx.array(2.0)})
|
||||
|
||||
fun = lambda p, b: mx.square(p[0]["foo"][2]) * b
|
||||
value, dfdx = mx.value_and_grad(fun)(
|
||||
[{"foo": [[], [], mx.array(2.0)]}], mx.array(0.5)
|
||||
)
|
||||
self.assertEqual(value.item(), 2.0)
|
||||
self.assertEqual(dfdx[0]["foo"][2].item(), 2.0)
|
||||
|
||||
fun = lambda x: x
|
||||
with self.assertRaises(TypeError):
|
||||
mx.value_and_grad(fun, (None, None))
|
||||
with self.assertRaises(ValueError):
|
||||
mx.value_and_grad(fun, tuple())
|
||||
|
||||
def test_auxiliary_values(self):
|
||||
def fun(x, y):
|
||||
l = (x * y).sum()
|
||||
extra = {"loss": l, "foo": y.square() + x.square(), "bar": [1, 2, 3, y, x]}
|
||||
return l, extra
|
||||
|
||||
fun_value_grad = mx.value_and_grad(fun)
|
||||
fun_grad = mx.grad(fun)
|
||||
|
||||
(loss, a), b = fun_value_grad(mx.ones((2, 2)), mx.ones((2, 2)))
|
||||
self.assertEqual(a["loss"].item(), 4)
|
||||
self.assertTrue(mx.array_equal(b, mx.ones((2, 2))))
|
||||
self.assertTrue(mx.array_equal(a["foo"], 2 * mx.ones((2, 2))))
|
||||
self.assertEqual(a["bar"][:3], [1, 2, 3])
|
||||
self.assertTrue(mx.array_equal(a["bar"][3], mx.ones((2, 2))))
|
||||
self.assertTrue(mx.array_equal(a["bar"][4], mx.ones((2, 2))))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
_ = fun_grad(mx.ones((2, 2)), mx.ones((2, 2)))
|
||||
|
||||
def test_grad_kwargs(self):
|
||||
fun = lambda x, y: x * y
|
||||
a, b = mx.array(0.5), mx.array(2.0)
|
||||
dfdx = mx.grad(fun)
|
||||
self.assertEqual(dfdx(a, b).item(), 2.0)
|
||||
self.assertEqual(dfdx(a, y=b).item(), 2.0)
|
||||
with self.assertRaises(ValueError):
|
||||
dfdx(x=a, y=b).item()
|
||||
|
||||
dfdy = mx.grad(fun, argnums=[], argnames=["y"])
|
||||
with self.assertRaises(ValueError):
|
||||
dfdy(a, b)
|
||||
grads = dfdy(a, y=b)
|
||||
self.assertTrue(isinstance(grads, tuple))
|
||||
self.assertTrue(grads[0] is None)
|
||||
self.assertTrue(isinstance(grads[1], dict))
|
||||
self.assertEqual(grads[1]["y"].item(), 0.5)
|
||||
grads = dfdy(x=a, y=b)
|
||||
self.assertEqual(grads[1]["y"].item(), 0.5)
|
||||
self.assertEqual(len(grads[1]), 1)
|
||||
|
||||
dfdxy = mx.grad(fun, argnums=[0], argnames=["y"])
|
||||
with self.assertRaises(ValueError):
|
||||
dfdxy(a, b)
|
||||
with self.assertRaises(ValueError):
|
||||
dfdxy(x=a, y=b)
|
||||
grads = dfdxy(a, y=b)
|
||||
self.assertTrue(isinstance(grads, tuple))
|
||||
self.assertEqual(grads[0].item(), 2.0)
|
||||
self.assertTrue(isinstance(grads[1], dict))
|
||||
self.assertEqual(grads[1]["y"].item(), 0.5)
|
||||
|
||||
fun = lambda x, y, z: x * y * z
|
||||
dfdxyz = mx.grad(fun, argnums=[0, 1], argnames=["z"])
|
||||
c = mx.array(4.0)
|
||||
grads = dfdxyz(a, b, z=c)
|
||||
self.assertTrue(isinstance(grads, tuple))
|
||||
self.assertTrue(isinstance(grads[0], tuple))
|
||||
self.assertEqual(grads[0][0].item(), 8.0)
|
||||
self.assertEqual(grads[0][1].item(), 2.0)
|
||||
self.assertTrue(isinstance(grads[1], dict))
|
||||
self.assertEqual(grads[1]["z"].item(), 1.0)
|
||||
|
||||
fun = lambda x, y: x * y
|
||||
dfdy = mx.grad(fun, argnames=["y"])
|
||||
grads = dfdy(a, y=b)
|
||||
self.assertTrue(isinstance(grads, tuple))
|
||||
self.assertTrue(grads[0] is None)
|
||||
self.assertTrue(isinstance(grads[1], dict))
|
||||
self.assertEqual(grads[1]["y"].item(), 0.5)
|
||||
|
||||
def test_captured(self):
|
||||
a = mx.array(5.0)
|
||||
f = lambda x: a + x
|
||||
g = lambda x: a + a
|
||||
h = lambda x: x + x
|
||||
|
||||
dfdx = mx.grad(f)
|
||||
self.assertEqual(dfdx(a).item(), 1.0)
|
||||
|
||||
dgdx = mx.grad(g)
|
||||
self.assertEqual(dgdx(a).item(), 0.0)
|
||||
|
||||
dhdx = mx.grad(h)
|
||||
self.assertEqual(dhdx(a).item(), 2.0)
|
||||
|
||||
d2fdx2 = mx.grad(dfdx)
|
||||
self.assertEqual(d2fdx2(a).item(), 0.0)
|
||||
|
||||
d2gdx2 = mx.grad(dgdx)
|
||||
self.assertEqual(d2gdx2(a).item(), 0.0)
|
||||
|
||||
d2hdx2 = mx.grad(dhdx)
|
||||
self.assertEqual(d2hdx2(a).item(), 0.0)
|
||||
|
||||
def test_stop_gradient(self):
|
||||
shape_in = (4, 4)
|
||||
w_in = mx.ones(shape_in)
|
||||
x_in = mx.ones(shape_in)
|
||||
cotan = mx.ones(shape_in)
|
||||
|
||||
def h(w, x):
|
||||
x1 = 2 * x
|
||||
y = mx.stop_gradient(x1)
|
||||
y1 = 3 * y
|
||||
return w @ y1
|
||||
|
||||
vals, vjps = mx.vjp(h, [w_in, x_in], [cotan])
|
||||
mx.eval(vjps)
|
||||
|
||||
self.assertTrue(mx.allclose(vjps[0], 24.0 * mx.ones(shape_in)))
|
||||
self.assertTrue(mx.allclose(vjps[1], mx.zeros(shape_in)))
|
||||
|
||||
g = lambda x: h(w_in, x)
|
||||
vals, vjps = mx.vjp(g, [x_in], [cotan])
|
||||
mx.eval(vjps)
|
||||
|
||||
self.assertTrue(mx.allclose(vjps[0], mx.zeros(shape_in)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
105
python/tests/test_device.py
Normal file
105
python/tests/test_device.py
Normal file
@@ -0,0 +1,105 @@
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
import mlx_tests
|
||||
|
||||
|
||||
# Don't inherit from MLXTestCase to avoid call to setUp
|
||||
class TestDefaultDevice(unittest.TestCase):
|
||||
def test_mlx_default_device(self):
|
||||
device = mx.default_device()
|
||||
if mx.metal.is_available():
|
||||
self.assertEqual(device, mx.Device(mx.gpu))
|
||||
self.assertEqual(str(device), "Device(gpu, 0)")
|
||||
self.assertEqual(device, mx.gpu)
|
||||
self.assertEqual(mx.gpu, device)
|
||||
else:
|
||||
self.assertEqual(device.type, mx.Device(mx.cpu))
|
||||
with self.assertRaises(ValueError):
|
||||
mx.set_default_device(mx.gpu)
|
||||
|
||||
|
||||
class TestDevice(mlx_tests.MLXTestCase):
|
||||
def test_device(self):
|
||||
device = mx.default_device()
|
||||
|
||||
cpu = mx.Device(mx.cpu)
|
||||
mx.set_default_device(cpu)
|
||||
self.assertEqual(mx.default_device(), cpu)
|
||||
self.assertEqual(str(cpu), "Device(cpu, 0)")
|
||||
|
||||
mx.set_default_device(mx.cpu)
|
||||
self.assertEqual(mx.default_device(), mx.cpu)
|
||||
self.assertEqual(cpu, mx.cpu)
|
||||
self.assertEqual(mx.cpu, cpu)
|
||||
|
||||
# Restore device
|
||||
mx.set_default_device(device)
|
||||
|
||||
def test_op_on_device(self):
|
||||
x = mx.array(1.0)
|
||||
y = mx.array(1.0)
|
||||
|
||||
a = mx.add(x, y, stream=None)
|
||||
b = mx.add(x, y, stream=mx.default_device())
|
||||
self.assertEqual(a.item(), b.item())
|
||||
b = mx.add(x, y, stream=mx.cpu)
|
||||
self.assertEqual(a.item(), b.item())
|
||||
|
||||
if mx.metal.is_available():
|
||||
b = mx.add(x, y, stream=mx.gpu)
|
||||
self.assertEqual(a.item(), b.item())
|
||||
|
||||
|
||||
class TestStream(mlx_tests.MLXTestCase):
|
||||
def test_stream(self):
|
||||
s1 = mx.default_stream(mx.default_device())
|
||||
self.assertEqual(s1.device, mx.default_device())
|
||||
|
||||
s2 = mx.new_stream(mx.default_device())
|
||||
self.assertEqual(s2.device, mx.default_device())
|
||||
self.assertNotEqual(s1, s2)
|
||||
|
||||
if mx.metal.is_available():
|
||||
s_gpu = mx.default_stream(mx.gpu)
|
||||
self.assertEqual(s_gpu.device, mx.gpu)
|
||||
else:
|
||||
with self.assertRaises(ValueError):
|
||||
mx.default_stream(mx.gpu)
|
||||
|
||||
s_cpu = mx.default_stream(mx.cpu)
|
||||
self.assertEqual(s_cpu.device, mx.cpu)
|
||||
|
||||
s_cpu = mx.new_stream(mx.cpu)
|
||||
self.assertEqual(s_cpu.device, mx.cpu)
|
||||
|
||||
if mx.metal.is_available():
|
||||
s_gpu = mx.new_stream(mx.gpu)
|
||||
self.assertEqual(s_gpu.device, mx.gpu)
|
||||
else:
|
||||
with self.assertRaises(ValueError):
|
||||
mx.new_stream(mx.gpu)
|
||||
|
||||
def test_op_on_stream(self):
|
||||
x = mx.array(1.0)
|
||||
y = mx.array(1.0)
|
||||
|
||||
a = mx.add(x, y, stream=mx.default_stream(mx.default_device()))
|
||||
|
||||
if mx.metal.is_available():
|
||||
b = mx.add(x, y, stream=mx.default_stream(mx.gpu))
|
||||
self.assertEqual(a.item(), b.item())
|
||||
s_gpu = mx.new_stream(mx.gpu)
|
||||
b = mx.add(x, y, stream=s_gpu)
|
||||
self.assertEqual(a.item(), b.item())
|
||||
|
||||
b = mx.add(x, y, stream=mx.default_stream(mx.cpu))
|
||||
self.assertEqual(a.item(), b.item())
|
||||
s_cpu = mx.new_stream(mx.cpu)
|
||||
b = mx.add(x, y, stream=s_cpu)
|
||||
self.assertEqual(a.item(), b.item())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
34
python/tests/test_eval.py
Normal file
34
python/tests/test_eval.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from functools import partial
|
||||
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
import mlx_tests
|
||||
|
||||
|
||||
class TestEval(mlx_tests.MLXTestCase):
|
||||
def test_eval(self):
|
||||
arrs = [mx.ones((2, 2)) for _ in range(4)]
|
||||
mx.eval(*arrs)
|
||||
for x in arrs:
|
||||
self.assertEqual(x.tolist(), [[1, 1], [1, 1]])
|
||||
|
||||
def test_retain_graph(self):
|
||||
def fun(x, retain_graph):
|
||||
y = 3 * x
|
||||
mx.eval(y, retain_graph=retain_graph)
|
||||
return 2 * y
|
||||
|
||||
dfun_dx_1 = mx.grad(partial(fun, retain_graph=False))
|
||||
dfun_dx_2 = mx.grad(partial(fun, retain_graph=True))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
dfun_dx_1(mx.array(1.0))
|
||||
|
||||
y = dfun_dx_2(mx.array(1.0))
|
||||
self.assertEqual(y.item(), 6.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
90
python/tests/test_fft.py
Normal file
90
python/tests/test_fft.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import unittest
|
||||
|
||||
import itertools
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
import mlx_tests
|
||||
|
||||
|
||||
class TestFFT(mlx_tests.MLXTestCase):
|
||||
def check_mx_np(self, op, a_np, axes, s):
|
||||
with self.subTest(op=op, axes=axes, s=s):
|
||||
op_np = getattr(np.fft, op)
|
||||
op_mx = getattr(mx.fft, op)
|
||||
out_np = op_np(a_np, s=s, axes=axes)
|
||||
a_mx = mx.array(a_np)
|
||||
out_mx = op_mx(a_mx, s=s, axes=axes)
|
||||
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6))
|
||||
|
||||
def test_fft(self):
|
||||
default = mx.default_device()
|
||||
mx.set_default_device(mx.cpu)
|
||||
|
||||
def check_mx_np(op_mx, op_np, a_np, **kwargs):
|
||||
out_np = op_np(a_np, **kwargs)
|
||||
a_mx = mx.array(a_np)
|
||||
out_mx = op_mx(a_mx, **kwargs)
|
||||
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6))
|
||||
|
||||
r = np.random.rand(100).astype(np.float32)
|
||||
i = np.random.rand(100).astype(np.float32)
|
||||
a_np = r + 1j * i
|
||||
check_mx_np(mx.fft.fft, np.fft.fft, a_np)
|
||||
|
||||
# Check with slicing and padding
|
||||
r = np.random.rand(100).astype(np.float32)
|
||||
i = np.random.rand(100).astype(np.float32)
|
||||
a_np = r + 1j * i
|
||||
check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=80)
|
||||
check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=120)
|
||||
|
||||
# Check different axes
|
||||
r = np.random.rand(100, 100).astype(np.float32)
|
||||
i = np.random.rand(100, 100).astype(np.float32)
|
||||
a_np = r + 1j * i
|
||||
check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=0)
|
||||
check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=1)
|
||||
|
||||
# Check real fft
|
||||
a_np = np.random.rand(100).astype(np.float32)
|
||||
check_mx_np(mx.fft.rfft, np.fft.rfft, a_np)
|
||||
check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=80)
|
||||
check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=120)
|
||||
|
||||
# Check real inverse
|
||||
r = np.random.rand(100, 100).astype(np.float32)
|
||||
i = np.random.rand(100, 100).astype(np.float32)
|
||||
a_np = r + 1j * i
|
||||
check_mx_np(mx.fft.ifft, np.fft.ifft, a_np)
|
||||
check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=80)
|
||||
check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=120)
|
||||
check_mx_np(mx.fft.irfft, np.fft.irfft, a_np)
|
||||
check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=80)
|
||||
check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=120)
|
||||
|
||||
mx.set_default_device(default)
|
||||
|
||||
def test_fftn(self):
|
||||
default = mx.default_device()
|
||||
mx.set_default_device(mx.cpu)
|
||||
|
||||
r = np.random.randn(8, 8, 8).astype(np.float32)
|
||||
i = np.random.randn(8, 8, 8).astype(np.float32)
|
||||
a = r + 1j * i
|
||||
|
||||
axes = [None, (1, 2), (2, 1), (0, 2)]
|
||||
shapes = [None, (10, 5), (5, 10)]
|
||||
ops = ["fft2", "ifft2", "rfft2", "irfft2", "fftn", "ifftn", "rfftn", "irfftn"]
|
||||
|
||||
for op, ax, s in itertools.product(ops, axes, shapes):
|
||||
x = a
|
||||
if op in ["rfft2", "rfftn"]:
|
||||
x = r
|
||||
self.check_mx_np(op, x, axes=ax, s=s)
|
||||
|
||||
mx.set_default_device(default)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
1283
python/tests/test_ops.py
Normal file
1283
python/tests/test_ops.py
Normal file
File diff suppressed because it is too large
Load Diff
118
python/tests/test_reduce.py
Normal file
118
python/tests/test_reduce.py
Normal file
@@ -0,0 +1,118 @@
|
||||
import unittest
|
||||
from itertools import permutations, combinations
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
import mlx_tests
|
||||
|
||||
|
||||
class TestReduce(mlx_tests.MLXTestCase):
|
||||
def test_axis_permutation_sums(self):
|
||||
x_npy = np.random.randn(5, 5, 5, 5, 5).astype(np.float32)
|
||||
x_mlx = mx.array(x_npy)
|
||||
for t in permutations(range(5)):
|
||||
with self.subTest(t=t):
|
||||
y_npy = np.transpose(x_npy, t)
|
||||
y_mlx = mx.transpose(x_mlx, t)
|
||||
for n in range(1, 6):
|
||||
for a in combinations(range(5), n):
|
||||
with self.subTest(a=a):
|
||||
z_npy = np.sum(y_npy, axis=a)
|
||||
z_mlx = mx.sum(y_mlx, axis=a)
|
||||
mx.eval(z_mlx)
|
||||
self.assertTrue(
|
||||
np.allclose(z_npy, np.array(z_mlx), atol=1e-4)
|
||||
)
|
||||
|
||||
def test_expand_sums(self):
|
||||
x_npy = np.random.randn(5, 1, 5, 1, 5, 1).astype(np.float32)
|
||||
x_mlx = mx.array(x_npy)
|
||||
for m in range(1, 4):
|
||||
for ax in combinations([1, 3, 5], m):
|
||||
shape = np.array([5, 1, 5, 1, 5, 1])
|
||||
shape[list(ax)] = 5
|
||||
shape = shape.tolist()
|
||||
with self.subTest(shape=shape):
|
||||
y_npy = np.broadcast_to(x_npy, shape)
|
||||
y_mlx = mx.broadcast_to(x_mlx, shape)
|
||||
for n in range(1, 7):
|
||||
for a in combinations(range(6), n):
|
||||
with self.subTest(a=a):
|
||||
z_npy = np.sum(y_npy, axis=a) / 1000
|
||||
z_mlx = mx.sum(y_mlx, axis=a) / 1000
|
||||
mx.eval(z_mlx)
|
||||
self.assertTrue(
|
||||
np.allclose(z_npy, np.array(z_mlx), atol=1e-4)
|
||||
)
|
||||
|
||||
def test_dtypes(self):
|
||||
int_dtypes = [
|
||||
"int8",
|
||||
"int16",
|
||||
"int32",
|
||||
"uint8",
|
||||
"uint16",
|
||||
"uint32",
|
||||
]
|
||||
float_dtypes = ["float32"]
|
||||
|
||||
for dtype in int_dtypes + float_dtypes:
|
||||
with self.subTest(dtype=dtype):
|
||||
x = np.random.uniform(0, 2, size=(3, 3, 3)).astype(getattr(np, dtype))
|
||||
y = mx.array(x)
|
||||
|
||||
for op in ("sum", "prod", "min", "max"):
|
||||
with self.subTest(op=op):
|
||||
|
||||
np_op = getattr(np, op)
|
||||
mlx_op = getattr(mx, op)
|
||||
|
||||
for axes in (None, 0, 1, 2, (0, 1), (0, 2), (1, 2), (0, 1, 2)):
|
||||
with self.subTest(axes=axes):
|
||||
if op in ("sum", "prod"):
|
||||
r_np = np_op(
|
||||
x, axis=axes, dtype=(getattr(np, dtype))
|
||||
)
|
||||
else:
|
||||
r_np = np_op(x, axis=axes)
|
||||
r_mlx = mlx_op(y, axis=axes)
|
||||
mx.eval(r_mlx)
|
||||
self.assertTrue(np.allclose(r_np, r_mlx, atol=1e-4))
|
||||
|
||||
def test_arg_reduce(self):
|
||||
dtypes = [
|
||||
"uint8",
|
||||
"uint16",
|
||||
"uint32",
|
||||
"uint64",
|
||||
"int8",
|
||||
"int16",
|
||||
"int32",
|
||||
"int64",
|
||||
"float16",
|
||||
"float32",
|
||||
]
|
||||
for dtype in dtypes:
|
||||
with self.subTest(dtype=dtype):
|
||||
|
||||
data = np.random.rand(10, 12, 13).astype(getattr(np, dtype))
|
||||
x = mx.array(data)
|
||||
for op in ["argmin", "argmax"]:
|
||||
for axis in range(3):
|
||||
for kd in [True, False]:
|
||||
a = getattr(mx, op)(x, axis, kd)
|
||||
b = getattr(np, op)(data, axis, keepdims=kd)
|
||||
self.assertEqual(a.tolist(), b.tolist())
|
||||
|
||||
for op in ["argmin", "argmax"]:
|
||||
a = getattr(mx, op)(x, keepdims=True)
|
||||
b = getattr(np, op)(data, keepdims=True)
|
||||
self.assertEqual(a.tolist(), b.tolist())
|
||||
a = getattr(mx, op)(x)
|
||||
b = getattr(np, op)(data)
|
||||
self.assertEqual(a.item(), b)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(failfast=True)
|
||||
Reference in New Issue
Block a user