diff --git a/python/src/array.cpp b/python/src/array.cpp index 22ef8e273..143d2e6f5 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -28,30 +28,45 @@ class ArrayAt { public: ArrayAt(mx::array x) : x_(std::move(x)) {} ArrayAt& set_indices(nb::object indices) { + initialized_ = true; indices_ = indices; return *this; } + void check_initialized() { + if (!initialized_) { + throw std::invalid_argument( + "Must give indices to array.at (e.g. `x.at[0].add(4)`)."); + } + } + mx::array add(const ScalarOrArray& v) { + check_initialized(); return mlx_add_item(x_, indices_, v); } mx::array subtract(const ScalarOrArray& v) { + check_initialized(); return mlx_subtract_item(x_, indices_, v); } mx::array multiply(const ScalarOrArray& v) { + check_initialized(); return mlx_multiply_item(x_, indices_, v); } mx::array divide(const ScalarOrArray& v) { + check_initialized(); return mlx_divide_item(x_, indices_, v); } mx::array maximum(const ScalarOrArray& v) { + check_initialized(); return mlx_maximum_item(x_, indices_, v); } mx::array minimum(const ScalarOrArray& v) { + check_initialized(); return mlx_minimum_item(x_, indices_, v); } private: mx::array x_; + bool initialized_{false}; nb::object indices_; }; diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 3ab41bef7..ae1cb784f 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -1365,6 +1365,9 @@ class TestArray(mlx_tests.MLXTestCase): def test_array_at(self): a = mx.array(1) + with self.assertRaises(ValueError): + a.at.add(1) + a = a.at[None].add(1) self.assertEqual(a.item(), 2)