Compare commits

..

8 Commits

Author SHA1 Message Date
Angelos Katharopoulos
a22d0bf273 Add stricter condition to matrix sdpa 2025-08-06 19:51:14 -07:00
Jagrit Digani
99d8de8445 Fix cudnn routing 2025-08-06 15:05:58 -07:00
Jagrit Digani
c66b76a8c8 Update routing 2025-08-06 15:01:15 -07:00
Jagrit Digani
f81edd184f Complete 2 pass sdpav 2025-08-06 13:57:40 -07:00
Jagrit Digani
7f8ba2a003 [WIP] 2 pass sdpav 2025-08-06 09:56:39 -07:00
Jagrit Digani
c28249b81a Add more nvtx range for debug 2025-08-06 09:56:39 -07:00
Jagrit Digani
e74bcdc5e3 Add sdpa file 2025-08-06 09:56:39 -07:00
Jagrit Digani
d8ed6c1aa3 Add base cudnn attention support 2025-08-06 09:56:39 -07:00
9 changed files with 1160 additions and 94 deletions

View File

@@ -51,14 +51,14 @@ the saved state. Here's a simple example:
optimizer.update(model, grads) optimizer.update(model, grads)
# Save the state # Save the state
state = tree_flatten(optimizer.state, destination={}) state = tree_flatten(optimizer.state)
mx.save_safetensors("optimizer.safetensors", state) mx.save_safetensors("optimizer.safetensors", dict(state))
# Later on, for example when loading from a checkpoint, # Later on, for example when loading from a checkpoint,
# recreate the optimizer and load the state # recreate the optimizer and load the state
optimizer = optim.Adam(learning_rate=1e-2) optimizer = optim.Adam(learning_rate=1e-2)
state = tree_unflatten(mx.load("optimizer.safetensors")) state = tree_unflatten(list(mx.load("optimizer.safetensors").items()))
optimizer.state = state optimizer.state = state
Note, not every optimizer configuation parameter is saved in the state. For Note, not every optimizer configuation parameter is saved in the state. For

View File

@@ -151,7 +151,7 @@ parameters, pass them as inputs to the ``call`` wrapper:
model.update(tree_unflatten(list(params.items()))) model.update(tree_unflatten(list(params.items())))
return model(x) return model(x)
params = tree_flatten(model.parameters(), destination={}) params = dict(tree_flatten(model.parameters()))
mx.export_function("model.mlxfn", call, (mx.zeros(4),), params) mx.export_function("model.mlxfn", call, (mx.zeros(4),), params)

View File

@@ -39,6 +39,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu ${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu ${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cu
${CMAKE_CURRENT_SOURCE_DIR}/scan.cu ${CMAKE_CURRENT_SOURCE_DIR}/scan.cu
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu

View File

@@ -6,17 +6,6 @@
namespace mlx::core { namespace mlx::core {
bool fast::ScaledDotProductAttention::use_fallback(
const array& q,
const array& k,
const array& v,
bool has_mask,
bool has_arr_mask,
bool do_causal,
Stream s) {
return true;
}
#define NO_GPU_MULTI(func) \ #define NO_GPU_MULTI(func) \
void func::eval_gpu( \ void func::eval_gpu( \
const std::vector<array>& inputs, std::vector<array>& outputs) { \ const std::vector<array>& inputs, std::vector<array>& outputs) { \
@@ -53,7 +42,6 @@ NO_GPU_MULTI(Eig)
NO_GPU_MULTI(Eigh) NO_GPU_MULTI(Eigh)
namespace fast { namespace fast {
NO_GPU(ScaledDotProductAttention)
NO_GPU_MULTI(CustomKernel) NO_GPU_MULTI(CustomKernel)
} // namespace fast } // namespace fast

File diff suppressed because it is too large Load Diff

View File

@@ -104,7 +104,7 @@ struct CommandEncoder {
}; };
// Outputs of all kernels in the encoder including temporaries // Outputs of all kernels in the encoder including temporaries
std::unordered_set<const void*>& outputs() { std::unordered_set<const void*> outputs() {
return all_outputs_; return all_outputs_;
}; };

View File

@@ -178,7 +178,7 @@ class Module(dict):
if strict: if strict:
new_weights = dict(weights) new_weights = dict(weights)
curr_weights = tree_flatten(self.parameters(), destination={}) curr_weights = dict(tree_flatten(self.parameters()))
if extras := (new_weights.keys() - curr_weights.keys()): if extras := (new_weights.keys() - curr_weights.keys()):
num_extra = len(extras) num_extra = len(extras)
extras = ",\n".join(sorted(extras)) extras = ",\n".join(sorted(extras))
@@ -212,7 +212,7 @@ class Module(dict):
- ``.npz`` will use :func:`mx.savez` - ``.npz`` will use :func:`mx.savez`
- ``.safetensors`` will use :func:`mx.save_safetensors` - ``.safetensors`` will use :func:`mx.save_safetensors`
""" """
params_dict = tree_flatten(self.parameters(), destination={}) params_dict = dict(tree_flatten(self.parameters()))
if file.endswith(".npz"): if file.endswith(".npz"):
mx.savez(file, **params_dict) mx.savez(file, **params_dict)

View File

@@ -1,7 +1,7 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023 Apple Inc.
from collections import defaultdict from collections import defaultdict
from itertools import zip_longest from itertools import zip_longest
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, List, Optional, Tuple
def tree_map( def tree_map(
@@ -114,11 +114,8 @@ def tree_map_with_path(
def tree_flatten( def tree_flatten(
tree: Any, tree: Any, prefix: str = "", is_leaf: Optional[Callable] = None
prefix: str = "", ) -> Any:
is_leaf: Optional[Callable] = None,
destination: Optional[Union[List[Tuple[str, Any]], Dict[str, Any]]] = None,
) -> Union[List[Tuple[str, Any]], Dict[str, Any]]:
"""Flattens a Python tree to a list of key, value tuples. """Flattens a Python tree to a list of key, value tuples.
The keys are using the dot notation to define trees of arbitrary depth and The keys are using the dot notation to define trees of arbitrary depth and
@@ -131,12 +128,9 @@ def tree_flatten(
print(tree_flatten([[[0]]])) print(tree_flatten([[[0]]]))
# [("0.0.0", 0)] # [("0.0.0", 0)]
print(tree_flatten([[[0]]], prefix=".hello")) print(tree_flatten([[[0]]], ".hello"))
# [("hello.0.0.0", 0)] # [("hello.0.0.0", 0)]
tree_flatten({"a": {"b": 1}}, destination={})
{"a.b": 1}
.. note:: .. note::
Dictionaries should have keys that are valid Python identifiers. Dictionaries should have keys that are valid Python identifiers.
@@ -146,50 +140,26 @@ def tree_flatten(
always discarded. always discarded.
is_leaf (callable): An optional callable that returns True if the is_leaf (callable): An optional callable that returns True if the
passed object is considered a leaf or False otherwise. passed object is considered a leaf or False otherwise.
destination (list or dict, optional): A list or dictionary to store the
flattened tree. If None an empty list will be used. Default: ``None``.
Returns: Returns:
Union[List[Tuple[str, Any]], Dict[str, Any]]: The flat representation of List[Tuple[str, Any]]: The flat representation of the Python tree.
the Python tree.
""" """
if destination is None: flat_tree = []
destination = []
# Create the function to update the destination. We are taking advantage of if is_leaf is None or not is_leaf(tree):
# the fact that list.extend and dict.update have the same API to simplify
# the code a bit.
if isinstance(destination, list):
_add_to_destination = destination.extend
elif isinstance(destination, dict):
_add_to_destination = destination.update
else:
raise ValueError("Destination should be either a list or a dictionary or None")
# Leaf identified by is_leaf so add it and return
if is_leaf is not None and is_leaf(tree):
_add_to_destination([(prefix[1:], tree)])
return destination
# List or tuple so recursively add each subtree
if isinstance(tree, (list, tuple)): if isinstance(tree, (list, tuple)):
for i, item in enumerate(tree): for i, t in enumerate(tree):
tree_flatten(item, f"{prefix}.{i}", is_leaf, destination) flat_tree.extend(tree_flatten(t, f"{prefix}.{i}", is_leaf))
return destination return flat_tree
# Dictionary so recursively add each subtree
if isinstance(tree, dict): if isinstance(tree, dict):
for key, value in tree.items(): for k, t in tree.items():
tree_flatten(value, f"{prefix}.{key}", is_leaf, destination) flat_tree.extend(tree_flatten(t, f"{prefix}.{k}", is_leaf))
return destination return flat_tree
# Leaf so add it and return return [(prefix[1:], tree)]
_add_to_destination([(prefix[1:], tree)])
return destination
def tree_unflatten(tree: Union[List[Tuple[str, Any]], Dict[str, Any]]) -> Any: def tree_unflatten(tree: List[Tuple[str, Any]]) -> Any:
"""Recreate a Python tree from its flat representation. """Recreate a Python tree from its flat representation.
.. code-block:: python .. code-block:: python
@@ -200,34 +170,31 @@ def tree_unflatten(tree: Union[List[Tuple[str, Any]], Dict[str, Any]]) -> Any:
print(d) print(d)
# {"hello": {"world": 42}} # {"hello": {"world": 42}}
d = tree_unflatten({"hello.world": 42})
print(d)
# {"hello": {"world": 42}}
Args: Args:
tree (list[tuple[str, Any]] or dict[str, Any]): The flat representation of a Python tree. tree (list[tuple[str, Any]]): The flat representation of a Python tree.
For instance as returned by :meth:`tree_flatten`. For instance as returned by :meth:`tree_flatten`.
Returns: Returns:
A Python tree. A Python tree.
""" """
items = tree.items() if isinstance(tree, dict) else tree if len(tree) == 1 and tree[0][0] == "":
return tree[0][1]
# Special case when we have just one element in the tree ie not a tree try:
if len(items) == 1: int(tree[0][0].split(".", maxsplit=1)[0])
key, value = next(iter(items)) is_list = True
if key == "": except ValueError:
return value is_list = False
# collect children # collect children
children = defaultdict(list) children = defaultdict(list)
for key, value in items: for key, value in tree:
current_idx, *next_idx = key.split(".", maxsplit=1) current_idx, *next_idx = key.split(".", maxsplit=1)
next_idx = "" if not next_idx else next_idx[0] next_idx = "" if not next_idx else next_idx[0]
children[current_idx].append((next_idx, value)) children[current_idx].append((next_idx, value))
# Assume they are a list and fail to dict if the keys are not all integers # recursively map them to the original container
try: if is_list:
keys = sorted((int(idx), idx) for idx in children.keys()) keys = sorted((int(idx), idx) for idx in children.keys())
l = [] l = []
for i, k in keys: for i, k in keys:
@@ -235,7 +202,7 @@ def tree_unflatten(tree: Union[List[Tuple[str, Any]], Dict[str, Any]]) -> Any:
l.extend([{} for _ in range(i - len(l))]) l.extend([{} for _ in range(i - len(l))])
l.append(tree_unflatten(children[k])) l.append(tree_unflatten(children[k]))
return l return l
except ValueError: else:
return {k: tree_unflatten(v) for k, v in children.items()} return {k: tree_unflatten(v) for k, v in children.items()}

View File

@@ -80,7 +80,7 @@ class TestBase(mlx_tests.MLXTestCase):
self.weights = {"w1": mx.zeros((2, 2)), "w2": mx.ones((2, 2))} self.weights = {"w1": mx.zeros((2, 2)), "w2": mx.ones((2, 2))}
model = DictModule() model = DictModule()
params = tree_flatten(model.parameters(), destination={}) params = dict(tree_flatten(model.parameters()))
self.assertEqual(len(params), 2) self.assertEqual(len(params), 2)
self.assertTrue(mx.array_equal(params["weights.w1"], mx.zeros((2, 2)))) self.assertTrue(mx.array_equal(params["weights.w1"], mx.zeros((2, 2))))
self.assertTrue(mx.array_equal(params["weights.w2"], mx.ones((2, 2)))) self.assertTrue(mx.array_equal(params["weights.w2"], mx.ones((2, 2))))