mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 09:51:17 +08:00
in place ops behave in place, fix some overloads (#411)
This commit is contained in:
parent
961435a243
commit
1d90a76d63
@ -711,6 +711,13 @@ void init_array(py::module_& m) {
|
|||||||
return add(a, to_array(v, a.dtype()));
|
return add(a, to_array(v, a.dtype()));
|
||||||
},
|
},
|
||||||
"other"_a)
|
"other"_a)
|
||||||
|
.def(
|
||||||
|
"__iadd__",
|
||||||
|
[](array& a, const ScalarOrArray v) {
|
||||||
|
a.overwrite_descriptor(add(a, to_array(v, a.dtype())));
|
||||||
|
return a;
|
||||||
|
},
|
||||||
|
"other"_a)
|
||||||
.def(
|
.def(
|
||||||
"__radd__",
|
"__radd__",
|
||||||
[](const array& a, const ScalarOrArray v) {
|
[](const array& a, const ScalarOrArray v) {
|
||||||
@ -723,6 +730,13 @@ void init_array(py::module_& m) {
|
|||||||
return subtract(a, to_array(v, a.dtype()));
|
return subtract(a, to_array(v, a.dtype()));
|
||||||
},
|
},
|
||||||
"other"_a)
|
"other"_a)
|
||||||
|
.def(
|
||||||
|
"__isub__",
|
||||||
|
[](array& a, const ScalarOrArray v) {
|
||||||
|
a.overwrite_descriptor(subtract(a, to_array(v, a.dtype())));
|
||||||
|
return a;
|
||||||
|
},
|
||||||
|
"other"_a)
|
||||||
.def(
|
.def(
|
||||||
"__rsub__",
|
"__rsub__",
|
||||||
[](const array& a, const ScalarOrArray v) {
|
[](const array& a, const ScalarOrArray v) {
|
||||||
@ -735,6 +749,13 @@ void init_array(py::module_& m) {
|
|||||||
return multiply(a, to_array(v, a.dtype()));
|
return multiply(a, to_array(v, a.dtype()));
|
||||||
},
|
},
|
||||||
"other"_a)
|
"other"_a)
|
||||||
|
.def(
|
||||||
|
"__imul__",
|
||||||
|
[](array& a, const ScalarOrArray v) {
|
||||||
|
a.overwrite_descriptor(multiply(a, to_array(v, a.dtype())));
|
||||||
|
return a;
|
||||||
|
},
|
||||||
|
"other"_a)
|
||||||
.def(
|
.def(
|
||||||
"__rmul__",
|
"__rmul__",
|
||||||
[](const array& a, const ScalarOrArray v) {
|
[](const array& a, const ScalarOrArray v) {
|
||||||
@ -748,16 +769,14 @@ void init_array(py::module_& m) {
|
|||||||
},
|
},
|
||||||
"other"_a)
|
"other"_a)
|
||||||
.def(
|
.def(
|
||||||
"__div__",
|
"__itruediv__",
|
||||||
[](const array& a, const ScalarOrArray v) {
|
[](array& a, const ScalarOrArray v) {
|
||||||
return divide(a, to_array(v, a.dtype()));
|
if (!is_floating_point(a.dtype())) {
|
||||||
},
|
throw std::invalid_argument(
|
||||||
"other"_a)
|
"In place division cannot cast to non-floating point type.");
|
||||||
.def(
|
}
|
||||||
"__floordiv__",
|
a.overwrite_descriptor(divide(a, to_array(v, a.dtype())));
|
||||||
[](const array& a, const ScalarOrArray v) {
|
return a;
|
||||||
auto b = to_array(v, a.dtype());
|
|
||||||
return floor_divide(a, b);
|
|
||||||
},
|
},
|
||||||
"other"_a)
|
"other"_a)
|
||||||
.def(
|
.def(
|
||||||
@ -766,6 +785,31 @@ void init_array(py::module_& m) {
|
|||||||
return divide(to_array(v, a.dtype()), a);
|
return divide(to_array(v, a.dtype()), a);
|
||||||
},
|
},
|
||||||
"other"_a)
|
"other"_a)
|
||||||
|
.def(
|
||||||
|
"__div__",
|
||||||
|
[](const array& a, const ScalarOrArray v) {
|
||||||
|
return divide(a, to_array(v, a.dtype()));
|
||||||
|
},
|
||||||
|
"other"_a)
|
||||||
|
.def(
|
||||||
|
"__rdiv__",
|
||||||
|
[](const array& a, const ScalarOrArray v) {
|
||||||
|
return divide(to_array(v, a.dtype()), a);
|
||||||
|
},
|
||||||
|
"other"_a)
|
||||||
|
.def(
|
||||||
|
"__floordiv__",
|
||||||
|
[](const array& a, const ScalarOrArray v) {
|
||||||
|
return floor_divide(a, to_array(v, a.dtype()));
|
||||||
|
},
|
||||||
|
"other"_a)
|
||||||
|
.def(
|
||||||
|
"__ifloordiv__",
|
||||||
|
[](array& a, const ScalarOrArray v) {
|
||||||
|
a.overwrite_descriptor(floor_divide(a, to_array(v, a.dtype())));
|
||||||
|
return a;
|
||||||
|
},
|
||||||
|
"other"_a)
|
||||||
.def(
|
.def(
|
||||||
"__rfloordiv__",
|
"__rfloordiv__",
|
||||||
[](const array& a, const ScalarOrArray v) {
|
[](const array& a, const ScalarOrArray v) {
|
||||||
@ -773,18 +817,19 @@ void init_array(py::module_& m) {
|
|||||||
return floor_divide(b, a);
|
return floor_divide(b, a);
|
||||||
},
|
},
|
||||||
"other"_a)
|
"other"_a)
|
||||||
.def(
|
|
||||||
"__rdiv__",
|
|
||||||
[](const array& a, const ScalarOrArray v) {
|
|
||||||
return divide(to_array(v, a.dtype()), a);
|
|
||||||
},
|
|
||||||
"other"_a)
|
|
||||||
.def(
|
.def(
|
||||||
"__mod__",
|
"__mod__",
|
||||||
[](const array& a, const ScalarOrArray v) {
|
[](const array& a, const ScalarOrArray v) {
|
||||||
return remainder(a, to_array(v, a.dtype()));
|
return remainder(a, to_array(v, a.dtype()));
|
||||||
},
|
},
|
||||||
"other"_a)
|
"other"_a)
|
||||||
|
.def(
|
||||||
|
"__imod__",
|
||||||
|
[](array& a, const ScalarOrArray v) {
|
||||||
|
a.overwrite_descriptor(remainder(a, to_array(v, a.dtype())));
|
||||||
|
return a;
|
||||||
|
},
|
||||||
|
"other"_a)
|
||||||
.def(
|
.def(
|
||||||
"__rmod__",
|
"__rmod__",
|
||||||
[](const array& a, const ScalarOrArray v) {
|
[](const array& a, const ScalarOrArray v) {
|
||||||
@ -840,23 +885,102 @@ void init_array(py::module_& m) {
|
|||||||
return os.str();
|
return os.str();
|
||||||
})
|
})
|
||||||
.def(
|
.def(
|
||||||
"__matmul__", [](array& a, array& other) { return matmul(a, other); })
|
"__matmul__",
|
||||||
|
[](const array& a, array& other) { return matmul(a, other); },
|
||||||
|
"other"_a)
|
||||||
|
.def(
|
||||||
|
"__imatmul__",
|
||||||
|
[](array& a, array& other) {
|
||||||
|
a.overwrite_descriptor(matmul(a, other));
|
||||||
|
return a;
|
||||||
|
},
|
||||||
|
"other"_a)
|
||||||
.def(
|
.def(
|
||||||
"__pow__",
|
"__pow__",
|
||||||
[](const array& a, const ScalarOrArray v) {
|
[](const array& a, const ScalarOrArray v) {
|
||||||
return power(a, to_array(v, a.dtype()));
|
return power(a, to_array(v, a.dtype()));
|
||||||
},
|
},
|
||||||
"other"_a)
|
"other"_a)
|
||||||
|
.def(
|
||||||
|
"__ipow__",
|
||||||
|
[](array& a, const ScalarOrArray v) {
|
||||||
|
a.overwrite_descriptor(power(a, to_array(v, a.dtype())));
|
||||||
|
return a;
|
||||||
|
},
|
||||||
|
"other"_a)
|
||||||
|
.def(
|
||||||
|
"__invert__",
|
||||||
|
[](const array& a) {
|
||||||
|
if (is_floating_point(a.dtype())) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"Floating point types not allowed with or bitwise inversion.");
|
||||||
|
}
|
||||||
|
if (a.dtype() != bool_) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"Bitwise inversion not yet supported for integer types.");
|
||||||
|
}
|
||||||
|
return logical_not(a);
|
||||||
|
})
|
||||||
.def(
|
.def(
|
||||||
"__and__",
|
"__and__",
|
||||||
[](const array& a, const ScalarOrArray v) {
|
[](const array& a, const ScalarOrArray v) {
|
||||||
return logical_and(a, to_array(v, a.dtype()));
|
auto b = to_array(v, a.dtype());
|
||||||
|
if (is_floating_point(a.dtype()) || is_floating_point(b.dtype())) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"Floating point types not allowed with bitwise and.");
|
||||||
|
}
|
||||||
|
if (a.dtype() != bool_ && b.dtype() != bool_) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"Bitwise and not yet supported for integer types.");
|
||||||
|
}
|
||||||
|
return logical_and(a, b);
|
||||||
|
},
|
||||||
|
"other"_a)
|
||||||
|
.def(
|
||||||
|
"__iand__",
|
||||||
|
[](array& a, const ScalarOrArray v) {
|
||||||
|
auto b = to_array(v, a.dtype());
|
||||||
|
if (is_floating_point(a.dtype()) || is_floating_point(b.dtype())) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"Floating point types not allowed with bitwise and.");
|
||||||
|
}
|
||||||
|
if (a.dtype() != bool_ && b.dtype() != bool_) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"Bitwise and not yet supported for integer types.");
|
||||||
|
}
|
||||||
|
a.overwrite_descriptor(logical_and(a, b));
|
||||||
|
return a;
|
||||||
},
|
},
|
||||||
"other"_a)
|
"other"_a)
|
||||||
.def(
|
.def(
|
||||||
"__or__",
|
"__or__",
|
||||||
[](const array& a, const ScalarOrArray v) {
|
[](const array& a, const ScalarOrArray v) {
|
||||||
return logical_or(a, to_array(v, a.dtype()));
|
auto b = to_array(v, a.dtype());
|
||||||
|
if (is_floating_point(a.dtype()) || is_floating_point(b.dtype())) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"Floating point types not allowed with or bitwise or.");
|
||||||
|
}
|
||||||
|
if (a.dtype() != bool_ && b.dtype() != bool_) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"Bitwise or not yet supported for integer types.");
|
||||||
|
}
|
||||||
|
return logical_or(a, b);
|
||||||
|
},
|
||||||
|
"other"_a)
|
||||||
|
.def(
|
||||||
|
"__ior__",
|
||||||
|
[](array& a, const ScalarOrArray v) {
|
||||||
|
auto b = to_array(v, a.dtype());
|
||||||
|
if (is_floating_point(a.dtype()) || is_floating_point(b.dtype())) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"Floating point types not allowed with or bitwise or.");
|
||||||
|
}
|
||||||
|
if (a.dtype() != bool_ && b.dtype() != bool_) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"Bitwise or not yet supported for integer types.");
|
||||||
|
}
|
||||||
|
a.overwrite_descriptor(logical_or(a, b));
|
||||||
|
return a;
|
||||||
},
|
},
|
||||||
"other"_a)
|
"other"_a)
|
||||||
.def(
|
.def(
|
||||||
|
@ -1313,6 +1313,65 @@ class TestArray(mlx_tests.MLXTestCase):
|
|||||||
rtol=0,
|
rtol=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_logical_overloads(self):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
mx.array(1.0) & mx.array(1)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
mx.array(1.0) | mx.array(1)
|
||||||
|
|
||||||
|
self.assertEqual((mx.array(True) & True).item(), True)
|
||||||
|
self.assertEqual((mx.array(True) & False).item(), False)
|
||||||
|
self.assertEqual((mx.array(True) | False).item(), True)
|
||||||
|
self.assertEqual((mx.array(False) | False).item(), False)
|
||||||
|
self.assertEqual((~mx.array(False)).item(), True)
|
||||||
|
|
||||||
|
def test_inplace(self):
|
||||||
|
iops = [
|
||||||
|
"__iadd__",
|
||||||
|
"__isub__",
|
||||||
|
"__imul__",
|
||||||
|
"__ifloordiv__",
|
||||||
|
"__imod__",
|
||||||
|
"__ipow__",
|
||||||
|
]
|
||||||
|
|
||||||
|
for op in iops:
|
||||||
|
a = mx.array([1, 2, 3])
|
||||||
|
a_np = np.array(a)
|
||||||
|
b = a
|
||||||
|
b = getattr(a, op)(3)
|
||||||
|
self.assertTrue(mx.array_equal(a, b))
|
||||||
|
out_np = getattr(a_np, op)(3)
|
||||||
|
self.assertTrue(np.array_equal(out_np, a))
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
a = mx.array([1])
|
||||||
|
a /= 1
|
||||||
|
|
||||||
|
a = mx.array([2.0])
|
||||||
|
b = a
|
||||||
|
b /= 2
|
||||||
|
self.assertEqual(b.item(), 1.0)
|
||||||
|
self.assertEqual(b.item(), a.item())
|
||||||
|
|
||||||
|
a = mx.array(True)
|
||||||
|
b = a
|
||||||
|
b &= False
|
||||||
|
self.assertEqual(b.item(), False)
|
||||||
|
self.assertEqual(b.item(), a.item())
|
||||||
|
|
||||||
|
a = mx.array(False)
|
||||||
|
b = a
|
||||||
|
b |= True
|
||||||
|
self.assertEqual(b.item(), True)
|
||||||
|
self.assertEqual(b.item(), a.item())
|
||||||
|
|
||||||
|
# In-place matmul on its own
|
||||||
|
a = mx.array([[1.0, 2.0], [3.0, 4.0]])
|
||||||
|
b = a
|
||||||
|
b @= a
|
||||||
|
self.assertTrue(mx.array_equal(a, b))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user