improved error msg for invalid axis(mx.split) (#685)

* improved error msg for invalid axis(`mx.split`)

* Apply suggestions from code review

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* fixed formatting issue

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
toji
2024-02-15 20:55:38 +05:30
committed by GitHub
parent 35431a4ac8
commit 85143fecdd
2 changed files with 10 additions and 0 deletions

View File

@@ -622,6 +622,13 @@ std::vector<array> split(
std::vector<array>
split(const array& a, int num_splits, int axis, StreamOrDevice s /* = {} */) {
auto ax = axis < 0 ? axis + a.ndim() : axis;
if (ax < 0 || ax >= a.ndim()) {
std::ostringstream msg;
msg << "Invalid axis " << axis << " passed to split"
<< " for array with shape " << a.shape() << ".";
throw std::invalid_argument(msg.str());
}
auto q_and_r = std::ldiv(a.shape(axis), num_splits);
if (q_and_r.rem) {
std::ostringstream msg;