mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-20 17:38:09 +08:00
Einsum (#1269)
* einsum initial * fix comma break * sum axis was wrong * small cleanups * python binding * changed bindings to resemble numpy * remove todo comment * comment changes * add count of operands/inputs * fail fast if operands list is empty * ignore comma if no output * einsum path matching numpy * getting somewhere with path * remove print * it passes the first test * moved einsum tests to seperate file * seperated einsum path * moved einsum naive * remove space from equation * fast fail if no operands passed * update tests and remove printf * small cleanup * some more cleanups * removed python helper file * ack * utilize std for finding min in vector * duplicate def * remove the tuple as it was unreadable * moved einsum_naive back to ops * remaining isn't needed * avoid creating another set * cleanup * greedy path, start of naive einsum * more einsum * fix some bugs * some more fixes, tests pass * benchmark * some simplify * fix einsum and test Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com> * add a bunch more tests and fix a bunch more bugs * some docs nits --------- Co-authored-by: dc-dc-dc <dgcruz983@gmail.com> Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
@@ -12,6 +12,7 @@
|
||||
#include <nanobind/stl/variant.h>
|
||||
#include <nanobind/stl/vector.h>
|
||||
|
||||
#include "mlx/einsum.h"
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/utils.h"
|
||||
#include "python/src/load.h"
|
||||
@@ -40,15 +41,6 @@ double scalar_to_double(Scalar s) {
|
||||
}
|
||||
|
||||
void init_ops(nb::module_& m) {
|
||||
// TODO, remove deprecation errors in a future release
|
||||
m.def("block_sparse_mm", [](nb::args, nb::kwargs) {
|
||||
throw std::invalid_argument(
|
||||
"block_sparse_mm is deprecated. Please use gather_mm which has the same signature");
|
||||
});
|
||||
m.def("block_sparse_qmm", [](nb::args, nb::kwargs) {
|
||||
throw std::invalid_argument(
|
||||
"block_sparse_qmm is deprecated. Please use gather_qmm which has the same signature");
|
||||
});
|
||||
m.def(
|
||||
"reshape",
|
||||
&reshape,
|
||||
@@ -1238,7 +1230,8 @@ void init_ops(nb::module_& m) {
|
||||
a (array): Input array.
|
||||
|
||||
Returns:
|
||||
array: The unchanged input ``a`` but without gradient flowing
|
||||
array:
|
||||
The unchanged input ``a`` but without gradient flowing
|
||||
through it.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
@@ -2936,6 +2929,9 @@ void init_ops(nb::module_& m) {
|
||||
reverse (bool): Perform the cumulative sum in reverse.
|
||||
inclusive (bool): The i-th element of the output includes the i-th
|
||||
element of the input.
|
||||
|
||||
Returns:
|
||||
array: The output array.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"cumprod",
|
||||
@@ -2969,6 +2965,9 @@ void init_ops(nb::module_& m) {
|
||||
reverse (bool): Perform the cumulative product in reverse.
|
||||
inclusive (bool): The i-th element of the output includes the i-th
|
||||
element of the input.
|
||||
|
||||
Returns:
|
||||
array: The output array.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"cummax",
|
||||
@@ -3002,6 +3001,9 @@ void init_ops(nb::module_& m) {
|
||||
reverse (bool): Perform the cumulative maximum in reverse.
|
||||
inclusive (bool): The i-th element of the output includes the i-th
|
||||
element of the input.
|
||||
|
||||
Returns:
|
||||
array: The output array.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"cummin",
|
||||
@@ -3035,6 +3037,9 @@ void init_ops(nb::module_& m) {
|
||||
reverse (bool): Perform the cumulative minimum in reverse.
|
||||
inclusive (bool): The i-th element of the output includes the i-th
|
||||
element of the input.
|
||||
|
||||
Returns:
|
||||
array: The output array.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"conj",
|
||||
@@ -3052,6 +3057,9 @@ void init_ops(nb::module_& m) {
|
||||
|
||||
Args:
|
||||
a (array): Input array
|
||||
|
||||
Returns:
|
||||
array: The output array.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"conjugate",
|
||||
@@ -3069,6 +3077,9 @@ void init_ops(nb::module_& m) {
|
||||
|
||||
Args:
|
||||
a (array): Input array
|
||||
|
||||
Returns:
|
||||
array: The output array.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"convolve",
|
||||
@@ -3492,14 +3503,11 @@ void init_ops(nb::module_& m) {
|
||||
Args:
|
||||
file (file, str): File in which the array is saved.
|
||||
format (str, optional): Format of the file. If ``None``, the
|
||||
format
|
||||
is inferred from the file extension. Supported formats:
|
||||
``npy``,
|
||||
``npz``, and ``safetensors``. Default: ``None``.
|
||||
format is inferred from the file extension. Supported formats:
|
||||
``npy``, ``npz``, and ``safetensors``. Default: ``None``.
|
||||
return_metadata (bool, optional): Load the metadata for formats
|
||||
which
|
||||
support matadata. The metadata will be returned as an
|
||||
additional dictionary.
|
||||
which support matadata. The metadata will be returned as an
|
||||
additional dictionary. Default: ``False``.
|
||||
Returns:
|
||||
array or dict:
|
||||
A single array if loading from a ``.npy`` file or a dict
|
||||
@@ -3551,9 +3559,9 @@ void init_ops(nb::module_& m) {
|
||||
Args:
|
||||
file (file, str): File in which the array is saved.
|
||||
arrays (dict(str, array)): The dictionary of names to arrays to
|
||||
be saved. metadata (dict(str, Union[array, str, list(str)])):
|
||||
The dictionary of
|
||||
metadata to be saved. The values can be a scalar or 1D
|
||||
be saved.
|
||||
metadata (dict(str, Union[array, str, list(str)])): The dictionary
|
||||
of metadata to be saved. The values can be a scalar or 1D
|
||||
obj:`array`, a :obj:`str`, or a :obj:`list` of :obj:`str`.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
@@ -3643,11 +3651,11 @@ void init_ops(nb::module_& m) {
|
||||
biases (array): The biases to use per ``group_size`` elements of ``w``
|
||||
transpose (bool, optional): Defines whether to multiply with the
|
||||
transposed ``w`` or not, namely whether we are performing
|
||||
``x @ w.T`` or ``x @ w``. (default: ``True``)
|
||||
``x @ w.T`` or ``x @ w``. Default: ``True``.
|
||||
group_size (int, optional): The size of the group in ``w`` that
|
||||
shares a scale and bias. (default: ``64``)
|
||||
shares a scale and bias. Default: ``64``.
|
||||
bits (int, optional): The number of bits occupied by each element in
|
||||
``w``. (default: ``4``)
|
||||
``w``. Default: ``4``.
|
||||
|
||||
Returns:
|
||||
array: The result of the multiplication of ``x`` with ``w``.
|
||||
@@ -3700,9 +3708,9 @@ void init_ops(nb::module_& m) {
|
||||
Args:
|
||||
w (array): Matrix to be quantized
|
||||
group_size (int, optional): The size of the group in ``w`` that shares a
|
||||
scale and bias. (default: ``64``)
|
||||
scale and bias. Default: ``64``.
|
||||
bits (int, optional): The number of bits occupied by each element of
|
||||
``w`` in the returned quantized matrix. (default: ``4``)
|
||||
``w`` in the returned quantized matrix. Default: ``4``.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing
|
||||
@@ -3740,9 +3748,9 @@ void init_ops(nb::module_& m) {
|
||||
scales (array): The scales to use per ``group_size`` elements of ``w``
|
||||
biases (array): The biases to use per ``group_size`` elements of ``w``
|
||||
group_size (int, optional): The size of the group in ``w`` that shares a
|
||||
scale and bias. (default: ``64``)
|
||||
scale and bias. Default: ``64``.
|
||||
bits (int, optional): The number of bits occupied by each element in
|
||||
``w``. (default: ``4``)
|
||||
``w``. Default: ``4``.
|
||||
|
||||
Returns:
|
||||
array: The dequantized version of ``w``
|
||||
@@ -3779,15 +3787,15 @@ void init_ops(nb::module_& m) {
|
||||
w (array): Quantized matrix packed in unsigned integers
|
||||
scales (array): The scales to use per ``group_size`` elements of ``w``
|
||||
biases (array): The biases to use per ``group_size`` elements of ``w``
|
||||
lhs_indices (array, optional): Integer indices for ``x`` (default: ``None``)
|
||||
rhs_indices (array, optional): Integer indices for ``w`` (default: ``None``)
|
||||
lhs_indices (array, optional): Integer indices for ``x``. Default: ``None``.
|
||||
rhs_indices (array, optional): Integer indices for ``w``. Default: ``None``.
|
||||
transpose (bool, optional): Defines whether to multiply with the
|
||||
transposed ``w`` or not, namely whether we are performing
|
||||
``x @ w.T`` or ``x @ w``. (default: ``True``)
|
||||
``x @ w.T`` or ``x @ w``. Default: ``True``.
|
||||
group_size (int, optional): The size of the group in ``w`` that
|
||||
shares a scale and bias. (default: ``64``)
|
||||
shares a scale and bias. Default: ``64``.
|
||||
bits (int, optional): The number of bits occupied by each element in
|
||||
``w``. (default: ``4``)
|
||||
``w``. Default: ``4``.
|
||||
|
||||
Returns:
|
||||
array: The result of the multiplication of ``x`` with ``w``
|
||||
@@ -3827,7 +3835,7 @@ void init_ops(nb::module_& m) {
|
||||
sum over. If an integer is provided, then sum over the last
|
||||
``axes`` dimensions of ``a`` and the first ``axes`` dimensions of
|
||||
``b``. If a list of lists is provided, then sum over the
|
||||
corresponding dimensions of ``a`` and ``b``. (default: 2)
|
||||
corresponding dimensions of ``a`` and ``b``. Default: 2.
|
||||
|
||||
Returns:
|
||||
array: The tensor dot product.
|
||||
@@ -3958,11 +3966,13 @@ void init_ops(nb::module_& m) {
|
||||
Args:
|
||||
a (array): Input array or scalar.
|
||||
b (array): Input array or scalar.
|
||||
block_size (int): Size of blocks to be masked. Must be ``32`` or ``64`` (default: ``64``)
|
||||
mask_out (array, optional): Mask for output (default: ``None``)
|
||||
mask_lhs (array, optional): Mask for a (default: ``None``)
|
||||
mask_rhs (array, optional): Mask for b (default: ``None``)
|
||||
block_size (int): Size of blocks to be masked. Must be ``32`` or ``64``. Default: ``64``.
|
||||
mask_out (array, optional): Mask for output. Default: ``None``.
|
||||
mask_lhs (array, optional): Mask for ``a``. Default: ``None``.
|
||||
mask_rhs (array, optional): Mask for ``b``. Default: ``None``.
|
||||
|
||||
Returns:
|
||||
array: The output array.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"gather_mm",
|
||||
@@ -3996,9 +4006,11 @@ void init_ops(nb::module_& m) {
|
||||
Args:
|
||||
a (array): Input array.
|
||||
b (array): Input array.
|
||||
lhs_indices (array, optional): Integer indices for ``a`` (default: ``None``)
|
||||
rhs_indices (array, optional): Integer indices for ``b`` (default: ``None``)
|
||||
lhs_indices (array, optional): Integer indices for ``a``. Default: ``None``
|
||||
rhs_indices (array, optional): Integer indices for ``b``. Default: ``None``
|
||||
|
||||
Returns:
|
||||
array: The output array.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"diagonal",
|
||||
@@ -4406,4 +4418,57 @@ void init_ops(nb::module_& m) {
|
||||
Returns:
|
||||
array: The transformed array.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"einsum_path",
|
||||
[](const std::string& equation, const nb::args& operands) {
|
||||
auto arrays_list = nb::cast<std::vector<array>>(operands);
|
||||
auto [path, str] = einsum_path(equation, arrays_list);
|
||||
// Convert to list of tuples
|
||||
std::vector<nb::tuple> tuple_path;
|
||||
for (auto& p : path) {
|
||||
tuple_path.push_back(nb::tuple(nb::cast(p)));
|
||||
}
|
||||
return std::make_pair(tuple_path, str);
|
||||
},
|
||||
"subscripts"_a,
|
||||
"operands"_a,
|
||||
nb::sig("def einsum_path(subscripts: str, *operands)"),
|
||||
R"pbdoc(
|
||||
|
||||
Compute the contraction order for the given Einstein summation.
|
||||
|
||||
Args:
|
||||
subscripts (str): The Einstein summation convention equation.
|
||||
*operands (array): The input arrays.
|
||||
|
||||
Returns:
|
||||
tuple(list(tuple(int, int)), str):
|
||||
The einsum path and a string containing information about the
|
||||
chosen path.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"einsum",
|
||||
[](const std::string& subscripts,
|
||||
const nb::args& operands,
|
||||
StreamOrDevice s) {
|
||||
auto arrays_list = nb::cast<std::vector<array>>(operands);
|
||||
return einsum(subscripts, arrays_list, s);
|
||||
},
|
||||
"subscripts"_a,
|
||||
"operands"_a,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def einsum(subscripts: str, *operands, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
|
||||
Perform the Einstein summation convention on the operands.
|
||||
|
||||
Args:
|
||||
subscripts (str): The Einstein summation convention equation.
|
||||
*operands (array): The input arrays.
|
||||
|
||||
Returns:
|
||||
array: The output array.
|
||||
)pbdoc");
|
||||
}
|
||||
|
Reference in New Issue
Block a user