* einsum initial

* fix comma break

* sum axis was wrong

* small cleanups

* python binding

* changed bindings to resemble numpy

* remove todo comment

* comment changes

* add count of operands/inputs

* fail fast if operands list is empty

* ignore comma if no output

* einsum path matching numpy

* getting somewhere with path

* remove print

* it passes the first test

* moved einsum tests to seperate file

* seperated einsum path

* moved einsum naive

* remove space from equation

* fast fail if no operands passed

* update tests and remove printf

* small cleanup

* some more cleanups

* removed python helper file

* ack

* utilize std for finding min in vector

* duplicate def

* remove the tuple as it was unreadable

* moved einsum_naive back to ops

* remaining isn't needed

* avoid creating another set

* cleanup

* greedy path, start of naive einsum

* more einsum

* fix some bugs

* some more fixes, tests pass

* benchmark

* some simplify

* fix einsum and test

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>

* add a bunch more tests and fix a bunch more bugs

* some docs nits

---------

Co-authored-by: dc-dc-dc <dgcruz983@gmail.com>
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
Awni Hannun
2024-07-25 09:36:44 -07:00
committed by GitHub
parent 7f914365fd
commit baf9fa5f42
13 changed files with 1498 additions and 65 deletions

View File

@@ -99,7 +99,7 @@ void init_random(nb::module_& parent_module) {
Args:
key (array): Input key to split.
num (int, optional): Number of sub keys. Default is 2.
num (int, optional): Number of sub keys. Default: ``2``.
Returns:
array: The array of sub keys with ``num`` as its first dimension.
@@ -137,11 +137,13 @@ void init_random(nb::module_& parent_module) {
broadcastable to ``shape``.
Args:
low (scalar or array, optional): Lower bound of the distribution. Default is ``0``.
high (scalar or array, optional): Upper bound of the distribution. Default is ``1``.
shape (list(int), optional): Shape of the output. Default is ``()``.
low (scalar or array, optional): Lower bound of the distribution.
Default: ``0``.
high (scalar or array, optional): Upper bound of the distribution.
Default: ``1``.
shape (list(int), optional): Shape of the output. Default:``()``.
dtype (Dtype, optional): Type of the output. Default: ``float32``.
key (array, optional): A PRNG key. Default: ``None``.
dtype (Dtype, optional): Type of the output. Default is ``float32``.
Returns:
array: The output array random values.
@@ -250,9 +252,9 @@ void init_random(nb::module_& parent_module) {
Args:
low (scalar or array): Lower bound of the interval.
high (scalar or array): Upper bound of the interval.
shape (list(int), optional): Shape of the output. Defaults to ``()``.
dtype (Dtype, optional): Type of the output. Defaults to ``int32``.
key (array, optional): A PRNG key. Default: None.
shape (list(int), optional): Shape of the output. Default: ``()``.
dtype (Dtype, optional): Type of the output. Default: ``int32``.
key (array, optional): A PRNG key. Default: ``None``.
Returns:
array: The array of random integers.
@@ -286,10 +288,10 @@ void init_random(nb::module_& parent_module) {
Args:
p (float or array, optional): Parameter of the Bernoulli
distribution. Default is 0.5.
shape (list(int), optional): Shape of the output. The default
shape is ``p.shape``.
key (array, optional): A PRNG key. Default: None.
distribution. Default: ``0.5``.
shape (list(int), optional): Shape of the output.
Default: ``p.shape``.
key (array, optional): A PRNG key. Default: ``None``.
Returns:
array: The array of random integers.
@@ -331,10 +333,10 @@ void init_random(nb::module_& parent_module) {
lower (scalar or array): Lower bound of the domain.
upper (scalar or array): Upper bound of the domain.
shape (list(int), optional): The shape of the output.
Default is ``()``.
Default:``()``.
dtype (Dtype, optional): The data type of the output.
Default is ``float32``.
key (array, optional): A PRNG key. Default: None.
Default: ``float32``.
key (array, optional): A PRNG key. Default: ``None``.
Returns:
array: The output array of random values.
@@ -362,7 +364,7 @@ void init_random(nb::module_& parent_module) {
Args:
shape (list(int)): The shape of the output.
key (array, optional): A PRNG key. Default: None.
key (array, optional): A PRNG key. Default: ``None``.
Returns:
array: The :class:`array` with shape ``shape`` and
@@ -407,14 +409,14 @@ void init_random(nb::module_& parent_module) {
Args:
logits (array): The *unnormalized* categorical distribution(s).
axis (int, optional): The axis which specifies the distribution.
Default is ``-1``.
Default: ``-1``.
shape (list(int), optional): The shape of the output. This must
be broadcast compatable with ``logits.shape`` with the ``axis``
dimension removed. Default: ``None``
num_samples (int, optional): The number of samples to draw from each
of the categorical distributions in ``logits``. The output will have
``num_samples`` in the last dimension. Default: ``None``.
key (array, optional): A PRNG key. Default: None.
key (array, optional): A PRNG key. Default: ``None``.
Returns:
array: The ``shape``-sized output array with type ``uint32``.
@@ -442,11 +444,12 @@ void init_random(nb::module_& parent_module) {
Sample numbers from a Laplace distribution.
Args:
shape (list(int), optional): Shape of the output. Default is ``()``.
dtype (Dtype, optional): Type of the output. Default is ``float32``.
loc (float, optional): Mean of the distribution. Default is ``0.0``.
scale (float, optional): The scale "b" of the Laplace distribution. Default is ``1.0``.
key (array, optional): A PRNG key. Default: None.
shape (list(int), optional): Shape of the output. Default: ``()``.
dtype (Dtype, optional): Type of the output. Default: ``float32``.
loc (float, optional): Mean of the distribution. Default: ``0.0``.
scale (float, optional): The scale "b" of the Laplace distribution.
Default:``1.0``.
key (array, optional): A PRNG key. Default: ``None``.
Returns:
array: The output array of random values.