mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
Implements divide for integer types and adds floor_divide op (#228)
* Add floor_divide * Add floor_divide to the tests * Add floor_divide to the docs
This commit is contained in:

committed by
GitHub

parent
de892cb66c
commit
2807c6aff0
@@ -115,6 +115,7 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
"subtract",
|
||||
"multiply",
|
||||
"divide",
|
||||
"floor_divide",
|
||||
"remainder",
|
||||
"equal",
|
||||
"not_equal",
|
||||
@@ -1096,6 +1097,7 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
"subtract",
|
||||
"multiply",
|
||||
"divide",
|
||||
"floor_divide",
|
||||
"maximum",
|
||||
"minimum",
|
||||
"power",
|
||||
@@ -1111,19 +1113,21 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
"uint32",
|
||||
"uint64",
|
||||
]
|
||||
|
||||
float_dtypes = ["float16", "float32"]
|
||||
|
||||
dtypes = (
|
||||
float_dtypes
|
||||
if op in ("divide", "power")
|
||||
else (int_dtypes + float_dtypes)
|
||||
)
|
||||
dtypes = {
|
||||
"divide": float_dtypes,
|
||||
"power": float_dtypes,
|
||||
"floor_divide": ["float32"] + int_dtypes,
|
||||
}
|
||||
dtypes = dtypes.get(op, int_dtypes + float_dtypes)
|
||||
|
||||
for dtype in dtypes:
|
||||
atol = 1e-3 if dtype == "float16" else 1e-6
|
||||
with self.subTest(dtype=dtype):
|
||||
x1_ = x1.astype(getattr(np, dtype))
|
||||
x2_ = x2.astype(getattr(np, dtype))
|
||||
m = 10 if dtype in int_dtypes else 1
|
||||
x1_ = (x1 * m).astype(getattr(np, dtype))
|
||||
x2_ = (x2 * m).astype(getattr(np, dtype))
|
||||
y1_ = mx.array(x1_)
|
||||
y2_ = mx.array(x2_)
|
||||
test_ops(
|
||||
|
Reference in New Issue
Block a user