Quantize embedding (#994)

* quantize embedding

* rename as_linear + comment

* consistency in docs

* fix test
This commit is contained in:
Awni Hannun
2024-04-15 16:42:10 -07:00
committed by GitHub
parent 2e7c02d5cd
commit cd9e184529
9 changed files with 269 additions and 54 deletions

View File

@@ -3,7 +3,7 @@ from collections import defaultdict
def tree_map(fn, tree, *rest, is_leaf=None):
"""Applies ``fn`` to the leaves of the python tree ``tree`` and
"""Applies ``fn`` to the leaves of the Python tree ``tree`` and
returns a new collection with the results.
If ``rest`` is provided, every item is assumed to be a superset of ``tree``
@@ -27,14 +27,14 @@ def tree_map(fn, tree, *rest, is_leaf=None):
model.update(tree_map(lambda x: x*x, model.parameters()))
Args:
fn (Callable): The function that processes the leaves of the tree
tree (Any): The main python tree that will be iterated upon
rest (Tuple[Any]): Extra trees to be iterated together with tree
is_leaf (Optional[Callable]): An optional callable that returns True if
the passed object is considered a leaf or False otherwise.
fn (callable): The function that processes the leaves of the tree.
tree (Any): The main Python tree that will be iterated upon.
rest (tuple[Any]): Extra trees to be iterated together with ``tree``.
is_leaf (callable, optional): An optional callable that returns ``True``
if the passed object is considered a leaf or ``False`` otherwise.
Returns:
A python tree with the new values returned by ``fn``.
A Python tree with the new values returned by ``fn``.
"""
if is_leaf is not None and is_leaf(tree):
return fn(tree, *rest)
@@ -53,8 +53,57 @@ def tree_map(fn, tree, *rest, is_leaf=None):
return fn(tree, *rest)
def tree_map_with_path(fn, tree, *rest, is_leaf=None, path=None):
"""Applies ``fn`` to the path and leaves of the Python tree ``tree`` and
returns a new collection with the results.
This function is the same :func:`tree_map` but the ``fn`` takes the path as
the first argument followed by the remaining tree nodes.
Args:
fn (callable): The function that processes the leaves of the tree.
tree (Any): The main Python tree that will be iterated upon.
rest (tuple[Any]): Extra trees to be iterated together with ``tree``.
is_leaf (callable, optional): An optional callable that returns ``True``
if the passed object is considered a leaf or ``False`` otherwise.
Returns:
A Python tree with the new values returned by ``fn``.
Example:
>>> from mlx.utils import tree_map_with_path
>>> tree = {"model": [{"w": 0, "b": 1}, {"w": 0, "b": 1}]}
>>> new_tree = tree_map_with_path(lambda path, _: print(path), tree)
model.0.w
model.0.b
model.1.w
model.1.b
"""
if is_leaf is not None and is_leaf(tree):
return fn(path, tree, *rest)
elif isinstance(tree, (list, tuple)):
prefix = f"{path}." if path else ""
TreeType = type(tree)
return TreeType(
tree_map_with_path(
fn, child, *(r[i] for r in rest), is_leaf=is_leaf, path=f"{prefix}{i}"
)
for i, child in enumerate(tree)
)
elif isinstance(tree, dict):
prefix = f"{path}." if path else ""
return {
k: tree_map_with_path(
fn, child, *(r[k] for r in rest), is_leaf=is_leaf, path=f"{prefix}{k}"
)
for k, child in tree.items()
}
else:
return fn(path, tree, *rest)
def tree_flatten(tree, prefix="", is_leaf=None):
"""Flattens a python tree to a list of key, value tuples.
"""Flattens a Python tree to a list of key, value tuples.
The keys are using the dot notation to define trees of arbitrary depth and
complexity.
@@ -70,17 +119,17 @@ def tree_flatten(tree, prefix="", is_leaf=None):
# [("hello.0.0.0", 0)]
.. note::
Dictionaries should have keys that are valid python identifiers.
Dictionaries should have keys that are valid Python identifiers.
Args:
tree (Any): The python tree to be flattened.
tree (Any): The Python tree to be flattened.
prefix (str): A prefix to use for the keys. The first character is
always discarded.
is_leaf (Callable): An optional callable that returns True if the
is_leaf (callable): An optional callable that returns True if the
passed object is considered a leaf or False otherwise.
Returns:
List[Tuple[str, Any]]: The flat representation of the python tree.
List[Tuple[str, Any]]: The flat representation of the Python tree.
"""
flat_tree = []
@@ -98,7 +147,7 @@ def tree_flatten(tree, prefix="", is_leaf=None):
def tree_unflatten(tree):
"""Recreate a python tree from its flat representation.
"""Recreate a Python tree from its flat representation.
.. code-block:: python
@@ -109,11 +158,11 @@ def tree_unflatten(tree):
# {"hello": {"world": 42}}
Args:
tree (List[Tuple[str, Any]]): The flat representation of a python tree.
For instance as returned by :meth:`tree_flatten`.
tree (list[tuple[str, Any]]): The flat representation of a Python tree.
For instance as returned by :meth:`tree_flatten`.
Returns:
A python tree.
A Python tree.
"""
if len(tree) == 1 and tree[0][0] == "":
return tree[0][1]