mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	awni's commit files
This commit is contained in:
		
							
								
								
									
										37
									
								
								python/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										37
									
								
								python/README.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,37 @@ | ||||
| ### Packaging for PyPI | ||||
|  | ||||
| Install `build` and `twine`: | ||||
|  | ||||
| ``` | ||||
| pip install --user --upgrade build | ||||
| pip install --user --upgrade twine | ||||
| ``` | ||||
|  | ||||
| Generate the source distribution and wheel: | ||||
|  | ||||
| ``` | ||||
| python -m build | ||||
| ``` | ||||
|  | ||||
| *Warning* use a test server first | ||||
|  | ||||
| #### Test Upload | ||||
|  | ||||
| Upload to test server: | ||||
|  | ||||
| ``` | ||||
| python -m twine upload --repository testpypi dist/* | ||||
| ``` | ||||
|  | ||||
| Install from test server and check that it works: | ||||
|  | ||||
| ``` | ||||
| python -m pip install --index-url https://test.pypi.org/simple/ --no-deps mlx | ||||
| ``` | ||||
|  | ||||
| #### Upload | ||||
|  | ||||
| ``` | ||||
| python -m twine upload dist/* | ||||
| ``` | ||||
|  | ||||
							
								
								
									
										18
									
								
								python/mlx/_reprlib_fix.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										18
									
								
								python/mlx/_reprlib_fix.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,18 @@ | ||||
| import array | ||||
| import reprlib | ||||
|  | ||||
|  | ||||
| class FixedRepr(reprlib.Repr): | ||||
|     """Only route python array instances to repr_array.""" | ||||
|  | ||||
|     def repr_array(self, x, maxlevel): | ||||
|         if isinstance(x, array.array): | ||||
|             return super().repr_array(x, maxlevel) | ||||
|         else: | ||||
|             return self.repr_instance(x, maxlevel) | ||||
|  | ||||
|  | ||||
| # We need to monkey-patch reprlib so that we can use the debugger without | ||||
| # renaming the array to something else | ||||
| fixed_repr = FixedRepr() | ||||
| reprlib.repr = fixed_repr.repr | ||||
							
								
								
									
										94
									
								
								python/mlx/extension.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										94
									
								
								python/mlx/extension.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,94 @@ | ||||
| import os | ||||
| import re | ||||
| import subprocess | ||||
| import sys | ||||
| from pathlib import Path | ||||
|  | ||||
| from setuptools import Extension, setup, find_namespace_packages | ||||
| from setuptools.command.build_ext import build_ext | ||||
|  | ||||
| import mlx | ||||
|  | ||||
| _MLX_PATH = str(mlx.__path__[0]) | ||||
|  | ||||
|  | ||||
| # A CMakeExtension needs a sourcedir instead of a file list. | ||||
| class CMakeExtension(Extension): | ||||
|     def __init__(self, name: str, sourcedir: str = "") -> None: | ||||
|         super().__init__(name, sources=[]) | ||||
|         self.sourcedir = os.fspath(Path(sourcedir).resolve()) | ||||
|  | ||||
|  | ||||
| class CMakeBuild(build_ext): | ||||
|     def build_extension(self, ext: CMakeExtension) -> None: | ||||
|         # Must be in this form due to bug in .resolve() only fixed in Python 3.10+ | ||||
|         ext_fullpath = Path.cwd() / self.get_ext_fullpath(ext.name)  # type: ignore[no-untyped-call] | ||||
|         extdir = ext_fullpath.parent.resolve() | ||||
|  | ||||
|         debug = int(os.environ.get("DEBUG", 0)) if self.debug is None else self.debug | ||||
|         cfg = "Debug" if debug else "Release" | ||||
|  | ||||
|         # CMake lets you override the generator - we need to check this. | ||||
|         # Can be set with Conda-Build, for example. | ||||
|         cmake_generator = os.environ.get("CMAKE_GENERATOR", "") | ||||
|  | ||||
|         # Set Python_EXECUTABLE instead if you use PYBIND11_FINDPYTHON | ||||
|         # EXAMPLE_VERSION_INFO shows you how to pass a value into the C++ code | ||||
|         # from Python. | ||||
|         cmake_args = [ | ||||
|             f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}{os.sep}", | ||||
|             f"-DCMAKE_BUILD_TYPE={cfg}", | ||||
|             "-DBUILD_SHARED_LIBS=ON", | ||||
|         ] | ||||
|         build_args = [] | ||||
|         # Adding CMake arguments set as environment variable | ||||
|         # (needed e.g. to build for ARM OSx on conda-forge) | ||||
|         if "CMAKE_ARGS" in os.environ: | ||||
|             cmake_args += [item for item in os.environ["CMAKE_ARGS"].split(" ") if item] | ||||
|  | ||||
|         if sys.platform.startswith("darwin"): | ||||
|             # Cross-compile support for macOS - respect ARCHFLAGS if set | ||||
|             archs = re.findall(r"-arch (\S+)", os.environ.get("ARCHFLAGS", "")) | ||||
|             if archs: | ||||
|                 cmake_args += ["-DCMAKE_OSX_ARCHITECTURES={}".format(";".join(archs))] | ||||
|  | ||||
|         # Set CMAKE_BUILD_PARALLEL_LEVEL to control the parallel build level | ||||
|         # across all generators. | ||||
|         if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ: | ||||
|             # self.parallel is a Python 3 only way to set parallel jobs by hand | ||||
|             # using -j in the build_ext call, not supported by pip or PyPA-build. | ||||
|             if hasattr(self, "parallel") and self.parallel: | ||||
|                 # CMake 3.12+ only. | ||||
|                 build_args += [f"-j{self.parallel}"] | ||||
|  | ||||
|         build_temp = Path(self.build_temp) / ext.name | ||||
|         if not build_temp.exists(): | ||||
|             build_temp.mkdir(parents=True) | ||||
|  | ||||
|         # Make sure cmake can find MLX | ||||
|         os.environ["MLX_DIR"] = _MLX_PATH | ||||
|  | ||||
|         subprocess.run( | ||||
|             ["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True | ||||
|         ) | ||||
|         subprocess.run( | ||||
|             ["cmake", "--build", ".", *build_args], cwd=build_temp, check=True | ||||
|         ) | ||||
|  | ||||
|     def run(self): | ||||
|         super().run() | ||||
|  | ||||
|         # Based on https://github.com/pypa/setuptools/blob/main/setuptools/command/build_ext.py#L102 | ||||
|         if self.inplace: | ||||
|             for ext in self.extensions: | ||||
|                 if isinstance(ext, CMakeExtension): | ||||
|                     # Resolve inplace package dir | ||||
|                     build_py = self.get_finalized_command("build_py") | ||||
|                     inplace_file, regular_file = self._get_inplace_equivalent( | ||||
|                         build_py, ext | ||||
|                     ) | ||||
|  | ||||
|                     inplace_dir = str(Path(inplace_file).parent.resolve()) | ||||
|                     regular_dir = str(Path(regular_file).parent.resolve()) | ||||
|  | ||||
|                     self.copy_tree(regular_dir, inplace_dir) | ||||
							
								
								
									
										401
									
								
								python/mlx/nn/layers/base.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										401
									
								
								python/mlx/nn/layers/base.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,401 @@ | ||||
| import textwrap | ||||
| from typing import Any, Callable, List, Union, Optional | ||||
|  | ||||
| import mlx.core as mx | ||||
| from mlx.utils import tree_flatten, tree_unflatten | ||||
|  | ||||
|  | ||||
| class Module(dict): | ||||
|     """Base class for building neural networks with MLX. | ||||
|  | ||||
|     All the layers provided in :mod:`mlx.nn.layers` subclass this class and | ||||
|     your models should do the same. | ||||
|  | ||||
|     A ``Module`` can contain other ``Module`` instances or :class:`mlx.core.array` | ||||
|     instances in arbitrary nesting of python lists or dicts. The ``Module`` | ||||
|     then allows recursively extracting all the :class:`mlx.core.array` instances | ||||
|     using :meth:`mlx.nn.Module.parameters`. | ||||
|  | ||||
|     In addition, the ``Module`` has the concept of trainable and non trainable | ||||
|     parameters (called "frozen"). When using :func:`mlx.nn.value_and_grad` | ||||
|     the gradients are returned only with respect to the trainable parameters. | ||||
|     All arrays in a module are trainable unless they are added in the "frozen" | ||||
|     set by calling :meth:`freeze`. | ||||
|  | ||||
|     .. code-block:: python | ||||
|  | ||||
|         import mlx.core as mx | ||||
|         import mlx.nn as nn | ||||
|  | ||||
|         class MyMLP(nn.Module): | ||||
|             def __init__(self, in_dims: int, out_dims: int, hidden_dims: int = 16): | ||||
|                 super().__init__() | ||||
|  | ||||
|                 self.in_proj = nn.Linear(in_dims, hidden_dims) | ||||
|                 self.out_proj = nn.Linear(hidden_dims, out_dims) | ||||
|  | ||||
|             def __call__(self, x): | ||||
|                 x = self.in_proj(x) | ||||
|                 x = mx.maximum(x, 0) | ||||
|                 return self.out_proj(x) | ||||
|  | ||||
|         model = MyMLP(2, 1) | ||||
|  | ||||
|         # All the model parameters are created but since MLX is lazy by | ||||
|         # default, they are not evaluated yet. Calling `mx.eval` actually | ||||
|         # allocates memory and initializes the parameters. | ||||
|         mx.eval(model.parameters()) | ||||
|  | ||||
|         # Setting a parameter to a new value is as simply as accessing that | ||||
|         # parameter and assigning a new array to it. | ||||
|         model.in_proj.weight = model.in_proj.weight * 2 | ||||
|         mx.eval(model.parameters()) | ||||
|     """ | ||||
|  | ||||
|     def __init__(self): | ||||
|         """Should be called by the subclasses of ``Module``.""" | ||||
|         self._no_grad = set() | ||||
|         self._training = True | ||||
|  | ||||
|     @property | ||||
|     def training(self): | ||||
|         return self._training | ||||
|  | ||||
|     def _extra_repr(self): | ||||
|         return "" | ||||
|  | ||||
|     def __repr__(self): | ||||
|         children = tree_flatten(self.children(), is_leaf=self.is_module) | ||||
|         value = f"{type(self).__name__}({self._extra_repr()}" | ||||
|         for k, v in children: | ||||
|             value += "\n" | ||||
|             value += textwrap.indent(f"({k}): {repr(v)}", prefix="  ") | ||||
|         if children: | ||||
|             value += "\n" | ||||
|         value += ")" | ||||
|  | ||||
|         return value | ||||
|  | ||||
|     def __getattr__(self, key: str): | ||||
|         if key in self: | ||||
|             return self[key] | ||||
|         else: | ||||
|             raise AttributeError(f"{type(self)!r} has no attribute {key!r}") | ||||
|  | ||||
|     def __setattr__(self, key: str, val: Any): | ||||
|         self[key] = val | ||||
|  | ||||
|     def load_weights(self, file: str): | ||||
|         """ | ||||
|         Load and update the model's weights from a `.npz` file. | ||||
|         """ | ||||
|         self.update(tree_unflatten(list(mx.load(file).items()))) | ||||
|  | ||||
|     def save_weights(self, file: str): | ||||
|         """ | ||||
|         Save the model's weights to a `.npz` file. | ||||
|         """ | ||||
|         mx.savez(file, **dict(tree_flatten(self.parameters()))) | ||||
|  | ||||
|     @staticmethod | ||||
|     def is_module(value): | ||||
|         return isinstance(value, Module) | ||||
|  | ||||
|     @staticmethod | ||||
|     def valid_child_filter(module, key, value): | ||||
|         return isinstance(value, (dict, list)) | ||||
|  | ||||
|     @staticmethod | ||||
|     def valid_parameter_filter(module, key, value): | ||||
|         return isinstance(value, (dict, list, mx.array)) and not key.startswith("_") | ||||
|  | ||||
|     @staticmethod | ||||
|     def trainable_parameter_filter(module, key, value): | ||||
|         return ( | ||||
|             Module.valid_parameter_filter(module, key, value) | ||||
|             and key not in module._no_grad | ||||
|         ) | ||||
|  | ||||
|     def filter_and_map( | ||||
|         self, | ||||
|         filter_fn: Callable[["mlx.nn.Module", str, Any], bool], | ||||
|         map_fn: Optional[Callable] = None, | ||||
|         is_leaf_fn: Optional[Callable[["mlx.nn.Module", str, Any], bool]] = None, | ||||
|     ): | ||||
|         """Recursively filter the contents of the module using ``filter_fn``, | ||||
|         namely only select keys and values where ``filter_fn`` returns true. | ||||
|  | ||||
|         This is used to implement :meth:`parameters` and :meth:`trainable_parameters` | ||||
|         but it can also be used to extract any subset of the module's parameters. | ||||
|  | ||||
|         Args: | ||||
|             filter_fn (Callable): Given a value, the key in which it is found | ||||
|                 and the containing module, decide whether to keep the value or | ||||
|                 drop it. | ||||
|             map_fn (Callable, optional): Optionally transform the value before | ||||
|                 returning it. | ||||
|             is_leaf_fn (Callable, optional): Given a value, the key in which it | ||||
|                 is found and the containing module decide if it is a leaf. | ||||
|  | ||||
|         Returns: | ||||
|             A dictionary containing the contents of the module recursively filtered | ||||
|         """ | ||||
|  | ||||
|         map_fn = map_fn or (lambda x: x) | ||||
|         is_leaf_fn = is_leaf_fn or ( | ||||
|             lambda m, k, v: not isinstance(v, (Module, dict, list)) | ||||
|         ) | ||||
|  | ||||
|         def unwrap(vk, v): | ||||
|             if is_leaf_fn(self, vk, v): | ||||
|                 return map_fn(v) | ||||
|  | ||||
|             if isinstance(v, Module): | ||||
|                 return v.filter_and_map(filter_fn, map_fn, is_leaf_fn) | ||||
|  | ||||
|             if isinstance(v, dict): | ||||
|                 nd = {} | ||||
|                 for k, v in v.items(): | ||||
|                     tk = f"{vk}.{k}" | ||||
|                     nd[k] = unwrap(tk, v) if filter_fn(self, tk, v) else {} | ||||
|                 return nd | ||||
|  | ||||
|             if isinstance(v, list): | ||||
|                 nl = [] | ||||
|                 for i, vi in enumerate(v): | ||||
|                     tk = f"{vk}.{i}" | ||||
|                     nl.append(unwrap(tk, vi) if filter_fn(self, tk, vi) else {}) | ||||
|                 return nl | ||||
|  | ||||
|             raise RuntimeError("Unexpected leaf found while traversing the module") | ||||
|  | ||||
|         return {k: unwrap(k, v) for k, v in self.items() if filter_fn(self, k, v)} | ||||
|  | ||||
|     def parameters(self): | ||||
|         """Recursively return all the :class:`mlx.core.array` members of this Module | ||||
|         as a dict of dicts and lists.""" | ||||
|         return self.filter_and_map(self.valid_parameter_filter) | ||||
|  | ||||
|     def trainable_parameters(self): | ||||
|         """Recursively return all the non frozen :class:`mlx.core.array` members of | ||||
|         this Module as a dict of dicts and lists.""" | ||||
|         return self.filter_and_map(self.trainable_parameter_filter) | ||||
|  | ||||
|     def children(self): | ||||
|         """Return the direct descendants of this Module instance.""" | ||||
|         return self.filter_and_map( | ||||
|             self.valid_child_filter, is_leaf_fn=lambda m, k, v: isinstance(v, Module) | ||||
|         ) | ||||
|  | ||||
|     def leaf_modules(self): | ||||
|         """Return the submodules that do not contain other modules.""" | ||||
|  | ||||
|         def _is_leaf_module(m, k, v): | ||||
|             return isinstance(v, Module) and len(tree_flatten(v.children())) == 0 | ||||
|  | ||||
|         return self.filter_and_map(self.valid_child_filter, is_leaf_fn=_is_leaf_module) | ||||
|  | ||||
|     def update(self, parameters: dict): | ||||
|         """Replace the parameters of this Module with the provided ones in the | ||||
|         dict of dicts and lists. | ||||
|  | ||||
|         Commonly used by the optimizer to change the model to the updated | ||||
|         (optimized) parameters. Also used by the :meth:`mlx.nn.value_and_grad` to set the | ||||
|         tracers in the model in order to compute gradients. | ||||
|  | ||||
|         The passed in parameters dictionary need not be a full dictionary | ||||
|         similar to :meth:`parameters`. Only the provided locations will be | ||||
|         updated. | ||||
|  | ||||
|         Args: | ||||
|             parameters (dict): A complete or partial dictionary of the modules | ||||
|                                parameters. | ||||
|         """ | ||||
|  | ||||
|         def apply(dst, parameters): | ||||
|             if isinstance(parameters, dict): | ||||
|                 for k in parameters: | ||||
|                     if k in dst: | ||||
|                         current_value = dst[k] | ||||
|                         new_value = parameters[k] | ||||
|                         if isinstance(current_value, mx.array): | ||||
|                             dst[k] = new_value | ||||
|                         elif isinstance(current_value, Module): | ||||
|                             current_value.update(new_value) | ||||
|                         elif isinstance(current_value, (dict, list)): | ||||
|                             apply(current_value, new_value) | ||||
|             elif isinstance(parameters, list): | ||||
|                 for i in range(len(dst)): | ||||
|                     current_value = dst[i] | ||||
|                     new_value = parameters[i] | ||||
|                     if isinstance(current_value, mx.array): | ||||
|                         dst[i] = new_value | ||||
|                     elif isinstance(current_value, Module): | ||||
|                         current_value.update(new_value) | ||||
|                     elif isinstance(current_value, (dict, list)): | ||||
|                         apply(current_value, new_value) | ||||
|  | ||||
|         apply(self, parameters) | ||||
|  | ||||
|     def apply( | ||||
|         self, | ||||
|         map_fn: Callable[[mx.array], mx.array], | ||||
|         filter_fn: Optional[Callable[["mlx.nn.Module", str, Any], bool]] = None, | ||||
|     ): | ||||
|         """Map all the parameters using the provided ``map_fn`` and immediately | ||||
|         update the module with the mapped parameters. | ||||
|  | ||||
|         For instance running ``model.apply(lambda x: x.astype(mx.float16))`` | ||||
|         casts all parameters to 16 bit floats. | ||||
|  | ||||
|         Args: | ||||
|             map_fn (Callable): Maps an array to another array | ||||
|             filter_fn (Callable, optional): Filter to select which arrays to | ||||
|                 map (default: :meth:`Module.valid_parameter_filter`). | ||||
|         """ | ||||
|         filter_fn = filter_fn or Module.valid_parameter_filter | ||||
|         self.update(self.filter_and_map(filter_fn, map_fn)) | ||||
|  | ||||
|     def apply_to_modules(self, apply_fn: Callable[[str, "mlx.nn.Module"], Any]): | ||||
|         """Apply a function to all the modules in this instance (including this | ||||
|         instance). | ||||
|  | ||||
|         Args: | ||||
|             apply_fn (Callable): The function to apply to the modules. | ||||
|         """ | ||||
|         module_stack = [("", self)] | ||||
|         while module_stack: | ||||
|             prefix, mod = module_stack.pop() | ||||
|             apply_fn(prefix, mod) | ||||
|             prefix = "." + prefix if prefix else "" | ||||
|             module_stack.extend( | ||||
|                 tree_flatten(mod.children(), prefix=prefix, is_leaf=self.is_module) | ||||
|             ) | ||||
|  | ||||
|     def modules(self): | ||||
|         """Return a list with all the modules in this instance. | ||||
|  | ||||
|         Returns: | ||||
|             A list of :class:`mlx.nn.Module` instances. | ||||
|         """ | ||||
|         modulelist = [] | ||||
|         self.apply_to_modules(lambda k, m: modulelist.append(m)) | ||||
|         return modulelist | ||||
|  | ||||
|     def named_modules(self): | ||||
|         """Return a list with all the modules in this instance and their name | ||||
|         with dot notation. | ||||
|  | ||||
|         Returns: | ||||
|             A list of tuples (str, :class:`mlx.nn.Module`). | ||||
|         """ | ||||
|         modulelist = [] | ||||
|         self.apply_to_modules(lambda k, m: modulelist.append((k, m))) | ||||
|         return modulelist | ||||
|  | ||||
|     def _validate_keys(self, keys, strict): | ||||
|         keys = keys if isinstance(keys, list) else [keys] | ||||
|         if strict: | ||||
|             for k in keys: | ||||
|                 if k not in self: | ||||
|                     raise KeyError(f"Module doesn't contain member {k}.") | ||||
|         return keys | ||||
|  | ||||
|     def freeze( | ||||
|         self, | ||||
|         *, | ||||
|         recurse: bool = True, | ||||
|         keys: Optional[Union[str, List[str]]] = None, | ||||
|         strict: bool = False, | ||||
|     ): | ||||
|         """Freeze the Module's parameters or some of them. Freezing a parameter means not | ||||
|         computing gradients for it. | ||||
|  | ||||
|         This function is idempotent ie freezing a frozen model is a noop. | ||||
|  | ||||
|         For instance to only train the attention parameters from a transformer: | ||||
|  | ||||
|             model = ... | ||||
|             model.freeze() | ||||
|             model.apply_to_modules(lambda k, v: v.unfreeze() if k.endswith("attention") else None) | ||||
|  | ||||
|         Args: | ||||
|             recurse (bool, optional): If True then freeze the parameters of the | ||||
|                 submodules as well (default: True). | ||||
|             keys (str or list[str], optional): If provided then only these | ||||
|                 parameters will be frozen otherwise all the parameters of a | ||||
|                 module. For instance freeze all biases by calling | ||||
|                 ``module.freeze(keys="bias")``. | ||||
|             strict (bool, optional): If set to True validate that the passed keys exist | ||||
|                 (default: False). | ||||
|         """ | ||||
|  | ||||
|         def _freeze_impl(_, m): | ||||
|             local_keys = keys | ||||
|             if local_keys is None: | ||||
|                 local_keys = tree_flatten( | ||||
|                     m.filter_and_map( | ||||
|                         lambda m, k, v: (not isinstance(v, Module)) | ||||
|                         and m.valid_parameter_filter(m, k, v) | ||||
|                     ) | ||||
|                 ) | ||||
|                 local_keys = [k for (k, v) in local_keys] | ||||
|  | ||||
|             local_keys = m._validate_keys(local_keys, strict) | ||||
|             m._no_grad.update(local_keys) | ||||
|  | ||||
|         if recurse: | ||||
|             self.apply_to_modules(_freeze_impl) | ||||
|         else: | ||||
|             _freeze_impl("", self) | ||||
|  | ||||
|     def unfreeze( | ||||
|         self, | ||||
|         *, | ||||
|         recurse: bool = True, | ||||
|         keys: Optional[Union[str, List[str]]] = None, | ||||
|         strict: bool = False, | ||||
|     ): | ||||
|         """Unfreeze the Module's parameters or some of them. | ||||
|  | ||||
|         This function is idempotent ie unfreezing a model that is not frozen is | ||||
|         a noop. | ||||
|  | ||||
|         For instance to only train the biases one can do: | ||||
|  | ||||
|             model = ... | ||||
|             model.freeze() | ||||
|             model.unfreeze(keys="bias") | ||||
|  | ||||
|         Args: | ||||
|             recurse (bool, optional): If True then unfreeze the parameters of the | ||||
|                 submodules as well (default: True). | ||||
|             keys (str or list[str], optional): If provided then only these | ||||
|                 parameters will be unfrozen otherwise all the parameters of a | ||||
|                 module. For instance unfreeze all biases by calling | ||||
|                 ``module.unfreeze(keys="bias")``. | ||||
|             strict (bool, optional): If set to True validate that the passed keys exist | ||||
|                 (default: False). | ||||
|         """ | ||||
|  | ||||
|         def _unfreeze_impl(_, m): | ||||
|             if keys is None: | ||||
|                 m._no_grad.clear() | ||||
|  | ||||
|             else: | ||||
|                 local_keys = m._validate_keys(keys, strict) | ||||
|                 m._no_grad.difference_update(local_keys) | ||||
|  | ||||
|         if recurse: | ||||
|             self.apply_to_modules(_unfreeze_impl) | ||||
|         else: | ||||
|             _unfreeze_impl("", self) | ||||
|  | ||||
|     def train(self, mode: bool = True): | ||||
|         def _set_train(_, m): | ||||
|             m._training = mode | ||||
|  | ||||
|         self.apply_to_modules(_set_train) | ||||
|  | ||||
|     def eval(self): | ||||
|         self.train(False) | ||||
							
								
								
									
										22
									
								
								python/mlx/nn/layers/containers.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								python/mlx/nn/layers/containers.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,22 @@ | ||||
| from mlx.nn.layers.base import Module | ||||
|  | ||||
|  | ||||
| class Sequential(Module): | ||||
|     """A layer that calls the passed callables in order. | ||||
|  | ||||
|     We can pass either modules or plain callables to the Sequential module. If | ||||
|     our functions have learnable parameters they should be implemented as | ||||
|     ``nn.Module`` instances. | ||||
|  | ||||
|     Args: | ||||
|         modules (tuple of Callables): The modules to call in order | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, *modules): | ||||
|         super().__init__() | ||||
|         self.layers = list(modules) | ||||
|  | ||||
|     def __call__(self, x): | ||||
|         for m in self.layers: | ||||
|             x = m(x) | ||||
|         return x | ||||
							
								
								
									
										33
									
								
								python/mlx/nn/layers/dropout.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										33
									
								
								python/mlx/nn/layers/dropout.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,33 @@ | ||||
| import mlx.core as mx | ||||
| from mlx.nn.layers.base import Module | ||||
|  | ||||
|  | ||||
| class Dropout(Module): | ||||
|     """Randomly zero a portion of the elements during training. | ||||
|  | ||||
|     The remaining elements are multiplied with :math:`\frac{1}{1-p}` where | ||||
|     :math:`p` is the probability of zeroing an element. This is done so the | ||||
|     expected value of a given element will remain the same. | ||||
|  | ||||
|     Args: | ||||
|         p (float): The probability to zero an element | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, p: float = 0.5): | ||||
|         super().__init__() | ||||
|  | ||||
|         if p < 0 or p >= 1: | ||||
|             raise ValueError("The dropout probability should be in [0, 1)") | ||||
|  | ||||
|         self._p_1 = 1 - p | ||||
|  | ||||
|     def _extra_repr(self): | ||||
|         return f"p={1-self._p_1}" | ||||
|  | ||||
|     def __call__(self, x): | ||||
|         if self._p_1 == 1 or not self.training: | ||||
|             return x | ||||
|  | ||||
|         mask = mx.random.bernoulli(self._p_1, x.shape) | ||||
|  | ||||
|         return (1 / self._p_1) * mask.astype(x.dtype) * x | ||||
							
								
								
									
										178
									
								
								python/mlx/nn/layers/normalization.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										178
									
								
								python/mlx/nn/layers/normalization.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,178 @@ | ||||
| import mlx.core as mx | ||||
| from mlx.nn.layers.base import Module | ||||
|  | ||||
|  | ||||
| class LayerNorm(Module): | ||||
|     r"""Applies layer normalization [1] on the inputs. | ||||
|  | ||||
|     Computes | ||||
|  | ||||
|     .. math:: | ||||
|  | ||||
|         y = \frac{x - E[x]}{\sqrt{Var[x]} + \epsilon} \gamma + \beta, | ||||
|  | ||||
|     where :math:`\gamma` and :math:`\beta` are learned per feature dimension | ||||
|     parameters initialized at 1 and 0 respectively. | ||||
|  | ||||
|     [1]: https://arxiv.org/abs/1607.06450 | ||||
|  | ||||
|     Args: | ||||
|         dims (int): The feature dimension of the input to normalize over | ||||
|         eps (float): A small additive constant for numerical stability | ||||
|         affine (bool): If True learn an affine transform to apply after the | ||||
|             normalization | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, dims: int, eps: float = 1e-5, affine: bool = True): | ||||
|         super().__init__() | ||||
|         if affine: | ||||
|             self.bias = mx.zeros((dims,)) | ||||
|             self.weight = mx.ones((dims,)) | ||||
|         self.eps = eps | ||||
|         self.dims = dims | ||||
|  | ||||
|     def _extra_repr(self): | ||||
|         return f"{self.dims}, eps={self.eps}, affine={'weight' in self}" | ||||
|  | ||||
|     def __call__(self, x): | ||||
|         means = mx.mean(x, axis=-1, keepdims=True) | ||||
|         var = mx.var(x, axis=-1, keepdims=True) | ||||
|         x = (x - means) * mx.rsqrt(var + self.eps) | ||||
|         return (self.weight * x + self.bias) if "weight" in self else x | ||||
|  | ||||
|  | ||||
| class RMSNorm(Module): | ||||
|     r"""Applies Root Mean Square normalization [1] to the inputs. | ||||
|  | ||||
|     Computes | ||||
|  | ||||
|     ..  math:: | ||||
|  | ||||
|         y = \frac{x}{\sqrt{E[x^2] + \epsilon}} \gamma | ||||
|  | ||||
|     where :math:`\gamma` is a learned per feature dimension parameter initialized at | ||||
|     1. | ||||
|  | ||||
|     [1]: https://arxiv.org/abs/1910.07467 | ||||
|  | ||||
|     Args: | ||||
|         dims (int): The feature dimension of the input to normalize over | ||||
|         eps (float): A small additive constant for numerical stability | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, dims: int, eps: float = 1e-5): | ||||
|         super().__init__() | ||||
|         self.weight = mx.ones((dims,)) | ||||
|         self.eps = eps | ||||
|  | ||||
|     def _extra_repr(self): | ||||
|         return f"{self.weight.shape[0]}, eps={self.eps}" | ||||
|  | ||||
|     def __call__(self, x): | ||||
|         # S is 1/sqrt(N) where N is the size of the features of x and is used | ||||
|         # to compute a numerically more stable RMS of x by multiplying with S | ||||
|         # first and summing. | ||||
|         # | ||||
|         # This way we prefer underflow over overflow which is controlled with | ||||
|         # the parameter epsilon anyway. | ||||
|         S = 1 / x.shape[-1] ** 0.5 | ||||
|  | ||||
|         n = (x * S).square().sum(axis=-1, keepdims=True) | ||||
|         n = mx.rsqrt(n + self.eps) | ||||
|  | ||||
|         return self.weight * x * n | ||||
|  | ||||
|  | ||||
| class GroupNorm(Module): | ||||
|     r"""Applies Group Normalization [1] to the inputs. | ||||
|  | ||||
|     Computes the same normalization as layer norm, namely | ||||
|  | ||||
|     .. math:: | ||||
|  | ||||
|         y = \frac{x - E[x]}{\sqrt{Var[x]} + \epsilon} \gamma + \beta, | ||||
|  | ||||
|     where :math:`\gamma` and :math:`\beta` are learned per feature dimension | ||||
|     parameters initialized at 1 and 0 respectively. However, the mean and | ||||
|     variance are computed over the spatial dimensions and each group of | ||||
|     features. In particular, the input is split into num_groups accross the | ||||
|     feature dimension. | ||||
|  | ||||
|     The feature dimension is assumed to be the last dimension and the dimensions | ||||
|     that precede it (except the first) are considered the spatial dimensions. | ||||
|  | ||||
|     [1]: https://arxiv.org/abs/1803.08494 | ||||
|  | ||||
|     Args: | ||||
|         num_groups (int): Number of groups to separate the features into | ||||
|         dims (int): The feature dimensions of the input to normalize over | ||||
|         eps (float): A small additive constant for numerical stability | ||||
|         affine (bool): If True learn an affine transform to apply after the | ||||
|             normalization. | ||||
|         pytorch_compatible (bool): If True perform the group normalization in | ||||
|             the same order/grouping as PyTorch. | ||||
|     """ | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         num_groups: int, | ||||
|         dims: int, | ||||
|         eps: float = 1e-5, | ||||
|         affine: bool = True, | ||||
|         pytorch_compatible: bool = False, | ||||
|     ): | ||||
|         super().__init__() | ||||
|         if affine: | ||||
|             self.bias = mx.zeros((dims,)) | ||||
|             self.weight = mx.ones((dims,)) | ||||
|         self.num_groups = num_groups | ||||
|         self.dims = dims | ||||
|         self.eps = eps | ||||
|         self.pytorch_compatible = pytorch_compatible | ||||
|  | ||||
|     def _extra_repr(self): | ||||
|         return ( | ||||
|             f"{self.num_groups}, {self.dims}, eps={self.eps}, " | ||||
|             f"affine={'weight' in self}, pytorch_compatible={self.pytorch_compatible}" | ||||
|         ) | ||||
|  | ||||
|     def _pytorch_compatible_group_norm(self, x): | ||||
|         num_groups = self.num_groups | ||||
|         batch, *rest, dims = x.shape | ||||
|  | ||||
|         # Split into groups | ||||
|         x = x.reshape(batch, -1, num_groups, dims // num_groups) | ||||
|         x = x.transpose(0, 1, 3, 2).reshape(batch, -1, num_groups) | ||||
|  | ||||
|         # Normalize | ||||
|         means = mx.mean(x, axis=1, keepdims=True) | ||||
|         var = mx.var(x, axis=1, keepdims=True) | ||||
|         x = (x - means) * mx.rsqrt(var + self.eps) | ||||
|         x = x.reshape(batch, -1, dims // num_groups, num_groups) | ||||
|         x = x.transpose(0, 1, 3, 2).reshape(batch, *rest, dims) | ||||
|  | ||||
|         return x | ||||
|  | ||||
|     def _group_norm(self, x): | ||||
|         num_groups = self.num_groups | ||||
|         batch, *rest, dims = x.shape | ||||
|  | ||||
|         # Split into groups | ||||
|         x = x.reshape(batch, -1, num_groups) | ||||
|  | ||||
|         # Normalize | ||||
|         means = mx.mean(x, axis=1, keepdims=True) | ||||
|         var = mx.var(x, axis=1, keepdims=True) | ||||
|         x = (x - means) * mx.rsqrt(var + self.eps) | ||||
|         x = x.reshape(batch, *rest, dims) | ||||
|  | ||||
|         return x | ||||
|  | ||||
|     def __call__(self, x): | ||||
|         group_norm = ( | ||||
|             self._pytorch_compatible_group_norm | ||||
|             if self.pytorch_compatible | ||||
|             else self._group_norm | ||||
|         ) | ||||
|         x = group_norm(x) | ||||
|         return (self.weight * x + self.bias) if "weight" in self else x | ||||
							
								
								
									
										142
									
								
								python/mlx/nn/layers/positional_encoding.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										142
									
								
								python/mlx/nn/layers/positional_encoding.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,142 @@ | ||||
| import math | ||||
| from typing import Optional | ||||
|  | ||||
| import mlx.core as mx | ||||
| from mlx.nn.layers.base import Module | ||||
|  | ||||
|  | ||||
| class RoPE(Module): | ||||
|     """Implements the rotary positional encoding [1]. | ||||
|  | ||||
|     The traditional implementation rotates consecutive pairs of elements in the | ||||
|     feature dimension while the default implementation rotates pairs with | ||||
|     stride half the feature dimensions for efficiency. | ||||
|  | ||||
|     [1]: https://arxiv.org/abs/2104.09864 | ||||
|  | ||||
|     Args: | ||||
|         dims (int): The feature dimensions to be rotated. If the input feature | ||||
|                     is larger than dims then the rest is left unchanged. | ||||
|         traditional (bool): If set to True choose the traditional | ||||
|                             implementation which is slightly less efficient. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, dims: int, traditional: bool = False): | ||||
|         super().__init__() | ||||
|         self.dims = dims | ||||
|         self.traditional = traditional | ||||
|  | ||||
|     def _extra_repr(self): | ||||
|         return f"{self.dims}, traditional={self.traditional}" | ||||
|  | ||||
|     def _compute_rope(self, costheta, sintheta, x): | ||||
|         x1 = x[..., : self.dims // 2] | ||||
|         x2 = x[..., self.dims // 2 : self.dims] | ||||
|         rx1 = x1 * costheta - x2 * sintheta | ||||
|         rx2 = x1 * sintheta + x2 * costheta | ||||
|  | ||||
|         if self.dims < x.shape[-1]: | ||||
|             rx = mx.concatenate([rx1, rx2, x[..., self.dims :]], axis=-1) | ||||
|         else: | ||||
|             rx = mx.concatenate([rx1, rx2], axis=-1) | ||||
|  | ||||
|         return rx | ||||
|  | ||||
|     def _compute_traditional_rope(self, costheta, sintheta, x): | ||||
|         x1 = x[..., ::2] | ||||
|         x2 = x[..., 1::2] | ||||
|         rx1 = x1 * costheta - x2 * sintheta | ||||
|         rx2 = x1 * sintheta + x2 * costheta | ||||
|  | ||||
|         if self.dims < x.shape[-1]: | ||||
|             raise NotImplementedError( | ||||
|                 "RoPE doesn't implement partial traditional application" | ||||
|             ) | ||||
|  | ||||
|         rx = mx.concatenate([rx1[..., None], rx2[..., None]], axis=-1) | ||||
|  | ||||
|         return rx | ||||
|  | ||||
|     def __call__(self, x, offset: int = 0): | ||||
|         shape = x.shape | ||||
|         x = mx.reshape(x, (-1, shape[-2], shape[-1])) | ||||
|         N = x.shape[1] + offset | ||||
|         costheta, sintheta = RoPE.create_cos_sin_theta( | ||||
|             N, self.dims, offset=offset, dtype=x.dtype | ||||
|         ) | ||||
|  | ||||
|         rope = ( | ||||
|             self._compute_traditional_rope if self.traditional else self._compute_rope | ||||
|         ) | ||||
|         rx = rope(costheta, sintheta, x) | ||||
|  | ||||
|         return mx.reshape(rx, shape) | ||||
|  | ||||
|     @staticmethod | ||||
|     def create_cos_sin_theta( | ||||
|         N: int, D: int, offset: int = 0, base: float = 10000, dtype=mx.float32 | ||||
|     ): | ||||
|         D = D // 2 | ||||
|         positions = mx.arange(offset, N, dtype=dtype) | ||||
|         freqs = mx.exp(-mx.arange(0.0, D, dtype=dtype) * (math.log(base) / D)) | ||||
|         theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1)) | ||||
|         costheta = mx.cos(theta) | ||||
|         sintheta = mx.sin(theta) | ||||
|  | ||||
|         return costheta, sintheta | ||||
|  | ||||
|  | ||||
| class SinusoidalPositionalEncoding(Module): | ||||
|     """Implements sinusoidal positional encoding similar to [1]. | ||||
|  | ||||
|     [1]: https://arxiv.org/abs/1706.03762 | ||||
|  | ||||
|     Args: | ||||
|         dims (int): The dimensionality of the resulting positional embeddings. | ||||
|         min_freq (float): The minimum frequency expected (default: 0.0001) | ||||
|         max_freq (float): The maximum frequency expected (default: 1) | ||||
|         scale (float): Scale the embeddings by that number (default: sqrt(dims//2)) | ||||
|         cos_first (bool): If set to True embed using ``[cos(x); sin(x)]`` | ||||
|             instead of the other way around (default: False) | ||||
|         full_turns (bool): If set to True multiply the frequencies | ||||
|             with ``2 pi`` (default: False) | ||||
|     """ | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         dims: int, | ||||
|         min_freq: float = 0.0001, | ||||
|         max_freq: float = 1, | ||||
|         scale: Optional[float] = None, | ||||
|         cos_first: bool = False, | ||||
|         full_turns: bool = False, | ||||
|     ): | ||||
|         super().__init__() | ||||
|  | ||||
|         one_zero = 1 - mx.arange(0, dims // 2) / (dims // 2 - 1) | ||||
|         min_freq = math.log(min_freq) | ||||
|         max_freq = math.log(max_freq) | ||||
|  | ||||
|         # Start with underscore so it is not included in the parameters | ||||
|         self._sigmas = mx.exp(one_zero * (max_freq - min_freq) + min_freq) | ||||
|         if full_turns: | ||||
|             self._sigmas = self._sigmas * (2 * math.pi) | ||||
|  | ||||
|         # Save some constants that define the implementation | ||||
|         self.scale = scale or (2 / dims) ** 0.5 | ||||
|         self.cos_first = cos_first | ||||
|  | ||||
|     def __call__(self, x): | ||||
|         y = x[..., None] * self._sigmas | ||||
|         cosy = mx.cos(y) | ||||
|         siny = mx.sin(y) | ||||
|  | ||||
|         if self.cos_first: | ||||
|             y = mx.concatenate([cosy, siny], axis=-1) | ||||
|         else: | ||||
|             y = mx.concatenate([siny, cosy], axis=-1) | ||||
|  | ||||
|         if self.scale != 1: | ||||
|             y = y * self.scale | ||||
|  | ||||
|         return y | ||||
							
								
								
									
										136
									
								
								python/mlx/nn/layers/transformer.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										136
									
								
								python/mlx/nn/layers/transformer.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,136 @@ | ||||
| import math | ||||
| from typing import Optional | ||||
|  | ||||
| import mlx.core as mx | ||||
| from mlx.nn.layers.base import Module | ||||
| from mlx.nn.layers.linear import Linear | ||||
| from mlx.nn.layers.normalization import LayerNorm | ||||
|  | ||||
|  | ||||
| class MultiHeadAttention(Module): | ||||
|     """Implements the scaled dot product attention with multiple heads. | ||||
|  | ||||
|     Given inputs for queries, keys and values the ``MultiHeadAttention`` produces | ||||
|     new values by aggregating information from the input values according to | ||||
|     the similarities of the input queries and keys. | ||||
|  | ||||
|     All inputs as well as the output are lineary projected without biases. | ||||
|  | ||||
|     MultiHeadAttention also expects an additive attention mask that should be | ||||
|     broadcastable with (batch, num_heads, # queries, # keys). The mask should | ||||
|     have ``-inf`` or very negative numbers to the positions that should *not* be | ||||
|     attended to. | ||||
|  | ||||
|     Args: | ||||
|         dims (int): The model dimensions. If no other dims are provided then | ||||
|             dims is used for queries, keys, values and the output. | ||||
|         num_heads (int): How many attention heads to use | ||||
|         query_input_dims (int, optional): The input dimensions of the queries (default: dims). | ||||
|         key_input_dims (int, optional): The input dimensions of the keys (default: dims). | ||||
|         value_input_dims (int, optional): The input dimensions of the values (default: key_input_dims). | ||||
|         value_dims (int, optional): The dimensions of the values after the projection (default: dims). | ||||
|         value_output_dims (int, optional): The dimensions the new values will be projected to (default: dims). | ||||
|     """ | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         dims: int, | ||||
|         num_heads: int, | ||||
|         query_input_dims: Optional[int] = None, | ||||
|         key_input_dims: Optional[int] = None, | ||||
|         value_input_dims: Optional[int] = None, | ||||
|         value_dims: Optional[int] = None, | ||||
|         value_output_dims: Optional[int] = None, | ||||
|     ): | ||||
|         super().__init__() | ||||
|  | ||||
|         if (dims % num_heads) != 0: | ||||
|             raise ValueError( | ||||
|                 f"The input feature dimensions should be divisble by the number of heads ({dims} % {num_heads}) != 0" | ||||
|             ) | ||||
|  | ||||
|         query_input_dims = query_input_dims or dims | ||||
|         key_input_dims = key_input_dims or dims | ||||
|         value_input_dims = value_input_dims or key_input_dims | ||||
|         value_dims = value_dims or dims | ||||
|         value_output_dims = value_output_dims or dims | ||||
|  | ||||
|         self.num_heads = num_heads | ||||
|         self.query_proj = Linear(query_input_dims, dims, False) | ||||
|         self.key_proj = Linear(key_input_dims, dims, False) | ||||
|         self.value_proj = Linear(value_input_dims, value_dims, False) | ||||
|         self.out_proj = Linear(value_dims, value_output_dims, False) | ||||
|  | ||||
|     def __call__(self, queries, keys, values, mask=None): | ||||
|         queries = self.query_proj(queries) | ||||
|         keys = self.key_proj(keys) | ||||
|         values = self.value_proj(values) | ||||
|  | ||||
|         num_heads = self.num_heads | ||||
|         B, L, D = queries.shape | ||||
|         _, S, _ = keys.shape | ||||
|         queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) | ||||
|         keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1) | ||||
|         values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) | ||||
|  | ||||
|         # Dimensions are [batch x num heads x sequence x hidden dim] | ||||
|         scale = math.sqrt(1 / queries.shape[-1]) | ||||
|         scores = (queries * scale) @ keys | ||||
|         if mask is not None: | ||||
|             scores = scores + mask.astype(scores.dtype) | ||||
|         scores = mx.softmax(scores, axis=-1) | ||||
|         values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) | ||||
|  | ||||
|         return self.out_proj(values_hat) | ||||
|  | ||||
|     @staticmethod | ||||
|     def create_additive_causal_mask(N: int, dtype: mx.Dtype = mx.float32): | ||||
|         indices = mx.arange(N) | ||||
|         mask = indices[:, None] < indices[None] | ||||
|         # usually inf but 1e9 is as good and softmax(full(1e9)) != nan | ||||
|         # TODO: Should replace this with finfo(dtype).min | ||||
|         mask = mask.astype(dtype) * -1e9 | ||||
|         return mask | ||||
|  | ||||
|  | ||||
| class TransformerEncoderLayer(Module): | ||||
|     def __init__(self, dims: int, num_heads: int, mlp_dims: Optional[int] = None): | ||||
|         super().__init__() | ||||
|         mlp_dims = mlp_dims or dims * 4 | ||||
|         self.attention = MultiHeadAttention(dims, num_heads) | ||||
|         self.ln1 = LayerNorm(dims) | ||||
|         self.ln2 = LayerNorm(dims) | ||||
|         self.linear1 = Linear(dims, mlp_dims) | ||||
|         self.linear2 = Linear(mlp_dims, dims) | ||||
|  | ||||
|     def __call__(self, x, mask): | ||||
|         y = self.ln1(x) | ||||
|         y = self.attention(y, y, y, mask) | ||||
|         x = x + y | ||||
|  | ||||
|         y = self.ln2(x) | ||||
|         y = self.linear1(y) | ||||
|         y = mx.maximum(y, 0) | ||||
|         y = self.linear2(y) | ||||
|         x = x + y | ||||
|  | ||||
|         return x | ||||
|  | ||||
|  | ||||
| class TransformerEncoder(Module): | ||||
|     def __init__( | ||||
|         self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None | ||||
|     ): | ||||
|         super().__init__() | ||||
|         self.layers = [ | ||||
|             TransformerEncoderLayer(dims, num_heads, mlp_dims) | ||||
|             for i in range(num_layers) | ||||
|         ] | ||||
|         self.ln = LayerNorm(dims) | ||||
|  | ||||
|     def __call__(self, x, mask): | ||||
|         for l in self.layers: | ||||
|             x = l(x, mask) | ||||
|         x = self.ln(x) | ||||
|  | ||||
|         return x | ||||
							
								
								
									
										31
									
								
								python/mlx/nn/utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										31
									
								
								python/mlx/nn/utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,31 @@ | ||||
| from typing import Callable | ||||
|  | ||||
| import mlx.core as mx | ||||
|  | ||||
|  | ||||
| def value_and_grad(model: "mlx.nn.Module", fn: Callable): | ||||
|     """Transform the passed function ``fn`` to a function that computes the | ||||
|     gradients of ``fn`` wrt the model's trainable parameters and also its | ||||
|     value. | ||||
|  | ||||
|     Args: | ||||
|         model (mlx.nn.Module): The model whose trainable parameters to compute | ||||
|                                gradients for | ||||
|         fn (Callable): The scalar function to compute gradients for | ||||
|  | ||||
|     Returns: | ||||
|         A callable that returns the value of ``fn`` and the gradients wrt the | ||||
|         trainable parameters of ``model`` | ||||
|     """ | ||||
|  | ||||
|     def inner_fn(params, *args, **kwargs): | ||||
|         model.update(params) | ||||
|         return fn(*args, **kwargs) | ||||
|  | ||||
|     value_grad_fn = mx.value_and_grad(inner_fn) | ||||
|  | ||||
|     def wrapped_value_grad_fn(*args, **kwargs): | ||||
|         value, grad = value_grad_fn(model.trainable_parameters(), *args, **kwargs) | ||||
|         return value, grad | ||||
|  | ||||
|     return wrapped_value_grad_fn | ||||
							
								
								
									
										152
									
								
								python/mlx/optimizers.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										152
									
								
								python/mlx/optimizers.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,152 @@ | ||||
| import math | ||||
| from typing import List | ||||
|  | ||||
| import mlx.core as mx | ||||
| from mlx.utils import tree_map | ||||
|  | ||||
|  | ||||
| class OptimizerState(dict): | ||||
|     """The optimizer state implements a recursively defined | ||||
|     :class:`collections.defaultdict`, namely a missing key in an optimizer | ||||
|     state is an :class:`OptimizerState`. | ||||
|  | ||||
|     .. note:: | ||||
|        :meth:`OptimizerState.get` in contrast to a normal dictionary also sets | ||||
|        the key to the ``default`` value if the ``key`` was not present in the | ||||
|        dictionary. | ||||
|     """ | ||||
|  | ||||
|     def __getitem__(self, key): | ||||
|         if key not in self: | ||||
|             self[key] = OptimizerState() | ||||
|         return super().__getitem__(key) | ||||
|  | ||||
|     def get(self, key, default): | ||||
|         """If ``key`` doesn't exist set its value to ``default`` and then return it.""" | ||||
|         if key not in self: | ||||
|             self[key] = default | ||||
|         return super().__getitem__(key) | ||||
|  | ||||
|  | ||||
| class Optimizer: | ||||
|     """The base class for all optimizers. It allows us to implement an | ||||
|     optimizer on a per-parameter basis and apply it to a parameter tree. | ||||
|  | ||||
|     Attributes: | ||||
|         state (OptimizerState): It holds the optimizer's state dictionary. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self): | ||||
|         self.state = OptimizerState() | ||||
|  | ||||
|     def update(self, model: "mlx.nn.Module", gradients: dict): | ||||
|         """Apply the gradients to the parameters of the model and update the | ||||
|         model with the new parameters. | ||||
|  | ||||
|         Args: | ||||
|             model (mlx.nn.Module): An mlx module to be updated. | ||||
|             gradients (dict): A Python tree of gradients, most likely computed | ||||
|                               via :func:`mlx.nn.value_and_grad`. | ||||
|         """ | ||||
|         model.update(self.apply_gradients(gradients, model)) | ||||
|  | ||||
|     def apply_gradients(self, gradients: dict, model: dict): | ||||
|         """Apply the gradients to the parameters and return the updated parameters. | ||||
|  | ||||
|         Can be used to update a model via | ||||
|         ``model.update(opt.apply_gradients(grads, model))`` which is precisely | ||||
|         how :meth:`Optimizer.update` is implemented. | ||||
|  | ||||
|         Args: | ||||
|             gradients (dict): A Python tree of gradients. | ||||
|             model (dict): A Python tree of parameters. It can be a superset of | ||||
|                           the gradients. In that case the returned python tree | ||||
|                           will be of the same structure as the gradients. | ||||
|         """ | ||||
|         return tree_map(self.apply_single, gradients, model, self.state) | ||||
|  | ||||
|     def apply_single( | ||||
|         self, gradient: mx.array, parameter: mx.array, state: OptimizerState | ||||
|     ): | ||||
|         """To be extended by the children classes to implement each optimizer's | ||||
|         update.""" | ||||
|         raise NotImplementedError() | ||||
|  | ||||
|  | ||||
| class SGD(Optimizer): | ||||
|     r"""Stochastic gradient descent optimizer. | ||||
|  | ||||
|     Updates a parameter :math:`w` with a gradient :math:`g` as follows | ||||
|  | ||||
|     .. math:: | ||||
|  | ||||
|         v_{t+1} &= \mu v_t + (1 - \mu) g_t \\ | ||||
|         w_{t+1} &= w_t - \lambda v_{t+1} | ||||
|  | ||||
|     Args: | ||||
|         learning_rate (float): The learning :math:`\lambda` for the update | ||||
|         momentum (float): The momentum strength :math:`\mu` | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, learning_rate: float, momentum: float = 0.0): | ||||
|         super().__init__() | ||||
|  | ||||
|         self.learning_rate = learning_rate | ||||
|         self.momentum = momentum | ||||
|  | ||||
|     def apply_single( | ||||
|         self, gradient: mx.array, parameter: mx.array, state: OptimizerState | ||||
|     ): | ||||
|         """Performs the SGD parameter update and stores :math:`v` in the | ||||
|         optimizer state.""" | ||||
|         if self.momentum <= 0: | ||||
|             return parameter - self.learning_rate * gradient | ||||
|  | ||||
|         v = state.get("v", mx.zeros_like(gradient)) | ||||
|         v = self.momentum * v + (1 - self.momentum) * gradient | ||||
|         state["v"] = v | ||||
|         return parameter - self.learning_rate * v | ||||
|  | ||||
|  | ||||
| class Adam(Optimizer): | ||||
|     r"""Implementation of the Adam optimizer [1]. | ||||
|  | ||||
|     Our Adam implementation follows the original paper and omits the bias | ||||
|     correction in the first and second moment estimates. In detail, | ||||
|  | ||||
|     .. math:: | ||||
|  | ||||
|         m_{t+1} &= \beta_1 m_t + (1 - \beta_1) g_t \\ | ||||
|         v_{t+1} &= \beta_2 v_t + (1 - \beta_2) g_t^2 \\ | ||||
|         w_{t+1} &= w_t - \lambda \frac{m_{t+1}}{\sqrt{v_{t+1} + \epsilon}} | ||||
|  | ||||
|     [1]: Kingma, D.P. and Ba, J., 2015. Adam: A method for stochastic | ||||
|     optimization. ICLR 2015. | ||||
|     """ | ||||
|  | ||||
|     def __init__( | ||||
|         self, learning_rate: float, betas: List[float] = [0.9, 0.999], eps: float = 1e-8 | ||||
|     ): | ||||
|         super().__init__() | ||||
|  | ||||
|         self.learning_rate = learning_rate | ||||
|         self.betas = betas | ||||
|         self.eps = eps | ||||
|  | ||||
|     def apply_single( | ||||
|         self, gradient: mx.array, parameter: mx.array, state: OptimizerState | ||||
|     ): | ||||
|         """Performs the Adam parameter update and stores :math:`v` and | ||||
|         :math:`m` in the optimizer state.""" | ||||
|         lr = self.learning_rate | ||||
|         b1, b2 = self.betas | ||||
|         eps = self.eps | ||||
|  | ||||
|         m = state.get("m", gradient) | ||||
|         v = state.get("v", mx.square(gradient)) | ||||
|         m = b1 * m + (1 - b1) * gradient | ||||
|         v = b2 * v + (1 - b2) * mx.square(gradient) | ||||
|         state["m"] = m | ||||
|         state["v"] = v | ||||
|  | ||||
|         return parameter - lr * m / (mx.sqrt(v) + eps) | ||||
							
								
								
									
										32
									
								
								python/src/CMakeLists.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										32
									
								
								python/src/CMakeLists.txt
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,32 @@ | ||||
| pybind11_add_module( | ||||
|   core | ||||
|   ${CMAKE_CURRENT_SOURCE_DIR}/mlx.cpp | ||||
|   ${CMAKE_CURRENT_SOURCE_DIR}/array.cpp | ||||
|   ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp | ||||
|   ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp | ||||
|   ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp | ||||
|   ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp | ||||
|   ${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp | ||||
|   ${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp | ||||
|   ${CMAKE_CURRENT_SOURCE_DIR}/stream.cpp | ||||
|   ${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp | ||||
|   ${CMAKE_CURRENT_SOURCE_DIR}/random.cpp | ||||
| ) | ||||
|  | ||||
| if (NOT MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY) | ||||
|   set(MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}) | ||||
| endif() | ||||
|  | ||||
| set_target_properties( | ||||
|   core  | ||||
|   PROPERTIES  | ||||
|   LIBRARY_OUTPUT_DIRECTORY  | ||||
|   ${MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY} | ||||
| ) | ||||
|  | ||||
| target_link_libraries(core PRIVATE mlx) | ||||
| target_compile_definitions(core PRIVATE _VERSION_=${MLX_VERSION}) | ||||
|  | ||||
| if(BUILD_SHARED_LIBS) | ||||
|   target_link_options(core PRIVATE -Wl,-rpath,@loader_path/lib) | ||||
| endif() | ||||
							
								
								
									
										468
									
								
								python/src/fft.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										468
									
								
								python/src/fft.cpp
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,468 @@ | ||||
| #include <pybind11/pybind11.h> | ||||
| #include <pybind11/stl.h> | ||||
|  | ||||
| #include "python/src/utils.h" | ||||
|  | ||||
| #include "mlx/fft.h" | ||||
| #include "mlx/ops.h" | ||||
|  | ||||
| namespace py = pybind11; | ||||
| using namespace py::literals; | ||||
|  | ||||
| using namespace mlx::core; | ||||
|  | ||||
| void init_fft(py::module_& parent_module) { | ||||
|   auto m = parent_module.def_submodule( | ||||
|       "fft", "mlx.core.fft: Fast Fourier Transforms."); | ||||
|   m.def( | ||||
|       "fft", | ||||
|       [](const array& a, | ||||
|          const std::optional<int>& n, | ||||
|          int axis, | ||||
|          StreamOrDevice s) { | ||||
|         if (n.has_value()) { | ||||
|           return fft::fft(a, n.value(), axis, s); | ||||
|         } else { | ||||
|           return fft::fft(a, axis, s); | ||||
|         } | ||||
|       }, | ||||
|       "a"_a, | ||||
|       "n"_a = none, | ||||
|       "axis"_a = -1, | ||||
|       "stream"_a = none, | ||||
|       R"pbdoc( | ||||
|         One dimensional discrete Fourier Transform. | ||||
|  | ||||
|         Args: | ||||
|             a (array): The input array. | ||||
|             n (int, optional): Size of the transformed axis. The | ||||
|                corresponding axis in the input is truncated or padded with | ||||
|                zeros to match ``n``. The default value is ``a.shape[axis]``. | ||||
|             axis (int, optional): Axis along which to perform the FFT. The | ||||
|                default is ``-1``. | ||||
|  | ||||
|         Returns: | ||||
|             array: The DFT of the input along the given axis. | ||||
|       )pbdoc"); | ||||
|   m.def( | ||||
|       "ifft", | ||||
|       [](const array& a, | ||||
|          const std::optional<int>& n, | ||||
|          int axis, | ||||
|          StreamOrDevice s) { | ||||
|         if (n.has_value()) { | ||||
|           return fft::ifft(a, n.value(), axis, s); | ||||
|         } else { | ||||
|           return fft::ifft(a, axis, s); | ||||
|         } | ||||
|       }, | ||||
|       "a"_a, | ||||
|       "n"_a = none, | ||||
|       "axis"_a = -1, | ||||
|       "stream"_a = none, | ||||
|       R"pbdoc( | ||||
|         One dimensional inverse discrete Fourier Transform. | ||||
|  | ||||
|         Args: | ||||
|             a (array): The input array. | ||||
|             n (int, optional): Size of the transformed axis. The | ||||
|                corresponding axis in the input is truncated or padded with | ||||
|                zeros to match ``n``. The default value is ``a.shape[axis]``. | ||||
|             axis (int, optional): Axis along which to perform the FFT. The | ||||
|                default is ``-1``. | ||||
|  | ||||
|         Returns: | ||||
|             array: The inverse DFT of the input along the given axis. | ||||
|       )pbdoc"); | ||||
|   m.def( | ||||
|       "fft2", | ||||
|       [](const array& a, | ||||
|          const std::optional<std::vector<int>>& n, | ||||
|          const std::optional<std::vector<int>>& axes, | ||||
|          StreamOrDevice s) { | ||||
|         if (axes.has_value() && n.has_value()) { | ||||
|           return fft::fftn(a, n.value(), axes.value(), s); | ||||
|         } else if (axes.has_value()) { | ||||
|           return fft::fftn(a, axes.value(), s); | ||||
|         } else if (n.has_value()) { | ||||
|           std::vector<int> axes_(n.value().size()); | ||||
|           std::iota(axes_.begin(), axes_.end(), -n.value().size()); | ||||
|           return fft::fftn(a, n.value(), axes_, s); | ||||
|         } else { | ||||
|           return fft::fftn(a, s); | ||||
|         } | ||||
|       }, | ||||
|       "a"_a, | ||||
|       "s"_a = none, | ||||
|       "axes"_a = std::vector<int>{-2, -1}, | ||||
|       "stream"_a = none, | ||||
|       R"pbdoc( | ||||
|         Two dimensional discrete Fourier Transform. | ||||
|  | ||||
|         Args: | ||||
|             a (array): The input array. | ||||
|             s (list(int), optional): Sizes of the transformed axes. The | ||||
|                corresponding axes in the input are truncated or padded with | ||||
|                zeros to match the sizes in ``s``. The default value is the | ||||
|                sizes of ``a`` along ``axes``. | ||||
|             axes (list(int), optional): Axes along which to perform the FFT. | ||||
|                The default is ``[-2, -1]``. | ||||
|  | ||||
|         Returns: | ||||
|             array: The DFT of the input along the given axes. | ||||
|       )pbdoc"); | ||||
|   m.def( | ||||
|       "ifft2", | ||||
|       [](const array& a, | ||||
|          const std::optional<std::vector<int>>& n, | ||||
|          const std::optional<std::vector<int>>& axes, | ||||
|          StreamOrDevice s) { | ||||
|         if (axes.has_value() && n.has_value()) { | ||||
|           return fft::ifftn(a, n.value(), axes.value(), s); | ||||
|         } else if (axes.has_value()) { | ||||
|           return fft::ifftn(a, axes.value(), s); | ||||
|         } else if (n.has_value()) { | ||||
|           std::vector<int> axes_(n.value().size()); | ||||
|           std::iota(axes_.begin(), axes_.end(), -n.value().size()); | ||||
|           return fft::ifftn(a, n.value(), axes_, s); | ||||
|         } else { | ||||
|           return fft::ifftn(a, s); | ||||
|         } | ||||
|       }, | ||||
|       "a"_a, | ||||
|       "s"_a = none, | ||||
|       "axes"_a = std::vector<int>{-2, -1}, | ||||
|       "stream"_a = none, | ||||
|       R"pbdoc( | ||||
|         Two dimensional inverse discrete Fourier Transform. | ||||
|  | ||||
|         Args: | ||||
|             a (array): The input array. | ||||
|             s (list(int), optional): Sizes of the transformed axes. The | ||||
|                corresponding axes in the input are truncated or padded with | ||||
|                zeros to match the sizes in ``s``. The default value is the | ||||
|                sizes of ``a`` along ``axes``. | ||||
|             axes (list(int), optional): Axes along which to perform the FFT. | ||||
|                The default is ``[-2, -1]``. | ||||
|  | ||||
|         Returns: | ||||
|             array: The inverse DFT of the input along the given axes. | ||||
|       )pbdoc"); | ||||
|   m.def( | ||||
|       "fftn", | ||||
|       [](const array& a, | ||||
|          const std::optional<std::vector<int>>& n, | ||||
|          const std::optional<std::vector<int>>& axes, | ||||
|          StreamOrDevice s) { | ||||
|         if (axes.has_value() && n.has_value()) { | ||||
|           return fft::fftn(a, n.value(), axes.value(), s); | ||||
|         } else if (axes.has_value()) { | ||||
|           return fft::fftn(a, axes.value(), s); | ||||
|         } else if (n.has_value()) { | ||||
|           std::vector<int> axes_(n.value().size()); | ||||
|           std::iota(axes_.begin(), axes_.end(), -n.value().size()); | ||||
|           return fft::fftn(a, n.value(), axes_, s); | ||||
|         } else { | ||||
|           return fft::fftn(a, s); | ||||
|         } | ||||
|       }, | ||||
|       "a"_a, | ||||
|       "s"_a = none, | ||||
|       "axes"_a = none, | ||||
|       "stream"_a = none, | ||||
|       R"pbdoc( | ||||
|         n-dimensional discrete Fourier Transform. | ||||
|  | ||||
|         Args: | ||||
|             a (array): The input array. | ||||
|             s (list(int), optional): Sizes of the transformed axes. The | ||||
|                corresponding axes in the input are truncated or padded with | ||||
|                zeros to match the sizes in ``s``. The default value is the | ||||
|                sizes of ``a`` along ``axes``. | ||||
|             axes (list(int), optional): Axes along which to perform the FFT. | ||||
|                The default is ``None`` in which case the FFT is over the last | ||||
|                ``len(s)`` axes are or all axes if ``s`` is also ``None``. | ||||
|  | ||||
|         Returns: | ||||
|             array: The DFT of the input along the given axes. | ||||
|       )pbdoc"); | ||||
|   m.def( | ||||
|       "ifftn", | ||||
|       [](const array& a, | ||||
|          const std::optional<std::vector<int>>& n, | ||||
|          const std::optional<std::vector<int>>& axes, | ||||
|          StreamOrDevice s) { | ||||
|         if (axes.has_value() && n.has_value()) { | ||||
|           return fft::ifftn(a, n.value(), axes.value(), s); | ||||
|         } else if (axes.has_value()) { | ||||
|           return fft::ifftn(a, axes.value(), s); | ||||
|         } else if (n.has_value()) { | ||||
|           std::vector<int> axes_(n.value().size()); | ||||
|           std::iota(axes_.begin(), axes_.end(), -n.value().size()); | ||||
|           return fft::ifftn(a, n.value(), axes_, s); | ||||
|         } else { | ||||
|           return fft::ifftn(a, s); | ||||
|         } | ||||
|       }, | ||||
|       "a"_a, | ||||
|       "s"_a = none, | ||||
|       "axes"_a = none, | ||||
|       "stream"_a = none, | ||||
|       R"pbdoc( | ||||
|         n-dimensional inverse discrete Fourier Transform. | ||||
|  | ||||
|         Args: | ||||
|             a (array): The input array. | ||||
|             s (list(int), optional): Sizes of the transformed axes. The | ||||
|                corresponding axes in the input are truncated or padded with | ||||
|                zeros to match the sizes in ``s``. The default value is the | ||||
|                sizes of ``a`` along ``axes``. | ||||
|             axes (list(int), optional): Axes along which to perform the FFT. | ||||
|                The default is ``None`` in which case the FFT is over the last | ||||
|                ``len(s)`` axes or all axes if ``s`` is also ``None``. | ||||
|  | ||||
|         Returns: | ||||
|             array: The inverse DFT of the input along the given axes. | ||||
|       )pbdoc"); | ||||
|   m.def( | ||||
|       "rfft", | ||||
|       [](const array& a, | ||||
|          const std::optional<int>& n, | ||||
|          int axis, | ||||
|          StreamOrDevice s) { | ||||
|         if (n.has_value()) { | ||||
|           return fft::rfft(a, n.value(), axis, s); | ||||
|         } else { | ||||
|           return fft::rfft(a, axis, s); | ||||
|         } | ||||
|       }, | ||||
|       "a"_a, | ||||
|       "n"_a = none, | ||||
|       "axis"_a = -1, | ||||
|       "stream"_a = none, | ||||
|       R"pbdoc( | ||||
|         One dimensional discrete Fourier Transform on a real input. | ||||
|  | ||||
|         The output has the same shape as the input except along ``axis`` in | ||||
|         which case it has size ``n // 2 + 1``. | ||||
|  | ||||
|         Args: | ||||
|             a (array): The input array. If the array is complex it will be silently | ||||
|                cast to a real type. | ||||
|             n (int, optional): Size of the transformed axis. The | ||||
|                corresponding axis in the input is truncated or padded with | ||||
|                zeros to match ``n``. The default value is ``a.shape[axis]``. | ||||
|             axis (int, optional): Axis along which to perform the FFT. The | ||||
|                default is ``-1``. | ||||
|  | ||||
|         Returns: | ||||
|             array: The DFT of the input along the given axis. The output | ||||
|             data type will be complex. | ||||
|       )pbdoc"); | ||||
|   m.def( | ||||
|       "irfft", | ||||
|       [](const array& a, | ||||
|          const std::optional<int>& n, | ||||
|          int axis, | ||||
|          StreamOrDevice s) { | ||||
|         if (n.has_value()) { | ||||
|           return fft::irfft(a, n.value(), axis, s); | ||||
|         } else { | ||||
|           return fft::irfft(a, axis, s); | ||||
|         } | ||||
|       }, | ||||
|       "a"_a, | ||||
|       "n"_a = none, | ||||
|       "axis"_a = -1, | ||||
|       "stream"_a = none, | ||||
|       R"pbdoc( | ||||
|         The inverse of :func:`rfft`. | ||||
|  | ||||
|         The output has the same shape as the input except along ``axis`` in | ||||
|         which case it has size ``n``. | ||||
|  | ||||
|         Args: | ||||
|             a (array): The input array. | ||||
|             n (int, optional): Size of the transformed axis. The | ||||
|                corresponding axis in the input is truncated or padded with | ||||
|                zeros to match ``n // 2 + 1``. The default value is | ||||
|                ``a.shape[axis] // 2 + 1``. | ||||
|             axis (int, optional): Axis along which to perform the FFT. The | ||||
|                default is ``-1``. | ||||
|  | ||||
|         Returns: | ||||
|             array: The real array containing the inverse of :func:`rfft`. | ||||
|       )pbdoc"); | ||||
|   m.def( | ||||
|       "rfft2", | ||||
|       [](const array& a, | ||||
|          const std::optional<std::vector<int>>& n, | ||||
|          const std::optional<std::vector<int>>& axes, | ||||
|          StreamOrDevice s) { | ||||
|         if (axes.has_value() && n.has_value()) { | ||||
|           return fft::rfftn(a, n.value(), axes.value(), s); | ||||
|         } else if (axes.has_value()) { | ||||
|           return fft::rfftn(a, axes.value(), s); | ||||
|         } else if (n.has_value()) { | ||||
|           std::vector<int> axes_(n.value().size()); | ||||
|           std::iota(axes_.begin(), axes_.end(), -n.value().size()); | ||||
|           return fft::rfftn(a, n.value(), axes_, s); | ||||
|         } else { | ||||
|           return fft::rfftn(a, s); | ||||
|         } | ||||
|       }, | ||||
|       "a"_a, | ||||
|       "s"_a = none, | ||||
|       "axes"_a = std::vector<int>{-2, -1}, | ||||
|       "stream"_a = none, | ||||
|       R"pbdoc( | ||||
|         Two dimensional real discrete Fourier Transform. | ||||
|  | ||||
|         The output has the same shape as the input except along the dimensions in | ||||
|         ``axes`` in which case it has sizes from ``s``. The last axis in ``axes`` is | ||||
|         treated as the real axis and will have size ``s[-1] // 2 + 1``. | ||||
|  | ||||
|         Args: | ||||
|             a (array): The input array. If the array is complex it will be silently | ||||
|                cast to a real type. | ||||
|             s (list(int), optional): Sizes of the transformed axes. The | ||||
|                corresponding axes in the input are truncated or padded with | ||||
|                zeros to match the sizes in ``s``. The default value is the | ||||
|                sizes of ``a`` along ``axes``. | ||||
|             axes (list(int), optional): Axes along which to perform the FFT. | ||||
|                The default is ``[-2, -1]``. | ||||
|  | ||||
|         Returns: | ||||
|             array: The real DFT of the input along the given axes. The output | ||||
|             data type will be complex. | ||||
|       )pbdoc"); | ||||
|   m.def( | ||||
|       "irfft2", | ||||
|       [](const array& a, | ||||
|          const std::optional<std::vector<int>>& n, | ||||
|          const std::optional<std::vector<int>>& axes, | ||||
|          StreamOrDevice s) { | ||||
|         if (axes.has_value() && n.has_value()) { | ||||
|           return fft::irfftn(a, n.value(), axes.value(), s); | ||||
|         } else if (axes.has_value()) { | ||||
|           return fft::irfftn(a, axes.value(), s); | ||||
|         } else if (n.has_value()) { | ||||
|           std::vector<int> axes_(n.value().size()); | ||||
|           std::iota(axes_.begin(), axes_.end(), -n.value().size()); | ||||
|           return fft::irfftn(a, n.value(), axes_, s); | ||||
|         } else { | ||||
|           return fft::irfftn(a, s); | ||||
|         } | ||||
|       }, | ||||
|       "a"_a, | ||||
|       "s"_a = none, | ||||
|       "axes"_a = std::vector<int>{-2, -1}, | ||||
|       "stream"_a = none, | ||||
|       R"pbdoc( | ||||
|         The inverse of :func:`rfft2`. | ||||
|  | ||||
|         Note the input is generally complex. The dimensions of the input | ||||
|         specified in ``axes`` are padded or truncated to match the sizes | ||||
|         from ``s``. The last axis in ``axes`` is treated as the real axis | ||||
|         and will have size ``s[-1] // 2 + 1``. | ||||
|  | ||||
|         Args: | ||||
|             a (array): The input array. | ||||
|             s (list(int), optional): Sizes of the transformed axes. The | ||||
|                corresponding axes in the input are truncated or padded with | ||||
|                zeros to match the sizes in ``s`` except for the last axis | ||||
|                which has size ``s[-1] // 2 + 1``. The default value is the | ||||
|                sizes of ``a`` along ``axes``. | ||||
|             axes (list(int), optional): Axes along which to perform the FFT. | ||||
|                The default is ``[-2, -1]``. | ||||
|  | ||||
|         Returns: | ||||
|             array: The real array containing the inverse of :func:`rfft2`. | ||||
|       )pbdoc"); | ||||
|   m.def( | ||||
|       "rfftn", | ||||
|       [](const array& a, | ||||
|          const std::optional<std::vector<int>>& n, | ||||
|          const std::optional<std::vector<int>>& axes, | ||||
|          StreamOrDevice s) { | ||||
|         if (axes.has_value() && n.has_value()) { | ||||
|           return fft::rfftn(a, n.value(), axes.value(), s); | ||||
|         } else if (axes.has_value()) { | ||||
|           return fft::rfftn(a, axes.value(), s); | ||||
|         } else if (n.has_value()) { | ||||
|           std::vector<int> axes_(n.value().size()); | ||||
|           std::iota(axes_.begin(), axes_.end(), -n.value().size()); | ||||
|           return fft::rfftn(a, n.value(), axes_, s); | ||||
|         } else { | ||||
|           return fft::rfftn(a, s); | ||||
|         } | ||||
|       }, | ||||
|       "a"_a, | ||||
|       "s"_a = none, | ||||
|       "axes"_a = none, | ||||
|       "stream"_a = none, | ||||
|       R"pbdoc( | ||||
|         n-dimensional real discrete Fourier Transform. | ||||
|  | ||||
|         The output has the same shape as the input except along the dimensions in | ||||
|         ``axes`` in which case it has sizes from ``s``. The last axis in ``axes`` is | ||||
|         treated as the real axis and will have size ``s[-1] // 2 + 1``. | ||||
|  | ||||
|         Args: | ||||
|             a (array): The input array. If the array is complex it will be silently | ||||
|                cast to a real type. | ||||
|             s (list(int), optional): Sizes of the transformed axes. The | ||||
|                corresponding axes in the input are truncated or padded with | ||||
|                zeros to match the sizes in ``s``. The default value is the | ||||
|                sizes of ``a`` along ``axes``. | ||||
|             axes (list(int), optional): Axes along which to perform the FFT. | ||||
|                The default is ``None`` in which case the FFT is over the last | ||||
|                ``len(s)`` axes or all axes if ``s`` is also ``None``. | ||||
|  | ||||
|         Returns: | ||||
|             array: The real DFT of the input along the given axes. The output | ||||
|       )pbdoc"); | ||||
|   m.def( | ||||
|       "irfftn", | ||||
|       [](const array& a, | ||||
|          const std::optional<std::vector<int>>& n, | ||||
|          const std::optional<std::vector<int>>& axes, | ||||
|          StreamOrDevice s) { | ||||
|         if (axes.has_value() && n.has_value()) { | ||||
|           return fft::irfftn(a, n.value(), axes.value(), s); | ||||
|         } else if (axes.has_value()) { | ||||
|           return fft::irfftn(a, axes.value(), s); | ||||
|         } else if (n.has_value()) { | ||||
|           std::vector<int> axes_(n.value().size()); | ||||
|           std::iota(axes_.begin(), axes_.end(), -n.value().size()); | ||||
|           return fft::irfftn(a, n.value(), axes_, s); | ||||
|         } else { | ||||
|           return fft::irfftn(a, s); | ||||
|         } | ||||
|       }, | ||||
|       "a"_a, | ||||
|       "s"_a = none, | ||||
|       "axes"_a = none, | ||||
|       "stream"_a = none, | ||||
|       R"pbdoc( | ||||
|         The inverse of :func:`rfftn`. | ||||
|  | ||||
|         Note the input is generally complex. The dimensions of the input | ||||
|         specified in ``axes`` are padded or truncated to match the sizes | ||||
|         from ``s``. The last axis in ``axes`` is treated as the real axis | ||||
|         and will have size ``s[-1] // 2 + 1``. | ||||
|  | ||||
|         Args: | ||||
|             a (array): The input array. | ||||
|             s (list(int), optional): Sizes of the transformed axes. The | ||||
|                corresponding axes in the input are truncated or padded with | ||||
|                zeros to match the sizes in ``s``. The default value is the | ||||
|                sizes of ``a`` along ``axes``. | ||||
|             axes (list(int), optional): Axes along which to perform the FFT. | ||||
|                The default is ``None`` in which case the FFT is over the last | ||||
|                ``len(s)`` axes or all axes if ``s`` is also ``None``. | ||||
|  | ||||
|         Returns: | ||||
|             array: The real array containing the inverse of :func:`rfftn`. | ||||
|       )pbdoc"); | ||||
| } | ||||
							
								
								
									
										635
									
								
								python/src/indexing.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										635
									
								
								python/src/indexing.cpp
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,635 @@ | ||||
| #include <numeric> | ||||
| #include <sstream> | ||||
|  | ||||
| #include "python/src/indexing.h" | ||||
|  | ||||
| #include "mlx/ops.h" | ||||
|  | ||||
| bool is_none_slice(const py::slice& in_slice) { | ||||
|   return ( | ||||
|       py::getattr(in_slice, "start").is_none() && | ||||
|       py::getattr(in_slice, "stop").is_none() && | ||||
|       py::getattr(in_slice, "step").is_none()); | ||||
| } | ||||
|  | ||||
| int get_slice_int(py::object obj, int default_val) { | ||||
|   if (!obj.is_none()) { | ||||
|     if (!py::isinstance<py::int_>(obj)) { | ||||
|       throw std::invalid_argument("Slice indices must be integers or None."); | ||||
|     } | ||||
|     return py::cast<int>(py::cast<py::int_>(obj)); | ||||
|   } | ||||
|   return default_val; | ||||
| } | ||||
|  | ||||
| void get_slice_params( | ||||
|     int& starts, | ||||
|     int& ends, | ||||
|     int& strides, | ||||
|     const py::slice& in_slice, | ||||
|     int axis_size) { | ||||
|   // Following numpy's convention | ||||
|   //    Assume n is the number of elements in the dimension being sliced. | ||||
|   //    Then, if i is not given it defaults to 0 for k > 0 and n - 1 for | ||||
|   //    k < 0 . If j is not given it defaults to n for k > 0 and -n-1 for | ||||
|   //    k < 0 . If k is not given it defaults to 1 | ||||
|  | ||||
|   strides = get_slice_int(py::getattr(in_slice, "step"), 1); | ||||
|   starts = get_slice_int( | ||||
|       py::getattr(in_slice, "start"), strides < 0 ? axis_size - 1 : 0); | ||||
|   ends = get_slice_int( | ||||
|       py::getattr(in_slice, "stop"), strides < 0 ? -axis_size - 1 : axis_size); | ||||
|  | ||||
|   // starts = (starts < 0) ? starts + axis_size : starts; | ||||
|   // ends = (ends < 0) ? ends + axis_size : ends; | ||||
| } | ||||
|  | ||||
| array get_int_index(py::object idx, int axis_size) { | ||||
|   int idx_ = py::cast<int>(idx); | ||||
|   idx_ = (idx_ < 0) ? idx_ + axis_size : idx_; | ||||
|  | ||||
|   return array(idx_, uint32); | ||||
| } | ||||
|  | ||||
| bool is_valid_index_type(const py::object& obj) { | ||||
|   return py::isinstance<py::slice>(obj) || py::isinstance<py::int_>(obj) || | ||||
|       py::isinstance<array>(obj) || obj.is_none() || py::ellipsis().is(obj); | ||||
| } | ||||
|  | ||||
| array mlx_get_item_slice(const array& src, const py::slice& in_slice) { | ||||
|   // Check input and raise error if 0 dim for parity with np | ||||
|   if (src.ndim() == 0) { | ||||
|     throw std::invalid_argument( | ||||
|         "too many indices for array: array is 0-dimensional"); | ||||
|   } | ||||
|  | ||||
|   // Return a copy of the array if none slice is request | ||||
|   if (is_none_slice(in_slice)) { | ||||
|     return src; | ||||
|   } | ||||
|  | ||||
|   std::vector<int> starts(src.ndim(), 0); | ||||
|   std::vector<int> ends = src.shape(); | ||||
|   std::vector<int> strides(src.ndim(), 1); | ||||
|  | ||||
|   // Check and update slice params | ||||
|   get_slice_params(starts[0], ends[0], strides[0], in_slice, ends[0]); | ||||
|   return slice(src, starts, ends, strides); | ||||
| } | ||||
|  | ||||
| array mlx_get_item_array(const array& src, const array& indices) { | ||||
|   // Check input and raise error if 0 dim for parity with np | ||||
|   if (src.ndim() == 0) { | ||||
|     throw std::invalid_argument( | ||||
|         "too many indices for array: array is 0-dimensional"); | ||||
|   } | ||||
|  | ||||
|   if (indices.dtype() == bool_) { | ||||
|     throw std::invalid_argument("boolean indices are not yet supported"); | ||||
|   } | ||||
|  | ||||
|   // If only one input array is mentioned, we set axis=0 in take | ||||
|   // for parity with np | ||||
|   return take(src, indices, 0); | ||||
| } | ||||
|  | ||||
| array mlx_get_item_int(const array& src, const py::int_& idx) { | ||||
|   // Check input and raise error if 0 dim for parity with np | ||||
|   if (src.ndim() == 0) { | ||||
|     throw std::invalid_argument( | ||||
|         "too many indices for array: array is 0-dimensional"); | ||||
|   } | ||||
|  | ||||
|   // If only one input idx is mentioned, we set axis=0 in take | ||||
|   // for parity with np | ||||
|   return take(src, get_int_index(idx, src.shape(0)), 0); | ||||
| } | ||||
|  | ||||
| array mlx_gather_nd( | ||||
|     array src, | ||||
|     const std::vector<py::object>& indices, | ||||
|     bool gather_first, | ||||
|     int& max_dims) { | ||||
|   max_dims = 0; | ||||
|   std::vector<array> gather_indices; | ||||
|   std::vector<bool> is_slice(indices.size(), false); | ||||
|   int num_slices = 0; | ||||
|   // gather all the arrays | ||||
|   for (int i = 0; i < indices.size(); i++) { | ||||
|     auto& idx = indices[i]; | ||||
|  | ||||
|     if (py::isinstance<py::slice>(idx)) { | ||||
|       int start, end, stride; | ||||
|       get_slice_params(start, end, stride, idx, src.shape(i)); | ||||
|       gather_indices.push_back(arange(start, end, stride, uint32)); | ||||
|       num_slices++; | ||||
|       is_slice[i] = true; | ||||
|     } else if (py::isinstance<py::int_>(idx)) { | ||||
|       gather_indices.push_back(get_int_index(idx, src.shape(i))); | ||||
|     } else if (py::isinstance<array>(idx)) { | ||||
|       auto arr = py::cast<array>(idx); | ||||
|       max_dims = std::max(static_cast<int>(arr.ndim()), max_dims); | ||||
|       gather_indices.push_back(arr); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   // reshape them so that the int/array indices are first | ||||
|   if (gather_first) { | ||||
|     int slice_index = 0; | ||||
|     for (int i = 0; i < gather_indices.size(); i++) { | ||||
|       if (is_slice[i]) { | ||||
|         std::vector<int> index_shape(max_dims + num_slices, 1); | ||||
|         index_shape[max_dims + slice_index] = gather_indices[i].shape(0); | ||||
|         gather_indices[i] = reshape(gather_indices[i], index_shape); | ||||
|         slice_index++; | ||||
|       } else { | ||||
|         std::vector<int> index_shape = gather_indices[i].shape(); | ||||
|         index_shape.insert(index_shape.end(), num_slices, 1); | ||||
|         gather_indices[i] = reshape(gather_indices[i], index_shape); | ||||
|       } | ||||
|     } | ||||
|   } else { | ||||
|     // reshape them so that the int/array indices are last | ||||
|     for (int i = 0; i < gather_indices.size(); i++) { | ||||
|       if (i < num_slices) { | ||||
|         std::vector<int> index_shape(max_dims + num_slices, 1); | ||||
|         index_shape[i] = gather_indices[i].shape(0); | ||||
|         gather_indices[i] = reshape(gather_indices[i], index_shape); | ||||
|       } | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   // Do the gather | ||||
|   std::vector<int> axes(indices.size()); | ||||
|   std::iota(axes.begin(), axes.end(), 0); | ||||
|   std::vector<int> slice_sizes = src.shape(); | ||||
|   std::fill(slice_sizes.begin(), slice_sizes.begin() + indices.size(), 1); | ||||
|   src = gather(src, gather_indices, axes, slice_sizes); | ||||
|  | ||||
|   // Squeeze the dims | ||||
|   std::vector<int> out_shape; | ||||
|   out_shape.insert( | ||||
|       out_shape.end(), | ||||
|       src.shape().begin(), | ||||
|       src.shape().begin() + max_dims + num_slices); | ||||
|   out_shape.insert( | ||||
|       out_shape.end(), | ||||
|       src.shape().begin() + max_dims + num_slices + indices.size(), | ||||
|       src.shape().end()); | ||||
|   src = reshape(src, out_shape); | ||||
|  | ||||
|   return src; | ||||
| } | ||||
|  | ||||
| array mlx_get_item_nd(array src, const py::tuple& entries) { | ||||
|   // No indices make this a noop | ||||
|   if (entries.size() == 0) { | ||||
|     return src; | ||||
|   } | ||||
|  | ||||
|   // The plan is as follows: | ||||
|   // 1. Replace the ellipsis with a series of slice(None) | ||||
|   // 2. Loop over the indices and calculate the gather indices | ||||
|   // 3. Calculate the remaining slices and reshapes | ||||
|  | ||||
|   // Ellipsis handling | ||||
|   std::vector<py::object> indices; | ||||
|   { | ||||
|     int non_none_indices_before = 0; | ||||
|     int non_none_indices_after = 0; | ||||
|     std::vector<py::object> r_indices; | ||||
|     int i = 0; | ||||
|     for (; i < entries.size(); i++) { | ||||
|       auto idx = entries[i]; | ||||
|       if (!is_valid_index_type(idx)) { | ||||
|         throw std::invalid_argument( | ||||
|             "Cannot index mlx array using the given type yet"); | ||||
|       } | ||||
|       if (!py::ellipsis().is(idx)) { | ||||
|         indices.push_back(idx); | ||||
|         non_none_indices_before += !idx.is_none(); | ||||
|       } else { | ||||
|         break; | ||||
|       } | ||||
|     } | ||||
|     for (int j = entries.size() - 1; j > i; j--) { | ||||
|       auto idx = entries[j]; | ||||
|       if (!is_valid_index_type(idx)) { | ||||
|         throw std::invalid_argument( | ||||
|             "Cannot index mlx array using the given type yet"); | ||||
|       } | ||||
|       if (py::ellipsis().is(idx)) { | ||||
|         throw std::invalid_argument( | ||||
|             "An index can only have a single ellipsis (...)"); | ||||
|       } | ||||
|       r_indices.push_back(idx); | ||||
|       non_none_indices_after += !idx.is_none(); | ||||
|     } | ||||
|     for (int axis = non_none_indices_before; | ||||
|          axis < src.ndim() - non_none_indices_after; | ||||
|          axis++) { | ||||
|       indices.push_back(py::slice(0, src.shape(axis), 1)); | ||||
|     } | ||||
|     indices.insert(indices.end(), r_indices.rbegin(), r_indices.rend()); | ||||
|   } | ||||
|  | ||||
|   // Check for the number of indices passed | ||||
|   { | ||||
|     int cnt = src.ndim(); | ||||
|     for (auto& idx : indices) { | ||||
|       if (!idx.is_none()) { | ||||
|         cnt--; | ||||
|       } | ||||
|     } | ||||
|     if (cnt < 0) { | ||||
|       std::ostringstream msg; | ||||
|       msg << "Too many indices for array with " << src.ndim() << "dimensions."; | ||||
|       throw std::invalid_argument(msg.str()); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   // Gather handling | ||||
|   // | ||||
|   // Check whether we have arrays or integer indices and delegate to gather_nd | ||||
|   // after removing the slices at the end and all Nones. | ||||
|   std::vector<py::object> remaining_indices; | ||||
|   bool have_array = false; | ||||
|   { | ||||
|     // First check whether the results of gather are going to be 1st or | ||||
|     // normally in between. | ||||
|     bool have_non_array = false; | ||||
|     bool gather_first = false; | ||||
|     for (auto& idx : indices) { | ||||
|       if (py::isinstance<array>(idx) || py::isinstance<py::int_>(idx)) { | ||||
|         if (have_array && have_non_array) { | ||||
|           gather_first = true; | ||||
|           break; | ||||
|         } | ||||
|         have_array = true; | ||||
|       } else { | ||||
|         have_non_array |= have_array; | ||||
|       } | ||||
|     } | ||||
|  | ||||
|     if (have_array) { | ||||
|       int last_array; | ||||
|       // Then find the last array | ||||
|       for (last_array = indices.size() - 1; last_array >= 0; last_array--) { | ||||
|         auto& idx = indices[last_array]; | ||||
|         if (py::isinstance<array>(idx) || py::isinstance<py::int_>(idx)) { | ||||
|           break; | ||||
|         } | ||||
|       } | ||||
|  | ||||
|       std::vector<py::object> gather_indices; | ||||
|       for (int i = 0; i <= last_array; i++) { | ||||
|         auto& idx = indices[i]; | ||||
|         if (!idx.is_none()) { | ||||
|           gather_indices.push_back(idx); | ||||
|         } | ||||
|       } | ||||
|       int max_dims; | ||||
|       src = mlx_gather_nd(src, gather_indices, gather_first, max_dims); | ||||
|  | ||||
|       // Reassemble the indices for the slicing or reshaping if there are any | ||||
|       if (gather_first) { | ||||
|         for (int i = 0; i < max_dims; i++) { | ||||
|           remaining_indices.push_back( | ||||
|               py::slice(py::none(), py::none(), py::none())); | ||||
|         } | ||||
|         for (int i = 0; i < last_array; i++) { | ||||
|           auto& idx = indices[i]; | ||||
|           if (idx.is_none()) { | ||||
|             remaining_indices.push_back(indices[i]); | ||||
|           } else if (py::isinstance<py::slice>(idx)) { | ||||
|             remaining_indices.push_back( | ||||
|                 py::slice(py::none(), py::none(), py::none())); | ||||
|           } | ||||
|         } | ||||
|         for (int i = last_array + 1; i < indices.size(); i++) { | ||||
|           remaining_indices.push_back(indices[i]); | ||||
|         } | ||||
|       } else { | ||||
|         for (int i = 0; i < indices.size(); i++) { | ||||
|           auto& idx = indices[i]; | ||||
|           if (py::isinstance<array>(idx) || py::isinstance<py::int_>(idx)) { | ||||
|             break; | ||||
|           } else if (idx.is_none()) { | ||||
|             remaining_indices.push_back(idx); | ||||
|           } else { | ||||
|             remaining_indices.push_back( | ||||
|                 py::slice(py::none(), py::none(), py::none())); | ||||
|           } | ||||
|         } | ||||
|         for (int i = 0; i < max_dims; i++) { | ||||
|           remaining_indices.push_back( | ||||
|               py::slice(py::none(), py::none(), py::none())); | ||||
|         } | ||||
|         for (int i = last_array + 1; i < indices.size(); i++) { | ||||
|           remaining_indices.push_back(indices[i]); | ||||
|         } | ||||
|       } | ||||
|     } | ||||
|   } | ||||
|   if (have_array && remaining_indices.empty()) { | ||||
|     return src; | ||||
|   } | ||||
|   if (remaining_indices.empty()) { | ||||
|     remaining_indices = indices; | ||||
|   } | ||||
|  | ||||
|   // Slice handling | ||||
|   { | ||||
|     std::vector<int> starts(src.ndim(), 0); | ||||
|     std::vector<int> ends = src.shape(); | ||||
|     std::vector<int> strides(src.ndim(), 1); | ||||
|     int axis = 0; | ||||
|     for (auto& idx : remaining_indices) { | ||||
|       if (!idx.is_none()) { | ||||
|         get_slice_params( | ||||
|             starts[axis], ends[axis], strides[axis], idx, ends[axis]); | ||||
|         axis++; | ||||
|       } | ||||
|     } | ||||
|     src = slice(src, starts, ends, strides); | ||||
|   } | ||||
|  | ||||
|   // Unsqueeze handling | ||||
|   if (remaining_indices.size() > src.ndim()) { | ||||
|     std::vector<int> out_shape; | ||||
|     int axis = 0; | ||||
|     for (auto& idx : remaining_indices) { | ||||
|       if (idx.is_none()) { | ||||
|         out_shape.push_back(1); | ||||
|       } else { | ||||
|         out_shape.push_back(src.shape(axis++)); | ||||
|       } | ||||
|     } | ||||
|     src = reshape(src, out_shape); | ||||
|   } | ||||
|  | ||||
|   return src; | ||||
| } | ||||
|  | ||||
| array mlx_get_item(const array& src, const py::object& obj) { | ||||
|   if (py::isinstance<py::slice>(obj)) { | ||||
|     return mlx_get_item_slice(src, obj); | ||||
|   } else if (py::isinstance<array>(obj)) { | ||||
|     return mlx_get_item_array(src, py::cast<array>(obj)); | ||||
|   } else if (py::isinstance<py::int_>(obj)) { | ||||
|     return mlx_get_item_int(src, obj); | ||||
|   } else if (py::isinstance<py::tuple>(obj)) { | ||||
|     return mlx_get_item_nd(src, obj); | ||||
|   } else if (obj.is_none()) { | ||||
|     std::vector<int> s(1, 1); | ||||
|     s.insert(s.end(), src.shape().begin(), src.shape().end()); | ||||
|     return reshape(src, s); | ||||
|   } | ||||
|   throw std::invalid_argument("Cannot index mlx array using the given type."); | ||||
| } | ||||
|  | ||||
| array mlx_set_item_int( | ||||
|     const array& src, | ||||
|     const py::int_& idx, | ||||
|     const array& update) { | ||||
|   if (src.ndim() == 0) { | ||||
|     throw std::invalid_argument( | ||||
|         "too many indices for array: array is 0-dimensional"); | ||||
|   } | ||||
|  | ||||
|   // Remove any leading singleton dimensions from the update | ||||
|   // and then broadcast update to shape of src[0, ...] | ||||
|   int s = 0; | ||||
|   for (; s < update.ndim() && update.shape(s) == 1; s++) | ||||
|     ; | ||||
|   auto up_shape = | ||||
|       std::vector<int>(update.shape().begin() + s, update.shape().end()); | ||||
|   auto shape = src.shape(); | ||||
|   shape[0] = 1; | ||||
|   return scatter( | ||||
|       src, | ||||
|       get_int_index(idx, src.shape(0)), | ||||
|       broadcast_to(reshape(update, up_shape), shape), | ||||
|       0); | ||||
| } | ||||
|  | ||||
| array mlx_set_item_array( | ||||
|     const array& src, | ||||
|     const array& indices, | ||||
|     const array& update) { | ||||
|   if (src.ndim() == 0) { | ||||
|     throw std::invalid_argument( | ||||
|         "too many indices for array: array is 0-dimensional"); | ||||
|   } | ||||
|  | ||||
|   // Remove any leading singleton dimensions from the update | ||||
|   int s = 0; | ||||
|   for (; s < update.ndim() && update.shape(s) == 1; s++) | ||||
|     ; | ||||
|   auto up_shape = | ||||
|       std::vector<int>(update.shape().begin() + s, update.shape().end()); | ||||
|   auto up = reshape(update, up_shape); | ||||
|  | ||||
|   // The update shape must broadcast with indices.shape + [1] + src.shape[1:] | ||||
|   up_shape = indices.shape(); | ||||
|   up_shape.insert(up_shape.end(), src.shape().begin() + 1, src.shape().end()); | ||||
|   up = broadcast_to(up, up_shape); | ||||
|   up_shape.insert(up_shape.begin() + indices.ndim(), 1); | ||||
|   up = reshape(up, up_shape); | ||||
|  | ||||
|   return scatter(src, indices, up, 0); | ||||
| } | ||||
|  | ||||
| array mlx_set_item_slice( | ||||
|     const array& src, | ||||
|     const py::slice& in_slice, | ||||
|     const array& update) { | ||||
|   // Check input and raise error if 0 dim for parity with np | ||||
|   if (src.ndim() == 0) { | ||||
|     throw std::invalid_argument( | ||||
|         "too many indices for array: array is 0-dimensional"); | ||||
|   } | ||||
|  | ||||
|   // If none slice is requested broadcast the update | ||||
|   // to the src size and return it. | ||||
|   if (is_none_slice(in_slice)) { | ||||
|     int s = 0; | ||||
|     for (; s < update.ndim() && update.shape(s) == 1; s++) | ||||
|       ; | ||||
|     auto up_shape = | ||||
|         std::vector<int>(update.shape().begin() + s, update.shape().end()); | ||||
|     return broadcast_to(reshape(update, up_shape), src.shape()); | ||||
|   } | ||||
|  | ||||
|   int start = 0; | ||||
|   int end = src.shape(0); | ||||
|   int stride = 1; | ||||
|  | ||||
|   // Check and update slice params | ||||
|   get_slice_params(start, end, stride, in_slice, end); | ||||
|  | ||||
|   return mlx_set_item_array(src, arange(start, end, stride, uint32), update); | ||||
| } | ||||
|  | ||||
| array mlx_set_item_nd( | ||||
|     const array& src, | ||||
|     const py::tuple& entries, | ||||
|     const array& update) { | ||||
|   std::vector<py::object> indices; | ||||
|   int non_none_indices = 0; | ||||
|  | ||||
|   // Expand ellipses into a series of ':' slices | ||||
|   { | ||||
|     int non_none_indices_before = 0; | ||||
|     int non_none_indices_after = 0; | ||||
|     bool has_ellipsis = false; | ||||
|     int indices_before = 0; | ||||
|     for (int i = 0; i < entries.size(); ++i) { | ||||
|       auto idx = entries[i]; | ||||
|       if (!is_valid_index_type(idx)) { | ||||
|         throw std::invalid_argument( | ||||
|             "Cannot index mlx array using the given type yet"); | ||||
|       } else if (!py::ellipsis().is(idx)) { | ||||
|         if (!has_ellipsis) { | ||||
|           indices_before++; | ||||
|           non_none_indices_before += !idx.is_none(); | ||||
|         } else { | ||||
|           non_none_indices_after += !idx.is_none(); | ||||
|         } | ||||
|         indices.push_back(idx); | ||||
|       } else if (has_ellipsis) { | ||||
|         throw std::invalid_argument( | ||||
|             "An index can only have a single ellipsis (...)"); | ||||
|       } else { | ||||
|         has_ellipsis = true; | ||||
|       } | ||||
|     } | ||||
|     if (has_ellipsis) { | ||||
|       for (int axis = non_none_indices_before; | ||||
|            axis < src.ndim() - non_none_indices_after; | ||||
|            axis++) { | ||||
|         indices.insert( | ||||
|             indices.begin() + indices_before, py::slice(0, src.shape(axis), 1)); | ||||
|       } | ||||
|       non_none_indices = src.ndim(); | ||||
|     } else { | ||||
|       non_none_indices = non_none_indices_before + non_none_indices_after; | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   if (non_none_indices > src.ndim()) { | ||||
|     std::ostringstream msg; | ||||
|     msg << "Too many indices for array with " << src.ndim() << "dimensions."; | ||||
|     throw std::invalid_argument(msg.str()); | ||||
|   } | ||||
|  | ||||
|   // Remove leading singletons dimensions from the update | ||||
|   int s = 0; | ||||
|   for (; s < update.ndim() && update.shape(s) == 1; s++) { | ||||
|   }; | ||||
|   auto up_shape = | ||||
|       std::vector<int>(update.shape().begin() + s, update.shape().end()); | ||||
|   auto up = reshape(update, up_shape); | ||||
|  | ||||
|   // If no non-None indices return the broadcasted update | ||||
|   if (non_none_indices == 0) { | ||||
|     return broadcast_to(up, src.shape()); | ||||
|   } | ||||
|  | ||||
|   unsigned long max_dim = 0; | ||||
|   bool arrays_first = false; | ||||
|   int num_slices = 0; | ||||
|   int num_arrays = 0; | ||||
|   { | ||||
|     bool have_array = false; | ||||
|     bool have_non_array = false; | ||||
|     for (auto& idx : indices) { | ||||
|       if (py::isinstance<py::slice>(idx) || idx.is_none()) { | ||||
|         have_non_array = have_array; | ||||
|         num_slices++; | ||||
|       } else if (py::isinstance<array>(idx)) { | ||||
|         have_array = true; | ||||
|         if (have_array && have_non_array) { | ||||
|           arrays_first = true; | ||||
|         } | ||||
|         max_dim = std::max(py::cast<array>(idx).ndim(), max_dim); | ||||
|         num_arrays++; | ||||
|       } | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   std::vector<array> arr_indices; | ||||
|   int slice_num = 0; | ||||
|   int array_num = 0; | ||||
|   int ax = 0; | ||||
|   for (int i = 0; i < indices.size(); ++i) { | ||||
|     auto& pyidx = indices[i]; | ||||
|     if (py::isinstance<py::slice>(pyidx)) { | ||||
|       int start, end, stride; | ||||
|       get_slice_params(start, end, stride, pyidx, src.shape(ax++)); | ||||
|       auto idx = arange(start, end, stride, uint32); | ||||
|       std::vector<int> idx_shape(max_dim + num_slices, 1); | ||||
|       auto loc = slice_num + (arrays_first ? max_dim : 0); | ||||
|       slice_num++; | ||||
|       idx_shape[loc] = idx.size(); | ||||
|       arr_indices.push_back(reshape(idx, idx_shape)); | ||||
|     } else if (py::isinstance<py::int_>(pyidx)) { | ||||
|       arr_indices.push_back(get_int_index(pyidx, src.shape(ax++))); | ||||
|     } else if (pyidx.is_none()) { | ||||
|       slice_num++; | ||||
|     } else if (py::isinstance<array>(pyidx)) { | ||||
|       ax++; | ||||
|       auto idx = py::cast<array>(pyidx); | ||||
|       std::vector<int> idx_shape; | ||||
|       if (!arrays_first) { | ||||
|         idx_shape.insert(idx_shape.end(), slice_num, 1); | ||||
|       } | ||||
|       idx_shape.insert(idx_shape.end(), max_dim - idx.ndim(), 1); | ||||
|       idx_shape.insert(idx_shape.end(), idx.shape().begin(), idx.shape().end()); | ||||
|       idx_shape.insert( | ||||
|           idx_shape.end(), num_slices - (arrays_first ? 0 : slice_num), 1); | ||||
|       arr_indices.push_back(reshape(idx, idx_shape)); | ||||
|       if (!arrays_first && ++array_num == num_arrays) { | ||||
|         slice_num += max_dim; | ||||
|       } | ||||
|     } else { | ||||
|       throw std::invalid_argument( | ||||
|           "Cannot index mlx array using the given type yet"); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   arr_indices = broadcast_arrays(arr_indices); | ||||
|   up_shape = arr_indices[0].shape(); | ||||
|   up_shape.insert( | ||||
|       up_shape.end(), | ||||
|       src.shape().begin() + non_none_indices, | ||||
|       src.shape().end()); | ||||
|   up = broadcast_to(up, up_shape); | ||||
|   up_shape.insert( | ||||
|       up_shape.begin() + arr_indices[0].ndim(), non_none_indices, 1); | ||||
|   up = reshape(up, up_shape); | ||||
|  | ||||
|   std::vector<int> axes(arr_indices.size(), 0); | ||||
|   std::iota(axes.begin(), axes.end(), 0); | ||||
|   return scatter(src, arr_indices, up, axes); | ||||
| } | ||||
|  | ||||
| void mlx_set_item(array& src, const py::object& obj, const ScalarOrArray& v) { | ||||
|   auto vals = to_array(v, src.dtype()); | ||||
|   auto impl = [&src, &obj, &vals]() { | ||||
|     if (py::isinstance<py::slice>(obj)) { | ||||
|       return mlx_set_item_slice(src, obj, vals); | ||||
|     } else if (py::isinstance<array>(obj)) { | ||||
|       return mlx_set_item_array(src, py::cast<array>(obj), vals); | ||||
|     } else if (py::isinstance<py::int_>(obj)) { | ||||
|       return mlx_set_item_int(src, obj, vals); | ||||
|     } else if (py::isinstance<py::tuple>(obj)) { | ||||
|       return mlx_set_item_nd(src, obj, vals); | ||||
|     } else if (obj.is_none()) { | ||||
|       return broadcast_to(vals, src.shape()); | ||||
|     } | ||||
|     throw std::invalid_argument("Cannot index mlx array using the given type."); | ||||
|   }; | ||||
|   auto out = impl(); | ||||
|   src.overwrite_descriptor(out); | ||||
| } | ||||
							
								
								
									
										12
									
								
								python/src/indexing.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										12
									
								
								python/src/indexing.h
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,12 @@ | ||||
| #pragma once | ||||
|  | ||||
| #include <pybind11/pybind11.h> | ||||
|  | ||||
| #include "mlx/array.h" | ||||
| #include "python/src/utils.h" | ||||
|  | ||||
| namespace py = pybind11; | ||||
| using namespace mlx::core; | ||||
|  | ||||
| array mlx_get_item(const array& src, const py::object& obj); | ||||
| void mlx_set_item(array& src, const py::object& obj, const ScalarOrArray& v); | ||||
							
								
								
									
										290
									
								
								python/src/load.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										290
									
								
								python/src/load.cpp
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,290 @@ | ||||
| #include <pybind11/pybind11.h> | ||||
| #include <pybind11/stl.h> | ||||
|  | ||||
| #include <cstring> | ||||
| #include <fstream> | ||||
| #include <stdexcept> | ||||
| #include <string> | ||||
| #include <string_view> | ||||
| #include <unordered_map> | ||||
| #include <vector> | ||||
|  | ||||
| #include <iostream> | ||||
|  | ||||
| #include "mlx/load.h" | ||||
| #include "mlx/ops.h" | ||||
| #include "mlx/utils.h" | ||||
| #include "python/src/load.h" | ||||
| #include "python/src/utils.h" | ||||
|  | ||||
| namespace py = pybind11; | ||||
| using namespace py::literals; | ||||
| using namespace mlx::core; | ||||
|  | ||||
| /////////////////////////////////////////////////////////////////////////////// | ||||
| // Helpers | ||||
| /////////////////////////////////////////////////////////////////////////////// | ||||
|  | ||||
| bool is_istream_object(const py::object& file) { | ||||
|   return py::hasattr(file, "read") && py::hasattr(file, "seek") && | ||||
|       py::hasattr(file, "tell") && py::hasattr(file, "closed"); | ||||
| } | ||||
|  | ||||
| bool is_ostream_object(const py::object& file) { | ||||
|   return py::hasattr(file, "write") && py::hasattr(file, "seek") && | ||||
|       py::hasattr(file, "tell") && py::hasattr(file, "closed"); | ||||
| } | ||||
|  | ||||
| bool is_zip_file(const py::module_& zipfile, const py::object& file) { | ||||
|   if (is_istream_object(file)) { | ||||
|     auto st_pos = file.attr("tell")(); | ||||
|     bool r = (zipfile.attr("is_zipfile")(file)).cast<bool>(); | ||||
|     file.attr("seek")(st_pos, 0); | ||||
|     return r; | ||||
|   } | ||||
|   return zipfile.attr("is_zipfile")(file).cast<bool>(); | ||||
| } | ||||
|  | ||||
| class ZipFileWrapper { | ||||
|  public: | ||||
|   ZipFileWrapper( | ||||
|       const py::module_& zipfile, | ||||
|       const py::object& file, | ||||
|       char mode = 'r', | ||||
|       int compression = 0) | ||||
|       : zipfile_module_(zipfile), | ||||
|         zipfile_object_(zipfile.attr("ZipFile")( | ||||
|             file, | ||||
|             "mode"_a = mode, | ||||
|             "compression"_a = compression, | ||||
|             "allowZip64"_a = true)), | ||||
|         files_list_(zipfile_object_.attr("namelist")()), | ||||
|         open_func_(zipfile_object_.attr("open")), | ||||
|         read_func_(zipfile_object_.attr("read")), | ||||
|         close_func_(zipfile_object_.attr("close")) {} | ||||
|  | ||||
|   std::vector<std::string> namelist() const { | ||||
|     return files_list_.cast<std::vector<std::string>>(); | ||||
|   } | ||||
|  | ||||
|   py::object open(const std::string& key, char mode = 'r') { | ||||
|     // Following numpy : | ||||
|     // https://github.com/numpy/numpy/blob/db4f43983cb938f12c311e1f5b7165e270c393b4/numpy/lib/npyio.py#L742C36-L742C47 | ||||
|     if (mode == 'w') { | ||||
|       return open_func_(key, "mode"_a = mode, "force_zip64"_a = true); | ||||
|     } | ||||
|     return open_func_(key, "mode"_a = mode); | ||||
|   } | ||||
|  | ||||
|  private: | ||||
|   py::module_ zipfile_module_; | ||||
|   py::object zipfile_object_; | ||||
|   py::list files_list_; | ||||
|   py::object open_func_; | ||||
|   py::object read_func_; | ||||
|   py::object close_func_; | ||||
| }; | ||||
|  | ||||
| /////////////////////////////////////////////////////////////////////////////// | ||||
| // Loading | ||||
| /////////////////////////////////////////////////////////////////////////////// | ||||
|  | ||||
| class PyFileReader : public io::Reader { | ||||
|  public: | ||||
|   PyFileReader(py::object file) | ||||
|       : pyistream_(file), | ||||
|         readinto_func_(file.attr("readinto")), | ||||
|         seek_func_(file.attr("seek")), | ||||
|         tell_func_(file.attr("tell")) {} | ||||
|  | ||||
|   bool is_open() const override { | ||||
|     return !pyistream_.attr("closed").cast<bool>(); | ||||
|   } | ||||
|  | ||||
|   bool good() const override { | ||||
|     return !pyistream_.is_none(); | ||||
|   } | ||||
|  | ||||
|   size_t tell() const override { | ||||
|     return tell_func_().cast<size_t>(); | ||||
|   } | ||||
|  | ||||
|   void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg) | ||||
|       override { | ||||
|     seek_func_(off, (int)way); | ||||
|   } | ||||
|  | ||||
|   void read(char* data, size_t n) override { | ||||
|     py::object bytes_read = | ||||
|         readinto_func_(py::memoryview::from_buffer(data, {n}, {sizeof(char)})); | ||||
|     if (bytes_read.is_none() || py::cast<size_t>(bytes_read) < n) { | ||||
|       throw std::runtime_error("[load] Failed to read from python stream"); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   std::string label() const override { | ||||
|     return "python file object"; | ||||
|   } | ||||
|  | ||||
|  private: | ||||
|   py::object pyistream_; | ||||
|   py::object readinto_func_; | ||||
|   py::object seek_func_; | ||||
|   py::object tell_func_; | ||||
| }; | ||||
|  | ||||
| DictOrArray mlx_load_helper(py::object file, StreamOrDevice s) { | ||||
|   py::module_ zipfile = py::module_::import("zipfile"); | ||||
|  | ||||
|   // Assume .npz file if it is zipped | ||||
|   if (is_zip_file(zipfile, file)) { | ||||
|     // Output dictionary filename in zip -> loaded array | ||||
|     std::unordered_map<std::string, array> array_dict; | ||||
|  | ||||
|     // Create python ZipFile object | ||||
|     ZipFileWrapper zipfile_object(zipfile, file); | ||||
|     for (const std::string& st : zipfile_object.namelist()) { | ||||
|       // Open zip file as a python file stream | ||||
|       py::object sub_file = zipfile_object.open(st); | ||||
|  | ||||
|       // Create array from python fille stream | ||||
|       auto arr = load(std::make_shared<PyFileReader>(sub_file), s); | ||||
|  | ||||
|       // Remove .npy from file if it is there | ||||
|       auto key = st; | ||||
|       if (st.length() > 4 && st.substr(st.length() - 4, 4) == ".npy") | ||||
|         key = st.substr(0, st.length() - 4); | ||||
|  | ||||
|       // Add array to dict | ||||
|       array_dict.insert({key, arr}); | ||||
|     } | ||||
|  | ||||
|     // If we don't own the stream and it was passed to us, eval immediately | ||||
|     for (auto& [key, arr] : array_dict) { | ||||
|       arr.eval(); | ||||
|     } | ||||
|  | ||||
|     return {array_dict}; | ||||
|   } else if (py::isinstance<py::str>(file)) { // Assume .npy file path string | ||||
|     return {load(py::cast<std::string>(file), s)}; | ||||
|   } else if (is_istream_object(file)) { | ||||
|     // If we don't own the stream and it was passed to us, eval immediately | ||||
|     auto arr = load(std::make_shared<PyFileReader>(file), s); | ||||
|     arr.eval(); | ||||
|     return {arr}; | ||||
|   } | ||||
|  | ||||
|   throw std::invalid_argument( | ||||
|       "[load] Input must be a file-like object, string, or pathlib.Path"); | ||||
| } | ||||
|  | ||||
| /////////////////////////////////////////////////////////////////////////////// | ||||
| // Saving | ||||
| /////////////////////////////////////////////////////////////////////////////// | ||||
|  | ||||
| class PyFileWriter : public io::Writer { | ||||
|  public: | ||||
|   PyFileWriter(py::object file) | ||||
|       : pyostream_(file), | ||||
|         write_func_(file.attr("write")), | ||||
|         seek_func_(file.attr("seek")), | ||||
|         tell_func_(file.attr("tell")) {} | ||||
|  | ||||
|   bool is_open() const override { | ||||
|     return !pyostream_.attr("closed").cast<bool>(); | ||||
|   } | ||||
|  | ||||
|   bool good() const override { | ||||
|     return !pyostream_.is_none(); | ||||
|   } | ||||
|  | ||||
|   size_t tell() const override { | ||||
|     return tell_func_().cast<size_t>(); | ||||
|   } | ||||
|  | ||||
|   void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg) | ||||
|       override { | ||||
|     seek_func_(off, (int)way); | ||||
|   } | ||||
|  | ||||
|   void write(const char* data, size_t n) override { | ||||
|     py::object bytes_written = | ||||
|         write_func_(py::memoryview::from_buffer(data, {n}, {sizeof(char)})); | ||||
|     if (bytes_written.is_none() || py::cast<size_t>(bytes_written) < n) { | ||||
|       throw std::runtime_error("[load] Failed to write to python stream"); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   std::string label() const override { | ||||
|     return "python file object"; | ||||
|   } | ||||
|  | ||||
|  private: | ||||
|   py::object pyostream_; | ||||
|   py::object write_func_; | ||||
|   py::object seek_func_; | ||||
|   py::object tell_func_; | ||||
| }; | ||||
|  | ||||
| void mlx_save_helper(py::object file, array a, bool retain_graph) { | ||||
|   if (py::isinstance<py::str>(file)) { | ||||
|     save(py::cast<std::string>(file), a, retain_graph); | ||||
|     return; | ||||
|   } else if (is_ostream_object(file)) { | ||||
|     save(std::make_shared<PyFileWriter>(file), a, retain_graph); | ||||
|     return; | ||||
|   } | ||||
|  | ||||
|   throw std::invalid_argument( | ||||
|       "[save] Input must be a file-like object, string, or pathlib.Path"); | ||||
| } | ||||
|  | ||||
| void mlx_savez_helper( | ||||
|     py::object file_, | ||||
|     py::args args, | ||||
|     const py::kwargs& kwargs, | ||||
|     bool compressed) { | ||||
|   // Add .npz to the end of the filename if not already there | ||||
|   py::object file = file_; | ||||
|  | ||||
|   if (py::isinstance<py::str>(file_)) { | ||||
|     std::string fname = file_.cast<std::string>(); | ||||
|  | ||||
|     // Add .npz to file name if it is not there | ||||
|     if (fname.length() < 4 || fname.substr(fname.length() - 4, 4) != ".npz") | ||||
|       fname += ".npz"; | ||||
|  | ||||
|     file = py::str(fname); | ||||
|   } | ||||
|  | ||||
|   // Collect args and kwargs | ||||
|   auto arrays_dict = kwargs.cast<std::unordered_map<std::string, array>>(); | ||||
|   auto arrays_list = args.cast<std::vector<array>>(); | ||||
|  | ||||
|   for (int i = 0; i < arrays_list.size(); i++) { | ||||
|     std::string arr_name = "arr_" + std::to_string(i); | ||||
|  | ||||
|     if (arrays_dict.count(arr_name) > 0) { | ||||
|       throw std::invalid_argument( | ||||
|           "[savez] Cannot use un-named variables and keyword " + arr_name); | ||||
|     } | ||||
|  | ||||
|     arrays_dict.insert({arr_name, arrays_list[i]}); | ||||
|   } | ||||
|  | ||||
|   // Create python ZipFile object depending on compression | ||||
|   py::module_ zipfile = py::module_::import("zipfile"); | ||||
|   int compression = compressed ? zipfile.attr("ZIP_DEFLATED").cast<int>() | ||||
|                                : zipfile.attr("ZIP_STORED").cast<int>(); | ||||
|   char mode = 'w'; | ||||
|   ZipFileWrapper zipfile_object(zipfile, file, mode, compression); | ||||
|  | ||||
|   // Save each array | ||||
|   for (auto [k, a] : arrays_dict) { | ||||
|     std::string fname = k + ".npy"; | ||||
|     auto py_ostream = zipfile_object.open(fname, 'w'); | ||||
|     save(std::make_shared<PyFileWriter>(py_ostream), a); | ||||
|   } | ||||
|  | ||||
|   return; | ||||
| } | ||||
							
								
								
									
										2422
									
								
								python/src/ops.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2422
									
								
								python/src/ops.cpp
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										289
									
								
								python/src/random.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										289
									
								
								python/src/random.cpp
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,289 @@ | ||||
| #include <pybind11/pybind11.h> | ||||
| #include <pybind11/stl.h> | ||||
|  | ||||
| #include "python/src/utils.h" | ||||
|  | ||||
| #include "mlx/ops.h" | ||||
| #include "mlx/random.h" | ||||
|  | ||||
| namespace py = pybind11; | ||||
| using namespace py::literals; | ||||
| using namespace mlx::core; | ||||
| using namespace mlx::core::random; | ||||
|  | ||||
| void init_random(py::module_& parent_module) { | ||||
|   auto m = parent_module.def_submodule( | ||||
|       "random", | ||||
|       "mlx.core.random: functionality related to random number generation"); | ||||
|   m.def( | ||||
|       "seed", | ||||
|       &seed, | ||||
|       "seed"_a, | ||||
|       R"pbdoc( | ||||
|         Seed the global PRNG. | ||||
|  | ||||
|         Args: | ||||
|             seed (int): Seed for the global PRNG. | ||||
|       )pbdoc"); | ||||
|   m.def( | ||||
|       "key", | ||||
|       &key, | ||||
|       "seed"_a, | ||||
|       R"pbdoc( | ||||
|         Get a PRNG key from a seed. | ||||
|  | ||||
|         Args: | ||||
|             seed (int): Seed for the PRNG. | ||||
|  | ||||
|         Returns: | ||||
|             array: The PRNG key array. | ||||
|       )pbdoc"); | ||||
|   m.def( | ||||
|       "split", | ||||
|       py::overload_cast<const array&, int, StreamOrDevice>(&random::split), | ||||
|       "key"_a, | ||||
|       "num"_a = 2, | ||||
|       "stream"_a = none, | ||||
|       R"pbdoc( | ||||
|         Split a PRNG key into sub keys. | ||||
|  | ||||
|         Args: | ||||
|             key (array): Input key to split. | ||||
|             num (int, optional): Number of sub keys. Default is 2. | ||||
|  | ||||
|         Returns: | ||||
|             array: The array of sub keys with ``num`` as its first dimension. | ||||
|       )pbdoc"); | ||||
|   m.def( | ||||
|       "uniform", | ||||
|       [](const ScalarOrArray& low, | ||||
|          const ScalarOrArray& high, | ||||
|          const std::vector<int>& shape, | ||||
|          Dtype type, | ||||
|          const std::optional<array>& key, | ||||
|          StreamOrDevice s) { | ||||
|         return uniform(to_array(low), to_array(high), shape, type, key, s); | ||||
|       }, | ||||
|       "low"_a = 0, | ||||
|       "high"_a = 1, | ||||
|       "shape"_a = std::vector<int>{}, | ||||
|       "dtype"_a = float32, | ||||
|       "key"_a = none, | ||||
|       "stream"_a = none, | ||||
|       R"pbdoc( | ||||
|         Generate uniformly distributed random numbers. | ||||
|  | ||||
|         The values are sampled uniformly in the half-open interval ``[low, high)``. | ||||
|         The lower and upper bound can be scalars or arrays and must be | ||||
|         broadcastable to ``shape``. | ||||
|  | ||||
|         Args: | ||||
|             low (scalar or array, optional): Lower bound of the distribution. Default is ``0``. | ||||
|             high (scalar or array, optional): Upper bound of the distribution. Default is ``1``. | ||||
|             shape (list(int), optional): Shape of the output. Default is ``()``. | ||||
|             key (array, optional): A PRNG key. Default: None. | ||||
|             dtype (Dtype, optional): Type of the output. Default is ``float32``. | ||||
|  | ||||
|         Returns: | ||||
|             array: The output array random values. | ||||
|       )pbdoc"); | ||||
|   m.def( | ||||
|       "normal", | ||||
|       [](const std::vector<int>& shape, | ||||
|          Dtype type, | ||||
|          const std::optional<array>& key, | ||||
|          StreamOrDevice s) { return normal(shape, type, key, s); }, | ||||
|  | ||||
|       "shape"_a = std::vector<int>{}, | ||||
|       "dtype"_a = float32, | ||||
|       "key"_a = none, | ||||
|       "stream"_a = none, | ||||
|       R"pbdoc( | ||||
|         Generate normally distributed random numbers. | ||||
|  | ||||
|         Args: | ||||
|             shape (list(int), optional): Shape of the output. Default is ``()``. | ||||
|             dtype (Dtype, optional): Type of the output. Default is ``float32``. | ||||
|             key (array, optional): A PRNG key. Default: None. | ||||
|  | ||||
|         Returns: | ||||
|             array: The output array of random values. | ||||
|       )pbdoc"); | ||||
|   m.def( | ||||
|       "randint", | ||||
|       [](const ScalarOrArray& low, | ||||
|          const ScalarOrArray& high, | ||||
|          const std::vector<int>& shape, | ||||
|          Dtype type, | ||||
|          const std::optional<array>& key, | ||||
|          StreamOrDevice s) { | ||||
|         return randint(to_array(low), to_array(high), shape, type, key, s); | ||||
|       }, | ||||
|       "low"_a, | ||||
|       "high"_a, | ||||
|       "shape"_a = std::vector<int>{}, | ||||
|       "dtype"_a = int32, | ||||
|       "key"_a = none, | ||||
|       "stream"_a = none, | ||||
|       R"pbdoc( | ||||
|         Generate random integers from the given interval. | ||||
|  | ||||
|         The values are sampled with equal probability from the integers in | ||||
|         half-open interval ``[low, high)``. The lower and upper bound can be | ||||
|         scalars or arrays and must be roadcastable to ``shape``. | ||||
|  | ||||
|         Args: | ||||
|             low (scalar or array): Lower bound of the interval. | ||||
|             high (scalar or array): Upper bound of the interval. | ||||
|             shape (list(int), optional): Shape of the output. Defaults to ``()``. | ||||
|             dtype (Dtype, optional): Type of the output. Defaults to ``int32``. | ||||
|             key (array, optional): A PRNG key. Default: None. | ||||
|  | ||||
|         Returns: | ||||
|             array: The array of random integers. | ||||
|       )pbdoc"); | ||||
|   m.def( | ||||
|       "bernoulli", | ||||
|       [](const ScalarOrArray& p_, | ||||
|          const std::optional<std::vector<int>> shape, | ||||
|          const std::optional<array>& key, | ||||
|          StreamOrDevice s) { | ||||
|         auto p = to_array(p_); | ||||
|         if (shape.has_value()) { | ||||
|           return bernoulli(p, shape.value(), key, s); | ||||
|         } else { | ||||
|           return bernoulli(p, key, s); | ||||
|         } | ||||
|       }, | ||||
|       "p"_a = 0.5, | ||||
|       "shape"_a = none, | ||||
|       "key"_a = none, | ||||
|       "stream"_a = none, | ||||
|       R"pbdoc( | ||||
|         Generate Bernoulli random values. | ||||
|  | ||||
|         The values are sampled from the bernoulli distribution with parameter | ||||
|         ``p``. The parameter ``p`` can be a :obj:`float` or :obj:`array` and | ||||
|         must be broadcastable to ``shape``. | ||||
|  | ||||
|         Args: | ||||
|             p (float or array, optional): Parameter of the Bernoulli | ||||
|               distribution. Default is 0.5. | ||||
|             shape (list(int), optional): Shape of the output. The default | ||||
|               shape is ``p.shape``. | ||||
|             key (array, optional): A PRNG key. Default: None. | ||||
|  | ||||
|         Returns: | ||||
|             array: The array of random integers. | ||||
|       )pbdoc"); | ||||
|   m.def( | ||||
|       "truncated_normal", | ||||
|       [](const ScalarOrArray& lower_, | ||||
|          const ScalarOrArray& upper_, | ||||
|          const std::optional<std::vector<int>> shape_, | ||||
|          Dtype dtype, | ||||
|          const std::optional<array>& key, | ||||
|          StreamOrDevice s) { | ||||
|         auto lower = to_array(lower_); | ||||
|         auto upper = to_array(upper_); | ||||
|         if (shape_.has_value()) { | ||||
|           return truncated_normal(lower, upper, shape_.value(), dtype, key, s); | ||||
|         } else { | ||||
|           return truncated_normal(lower, upper, dtype, key, s); | ||||
|         } | ||||
|       }, | ||||
|       "lower"_a, | ||||
|       "upper"_a, | ||||
|       "shape"_a = none, | ||||
|       "dtype"_a = float32, | ||||
|       "key"_a = none, | ||||
|       "stream"_a = none, | ||||
|       R"pbdoc( | ||||
|         Generate values from a truncated normal distribution. | ||||
|  | ||||
|         The values are sampled from the truncated normal distribution | ||||
|         on the domain ``(lower, upper)``. The bounds ``lower`` and ``upper`` | ||||
|         can be scalars or arrays and must be broadcastable to ``shape``. | ||||
|  | ||||
|         Args: | ||||
|             lower (scalar or array): Lower bound of the domain. | ||||
|             upper (scalar or array): Upper bound of the domain. | ||||
|             shape (list(int), optional): The shape of the output. | ||||
|               Default is ``()``. | ||||
|             dtype (Dtype, optinoal): The data type of the output. | ||||
|               Default is ``float32``. | ||||
|             key (array, optional): A PRNG key. Default: None. | ||||
|  | ||||
|         Returns: | ||||
|             array: The output array of random values. | ||||
|       )pbdoc"); | ||||
|   m.def( | ||||
|       "gumbel", | ||||
|       &gumbel, | ||||
|       "shape"_a = std::vector<int>{}, | ||||
|       "dtype"_a = float32, | ||||
|       "stream"_a = none, | ||||
|       "key"_a = none, | ||||
|       R"pbdoc( | ||||
|         Sample from the standard Gumbel distribution. | ||||
|  | ||||
|         The values are sampled from a standard Gumbel distribution | ||||
|         which CDF ``exp(-exp(-x))``. | ||||
|  | ||||
|         Args: | ||||
|             shape (list(int)): The shape of the output. | ||||
|             key (array, optional): A PRNG key. Default: None. | ||||
|  | ||||
|         Returns: | ||||
|             array: The :class:`array` with shape ``shape`` and | ||||
|                    distributed according to the Gumbel distribution | ||||
|       )pbdoc"); | ||||
|   m.def( | ||||
|       "categorical", | ||||
|       [](const array& logits, | ||||
|          int axis, | ||||
|          const std::optional<std::vector<int>> shape, | ||||
|          const std::optional<int> num_samples, | ||||
|          const std::optional<array>& key, | ||||
|          StreamOrDevice s) { | ||||
|         if (shape.has_value() && num_samples.has_value()) { | ||||
|           throw std::invalid_argument( | ||||
|               "[categorical] At most one of shape or num_samples can be specified."); | ||||
|         } else if (shape.has_value()) { | ||||
|           return categorical(logits, axis, shape.value(), key, s); | ||||
|         } else if (num_samples.has_value()) { | ||||
|           return categorical(logits, axis, num_samples.value(), key, s); | ||||
|         } else { | ||||
|           return categorical(logits, axis, key, s); | ||||
|         } | ||||
|       }, | ||||
|       "logits"_a, | ||||
|       "axis"_a = -1, | ||||
|       "shape"_a = none, | ||||
|       "num_samples"_a = none, | ||||
|       "key"_a = none, | ||||
|       "stream"_a = none, | ||||
|       R"pbdoc( | ||||
|         Sample from a categorical distribution. | ||||
|  | ||||
|         The values are sampled from the categorical distribution specified by | ||||
|         the unnormalized values in ``logits``. Note, at most one of ``shape`` | ||||
|         or ``num_samples`` can be specified. If both are ``None``, the output | ||||
|         has the same shape as ``logits`` with the ``axis`` dimension removed. | ||||
|  | ||||
|         Args: | ||||
|             logits (array): The *unnormalized* categorical distribution(s). | ||||
|             axis (int, optional): The axis which specifies the distribution. | ||||
|                Default is ``-1``. | ||||
|             shape (list(int), optional): The shape of the output. This must | ||||
|                be broadcast compatable with ``logits.shape`` with the ``axis`` | ||||
|                dimension removed. Default: ``None`` | ||||
|             num_samples (int, optional): The number of samples to draw from each | ||||
|               of the categorical distributions in ``logits``. The output will have | ||||
|               ``num_samples`` in the last dimension. Default: ``None``. | ||||
|             key (array, optional): A PRNG key. Default: None. | ||||
|  | ||||
|         Returns: | ||||
|             array: The ``shape``-sized output array with type ``uint32``. | ||||
|       )pbdoc"); | ||||
| } | ||||
							
								
								
									
										71
									
								
								python/src/utils.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										71
									
								
								python/src/utils.h
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,71 @@ | ||||
| #pragma once | ||||
| #include <numeric> | ||||
| #include <variant> | ||||
|  | ||||
| #include <pybind11/complex.h> | ||||
| #include <pybind11/pybind11.h> | ||||
| #include <pybind11/stl.h> | ||||
|  | ||||
| #include "mlx/array.h" | ||||
|  | ||||
| namespace py = pybind11; | ||||
|  | ||||
| using namespace mlx::core; | ||||
|  | ||||
| using IntOrVec = std::variant<std::monostate, int, std::vector<int>>; | ||||
| using ScalarOrArray = | ||||
|     std::variant<py::bool_, py::int_, py::float_, std::complex<float>, array>; | ||||
| static constexpr std::monostate none{}; | ||||
|  | ||||
| inline std::vector<int> get_reduce_axes(const IntOrVec& v, int dims) { | ||||
|   std::vector<int> axes; | ||||
|   if (std::holds_alternative<std::monostate>(v)) { | ||||
|     axes.resize(dims); | ||||
|     std::iota(axes.begin(), axes.end(), 0); | ||||
|   } else if (auto pv = std::get_if<int>(&v); pv) { | ||||
|     axes.push_back(*pv); | ||||
|   } else { | ||||
|     axes = std::get<std::vector<int>>(v); | ||||
|   } | ||||
|   return axes; | ||||
| } | ||||
|  | ||||
| inline array to_array( | ||||
|     const ScalarOrArray& v, | ||||
|     std::optional<Dtype> dtype = std::nullopt) { | ||||
|   if (auto pv = std::get_if<py::bool_>(&v); pv) { | ||||
|     return array(py::cast<bool>(*pv), dtype.value_or(bool_)); | ||||
|   } else if (auto pv = std::get_if<py::int_>(&v); pv) { | ||||
|     auto out_t = dtype.value_or(int32); | ||||
|     // bool_ is an exception and is always promoted | ||||
|     return array(py::cast<int>(*pv), (out_t == bool_) ? int32 : out_t); | ||||
|   } else if (auto pv = std::get_if<py::float_>(&v); pv) { | ||||
|     auto out_t = dtype.value_or(float32); | ||||
|     return array( | ||||
|         py::cast<float>(*pv), is_floating_point(out_t) ? out_t : float32); | ||||
|   } else if (auto pv = std::get_if<std::complex<float>>(&v); pv) { | ||||
|     return array(static_cast<complex64_t>(*pv), complex64); | ||||
|   } else { | ||||
|     return std::get<array>(v); | ||||
|   } | ||||
| } | ||||
|  | ||||
| inline std::pair<array, array> to_arrays( | ||||
|     const ScalarOrArray& a, | ||||
|     const ScalarOrArray& b) { | ||||
|   // Four cases: | ||||
|   // - If both a and b are arrays leave their types alone | ||||
|   // - If a is an array but b is not, treat b as a weak python type | ||||
|   // - If b is an array but a is not, treat a as a weak python type | ||||
|   // - If neither is an array convert to arrays but leave their types alone | ||||
|   if (auto pa = std::get_if<array>(&a); pa) { | ||||
|     if (auto pb = std::get_if<array>(&b); pb) { | ||||
|       return {*pa, *pb}; | ||||
|     } | ||||
|     return {*pa, to_array(b, pa->dtype())}; | ||||
|   } else if (auto pb = std::get_if<array>(&b); pb) { | ||||
|     return {to_array(a, pb->dtype()), *pb}; | ||||
|   } else { | ||||
|     return {to_array(a), to_array(b)}; | ||||
|   } | ||||
| } | ||||
							
								
								
									
										1041
									
								
								python/tests/test_array.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1041
									
								
								python/tests/test_array.py
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										263
									
								
								python/tests/test_autograd.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										263
									
								
								python/tests/test_autograd.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,263 @@ | ||||
| import unittest | ||||
|  | ||||
| import mlx.core as mx | ||||
|  | ||||
| import mlx_tests | ||||
|  | ||||
|  | ||||
| class TestAutograd(mlx_tests.MLXTestCase): | ||||
|     def test_jvp(self): | ||||
|         fun = lambda x: 2 * x | ||||
|         out, dout = mx.jvp(fun, [mx.array(1.0)], [mx.array(2.0)]) | ||||
|         self.assertEqual(out[0].item(), 2.0) | ||||
|         self.assertEqual(dout[0].item(), 4.0) | ||||
|  | ||||
|         fun = lambda x, y: x * y | ||||
|         _, out = mx.jvp( | ||||
|             fun, [mx.array(4.0), mx.array(2.0)], [mx.array(3.0), mx.array(2.0)] | ||||
|         ) | ||||
|         self.assertEqual(out[0].item(), 4.0 * 2.0 + 2.0 * 3.0) | ||||
|  | ||||
|         fun = lambda x, y, z: (x * y, y * z) | ||||
|         _, out = mx.jvp( | ||||
|             fun, | ||||
|             [mx.array(2.0), mx.array(4.0), mx.array(6.0)], | ||||
|             [mx.array(1.0), mx.array(3.0), mx.array(1.0)], | ||||
|         ) | ||||
|         self.assertEqual(len(out), 2) | ||||
|         self.assertEqual(out[0].item(), 4.0 * 1.0 + 2.0 * 3.0) | ||||
|         self.assertEqual(out[1].item(), 4.0 * 1.0 + 6.0 * 3.0) | ||||
|  | ||||
|     def test_vjp(self): | ||||
|         fun = lambda x: 2 * x | ||||
|         out, dout = mx.vjp(fun, [mx.array(1.0)], [mx.array(2.0)]) | ||||
|         self.assertEqual(out[0].item(), 2.0) | ||||
|         self.assertEqual(dout[0].item(), 4.0) | ||||
|  | ||||
|         fun = lambda x, y: x * y | ||||
|         _, dout = mx.vjp(fun, [mx.array(4.0), mx.array(2.0)], [mx.array(3.0)]) | ||||
|         self.assertEqual(dout[0].item(), 6.0) | ||||
|         self.assertEqual(dout[1].item(), 12.0) | ||||
|  | ||||
|         fun = lambda x, y, z: (x * y, y * z) | ||||
|         _, out = mx.vjp( | ||||
|             fun, | ||||
|             [mx.array(2.0), mx.array(4.0), mx.array(6.0)], | ||||
|             [mx.array(1.0), mx.array(3.0)], | ||||
|         ) | ||||
|         self.assertEqual(len(out), 3) | ||||
|         self.assertEqual(out[0].item(), 4.0 * 1.0) | ||||
|         self.assertEqual(out[1].item(), 2.0 * 1.0 + 6.0 * 3.0) | ||||
|         self.assertEqual(out[2].item(), 4.0 * 3.0) | ||||
|  | ||||
|     def test_grad(self): | ||||
|         fun = lambda x: x * x | ||||
|  | ||||
|         value, dfdx = mx.value_and_grad(fun)(mx.array(0.5)) | ||||
|         self.assertEqual(value.item(), 0.25) | ||||
|         self.assertEqual(dfdx.item(), 1.0) | ||||
|  | ||||
|         dfdx = mx.grad(fun)(mx.array(0.5)) | ||||
|         self.assertEqual(dfdx.item(), 1.0) | ||||
|  | ||||
|         df2dx2 = mx.grad(mx.grad(fun))(mx.array(0.5)) | ||||
|         self.assertEqual(df2dx2.item(), 2.0) | ||||
|         df3dx3 = mx.grad(mx.grad(mx.grad(fun)))(mx.array(0.5)) | ||||
|         self.assertEqual(df3dx3.item(), 0.0) | ||||
|  | ||||
|         fun = lambda x, y: x * y | ||||
|         x = mx.array(2.0) | ||||
|         y = mx.array(3.0) | ||||
|         dfdx = mx.grad(fun, argnums=0)(x, y) | ||||
|         self.assertEqual(dfdx.item(), 3.0) | ||||
|         dfdx = mx.grad(fun, argnums=1)(x, y) | ||||
|         self.assertEqual(dfdx.item(), 2.0) | ||||
|  | ||||
|         # Pass non array args to functions works | ||||
|         fun = lambda x, y: x | ||||
|         value, dfdx = mx.value_and_grad(fun)(mx.array(2.0), "hello") | ||||
|         self.assertEqual(value.item(), 2.0) | ||||
|         self.assertEqual(dfdx.item(), 1.0) | ||||
|  | ||||
|         dfdx = mx.grad(fun)(mx.array(2.0), "hello") | ||||
|         self.assertEqual(dfdx.item(), 1.0) | ||||
|  | ||||
|         # Raises when function does not return array | ||||
|         fun = lambda x: "hello" | ||||
|         with self.assertRaises(ValueError): | ||||
|             mx.grad(fun)(mx.array(2.0)) | ||||
|  | ||||
|         # Raises for invalid argument number or argument type | ||||
|         fun = lambda x: x | ||||
|         with self.assertRaises(ValueError): | ||||
|             mx.grad(fun, argnums=2)(mx.array(2.0)) | ||||
|         with self.assertRaises(ValueError): | ||||
|             mx.grad(fun, argnums=-2)(mx.array(2.0)) | ||||
|         with self.assertRaises(ValueError): | ||||
|             mx.grad(fun)("hello") | ||||
|  | ||||
|         # Raises when output is not a scalar array | ||||
|         fun = lambda x: mx.sum(x, keepdims=True) | ||||
|         with self.assertRaises(ValueError): | ||||
|             mx.grad(fun)(mx.ones((2, 2))) | ||||
|  | ||||
|     def test_grad_trees(self): | ||||
|         fun = lambda x, y: x * y | ||||
|         value, dfdx = mx.value_and_grad(fun, (0, 1))(mx.array(0.5), mx.array(2.0)) | ||||
|         self.assertEqual(value.item(), 1.0) | ||||
|         self.assertTrue(isinstance(dfdx, tuple)) | ||||
|         self.assertEqual(dfdx[0].item(), 2.0) | ||||
|         self.assertEqual(dfdx[1].item(), 0.5) | ||||
|  | ||||
|         fun = lambda x, y: x * y | ||||
|         value, dfdx = mx.value_and_grad(fun, 1)(mx.array(0.5), mx.array(2.0)) | ||||
|         self.assertEqual(value.item(), 1.0) | ||||
|         self.assertEqual(dfdx.item(), 0.5) | ||||
|  | ||||
|         fun = lambda p: p["x"] * p["y"] | ||||
|         value, dfdx = mx.value_and_grad(fun)({"x": mx.array(0.5), "y": mx.array(2.0)}) | ||||
|         self.assertEqual(value.item(), 1.0) | ||||
|         self.assertEqual(dfdx["x"].item(), 2.0) | ||||
|         self.assertEqual(dfdx["y"].item(), 0.5) | ||||
|  | ||||
|         fun = lambda p: p["x"] * p["y"] | ||||
|         with self.assertRaises(ValueError): | ||||
|             mx.value_and_grad(fun)({"x": 0.5, "y": mx.array(2.0)}) | ||||
|         with self.assertRaises(ValueError): | ||||
|             mx.value_and_grad(fun, (0, 1))({"x": mx.array(0.5), "y": mx.array(2.0)}) | ||||
|  | ||||
|         fun = lambda p, b: mx.square(p[0]["foo"][2]) * b | ||||
|         value, dfdx = mx.value_and_grad(fun)( | ||||
|             [{"foo": [[], [], mx.array(2.0)]}], mx.array(0.5) | ||||
|         ) | ||||
|         self.assertEqual(value.item(), 2.0) | ||||
|         self.assertEqual(dfdx[0]["foo"][2].item(), 2.0) | ||||
|  | ||||
|         fun = lambda x: x | ||||
|         with self.assertRaises(TypeError): | ||||
|             mx.value_and_grad(fun, (None, None)) | ||||
|         with self.assertRaises(ValueError): | ||||
|             mx.value_and_grad(fun, tuple()) | ||||
|  | ||||
|     def test_auxiliary_values(self): | ||||
|         def fun(x, y): | ||||
|             l = (x * y).sum() | ||||
|             extra = {"loss": l, "foo": y.square() + x.square(), "bar": [1, 2, 3, y, x]} | ||||
|             return l, extra | ||||
|  | ||||
|         fun_value_grad = mx.value_and_grad(fun) | ||||
|         fun_grad = mx.grad(fun) | ||||
|  | ||||
|         (loss, a), b = fun_value_grad(mx.ones((2, 2)), mx.ones((2, 2))) | ||||
|         self.assertEqual(a["loss"].item(), 4) | ||||
|         self.assertTrue(mx.array_equal(b, mx.ones((2, 2)))) | ||||
|         self.assertTrue(mx.array_equal(a["foo"], 2 * mx.ones((2, 2)))) | ||||
|         self.assertEqual(a["bar"][:3], [1, 2, 3]) | ||||
|         self.assertTrue(mx.array_equal(a["bar"][3], mx.ones((2, 2)))) | ||||
|         self.assertTrue(mx.array_equal(a["bar"][4], mx.ones((2, 2)))) | ||||
|  | ||||
|         with self.assertRaises(ValueError): | ||||
|             _ = fun_grad(mx.ones((2, 2)), mx.ones((2, 2))) | ||||
|  | ||||
|     def test_grad_kwargs(self): | ||||
|         fun = lambda x, y: x * y | ||||
|         a, b = mx.array(0.5), mx.array(2.0) | ||||
|         dfdx = mx.grad(fun) | ||||
|         self.assertEqual(dfdx(a, b).item(), 2.0) | ||||
|         self.assertEqual(dfdx(a, y=b).item(), 2.0) | ||||
|         with self.assertRaises(ValueError): | ||||
|             dfdx(x=a, y=b).item() | ||||
|  | ||||
|         dfdy = mx.grad(fun, argnums=[], argnames=["y"]) | ||||
|         with self.assertRaises(ValueError): | ||||
|             dfdy(a, b) | ||||
|         grads = dfdy(a, y=b) | ||||
|         self.assertTrue(isinstance(grads, tuple)) | ||||
|         self.assertTrue(grads[0] is None) | ||||
|         self.assertTrue(isinstance(grads[1], dict)) | ||||
|         self.assertEqual(grads[1]["y"].item(), 0.5) | ||||
|         grads = dfdy(x=a, y=b) | ||||
|         self.assertEqual(grads[1]["y"].item(), 0.5) | ||||
|         self.assertEqual(len(grads[1]), 1) | ||||
|  | ||||
|         dfdxy = mx.grad(fun, argnums=[0], argnames=["y"]) | ||||
|         with self.assertRaises(ValueError): | ||||
|             dfdxy(a, b) | ||||
|         with self.assertRaises(ValueError): | ||||
|             dfdxy(x=a, y=b) | ||||
|         grads = dfdxy(a, y=b) | ||||
|         self.assertTrue(isinstance(grads, tuple)) | ||||
|         self.assertEqual(grads[0].item(), 2.0) | ||||
|         self.assertTrue(isinstance(grads[1], dict)) | ||||
|         self.assertEqual(grads[1]["y"].item(), 0.5) | ||||
|  | ||||
|         fun = lambda x, y, z: x * y * z | ||||
|         dfdxyz = mx.grad(fun, argnums=[0, 1], argnames=["z"]) | ||||
|         c = mx.array(4.0) | ||||
|         grads = dfdxyz(a, b, z=c) | ||||
|         self.assertTrue(isinstance(grads, tuple)) | ||||
|         self.assertTrue(isinstance(grads[0], tuple)) | ||||
|         self.assertEqual(grads[0][0].item(), 8.0) | ||||
|         self.assertEqual(grads[0][1].item(), 2.0) | ||||
|         self.assertTrue(isinstance(grads[1], dict)) | ||||
|         self.assertEqual(grads[1]["z"].item(), 1.0) | ||||
|  | ||||
|         fun = lambda x, y: x * y | ||||
|         dfdy = mx.grad(fun, argnames=["y"]) | ||||
|         grads = dfdy(a, y=b) | ||||
|         self.assertTrue(isinstance(grads, tuple)) | ||||
|         self.assertTrue(grads[0] is None) | ||||
|         self.assertTrue(isinstance(grads[1], dict)) | ||||
|         self.assertEqual(grads[1]["y"].item(), 0.5) | ||||
|  | ||||
|     def test_captured(self): | ||||
|         a = mx.array(5.0) | ||||
|         f = lambda x: a + x | ||||
|         g = lambda x: a + a | ||||
|         h = lambda x: x + x | ||||
|  | ||||
|         dfdx = mx.grad(f) | ||||
|         self.assertEqual(dfdx(a).item(), 1.0) | ||||
|  | ||||
|         dgdx = mx.grad(g) | ||||
|         self.assertEqual(dgdx(a).item(), 0.0) | ||||
|  | ||||
|         dhdx = mx.grad(h) | ||||
|         self.assertEqual(dhdx(a).item(), 2.0) | ||||
|  | ||||
|         d2fdx2 = mx.grad(dfdx) | ||||
|         self.assertEqual(d2fdx2(a).item(), 0.0) | ||||
|  | ||||
|         d2gdx2 = mx.grad(dgdx) | ||||
|         self.assertEqual(d2gdx2(a).item(), 0.0) | ||||
|  | ||||
|         d2hdx2 = mx.grad(dhdx) | ||||
|         self.assertEqual(d2hdx2(a).item(), 0.0) | ||||
|  | ||||
|     def test_stop_gradient(self): | ||||
|         shape_in = (4, 4) | ||||
|         w_in = mx.ones(shape_in) | ||||
|         x_in = mx.ones(shape_in) | ||||
|         cotan = mx.ones(shape_in) | ||||
|  | ||||
|         def h(w, x): | ||||
|             x1 = 2 * x | ||||
|             y = mx.stop_gradient(x1) | ||||
|             y1 = 3 * y | ||||
|             return w @ y1 | ||||
|  | ||||
|         vals, vjps = mx.vjp(h, [w_in, x_in], [cotan]) | ||||
|         mx.eval(vjps) | ||||
|  | ||||
|         self.assertTrue(mx.allclose(vjps[0], 24.0 * mx.ones(shape_in))) | ||||
|         self.assertTrue(mx.allclose(vjps[1], mx.zeros(shape_in))) | ||||
|  | ||||
|         g = lambda x: h(w_in, x) | ||||
|         vals, vjps = mx.vjp(g, [x_in], [cotan]) | ||||
|         mx.eval(vjps) | ||||
|  | ||||
|         self.assertTrue(mx.allclose(vjps[0], mx.zeros(shape_in))) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
							
								
								
									
										105
									
								
								python/tests/test_device.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										105
									
								
								python/tests/test_device.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,105 @@ | ||||
| import unittest | ||||
|  | ||||
| import mlx.core as mx | ||||
|  | ||||
| import mlx_tests | ||||
|  | ||||
|  | ||||
| # Don't inherit from MLXTestCase to avoid call to setUp | ||||
| class TestDefaultDevice(unittest.TestCase): | ||||
|     def test_mlx_default_device(self): | ||||
|         device = mx.default_device() | ||||
|         if mx.metal.is_available(): | ||||
|             self.assertEqual(device, mx.Device(mx.gpu)) | ||||
|             self.assertEqual(str(device), "Device(gpu, 0)") | ||||
|             self.assertEqual(device, mx.gpu) | ||||
|             self.assertEqual(mx.gpu, device) | ||||
|         else: | ||||
|             self.assertEqual(device.type, mx.Device(mx.cpu)) | ||||
|             with self.assertRaises(ValueError): | ||||
|                 mx.set_default_device(mx.gpu) | ||||
|  | ||||
|  | ||||
| class TestDevice(mlx_tests.MLXTestCase): | ||||
|     def test_device(self): | ||||
|         device = mx.default_device() | ||||
|  | ||||
|         cpu = mx.Device(mx.cpu) | ||||
|         mx.set_default_device(cpu) | ||||
|         self.assertEqual(mx.default_device(), cpu) | ||||
|         self.assertEqual(str(cpu), "Device(cpu, 0)") | ||||
|  | ||||
|         mx.set_default_device(mx.cpu) | ||||
|         self.assertEqual(mx.default_device(), mx.cpu) | ||||
|         self.assertEqual(cpu, mx.cpu) | ||||
|         self.assertEqual(mx.cpu, cpu) | ||||
|  | ||||
|         # Restore device | ||||
|         mx.set_default_device(device) | ||||
|  | ||||
|     def test_op_on_device(self): | ||||
|         x = mx.array(1.0) | ||||
|         y = mx.array(1.0) | ||||
|  | ||||
|         a = mx.add(x, y, stream=None) | ||||
|         b = mx.add(x, y, stream=mx.default_device()) | ||||
|         self.assertEqual(a.item(), b.item()) | ||||
|         b = mx.add(x, y, stream=mx.cpu) | ||||
|         self.assertEqual(a.item(), b.item()) | ||||
|  | ||||
|         if mx.metal.is_available(): | ||||
|             b = mx.add(x, y, stream=mx.gpu) | ||||
|             self.assertEqual(a.item(), b.item()) | ||||
|  | ||||
|  | ||||
| class TestStream(mlx_tests.MLXTestCase): | ||||
|     def test_stream(self): | ||||
|         s1 = mx.default_stream(mx.default_device()) | ||||
|         self.assertEqual(s1.device, mx.default_device()) | ||||
|  | ||||
|         s2 = mx.new_stream(mx.default_device()) | ||||
|         self.assertEqual(s2.device, mx.default_device()) | ||||
|         self.assertNotEqual(s1, s2) | ||||
|  | ||||
|         if mx.metal.is_available(): | ||||
|             s_gpu = mx.default_stream(mx.gpu) | ||||
|             self.assertEqual(s_gpu.device, mx.gpu) | ||||
|         else: | ||||
|             with self.assertRaises(ValueError): | ||||
|                 mx.default_stream(mx.gpu) | ||||
|  | ||||
|         s_cpu = mx.default_stream(mx.cpu) | ||||
|         self.assertEqual(s_cpu.device, mx.cpu) | ||||
|  | ||||
|         s_cpu = mx.new_stream(mx.cpu) | ||||
|         self.assertEqual(s_cpu.device, mx.cpu) | ||||
|  | ||||
|         if mx.metal.is_available(): | ||||
|             s_gpu = mx.new_stream(mx.gpu) | ||||
|             self.assertEqual(s_gpu.device, mx.gpu) | ||||
|         else: | ||||
|             with self.assertRaises(ValueError): | ||||
|                 mx.new_stream(mx.gpu) | ||||
|  | ||||
|     def test_op_on_stream(self): | ||||
|         x = mx.array(1.0) | ||||
|         y = mx.array(1.0) | ||||
|  | ||||
|         a = mx.add(x, y, stream=mx.default_stream(mx.default_device())) | ||||
|  | ||||
|         if mx.metal.is_available(): | ||||
|             b = mx.add(x, y, stream=mx.default_stream(mx.gpu)) | ||||
|             self.assertEqual(a.item(), b.item()) | ||||
|             s_gpu = mx.new_stream(mx.gpu) | ||||
|             b = mx.add(x, y, stream=s_gpu) | ||||
|             self.assertEqual(a.item(), b.item()) | ||||
|  | ||||
|         b = mx.add(x, y, stream=mx.default_stream(mx.cpu)) | ||||
|         self.assertEqual(a.item(), b.item()) | ||||
|         s_cpu = mx.new_stream(mx.cpu) | ||||
|         b = mx.add(x, y, stream=s_cpu) | ||||
|         self.assertEqual(a.item(), b.item()) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
							
								
								
									
										34
									
								
								python/tests/test_eval.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										34
									
								
								python/tests/test_eval.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,34 @@ | ||||
| from functools import partial | ||||
|  | ||||
| import unittest | ||||
|  | ||||
| import mlx.core as mx | ||||
|  | ||||
| import mlx_tests | ||||
|  | ||||
|  | ||||
| class TestEval(mlx_tests.MLXTestCase): | ||||
|     def test_eval(self): | ||||
|         arrs = [mx.ones((2, 2)) for _ in range(4)] | ||||
|         mx.eval(*arrs) | ||||
|         for x in arrs: | ||||
|             self.assertEqual(x.tolist(), [[1, 1], [1, 1]]) | ||||
|  | ||||
|     def test_retain_graph(self): | ||||
|         def fun(x, retain_graph): | ||||
|             y = 3 * x | ||||
|             mx.eval(y, retain_graph=retain_graph) | ||||
|             return 2 * y | ||||
|  | ||||
|         dfun_dx_1 = mx.grad(partial(fun, retain_graph=False)) | ||||
|         dfun_dx_2 = mx.grad(partial(fun, retain_graph=True)) | ||||
|  | ||||
|         with self.assertRaises(ValueError): | ||||
|             dfun_dx_1(mx.array(1.0)) | ||||
|  | ||||
|         y = dfun_dx_2(mx.array(1.0)) | ||||
|         self.assertEqual(y.item(), 6.0) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
							
								
								
									
										90
									
								
								python/tests/test_fft.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										90
									
								
								python/tests/test_fft.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,90 @@ | ||||
| import unittest | ||||
|  | ||||
| import itertools | ||||
| import mlx.core as mx | ||||
| import numpy as np | ||||
|  | ||||
| import mlx_tests | ||||
|  | ||||
|  | ||||
| class TestFFT(mlx_tests.MLXTestCase): | ||||
|     def check_mx_np(self, op, a_np, axes, s): | ||||
|         with self.subTest(op=op, axes=axes, s=s): | ||||
|             op_np = getattr(np.fft, op) | ||||
|             op_mx = getattr(mx.fft, op) | ||||
|             out_np = op_np(a_np, s=s, axes=axes) | ||||
|             a_mx = mx.array(a_np) | ||||
|             out_mx = op_mx(a_mx, s=s, axes=axes) | ||||
|             self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)) | ||||
|  | ||||
|     def test_fft(self): | ||||
|         default = mx.default_device() | ||||
|         mx.set_default_device(mx.cpu) | ||||
|  | ||||
|         def check_mx_np(op_mx, op_np, a_np, **kwargs): | ||||
|             out_np = op_np(a_np, **kwargs) | ||||
|             a_mx = mx.array(a_np) | ||||
|             out_mx = op_mx(a_mx, **kwargs) | ||||
|             self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)) | ||||
|  | ||||
|         r = np.random.rand(100).astype(np.float32) | ||||
|         i = np.random.rand(100).astype(np.float32) | ||||
|         a_np = r + 1j * i | ||||
|         check_mx_np(mx.fft.fft, np.fft.fft, a_np) | ||||
|  | ||||
|         # Check with slicing and padding | ||||
|         r = np.random.rand(100).astype(np.float32) | ||||
|         i = np.random.rand(100).astype(np.float32) | ||||
|         a_np = r + 1j * i | ||||
|         check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=80) | ||||
|         check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=120) | ||||
|  | ||||
|         # Check different axes | ||||
|         r = np.random.rand(100, 100).astype(np.float32) | ||||
|         i = np.random.rand(100, 100).astype(np.float32) | ||||
|         a_np = r + 1j * i | ||||
|         check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=0) | ||||
|         check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=1) | ||||
|  | ||||
|         # Check real fft | ||||
|         a_np = np.random.rand(100).astype(np.float32) | ||||
|         check_mx_np(mx.fft.rfft, np.fft.rfft, a_np) | ||||
|         check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=80) | ||||
|         check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=120) | ||||
|  | ||||
|         # Check real inverse | ||||
|         r = np.random.rand(100, 100).astype(np.float32) | ||||
|         i = np.random.rand(100, 100).astype(np.float32) | ||||
|         a_np = r + 1j * i | ||||
|         check_mx_np(mx.fft.ifft, np.fft.ifft, a_np) | ||||
|         check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=80) | ||||
|         check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=120) | ||||
|         check_mx_np(mx.fft.irfft, np.fft.irfft, a_np) | ||||
|         check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=80) | ||||
|         check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=120) | ||||
|  | ||||
|         mx.set_default_device(default) | ||||
|  | ||||
|     def test_fftn(self): | ||||
|         default = mx.default_device() | ||||
|         mx.set_default_device(mx.cpu) | ||||
|  | ||||
|         r = np.random.randn(8, 8, 8).astype(np.float32) | ||||
|         i = np.random.randn(8, 8, 8).astype(np.float32) | ||||
|         a = r + 1j * i | ||||
|  | ||||
|         axes = [None, (1, 2), (2, 1), (0, 2)] | ||||
|         shapes = [None, (10, 5), (5, 10)] | ||||
|         ops = ["fft2", "ifft2", "rfft2", "irfft2", "fftn", "ifftn", "rfftn", "irfftn"] | ||||
|  | ||||
|         for op, ax, s in itertools.product(ops, axes, shapes): | ||||
|             x = a | ||||
|             if op in ["rfft2", "rfftn"]: | ||||
|                 x = r | ||||
|             self.check_mx_np(op, x, axes=ax, s=s) | ||||
|  | ||||
|         mx.set_default_device(default) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
							
								
								
									
										1283
									
								
								python/tests/test_ops.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1283
									
								
								python/tests/test_ops.py
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										118
									
								
								python/tests/test_reduce.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										118
									
								
								python/tests/test_reduce.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,118 @@ | ||||
| import unittest | ||||
| from itertools import permutations, combinations | ||||
|  | ||||
| import mlx.core as mx | ||||
| import numpy as np | ||||
|  | ||||
| import mlx_tests | ||||
|  | ||||
|  | ||||
| class TestReduce(mlx_tests.MLXTestCase): | ||||
|     def test_axis_permutation_sums(self): | ||||
|         x_npy = np.random.randn(5, 5, 5, 5, 5).astype(np.float32) | ||||
|         x_mlx = mx.array(x_npy) | ||||
|         for t in permutations(range(5)): | ||||
|             with self.subTest(t=t): | ||||
|                 y_npy = np.transpose(x_npy, t) | ||||
|                 y_mlx = mx.transpose(x_mlx, t) | ||||
|                 for n in range(1, 6): | ||||
|                     for a in combinations(range(5), n): | ||||
|                         with self.subTest(a=a): | ||||
|                             z_npy = np.sum(y_npy, axis=a) | ||||
|                             z_mlx = mx.sum(y_mlx, axis=a) | ||||
|                             mx.eval(z_mlx) | ||||
|                             self.assertTrue( | ||||
|                                 np.allclose(z_npy, np.array(z_mlx), atol=1e-4) | ||||
|                             ) | ||||
|  | ||||
|     def test_expand_sums(self): | ||||
|         x_npy = np.random.randn(5, 1, 5, 1, 5, 1).astype(np.float32) | ||||
|         x_mlx = mx.array(x_npy) | ||||
|         for m in range(1, 4): | ||||
|             for ax in combinations([1, 3, 5], m): | ||||
|                 shape = np.array([5, 1, 5, 1, 5, 1]) | ||||
|                 shape[list(ax)] = 5 | ||||
|                 shape = shape.tolist() | ||||
|                 with self.subTest(shape=shape): | ||||
|                     y_npy = np.broadcast_to(x_npy, shape) | ||||
|                     y_mlx = mx.broadcast_to(x_mlx, shape) | ||||
|                     for n in range(1, 7): | ||||
|                         for a in combinations(range(6), n): | ||||
|                             with self.subTest(a=a): | ||||
|                                 z_npy = np.sum(y_npy, axis=a) / 1000 | ||||
|                                 z_mlx = mx.sum(y_mlx, axis=a) / 1000 | ||||
|                                 mx.eval(z_mlx) | ||||
|                                 self.assertTrue( | ||||
|                                     np.allclose(z_npy, np.array(z_mlx), atol=1e-4) | ||||
|                                 ) | ||||
|  | ||||
|     def test_dtypes(self): | ||||
|         int_dtypes = [ | ||||
|             "int8", | ||||
|             "int16", | ||||
|             "int32", | ||||
|             "uint8", | ||||
|             "uint16", | ||||
|             "uint32", | ||||
|         ] | ||||
|         float_dtypes = ["float32"] | ||||
|  | ||||
|         for dtype in int_dtypes + float_dtypes: | ||||
|             with self.subTest(dtype=dtype): | ||||
|                 x = np.random.uniform(0, 2, size=(3, 3, 3)).astype(getattr(np, dtype)) | ||||
|                 y = mx.array(x) | ||||
|  | ||||
|                 for op in ("sum", "prod", "min", "max"): | ||||
|                     with self.subTest(op=op): | ||||
|  | ||||
|                         np_op = getattr(np, op) | ||||
|                         mlx_op = getattr(mx, op) | ||||
|  | ||||
|                         for axes in (None, 0, 1, 2, (0, 1), (0, 2), (1, 2), (0, 1, 2)): | ||||
|                             with self.subTest(axes=axes): | ||||
|                                 if op in ("sum", "prod"): | ||||
|                                     r_np = np_op( | ||||
|                                         x, axis=axes, dtype=(getattr(np, dtype)) | ||||
|                                     ) | ||||
|                                 else: | ||||
|                                     r_np = np_op(x, axis=axes) | ||||
|                                 r_mlx = mlx_op(y, axis=axes) | ||||
|                                 mx.eval(r_mlx) | ||||
|                                 self.assertTrue(np.allclose(r_np, r_mlx, atol=1e-4)) | ||||
|  | ||||
|     def test_arg_reduce(self): | ||||
|         dtypes = [ | ||||
|             "uint8", | ||||
|             "uint16", | ||||
|             "uint32", | ||||
|             "uint64", | ||||
|             "int8", | ||||
|             "int16", | ||||
|             "int32", | ||||
|             "int64", | ||||
|             "float16", | ||||
|             "float32", | ||||
|         ] | ||||
|         for dtype in dtypes: | ||||
|             with self.subTest(dtype=dtype): | ||||
|  | ||||
|                 data = np.random.rand(10, 12, 13).astype(getattr(np, dtype)) | ||||
|                 x = mx.array(data) | ||||
|                 for op in ["argmin", "argmax"]: | ||||
|                     for axis in range(3): | ||||
|                         for kd in [True, False]: | ||||
|                             a = getattr(mx, op)(x, axis, kd) | ||||
|                             b = getattr(np, op)(data, axis, keepdims=kd) | ||||
|                             self.assertEqual(a.tolist(), b.tolist()) | ||||
|  | ||||
|                 for op in ["argmin", "argmax"]: | ||||
|                     a = getattr(mx, op)(x, keepdims=True) | ||||
|                     b = getattr(np, op)(data, keepdims=True) | ||||
|                     self.assertEqual(a.tolist(), b.tolist()) | ||||
|                     a = getattr(mx, op)(x) | ||||
|                     b = getattr(np, op)(data) | ||||
|                     self.assertEqual(a.item(), b) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main(failfast=True) | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun