in place ops behave in place, fix some overloads (#411)

This commit is contained in:
Awni Hannun 2024-01-09 16:05:38 -08:00 committed by GitHub
parent 961435a243
commit 1d90a76d63
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 202 additions and 19 deletions

View File

@ -711,6 +711,13 @@ void init_array(py::module_& m) {
return add(a, to_array(v, a.dtype()));
},
"other"_a)
.def(
"__iadd__",
[](array& a, const ScalarOrArray v) {
a.overwrite_descriptor(add(a, to_array(v, a.dtype())));
return a;
},
"other"_a)
.def(
"__radd__",
[](const array& a, const ScalarOrArray v) {
@ -723,6 +730,13 @@ void init_array(py::module_& m) {
return subtract(a, to_array(v, a.dtype()));
},
"other"_a)
.def(
"__isub__",
[](array& a, const ScalarOrArray v) {
a.overwrite_descriptor(subtract(a, to_array(v, a.dtype())));
return a;
},
"other"_a)
.def(
"__rsub__",
[](const array& a, const ScalarOrArray v) {
@ -735,6 +749,13 @@ void init_array(py::module_& m) {
return multiply(a, to_array(v, a.dtype()));
},
"other"_a)
.def(
"__imul__",
[](array& a, const ScalarOrArray v) {
a.overwrite_descriptor(multiply(a, to_array(v, a.dtype())));
return a;
},
"other"_a)
.def(
"__rmul__",
[](const array& a, const ScalarOrArray v) {
@ -748,16 +769,14 @@ void init_array(py::module_& m) {
},
"other"_a)
.def(
"__div__",
[](const array& a, const ScalarOrArray v) {
return divide(a, to_array(v, a.dtype()));
},
"other"_a)
.def(
"__floordiv__",
[](const array& a, const ScalarOrArray v) {
auto b = to_array(v, a.dtype());
return floor_divide(a, b);
"__itruediv__",
[](array& a, const ScalarOrArray v) {
if (!is_floating_point(a.dtype())) {
throw std::invalid_argument(
"In place division cannot cast to non-floating point type.");
}
a.overwrite_descriptor(divide(a, to_array(v, a.dtype())));
return a;
},
"other"_a)
.def(
@ -766,6 +785,31 @@ void init_array(py::module_& m) {
return divide(to_array(v, a.dtype()), a);
},
"other"_a)
.def(
"__div__",
[](const array& a, const ScalarOrArray v) {
return divide(a, to_array(v, a.dtype()));
},
"other"_a)
.def(
"__rdiv__",
[](const array& a, const ScalarOrArray v) {
return divide(to_array(v, a.dtype()), a);
},
"other"_a)
.def(
"__floordiv__",
[](const array& a, const ScalarOrArray v) {
return floor_divide(a, to_array(v, a.dtype()));
},
"other"_a)
.def(
"__ifloordiv__",
[](array& a, const ScalarOrArray v) {
a.overwrite_descriptor(floor_divide(a, to_array(v, a.dtype())));
return a;
},
"other"_a)
.def(
"__rfloordiv__",
[](const array& a, const ScalarOrArray v) {
@ -773,18 +817,19 @@ void init_array(py::module_& m) {
return floor_divide(b, a);
},
"other"_a)
.def(
"__rdiv__",
[](const array& a, const ScalarOrArray v) {
return divide(to_array(v, a.dtype()), a);
},
"other"_a)
.def(
"__mod__",
[](const array& a, const ScalarOrArray v) {
return remainder(a, to_array(v, a.dtype()));
},
"other"_a)
.def(
"__imod__",
[](array& a, const ScalarOrArray v) {
a.overwrite_descriptor(remainder(a, to_array(v, a.dtype())));
return a;
},
"other"_a)
.def(
"__rmod__",
[](const array& a, const ScalarOrArray v) {
@ -840,23 +885,102 @@ void init_array(py::module_& m) {
return os.str();
})
.def(
"__matmul__", [](array& a, array& other) { return matmul(a, other); })
"__matmul__",
[](const array& a, array& other) { return matmul(a, other); },
"other"_a)
.def(
"__imatmul__",
[](array& a, array& other) {
a.overwrite_descriptor(matmul(a, other));
return a;
},
"other"_a)
.def(
"__pow__",
[](const array& a, const ScalarOrArray v) {
return power(a, to_array(v, a.dtype()));
},
"other"_a)
.def(
"__ipow__",
[](array& a, const ScalarOrArray v) {
a.overwrite_descriptor(power(a, to_array(v, a.dtype())));
return a;
},
"other"_a)
.def(
"__invert__",
[](const array& a) {
if (is_floating_point(a.dtype())) {
throw std::invalid_argument(
"Floating point types not allowed with or bitwise inversion.");
}
if (a.dtype() != bool_) {
throw std::invalid_argument(
"Bitwise inversion not yet supported for integer types.");
}
return logical_not(a);
})
.def(
"__and__",
[](const array& a, const ScalarOrArray v) {
return logical_and(a, to_array(v, a.dtype()));
auto b = to_array(v, a.dtype());
if (is_floating_point(a.dtype()) || is_floating_point(b.dtype())) {
throw std::invalid_argument(
"Floating point types not allowed with bitwise and.");
}
if (a.dtype() != bool_ && b.dtype() != bool_) {
throw std::invalid_argument(
"Bitwise and not yet supported for integer types.");
}
return logical_and(a, b);
},
"other"_a)
.def(
"__iand__",
[](array& a, const ScalarOrArray v) {
auto b = to_array(v, a.dtype());
if (is_floating_point(a.dtype()) || is_floating_point(b.dtype())) {
throw std::invalid_argument(
"Floating point types not allowed with bitwise and.");
}
if (a.dtype() != bool_ && b.dtype() != bool_) {
throw std::invalid_argument(
"Bitwise and not yet supported for integer types.");
}
a.overwrite_descriptor(logical_and(a, b));
return a;
},
"other"_a)
.def(
"__or__",
[](const array& a, const ScalarOrArray v) {
return logical_or(a, to_array(v, a.dtype()));
auto b = to_array(v, a.dtype());
if (is_floating_point(a.dtype()) || is_floating_point(b.dtype())) {
throw std::invalid_argument(
"Floating point types not allowed with or bitwise or.");
}
if (a.dtype() != bool_ && b.dtype() != bool_) {
throw std::invalid_argument(
"Bitwise or not yet supported for integer types.");
}
return logical_or(a, b);
},
"other"_a)
.def(
"__ior__",
[](array& a, const ScalarOrArray v) {
auto b = to_array(v, a.dtype());
if (is_floating_point(a.dtype()) || is_floating_point(b.dtype())) {
throw std::invalid_argument(
"Floating point types not allowed with or bitwise or.");
}
if (a.dtype() != bool_ && b.dtype() != bool_) {
throw std::invalid_argument(
"Bitwise or not yet supported for integer types.");
}
a.overwrite_descriptor(logical_or(a, b));
return a;
},
"other"_a)
.def(

View File

@ -1313,6 +1313,65 @@ class TestArray(mlx_tests.MLXTestCase):
rtol=0,
)
def test_logical_overloads(self):
with self.assertRaises(ValueError):
mx.array(1.0) & mx.array(1)
with self.assertRaises(ValueError):
mx.array(1.0) | mx.array(1)
self.assertEqual((mx.array(True) & True).item(), True)
self.assertEqual((mx.array(True) & False).item(), False)
self.assertEqual((mx.array(True) | False).item(), True)
self.assertEqual((mx.array(False) | False).item(), False)
self.assertEqual((~mx.array(False)).item(), True)
def test_inplace(self):
iops = [
"__iadd__",
"__isub__",
"__imul__",
"__ifloordiv__",
"__imod__",
"__ipow__",
]
for op in iops:
a = mx.array([1, 2, 3])
a_np = np.array(a)
b = a
b = getattr(a, op)(3)
self.assertTrue(mx.array_equal(a, b))
out_np = getattr(a_np, op)(3)
self.assertTrue(np.array_equal(out_np, a))
with self.assertRaises(ValueError):
a = mx.array([1])
a /= 1
a = mx.array([2.0])
b = a
b /= 2
self.assertEqual(b.item(), 1.0)
self.assertEqual(b.item(), a.item())
a = mx.array(True)
b = a
b &= False
self.assertEqual(b.item(), False)
self.assertEqual(b.item(), a.item())
a = mx.array(False)
b = a
b |= True
self.assertEqual(b.item(), True)
self.assertEqual(b.item(), a.item())
# In-place matmul on its own
a = mx.array([[1.0, 2.0], [3.0, 4.0]])
b = a
b @= a
self.assertTrue(mx.array_equal(a, b))
if __name__ == "__main__":
unittest.main()