in place ops behave in place, fix some overloads (#411)

This commit is contained in:
Awni Hannun 2024-01-09 16:05:38 -08:00 committed by GitHub
parent 961435a243
commit 1d90a76d63
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 202 additions and 19 deletions

View File

@ -711,6 +711,13 @@ void init_array(py::module_& m) {
return add(a, to_array(v, a.dtype())); return add(a, to_array(v, a.dtype()));
}, },
"other"_a) "other"_a)
.def(
"__iadd__",
[](array& a, const ScalarOrArray v) {
a.overwrite_descriptor(add(a, to_array(v, a.dtype())));
return a;
},
"other"_a)
.def( .def(
"__radd__", "__radd__",
[](const array& a, const ScalarOrArray v) { [](const array& a, const ScalarOrArray v) {
@ -723,6 +730,13 @@ void init_array(py::module_& m) {
return subtract(a, to_array(v, a.dtype())); return subtract(a, to_array(v, a.dtype()));
}, },
"other"_a) "other"_a)
.def(
"__isub__",
[](array& a, const ScalarOrArray v) {
a.overwrite_descriptor(subtract(a, to_array(v, a.dtype())));
return a;
},
"other"_a)
.def( .def(
"__rsub__", "__rsub__",
[](const array& a, const ScalarOrArray v) { [](const array& a, const ScalarOrArray v) {
@ -735,6 +749,13 @@ void init_array(py::module_& m) {
return multiply(a, to_array(v, a.dtype())); return multiply(a, to_array(v, a.dtype()));
}, },
"other"_a) "other"_a)
.def(
"__imul__",
[](array& a, const ScalarOrArray v) {
a.overwrite_descriptor(multiply(a, to_array(v, a.dtype())));
return a;
},
"other"_a)
.def( .def(
"__rmul__", "__rmul__",
[](const array& a, const ScalarOrArray v) { [](const array& a, const ScalarOrArray v) {
@ -748,16 +769,14 @@ void init_array(py::module_& m) {
}, },
"other"_a) "other"_a)
.def( .def(
"__div__", "__itruediv__",
[](const array& a, const ScalarOrArray v) { [](array& a, const ScalarOrArray v) {
return divide(a, to_array(v, a.dtype())); if (!is_floating_point(a.dtype())) {
}, throw std::invalid_argument(
"other"_a) "In place division cannot cast to non-floating point type.");
.def( }
"__floordiv__", a.overwrite_descriptor(divide(a, to_array(v, a.dtype())));
[](const array& a, const ScalarOrArray v) { return a;
auto b = to_array(v, a.dtype());
return floor_divide(a, b);
}, },
"other"_a) "other"_a)
.def( .def(
@ -766,6 +785,31 @@ void init_array(py::module_& m) {
return divide(to_array(v, a.dtype()), a); return divide(to_array(v, a.dtype()), a);
}, },
"other"_a) "other"_a)
.def(
"__div__",
[](const array& a, const ScalarOrArray v) {
return divide(a, to_array(v, a.dtype()));
},
"other"_a)
.def(
"__rdiv__",
[](const array& a, const ScalarOrArray v) {
return divide(to_array(v, a.dtype()), a);
},
"other"_a)
.def(
"__floordiv__",
[](const array& a, const ScalarOrArray v) {
return floor_divide(a, to_array(v, a.dtype()));
},
"other"_a)
.def(
"__ifloordiv__",
[](array& a, const ScalarOrArray v) {
a.overwrite_descriptor(floor_divide(a, to_array(v, a.dtype())));
return a;
},
"other"_a)
.def( .def(
"__rfloordiv__", "__rfloordiv__",
[](const array& a, const ScalarOrArray v) { [](const array& a, const ScalarOrArray v) {
@ -773,18 +817,19 @@ void init_array(py::module_& m) {
return floor_divide(b, a); return floor_divide(b, a);
}, },
"other"_a) "other"_a)
.def(
"__rdiv__",
[](const array& a, const ScalarOrArray v) {
return divide(to_array(v, a.dtype()), a);
},
"other"_a)
.def( .def(
"__mod__", "__mod__",
[](const array& a, const ScalarOrArray v) { [](const array& a, const ScalarOrArray v) {
return remainder(a, to_array(v, a.dtype())); return remainder(a, to_array(v, a.dtype()));
}, },
"other"_a) "other"_a)
.def(
"__imod__",
[](array& a, const ScalarOrArray v) {
a.overwrite_descriptor(remainder(a, to_array(v, a.dtype())));
return a;
},
"other"_a)
.def( .def(
"__rmod__", "__rmod__",
[](const array& a, const ScalarOrArray v) { [](const array& a, const ScalarOrArray v) {
@ -840,23 +885,102 @@ void init_array(py::module_& m) {
return os.str(); return os.str();
}) })
.def( .def(
"__matmul__", [](array& a, array& other) { return matmul(a, other); }) "__matmul__",
[](const array& a, array& other) { return matmul(a, other); },
"other"_a)
.def(
"__imatmul__",
[](array& a, array& other) {
a.overwrite_descriptor(matmul(a, other));
return a;
},
"other"_a)
.def( .def(
"__pow__", "__pow__",
[](const array& a, const ScalarOrArray v) { [](const array& a, const ScalarOrArray v) {
return power(a, to_array(v, a.dtype())); return power(a, to_array(v, a.dtype()));
}, },
"other"_a) "other"_a)
.def(
"__ipow__",
[](array& a, const ScalarOrArray v) {
a.overwrite_descriptor(power(a, to_array(v, a.dtype())));
return a;
},
"other"_a)
.def(
"__invert__",
[](const array& a) {
if (is_floating_point(a.dtype())) {
throw std::invalid_argument(
"Floating point types not allowed with or bitwise inversion.");
}
if (a.dtype() != bool_) {
throw std::invalid_argument(
"Bitwise inversion not yet supported for integer types.");
}
return logical_not(a);
})
.def( .def(
"__and__", "__and__",
[](const array& a, const ScalarOrArray v) { [](const array& a, const ScalarOrArray v) {
return logical_and(a, to_array(v, a.dtype())); auto b = to_array(v, a.dtype());
if (is_floating_point(a.dtype()) || is_floating_point(b.dtype())) {
throw std::invalid_argument(
"Floating point types not allowed with bitwise and.");
}
if (a.dtype() != bool_ && b.dtype() != bool_) {
throw std::invalid_argument(
"Bitwise and not yet supported for integer types.");
}
return logical_and(a, b);
},
"other"_a)
.def(
"__iand__",
[](array& a, const ScalarOrArray v) {
auto b = to_array(v, a.dtype());
if (is_floating_point(a.dtype()) || is_floating_point(b.dtype())) {
throw std::invalid_argument(
"Floating point types not allowed with bitwise and.");
}
if (a.dtype() != bool_ && b.dtype() != bool_) {
throw std::invalid_argument(
"Bitwise and not yet supported for integer types.");
}
a.overwrite_descriptor(logical_and(a, b));
return a;
}, },
"other"_a) "other"_a)
.def( .def(
"__or__", "__or__",
[](const array& a, const ScalarOrArray v) { [](const array& a, const ScalarOrArray v) {
return logical_or(a, to_array(v, a.dtype())); auto b = to_array(v, a.dtype());
if (is_floating_point(a.dtype()) || is_floating_point(b.dtype())) {
throw std::invalid_argument(
"Floating point types not allowed with or bitwise or.");
}
if (a.dtype() != bool_ && b.dtype() != bool_) {
throw std::invalid_argument(
"Bitwise or not yet supported for integer types.");
}
return logical_or(a, b);
},
"other"_a)
.def(
"__ior__",
[](array& a, const ScalarOrArray v) {
auto b = to_array(v, a.dtype());
if (is_floating_point(a.dtype()) || is_floating_point(b.dtype())) {
throw std::invalid_argument(
"Floating point types not allowed with or bitwise or.");
}
if (a.dtype() != bool_ && b.dtype() != bool_) {
throw std::invalid_argument(
"Bitwise or not yet supported for integer types.");
}
a.overwrite_descriptor(logical_or(a, b));
return a;
}, },
"other"_a) "other"_a)
.def( .def(

View File

@ -1313,6 +1313,65 @@ class TestArray(mlx_tests.MLXTestCase):
rtol=0, rtol=0,
) )
def test_logical_overloads(self):
with self.assertRaises(ValueError):
mx.array(1.0) & mx.array(1)
with self.assertRaises(ValueError):
mx.array(1.0) | mx.array(1)
self.assertEqual((mx.array(True) & True).item(), True)
self.assertEqual((mx.array(True) & False).item(), False)
self.assertEqual((mx.array(True) | False).item(), True)
self.assertEqual((mx.array(False) | False).item(), False)
self.assertEqual((~mx.array(False)).item(), True)
def test_inplace(self):
iops = [
"__iadd__",
"__isub__",
"__imul__",
"__ifloordiv__",
"__imod__",
"__ipow__",
]
for op in iops:
a = mx.array([1, 2, 3])
a_np = np.array(a)
b = a
b = getattr(a, op)(3)
self.assertTrue(mx.array_equal(a, b))
out_np = getattr(a_np, op)(3)
self.assertTrue(np.array_equal(out_np, a))
with self.assertRaises(ValueError):
a = mx.array([1])
a /= 1
a = mx.array([2.0])
b = a
b /= 2
self.assertEqual(b.item(), 1.0)
self.assertEqual(b.item(), a.item())
a = mx.array(True)
b = a
b &= False
self.assertEqual(b.item(), False)
self.assertEqual(b.item(), a.item())
a = mx.array(False)
b = a
b |= True
self.assertEqual(b.item(), True)
self.assertEqual(b.item(), a.item())
# In-place matmul on its own
a = mx.array([[1.0, 2.0], [3.0, 4.0]])
b = a
b @= a
self.assertTrue(mx.array_equal(a, b))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()