mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Adds C++ and nn quantization utilities (#230)
* Add C++ de-/quantize ops * Add quantize functions to the docs and tests * Add a QuantizedLinear module
This commit is contained in:
committed by
GitHub
parent
4912ff3ec2
commit
57fe918cf8
@@ -1,7 +1,7 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
|
||||
def tree_map(fn, tree, *rest):
|
||||
def tree_map(fn, tree, *rest, is_leaf=None):
|
||||
"""Applies ``fn`` to the leaves of the python tree ``tree`` and
|
||||
returns a new collection with the results.
|
||||
|
||||
@@ -10,6 +10,9 @@ def tree_map(fn, tree, *rest):
|
||||
``fn``. In that respect, :meth:`tree_map` is closer to :func:`itertools.starmap`
|
||||
than to :func:`map`.
|
||||
|
||||
The keyword argument ``is_leaf`` decides what constitutes a leaf from
|
||||
``tree`` similar to :func:`tree_flatten`.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.nn as nn
|
||||
@@ -26,21 +29,28 @@ def tree_map(fn, tree, *rest):
|
||||
fn (Callable): The function that processes the leaves of the tree
|
||||
tree (Any): The main python tree that will be iterated upon
|
||||
rest (Tuple[Any]): Extra trees to be iterated together with tree
|
||||
is_leaf (Optional[Callable]): An optional callable that returns True if
|
||||
the passed object is considered a leaf or False otherwise.
|
||||
|
||||
Returns:
|
||||
A python tree with the new values returned by ``fn``.
|
||||
"""
|
||||
if isinstance(tree, list):
|
||||
if is_leaf is not None and is_leaf(tree):
|
||||
return fn(tree, *rest)
|
||||
elif isinstance(tree, list):
|
||||
return [
|
||||
tree_map(fn, child, *(r[i] for r in rest)) for i, child in enumerate(tree)
|
||||
tree_map(fn, child, *(r[i] for r in rest), is_leaf=is_leaf)
|
||||
for i, child in enumerate(tree)
|
||||
]
|
||||
elif isinstance(tree, tuple):
|
||||
return tuple(
|
||||
tree_map(fn, child, *(r[i] for r in rest)) for i, child in enumerate(tree)
|
||||
tree_map(fn, child, *(r[i] for r in rest), is_leaf=is_leaf)
|
||||
for i, child in enumerate(tree)
|
||||
)
|
||||
elif isinstance(tree, dict):
|
||||
return {
|
||||
k: tree_map(fn, child, *(r[k] for r in rest)) for k, child in tree.items()
|
||||
k: tree_map(fn, child, *(r[k] for r in rest), is_leaf=is_leaf)
|
||||
for k, child in tree.items()
|
||||
}
|
||||
else:
|
||||
return fn(tree, *rest)
|
||||
|
||||
Reference in New Issue
Block a user