Implement diagonal operator (#562)

* Implement diagonal operator

This implements mx.diagonal in operator level, inspired by
@ManishAradwad.

* added `mx.diag` with tests

* corrected few things

* nits in bindings

* updates to diag

---------

Co-authored-by: ManishAradwad <manisharadwad@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Jacket
2024-01-30 11:45:48 -06:00
committed by GitHub
parent 65d0b8df9f
commit 3f7aba8498
8 changed files with 309 additions and 4 deletions

View File

@@ -395,7 +395,7 @@ class array {
// The ArrayDesc contains the details of the materialized array including the
// shape, strides, the data type. It also includes
// the primitive which knows how to compute the array's data from its inputs
// and a the list of array's inputs for the primitive.
// and the list of array's inputs for the primitive.
std::shared_ptr<ArrayDesc> array_desc_{nullptr};
};