mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Quantize embedding (#994)
* quantize embedding * rename as_linear + comment * consistency in docs * fix test
This commit is contained in:
@@ -3,7 +3,7 @@ from collections import defaultdict
|
||||
|
||||
|
||||
def tree_map(fn, tree, *rest, is_leaf=None):
|
||||
"""Applies ``fn`` to the leaves of the python tree ``tree`` and
|
||||
"""Applies ``fn`` to the leaves of the Python tree ``tree`` and
|
||||
returns a new collection with the results.
|
||||
|
||||
If ``rest`` is provided, every item is assumed to be a superset of ``tree``
|
||||
@@ -27,14 +27,14 @@ def tree_map(fn, tree, *rest, is_leaf=None):
|
||||
model.update(tree_map(lambda x: x*x, model.parameters()))
|
||||
|
||||
Args:
|
||||
fn (Callable): The function that processes the leaves of the tree
|
||||
tree (Any): The main python tree that will be iterated upon
|
||||
rest (Tuple[Any]): Extra trees to be iterated together with tree
|
||||
is_leaf (Optional[Callable]): An optional callable that returns True if
|
||||
the passed object is considered a leaf or False otherwise.
|
||||
fn (callable): The function that processes the leaves of the tree.
|
||||
tree (Any): The main Python tree that will be iterated upon.
|
||||
rest (tuple[Any]): Extra trees to be iterated together with ``tree``.
|
||||
is_leaf (callable, optional): An optional callable that returns ``True``
|
||||
if the passed object is considered a leaf or ``False`` otherwise.
|
||||
|
||||
Returns:
|
||||
A python tree with the new values returned by ``fn``.
|
||||
A Python tree with the new values returned by ``fn``.
|
||||
"""
|
||||
if is_leaf is not None and is_leaf(tree):
|
||||
return fn(tree, *rest)
|
||||
@@ -53,8 +53,57 @@ def tree_map(fn, tree, *rest, is_leaf=None):
|
||||
return fn(tree, *rest)
|
||||
|
||||
|
||||
def tree_map_with_path(fn, tree, *rest, is_leaf=None, path=None):
|
||||
"""Applies ``fn`` to the path and leaves of the Python tree ``tree`` and
|
||||
returns a new collection with the results.
|
||||
|
||||
This function is the same :func:`tree_map` but the ``fn`` takes the path as
|
||||
the first argument followed by the remaining tree nodes.
|
||||
|
||||
Args:
|
||||
fn (callable): The function that processes the leaves of the tree.
|
||||
tree (Any): The main Python tree that will be iterated upon.
|
||||
rest (tuple[Any]): Extra trees to be iterated together with ``tree``.
|
||||
is_leaf (callable, optional): An optional callable that returns ``True``
|
||||
if the passed object is considered a leaf or ``False`` otherwise.
|
||||
|
||||
Returns:
|
||||
A Python tree with the new values returned by ``fn``.
|
||||
|
||||
Example:
|
||||
>>> from mlx.utils import tree_map_with_path
|
||||
>>> tree = {"model": [{"w": 0, "b": 1}, {"w": 0, "b": 1}]}
|
||||
>>> new_tree = tree_map_with_path(lambda path, _: print(path), tree)
|
||||
model.0.w
|
||||
model.0.b
|
||||
model.1.w
|
||||
model.1.b
|
||||
"""
|
||||
if is_leaf is not None and is_leaf(tree):
|
||||
return fn(path, tree, *rest)
|
||||
elif isinstance(tree, (list, tuple)):
|
||||
prefix = f"{path}." if path else ""
|
||||
TreeType = type(tree)
|
||||
return TreeType(
|
||||
tree_map_with_path(
|
||||
fn, child, *(r[i] for r in rest), is_leaf=is_leaf, path=f"{prefix}{i}"
|
||||
)
|
||||
for i, child in enumerate(tree)
|
||||
)
|
||||
elif isinstance(tree, dict):
|
||||
prefix = f"{path}." if path else ""
|
||||
return {
|
||||
k: tree_map_with_path(
|
||||
fn, child, *(r[k] for r in rest), is_leaf=is_leaf, path=f"{prefix}{k}"
|
||||
)
|
||||
for k, child in tree.items()
|
||||
}
|
||||
else:
|
||||
return fn(path, tree, *rest)
|
||||
|
||||
|
||||
def tree_flatten(tree, prefix="", is_leaf=None):
|
||||
"""Flattens a python tree to a list of key, value tuples.
|
||||
"""Flattens a Python tree to a list of key, value tuples.
|
||||
|
||||
The keys are using the dot notation to define trees of arbitrary depth and
|
||||
complexity.
|
||||
@@ -70,17 +119,17 @@ def tree_flatten(tree, prefix="", is_leaf=None):
|
||||
# [("hello.0.0.0", 0)]
|
||||
|
||||
.. note::
|
||||
Dictionaries should have keys that are valid python identifiers.
|
||||
Dictionaries should have keys that are valid Python identifiers.
|
||||
|
||||
Args:
|
||||
tree (Any): The python tree to be flattened.
|
||||
tree (Any): The Python tree to be flattened.
|
||||
prefix (str): A prefix to use for the keys. The first character is
|
||||
always discarded.
|
||||
is_leaf (Callable): An optional callable that returns True if the
|
||||
is_leaf (callable): An optional callable that returns True if the
|
||||
passed object is considered a leaf or False otherwise.
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, Any]]: The flat representation of the python tree.
|
||||
List[Tuple[str, Any]]: The flat representation of the Python tree.
|
||||
"""
|
||||
flat_tree = []
|
||||
|
||||
@@ -98,7 +147,7 @@ def tree_flatten(tree, prefix="", is_leaf=None):
|
||||
|
||||
|
||||
def tree_unflatten(tree):
|
||||
"""Recreate a python tree from its flat representation.
|
||||
"""Recreate a Python tree from its flat representation.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@@ -109,11 +158,11 @@ def tree_unflatten(tree):
|
||||
# {"hello": {"world": 42}}
|
||||
|
||||
Args:
|
||||
tree (List[Tuple[str, Any]]): The flat representation of a python tree.
|
||||
For instance as returned by :meth:`tree_flatten`.
|
||||
tree (list[tuple[str, Any]]): The flat representation of a Python tree.
|
||||
For instance as returned by :meth:`tree_flatten`.
|
||||
|
||||
Returns:
|
||||
A python tree.
|
||||
A Python tree.
|
||||
"""
|
||||
if len(tree) == 1 and tree[0][0] == "":
|
||||
return tree[0][1]
|
||||
|
||||
Reference in New Issue
Block a user