Implements divide for integer types and adds floor_divide op (#228)

* Add floor_divide
* Add floor_divide to the tests
* Add floor_divide to the docs
This commit is contained in:
Angelos Katharopoulos
2023-12-19 20:12:19 -08:00
committed by GitHub
parent de892cb66c
commit 2807c6aff0
8 changed files with 67 additions and 14 deletions

View File

@@ -115,6 +115,7 @@ class TestOps(mlx_tests.MLXTestCase):
"subtract",
"multiply",
"divide",
"floor_divide",
"remainder",
"equal",
"not_equal",
@@ -1096,6 +1097,7 @@ class TestOps(mlx_tests.MLXTestCase):
"subtract",
"multiply",
"divide",
"floor_divide",
"maximum",
"minimum",
"power",
@@ -1111,19 +1113,21 @@ class TestOps(mlx_tests.MLXTestCase):
"uint32",
"uint64",
]
float_dtypes = ["float16", "float32"]
dtypes = (
float_dtypes
if op in ("divide", "power")
else (int_dtypes + float_dtypes)
)
dtypes = {
"divide": float_dtypes,
"power": float_dtypes,
"floor_divide": ["float32"] + int_dtypes,
}
dtypes = dtypes.get(op, int_dtypes + float_dtypes)
for dtype in dtypes:
atol = 1e-3 if dtype == "float16" else 1e-6
with self.subTest(dtype=dtype):
x1_ = x1.astype(getattr(np, dtype))
x2_ = x2.astype(getattr(np, dtype))
m = 10 if dtype in int_dtypes else 1
x1_ = (x1 * m).astype(getattr(np, dtype))
x2_ = (x2 * m).astype(getattr(np, dtype))
y1_ = mx.array(x1_)
y2_ = mx.array(x2_)
test_ops(