Compare commits

..

8 Commits

Author SHA1 Message Date
Angelos Katharopoulos
a22d0bf273 Add stricter condition to matrix sdpa 2025-08-06 19:51:14 -07:00
Jagrit Digani
99d8de8445 Fix cudnn routing 2025-08-06 15:05:58 -07:00
Jagrit Digani
c66b76a8c8 Update routing 2025-08-06 15:01:15 -07:00
Jagrit Digani
f81edd184f Complete 2 pass sdpav 2025-08-06 13:57:40 -07:00
Jagrit Digani
7f8ba2a003 [WIP] 2 pass sdpav 2025-08-06 09:56:39 -07:00
Jagrit Digani
c28249b81a Add more nvtx range for debug 2025-08-06 09:56:39 -07:00
Jagrit Digani
e74bcdc5e3 Add sdpa file 2025-08-06 09:56:39 -07:00
Jagrit Digani
d8ed6c1aa3 Add base cudnn attention support 2025-08-06 09:56:39 -07:00
9 changed files with 1160 additions and 94 deletions

View File

@@ -51,14 +51,14 @@ the saved state. Here's a simple example:
optimizer.update(model, grads) optimizer.update(model, grads)
# Save the state # Save the state
state = tree_flatten(optimizer.state, destination={}) state = tree_flatten(optimizer.state)
mx.save_safetensors("optimizer.safetensors", state) mx.save_safetensors("optimizer.safetensors", dict(state))
# Later on, for example when loading from a checkpoint, # Later on, for example when loading from a checkpoint,
# recreate the optimizer and load the state # recreate the optimizer and load the state
optimizer = optim.Adam(learning_rate=1e-2) optimizer = optim.Adam(learning_rate=1e-2)
state = tree_unflatten(mx.load("optimizer.safetensors")) state = tree_unflatten(list(mx.load("optimizer.safetensors").items()))
optimizer.state = state optimizer.state = state
Note, not every optimizer configuation parameter is saved in the state. For Note, not every optimizer configuation parameter is saved in the state. For

View File

@@ -7,17 +7,17 @@ Exporting Functions
MLX has an API to export and import functions to and from a file. This lets you MLX has an API to export and import functions to and from a file. This lets you
run computations written in one MLX front-end (e.g. Python) in another MLX run computations written in one MLX front-end (e.g. Python) in another MLX
front-end (e.g. C++). front-end (e.g. C++).
This guide walks through the basics of the MLX export API with some examples. This guide walks through the basics of the MLX export API with some examples.
To see the full list of functions check-out the :ref:`API documentation To see the full list of functions check-out the :ref:`API documentation
<export>`. <export>`.
Basics of Exporting Basics of Exporting
------------------- -------------------
Let's start with a simple example: Let's start with a simple example:
.. code-block:: python .. code-block:: python
def fun(x, y): def fun(x, y):
@@ -67,7 +67,7 @@ specified as variable positional arguments or as a tuple of arrays:
x = mx.array(1.0) x = mx.array(1.0)
y = mx.array(1.0) y = mx.array(1.0)
# Both arguments to fun are positional # Both arguments to fun are positional
mx.export_function("add.mlxfn", fun, x, y) mx.export_function("add.mlxfn", fun, x, y)
@@ -133,7 +133,7 @@ parameters are also saved to the ``model.mlxfn`` file.
For enclosed arrays inside an exported function, be extra careful to ensure For enclosed arrays inside an exported function, be extra careful to ensure
they are evaluated. The computation graph that gets exported will include they are evaluated. The computation graph that gets exported will include
the computation that produces enclosed inputs. the computation that produces enclosed inputs.
If the above example was missing ``mx.eval(model.parameters()``, the If the above example was missing ``mx.eval(model.parameters()``, the
exported function would include the random initialization of the exported function would include the random initialization of the
:obj:`mlx.nn.Module` parameters. :obj:`mlx.nn.Module` parameters.
@@ -150,8 +150,8 @@ parameters, pass them as inputs to the ``call`` wrapper:
# Set the model's parameters to the input parameters # Set the model's parameters to the input parameters
model.update(tree_unflatten(list(params.items()))) model.update(tree_unflatten(list(params.items())))
return model(x) return model(x)
params = tree_flatten(model.parameters(), destination={}) params = dict(tree_flatten(model.parameters()))
mx.export_function("model.mlxfn", call, (mx.zeros(4),), params) mx.export_function("model.mlxfn", call, (mx.zeros(4),), params)
@@ -169,8 +169,8 @@ to export a function which can be used for inputs with variable shapes:
# Ok # Ok
out, = imported_abs(mx.array(-1.0)) out, = imported_abs(mx.array(-1.0))
# Also ok # Also ok
out, = imported_abs(mx.array([-1.0, -2.0])) out, = imported_abs(mx.array([-1.0, -2.0]))
With ``shapeless=False`` (which is the default), the second call to With ``shapeless=False`` (which is the default), the second call to
@@ -197,7 +197,7 @@ a single file by creating an exporting context manager with :func:`exporter`:
def fun(x, y=None): def fun(x, y=None):
constant = mx.array(3.0) constant = mx.array(3.0)
if y is not None: if y is not None:
x += y x += y
return x + constant return x + constant
with mx.exporter("fun.mlxfn", fun) as exporter: with mx.exporter("fun.mlxfn", fun) as exporter:
@@ -215,7 +215,7 @@ a single file by creating an exporting context manager with :func:`exporter`:
print(out) print(out)
In the above example the function constant data, (i.e. ``constant``), is only In the above example the function constant data, (i.e. ``constant``), is only
saved once. saved once.
Transformations with Imported Functions Transformations with Imported Functions
--------------------------------------- ---------------------------------------
@@ -238,7 +238,7 @@ on imported functions just like regular Python functions:
# Prints: array(1, dtype=float32) # Prints: array(1, dtype=float32)
print(dfdx(x)) print(dfdx(x))
# Compile the imported function # Compile the imported function
mx.compile(imported_fun) mx.compile(imported_fun)
# Prints: array(0, dtype=float32) # Prints: array(0, dtype=float32)
print(compiled_fun(x)[0]) print(compiled_fun(x)[0])
@@ -275,7 +275,7 @@ Import and run the function in C++ with only a few lines of code:
// Prints: array(2, dtype=float32) // Prints: array(2, dtype=float32)
std::cout << outputs[0] << std::endl; std::cout << outputs[0] << std::endl;
Imported functions can be transformed in C++ just like in Python. Use Imported functions can be transformed in C++ just like in Python. Use
``std::vector<mx::array>`` for positional arguments and ``std::map<std::string, ``std::vector<mx::array>`` for positional arguments and ``std::map<std::string,
mx::array>`` for keyword arguments when calling imported functions in C++. mx::array>`` for keyword arguments when calling imported functions in C++.

View File

@@ -39,6 +39,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu ${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu ${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cu
${CMAKE_CURRENT_SOURCE_DIR}/scan.cu ${CMAKE_CURRENT_SOURCE_DIR}/scan.cu
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu

View File

@@ -6,17 +6,6 @@
namespace mlx::core { namespace mlx::core {
bool fast::ScaledDotProductAttention::use_fallback(
const array& q,
const array& k,
const array& v,
bool has_mask,
bool has_arr_mask,
bool do_causal,
Stream s) {
return true;
}
#define NO_GPU_MULTI(func) \ #define NO_GPU_MULTI(func) \
void func::eval_gpu( \ void func::eval_gpu( \
const std::vector<array>& inputs, std::vector<array>& outputs) { \ const std::vector<array>& inputs, std::vector<array>& outputs) { \
@@ -53,7 +42,6 @@ NO_GPU_MULTI(Eig)
NO_GPU_MULTI(Eigh) NO_GPU_MULTI(Eigh)
namespace fast { namespace fast {
NO_GPU(ScaledDotProductAttention)
NO_GPU_MULTI(CustomKernel) NO_GPU_MULTI(CustomKernel)
} // namespace fast } // namespace fast

File diff suppressed because it is too large Load Diff

View File

@@ -104,7 +104,7 @@ struct CommandEncoder {
}; };
// Outputs of all kernels in the encoder including temporaries // Outputs of all kernels in the encoder including temporaries
std::unordered_set<const void*>& outputs() { std::unordered_set<const void*> outputs() {
return all_outputs_; return all_outputs_;
}; };

View File

@@ -178,7 +178,7 @@ class Module(dict):
if strict: if strict:
new_weights = dict(weights) new_weights = dict(weights)
curr_weights = tree_flatten(self.parameters(), destination={}) curr_weights = dict(tree_flatten(self.parameters()))
if extras := (new_weights.keys() - curr_weights.keys()): if extras := (new_weights.keys() - curr_weights.keys()):
num_extra = len(extras) num_extra = len(extras)
extras = ",\n".join(sorted(extras)) extras = ",\n".join(sorted(extras))
@@ -212,7 +212,7 @@ class Module(dict):
- ``.npz`` will use :func:`mx.savez` - ``.npz`` will use :func:`mx.savez`
- ``.safetensors`` will use :func:`mx.save_safetensors` - ``.safetensors`` will use :func:`mx.save_safetensors`
""" """
params_dict = tree_flatten(self.parameters(), destination={}) params_dict = dict(tree_flatten(self.parameters()))
if file.endswith(".npz"): if file.endswith(".npz"):
mx.savez(file, **params_dict) mx.savez(file, **params_dict)

View File

@@ -1,7 +1,7 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023 Apple Inc.
from collections import defaultdict from collections import defaultdict
from itertools import zip_longest from itertools import zip_longest
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, List, Optional, Tuple
def tree_map( def tree_map(
@@ -114,11 +114,8 @@ def tree_map_with_path(
def tree_flatten( def tree_flatten(
tree: Any, tree: Any, prefix: str = "", is_leaf: Optional[Callable] = None
prefix: str = "", ) -> Any:
is_leaf: Optional[Callable] = None,
destination: Optional[Union[List[Tuple[str, Any]], Dict[str, Any]]] = None,
) -> Union[List[Tuple[str, Any]], Dict[str, Any]]:
"""Flattens a Python tree to a list of key, value tuples. """Flattens a Python tree to a list of key, value tuples.
The keys are using the dot notation to define trees of arbitrary depth and The keys are using the dot notation to define trees of arbitrary depth and
@@ -131,12 +128,9 @@ def tree_flatten(
print(tree_flatten([[[0]]])) print(tree_flatten([[[0]]]))
# [("0.0.0", 0)] # [("0.0.0", 0)]
print(tree_flatten([[[0]]], prefix=".hello")) print(tree_flatten([[[0]]], ".hello"))
# [("hello.0.0.0", 0)] # [("hello.0.0.0", 0)]
tree_flatten({"a": {"b": 1}}, destination={})
{"a.b": 1}
.. note:: .. note::
Dictionaries should have keys that are valid Python identifiers. Dictionaries should have keys that are valid Python identifiers.
@@ -146,50 +140,26 @@ def tree_flatten(
always discarded. always discarded.
is_leaf (callable): An optional callable that returns True if the is_leaf (callable): An optional callable that returns True if the
passed object is considered a leaf or False otherwise. passed object is considered a leaf or False otherwise.
destination (list or dict, optional): A list or dictionary to store the
flattened tree. If None an empty list will be used. Default: ``None``.
Returns: Returns:
Union[List[Tuple[str, Any]], Dict[str, Any]]: The flat representation of List[Tuple[str, Any]]: The flat representation of the Python tree.
the Python tree.
""" """
if destination is None: flat_tree = []
destination = []
# Create the function to update the destination. We are taking advantage of if is_leaf is None or not is_leaf(tree):
# the fact that list.extend and dict.update have the same API to simplify if isinstance(tree, (list, tuple)):
# the code a bit. for i, t in enumerate(tree):
if isinstance(destination, list): flat_tree.extend(tree_flatten(t, f"{prefix}.{i}", is_leaf))
_add_to_destination = destination.extend return flat_tree
elif isinstance(destination, dict): if isinstance(tree, dict):
_add_to_destination = destination.update for k, t in tree.items():
else: flat_tree.extend(tree_flatten(t, f"{prefix}.{k}", is_leaf))
raise ValueError("Destination should be either a list or a dictionary or None") return flat_tree
# Leaf identified by is_leaf so add it and return return [(prefix[1:], tree)]
if is_leaf is not None and is_leaf(tree):
_add_to_destination([(prefix[1:], tree)])
return destination
# List or tuple so recursively add each subtree
if isinstance(tree, (list, tuple)):
for i, item in enumerate(tree):
tree_flatten(item, f"{prefix}.{i}", is_leaf, destination)
return destination
# Dictionary so recursively add each subtree
if isinstance(tree, dict):
for key, value in tree.items():
tree_flatten(value, f"{prefix}.{key}", is_leaf, destination)
return destination
# Leaf so add it and return
_add_to_destination([(prefix[1:], tree)])
return destination
def tree_unflatten(tree: Union[List[Tuple[str, Any]], Dict[str, Any]]) -> Any: def tree_unflatten(tree: List[Tuple[str, Any]]) -> Any:
"""Recreate a Python tree from its flat representation. """Recreate a Python tree from its flat representation.
.. code-block:: python .. code-block:: python
@@ -200,34 +170,31 @@ def tree_unflatten(tree: Union[List[Tuple[str, Any]], Dict[str, Any]]) -> Any:
print(d) print(d)
# {"hello": {"world": 42}} # {"hello": {"world": 42}}
d = tree_unflatten({"hello.world": 42})
print(d)
# {"hello": {"world": 42}}
Args: Args:
tree (list[tuple[str, Any]] or dict[str, Any]): The flat representation of a Python tree. tree (list[tuple[str, Any]]): The flat representation of a Python tree.
For instance as returned by :meth:`tree_flatten`. For instance as returned by :meth:`tree_flatten`.
Returns: Returns:
A Python tree. A Python tree.
""" """
items = tree.items() if isinstance(tree, dict) else tree if len(tree) == 1 and tree[0][0] == "":
return tree[0][1]
# Special case when we have just one element in the tree ie not a tree try:
if len(items) == 1: int(tree[0][0].split(".", maxsplit=1)[0])
key, value = next(iter(items)) is_list = True
if key == "": except ValueError:
return value is_list = False
# collect children # collect children
children = defaultdict(list) children = defaultdict(list)
for key, value in items: for key, value in tree:
current_idx, *next_idx = key.split(".", maxsplit=1) current_idx, *next_idx = key.split(".", maxsplit=1)
next_idx = "" if not next_idx else next_idx[0] next_idx = "" if not next_idx else next_idx[0]
children[current_idx].append((next_idx, value)) children[current_idx].append((next_idx, value))
# Assume they are a list and fail to dict if the keys are not all integers # recursively map them to the original container
try: if is_list:
keys = sorted((int(idx), idx) for idx in children.keys()) keys = sorted((int(idx), idx) for idx in children.keys())
l = [] l = []
for i, k in keys: for i, k in keys:
@@ -235,7 +202,7 @@ def tree_unflatten(tree: Union[List[Tuple[str, Any]], Dict[str, Any]]) -> Any:
l.extend([{} for _ in range(i - len(l))]) l.extend([{} for _ in range(i - len(l))])
l.append(tree_unflatten(children[k])) l.append(tree_unflatten(children[k]))
return l return l
except ValueError: else:
return {k: tree_unflatten(v) for k, v in children.items()} return {k: tree_unflatten(v) for k, v in children.items()}

View File

@@ -80,7 +80,7 @@ class TestBase(mlx_tests.MLXTestCase):
self.weights = {"w1": mx.zeros((2, 2)), "w2": mx.ones((2, 2))} self.weights = {"w1": mx.zeros((2, 2)), "w2": mx.ones((2, 2))}
model = DictModule() model = DictModule()
params = tree_flatten(model.parameters(), destination={}) params = dict(tree_flatten(model.parameters()))
self.assertEqual(len(params), 2) self.assertEqual(len(params), 2)
self.assertTrue(mx.array_equal(params["weights.w1"], mx.zeros((2, 2)))) self.assertTrue(mx.array_equal(params["weights.w1"], mx.zeros((2, 2))))
self.assertTrue(mx.array_equal(params["weights.w2"], mx.ones((2, 2)))) self.assertTrue(mx.array_equal(params["weights.w2"], mx.ones((2, 2))))