mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-15 17:39:05 +08:00
Einsum (#1269)
* einsum initial * fix comma break * sum axis was wrong * small cleanups * python binding * changed bindings to resemble numpy * remove todo comment * comment changes * add count of operands/inputs * fail fast if operands list is empty * ignore comma if no output * einsum path matching numpy * getting somewhere with path * remove print * it passes the first test * moved einsum tests to seperate file * seperated einsum path * moved einsum naive * remove space from equation * fast fail if no operands passed * update tests and remove printf * small cleanup * some more cleanups * removed python helper file * ack * utilize std for finding min in vector * duplicate def * remove the tuple as it was unreadable * moved einsum_naive back to ops * remaining isn't needed * avoid creating another set * cleanup * greedy path, start of naive einsum * more einsum * fix some bugs * some more fixes, tests pass * benchmark * some simplify * fix einsum and test Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com> * add a bunch more tests and fix a bunch more bugs * some docs nits --------- Co-authored-by: dc-dc-dc <dgcruz983@gmail.com> Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
@@ -99,7 +99,7 @@ void init_random(nb::module_& parent_module) {
|
||||
|
||||
Args:
|
||||
key (array): Input key to split.
|
||||
num (int, optional): Number of sub keys. Default is 2.
|
||||
num (int, optional): Number of sub keys. Default: ``2``.
|
||||
|
||||
Returns:
|
||||
array: The array of sub keys with ``num`` as its first dimension.
|
||||
@@ -137,11 +137,13 @@ void init_random(nb::module_& parent_module) {
|
||||
broadcastable to ``shape``.
|
||||
|
||||
Args:
|
||||
low (scalar or array, optional): Lower bound of the distribution. Default is ``0``.
|
||||
high (scalar or array, optional): Upper bound of the distribution. Default is ``1``.
|
||||
shape (list(int), optional): Shape of the output. Default is ``()``.
|
||||
low (scalar or array, optional): Lower bound of the distribution.
|
||||
Default: ``0``.
|
||||
high (scalar or array, optional): Upper bound of the distribution.
|
||||
Default: ``1``.
|
||||
shape (list(int), optional): Shape of the output. Default:``()``.
|
||||
dtype (Dtype, optional): Type of the output. Default: ``float32``.
|
||||
key (array, optional): A PRNG key. Default: ``None``.
|
||||
dtype (Dtype, optional): Type of the output. Default is ``float32``.
|
||||
|
||||
Returns:
|
||||
array: The output array random values.
|
||||
@@ -250,9 +252,9 @@ void init_random(nb::module_& parent_module) {
|
||||
Args:
|
||||
low (scalar or array): Lower bound of the interval.
|
||||
high (scalar or array): Upper bound of the interval.
|
||||
shape (list(int), optional): Shape of the output. Defaults to ``()``.
|
||||
dtype (Dtype, optional): Type of the output. Defaults to ``int32``.
|
||||
key (array, optional): A PRNG key. Default: None.
|
||||
shape (list(int), optional): Shape of the output. Default: ``()``.
|
||||
dtype (Dtype, optional): Type of the output. Default: ``int32``.
|
||||
key (array, optional): A PRNG key. Default: ``None``.
|
||||
|
||||
Returns:
|
||||
array: The array of random integers.
|
||||
@@ -286,10 +288,10 @@ void init_random(nb::module_& parent_module) {
|
||||
|
||||
Args:
|
||||
p (float or array, optional): Parameter of the Bernoulli
|
||||
distribution. Default is 0.5.
|
||||
shape (list(int), optional): Shape of the output. The default
|
||||
shape is ``p.shape``.
|
||||
key (array, optional): A PRNG key. Default: None.
|
||||
distribution. Default: ``0.5``.
|
||||
shape (list(int), optional): Shape of the output.
|
||||
Default: ``p.shape``.
|
||||
key (array, optional): A PRNG key. Default: ``None``.
|
||||
|
||||
Returns:
|
||||
array: The array of random integers.
|
||||
@@ -331,10 +333,10 @@ void init_random(nb::module_& parent_module) {
|
||||
lower (scalar or array): Lower bound of the domain.
|
||||
upper (scalar or array): Upper bound of the domain.
|
||||
shape (list(int), optional): The shape of the output.
|
||||
Default is ``()``.
|
||||
Default:``()``.
|
||||
dtype (Dtype, optional): The data type of the output.
|
||||
Default is ``float32``.
|
||||
key (array, optional): A PRNG key. Default: None.
|
||||
Default: ``float32``.
|
||||
key (array, optional): A PRNG key. Default: ``None``.
|
||||
|
||||
Returns:
|
||||
array: The output array of random values.
|
||||
@@ -362,7 +364,7 @@ void init_random(nb::module_& parent_module) {
|
||||
|
||||
Args:
|
||||
shape (list(int)): The shape of the output.
|
||||
key (array, optional): A PRNG key. Default: None.
|
||||
key (array, optional): A PRNG key. Default: ``None``.
|
||||
|
||||
Returns:
|
||||
array: The :class:`array` with shape ``shape`` and
|
||||
@@ -407,14 +409,14 @@ void init_random(nb::module_& parent_module) {
|
||||
Args:
|
||||
logits (array): The *unnormalized* categorical distribution(s).
|
||||
axis (int, optional): The axis which specifies the distribution.
|
||||
Default is ``-1``.
|
||||
Default: ``-1``.
|
||||
shape (list(int), optional): The shape of the output. This must
|
||||
be broadcast compatable with ``logits.shape`` with the ``axis``
|
||||
dimension removed. Default: ``None``
|
||||
num_samples (int, optional): The number of samples to draw from each
|
||||
of the categorical distributions in ``logits``. The output will have
|
||||
``num_samples`` in the last dimension. Default: ``None``.
|
||||
key (array, optional): A PRNG key. Default: None.
|
||||
key (array, optional): A PRNG key. Default: ``None``.
|
||||
|
||||
Returns:
|
||||
array: The ``shape``-sized output array with type ``uint32``.
|
||||
@@ -442,11 +444,12 @@ void init_random(nb::module_& parent_module) {
|
||||
Sample numbers from a Laplace distribution.
|
||||
|
||||
Args:
|
||||
shape (list(int), optional): Shape of the output. Default is ``()``.
|
||||
dtype (Dtype, optional): Type of the output. Default is ``float32``.
|
||||
loc (float, optional): Mean of the distribution. Default is ``0.0``.
|
||||
scale (float, optional): The scale "b" of the Laplace distribution. Default is ``1.0``.
|
||||
key (array, optional): A PRNG key. Default: None.
|
||||
shape (list(int), optional): Shape of the output. Default: ``()``.
|
||||
dtype (Dtype, optional): Type of the output. Default: ``float32``.
|
||||
loc (float, optional): Mean of the distribution. Default: ``0.0``.
|
||||
scale (float, optional): The scale "b" of the Laplace distribution.
|
||||
Default:``1.0``.
|
||||
key (array, optional): A PRNG key. Default: ``None``.
|
||||
|
||||
Returns:
|
||||
array: The output array of random values.
|
||||
|
||||
Reference in New Issue
Block a user