* einsum initial

* fix comma break

* sum axis was wrong

* small cleanups

* python binding

* changed bindings to resemble numpy

* remove todo comment

* comment changes

* add count of operands/inputs

* fail fast if operands list is empty

* ignore comma if no output

* einsum path matching numpy

* getting somewhere with path

* remove print

* it passes the first test

* moved einsum tests to seperate file

* seperated einsum path

* moved einsum naive

* remove space from equation

* fast fail if no operands passed

* update tests and remove printf

* small cleanup

* some more cleanups

* removed python helper file

* ack

* utilize std for finding min in vector

* duplicate def

* remove the tuple as it was unreadable

* moved einsum_naive back to ops

* remaining isn't needed

* avoid creating another set

* cleanup

* greedy path, start of naive einsum

* more einsum

* fix some bugs

* some more fixes, tests pass

* benchmark

* some simplify

* fix einsum and test

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>

* add a bunch more tests and fix a bunch more bugs

* some docs nits

---------

Co-authored-by: dc-dc-dc <dgcruz983@gmail.com>
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
Awni Hannun
2024-07-25 09:36:44 -07:00
committed by GitHub
parent 7f914365fd
commit baf9fa5f42
13 changed files with 1498 additions and 65 deletions

View File

@@ -12,6 +12,7 @@
#include <nanobind/stl/variant.h>
#include <nanobind/stl/vector.h>
#include "mlx/einsum.h"
#include "mlx/ops.h"
#include "mlx/utils.h"
#include "python/src/load.h"
@@ -40,15 +41,6 @@ double scalar_to_double(Scalar s) {
}
void init_ops(nb::module_& m) {
// TODO, remove deprecation errors in a future release
m.def("block_sparse_mm", [](nb::args, nb::kwargs) {
throw std::invalid_argument(
"block_sparse_mm is deprecated. Please use gather_mm which has the same signature");
});
m.def("block_sparse_qmm", [](nb::args, nb::kwargs) {
throw std::invalid_argument(
"block_sparse_qmm is deprecated. Please use gather_qmm which has the same signature");
});
m.def(
"reshape",
&reshape,
@@ -1238,7 +1230,8 @@ void init_ops(nb::module_& m) {
a (array): Input array.
Returns:
array: The unchanged input ``a`` but without gradient flowing
array:
The unchanged input ``a`` but without gradient flowing
through it.
)pbdoc");
m.def(
@@ -2936,6 +2929,9 @@ void init_ops(nb::module_& m) {
reverse (bool): Perform the cumulative sum in reverse.
inclusive (bool): The i-th element of the output includes the i-th
element of the input.
Returns:
array: The output array.
)pbdoc");
m.def(
"cumprod",
@@ -2969,6 +2965,9 @@ void init_ops(nb::module_& m) {
reverse (bool): Perform the cumulative product in reverse.
inclusive (bool): The i-th element of the output includes the i-th
element of the input.
Returns:
array: The output array.
)pbdoc");
m.def(
"cummax",
@@ -3002,6 +3001,9 @@ void init_ops(nb::module_& m) {
reverse (bool): Perform the cumulative maximum in reverse.
inclusive (bool): The i-th element of the output includes the i-th
element of the input.
Returns:
array: The output array.
)pbdoc");
m.def(
"cummin",
@@ -3035,6 +3037,9 @@ void init_ops(nb::module_& m) {
reverse (bool): Perform the cumulative minimum in reverse.
inclusive (bool): The i-th element of the output includes the i-th
element of the input.
Returns:
array: The output array.
)pbdoc");
m.def(
"conj",
@@ -3052,6 +3057,9 @@ void init_ops(nb::module_& m) {
Args:
a (array): Input array
Returns:
array: The output array.
)pbdoc");
m.def(
"conjugate",
@@ -3069,6 +3077,9 @@ void init_ops(nb::module_& m) {
Args:
a (array): Input array
Returns:
array: The output array.
)pbdoc");
m.def(
"convolve",
@@ -3492,14 +3503,11 @@ void init_ops(nb::module_& m) {
Args:
file (file, str): File in which the array is saved.
format (str, optional): Format of the file. If ``None``, the
format
is inferred from the file extension. Supported formats:
``npy``,
``npz``, and ``safetensors``. Default: ``None``.
format is inferred from the file extension. Supported formats:
``npy``, ``npz``, and ``safetensors``. Default: ``None``.
return_metadata (bool, optional): Load the metadata for formats
which
support matadata. The metadata will be returned as an
additional dictionary.
which support matadata. The metadata will be returned as an
additional dictionary. Default: ``False``.
Returns:
array or dict:
A single array if loading from a ``.npy`` file or a dict
@@ -3551,9 +3559,9 @@ void init_ops(nb::module_& m) {
Args:
file (file, str): File in which the array is saved.
arrays (dict(str, array)): The dictionary of names to arrays to
be saved. metadata (dict(str, Union[array, str, list(str)])):
The dictionary of
metadata to be saved. The values can be a scalar or 1D
be saved.
metadata (dict(str, Union[array, str, list(str)])): The dictionary
of metadata to be saved. The values can be a scalar or 1D
obj:`array`, a :obj:`str`, or a :obj:`list` of :obj:`str`.
)pbdoc");
m.def(
@@ -3643,11 +3651,11 @@ void init_ops(nb::module_& m) {
biases (array): The biases to use per ``group_size`` elements of ``w``
transpose (bool, optional): Defines whether to multiply with the
transposed ``w`` or not, namely whether we are performing
``x @ w.T`` or ``x @ w``. (default: ``True``)
``x @ w.T`` or ``x @ w``. Default: ``True``.
group_size (int, optional): The size of the group in ``w`` that
shares a scale and bias. (default: ``64``)
shares a scale and bias. Default: ``64``.
bits (int, optional): The number of bits occupied by each element in
``w``. (default: ``4``)
``w``. Default: ``4``.
Returns:
array: The result of the multiplication of ``x`` with ``w``.
@@ -3700,9 +3708,9 @@ void init_ops(nb::module_& m) {
Args:
w (array): Matrix to be quantized
group_size (int, optional): The size of the group in ``w`` that shares a
scale and bias. (default: ``64``)
scale and bias. Default: ``64``.
bits (int, optional): The number of bits occupied by each element of
``w`` in the returned quantized matrix. (default: ``4``)
``w`` in the returned quantized matrix. Default: ``4``.
Returns:
tuple: A tuple containing
@@ -3740,9 +3748,9 @@ void init_ops(nb::module_& m) {
scales (array): The scales to use per ``group_size`` elements of ``w``
biases (array): The biases to use per ``group_size`` elements of ``w``
group_size (int, optional): The size of the group in ``w`` that shares a
scale and bias. (default: ``64``)
scale and bias. Default: ``64``.
bits (int, optional): The number of bits occupied by each element in
``w``. (default: ``4``)
``w``. Default: ``4``.
Returns:
array: The dequantized version of ``w``
@@ -3779,15 +3787,15 @@ void init_ops(nb::module_& m) {
w (array): Quantized matrix packed in unsigned integers
scales (array): The scales to use per ``group_size`` elements of ``w``
biases (array): The biases to use per ``group_size`` elements of ``w``
lhs_indices (array, optional): Integer indices for ``x`` (default: ``None``)
rhs_indices (array, optional): Integer indices for ``w`` (default: ``None``)
lhs_indices (array, optional): Integer indices for ``x``. Default: ``None``.
rhs_indices (array, optional): Integer indices for ``w``. Default: ``None``.
transpose (bool, optional): Defines whether to multiply with the
transposed ``w`` or not, namely whether we are performing
``x @ w.T`` or ``x @ w``. (default: ``True``)
``x @ w.T`` or ``x @ w``. Default: ``True``.
group_size (int, optional): The size of the group in ``w`` that
shares a scale and bias. (default: ``64``)
shares a scale and bias. Default: ``64``.
bits (int, optional): The number of bits occupied by each element in
``w``. (default: ``4``)
``w``. Default: ``4``.
Returns:
array: The result of the multiplication of ``x`` with ``w``
@@ -3827,7 +3835,7 @@ void init_ops(nb::module_& m) {
sum over. If an integer is provided, then sum over the last
``axes`` dimensions of ``a`` and the first ``axes`` dimensions of
``b``. If a list of lists is provided, then sum over the
corresponding dimensions of ``a`` and ``b``. (default: 2)
corresponding dimensions of ``a`` and ``b``. Default: 2.
Returns:
array: The tensor dot product.
@@ -3958,11 +3966,13 @@ void init_ops(nb::module_& m) {
Args:
a (array): Input array or scalar.
b (array): Input array or scalar.
block_size (int): Size of blocks to be masked. Must be ``32`` or ``64`` (default: ``64``)
mask_out (array, optional): Mask for output (default: ``None``)
mask_lhs (array, optional): Mask for a (default: ``None``)
mask_rhs (array, optional): Mask for b (default: ``None``)
block_size (int): Size of blocks to be masked. Must be ``32`` or ``64``. Default: ``64``.
mask_out (array, optional): Mask for output. Default: ``None``.
mask_lhs (array, optional): Mask for ``a``. Default: ``None``.
mask_rhs (array, optional): Mask for ``b``. Default: ``None``.
Returns:
array: The output array.
)pbdoc");
m.def(
"gather_mm",
@@ -3996,9 +4006,11 @@ void init_ops(nb::module_& m) {
Args:
a (array): Input array.
b (array): Input array.
lhs_indices (array, optional): Integer indices for ``a`` (default: ``None``)
rhs_indices (array, optional): Integer indices for ``b`` (default: ``None``)
lhs_indices (array, optional): Integer indices for ``a``. Default: ``None``
rhs_indices (array, optional): Integer indices for ``b``. Default: ``None``
Returns:
array: The output array.
)pbdoc");
m.def(
"diagonal",
@@ -4406,4 +4418,57 @@ void init_ops(nb::module_& m) {
Returns:
array: The transformed array.
)pbdoc");
m.def(
"einsum_path",
[](const std::string& equation, const nb::args& operands) {
auto arrays_list = nb::cast<std::vector<array>>(operands);
auto [path, str] = einsum_path(equation, arrays_list);
// Convert to list of tuples
std::vector<nb::tuple> tuple_path;
for (auto& p : path) {
tuple_path.push_back(nb::tuple(nb::cast(p)));
}
return std::make_pair(tuple_path, str);
},
"subscripts"_a,
"operands"_a,
nb::sig("def einsum_path(subscripts: str, *operands)"),
R"pbdoc(
Compute the contraction order for the given Einstein summation.
Args:
subscripts (str): The Einstein summation convention equation.
*operands (array): The input arrays.
Returns:
tuple(list(tuple(int, int)), str):
The einsum path and a string containing information about the
chosen path.
)pbdoc");
m.def(
"einsum",
[](const std::string& subscripts,
const nb::args& operands,
StreamOrDevice s) {
auto arrays_list = nb::cast<std::vector<array>>(operands);
return einsum(subscripts, arrays_list, s);
},
"subscripts"_a,
"operands"_a,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def einsum(subscripts: str, *operands, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Perform the Einstein summation convention on the operands.
Args:
subscripts (str): The Einstein summation convention equation.
*operands (array): The input arrays.
Returns:
array: The output array.
)pbdoc");
}