mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-21 16:51:15 +08:00
Fix divide types + floor divide (//) (#138)
* divide types * fix black + test
This commit is contained in:
parent
02de234ef0
commit
25f70d4ca4
@ -176,6 +176,8 @@ def selu(x):
|
|||||||
See also :func:`elu`.
|
See also :func:`elu`.
|
||||||
"""
|
"""
|
||||||
return elu(x, 1.67326) * 1.0507
|
return elu(x, 1.67326) * 1.0507
|
||||||
|
|
||||||
|
|
||||||
def prelu(x: mx.array, alpha: mx.array) -> mx.array:
|
def prelu(x: mx.array, alpha: mx.array) -> mx.array:
|
||||||
r"""Applies the element-wise function:
|
r"""Applies the element-wise function:
|
||||||
|
|
||||||
|
@ -623,25 +623,41 @@ void init_array(py::module_& m) {
|
|||||||
.def(
|
.def(
|
||||||
"__truediv__",
|
"__truediv__",
|
||||||
[](const array& a, const ScalarOrArray v) {
|
[](const array& a, const ScalarOrArray v) {
|
||||||
return divide(a, to_array(v, float32));
|
return divide(a, to_array(v, a.dtype()));
|
||||||
},
|
},
|
||||||
"other"_a)
|
"other"_a)
|
||||||
.def(
|
.def(
|
||||||
"__div__",
|
"__div__",
|
||||||
[](const array& a, const ScalarOrArray v) {
|
[](const array& a, const ScalarOrArray v) {
|
||||||
return divide(a, to_array(v, float32));
|
return divide(a, to_array(v, a.dtype()));
|
||||||
|
},
|
||||||
|
"other"_a)
|
||||||
|
.def(
|
||||||
|
"__floordiv__",
|
||||||
|
[](const array& a, const ScalarOrArray v) {
|
||||||
|
auto b = to_array(v, a.dtype());
|
||||||
|
auto t = promote_types(a.dtype(), b.dtype());
|
||||||
|
return astype(divide(a, b), t);
|
||||||
},
|
},
|
||||||
"other"_a)
|
"other"_a)
|
||||||
.def(
|
.def(
|
||||||
"__rtruediv__",
|
"__rtruediv__",
|
||||||
[](const array& a, const ScalarOrArray v) {
|
[](const array& a, const ScalarOrArray v) {
|
||||||
return divide(to_array(v, float32), a);
|
return divide(to_array(v, a.dtype()), a);
|
||||||
|
},
|
||||||
|
"other"_a)
|
||||||
|
.def(
|
||||||
|
"__rfloordiv__",
|
||||||
|
[](const array& a, const ScalarOrArray v) {
|
||||||
|
auto b = to_array(v, a.dtype());
|
||||||
|
auto t = promote_types(a.dtype(), b.dtype());
|
||||||
|
return astype(divide(b, a), t);
|
||||||
},
|
},
|
||||||
"other"_a)
|
"other"_a)
|
||||||
.def(
|
.def(
|
||||||
"__rdiv__",
|
"__rdiv__",
|
||||||
[](const array& a, const ScalarOrArray v) {
|
[](const array& a, const ScalarOrArray v) {
|
||||||
return divide(to_array(v, float32), a);
|
return divide(to_array(v, a.dtype()), a);
|
||||||
},
|
},
|
||||||
"other"_a)
|
"other"_a)
|
||||||
.def(
|
.def(
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
from typing import Callable, List, Tuple
|
from typing import Callable, List, Tuple, Union
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -21,7 +21,7 @@ class MLXTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
def assertEqualArray(
|
def assertEqualArray(
|
||||||
self,
|
self,
|
||||||
args: List[mx.array | float | int],
|
args: List[Union[mx.array, float, int]],
|
||||||
mlx_func: Callable[..., mx.array],
|
mlx_func: Callable[..., mx.array],
|
||||||
expected: mx.array,
|
expected: mx.array,
|
||||||
atol=1e-2,
|
atol=1e-2,
|
||||||
|
@ -236,6 +236,24 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqual(z.dtype, mx.float32)
|
self.assertEqual(z.dtype, mx.float32)
|
||||||
self.assertEqual(z.item(), 0.5)
|
self.assertEqual(z.item(), 0.5)
|
||||||
|
|
||||||
|
x = x.astype(mx.float16)
|
||||||
|
z = x / 4.0
|
||||||
|
self.assertEqual(z.dtype, mx.float16)
|
||||||
|
|
||||||
|
x = x.astype(mx.float16)
|
||||||
|
z = 4.0 / x
|
||||||
|
self.assertEqual(z.dtype, mx.float16)
|
||||||
|
|
||||||
|
x = mx.array(5)
|
||||||
|
y = mx.array(2)
|
||||||
|
z = x / y
|
||||||
|
self.assertEqual(z.dtype, mx.float32)
|
||||||
|
self.assertEqual(z.item(), 2.5)
|
||||||
|
|
||||||
|
z = x // y
|
||||||
|
self.assertEqual(z.dtype, mx.int32)
|
||||||
|
self.assertEqual(z.item(), 2)
|
||||||
|
|
||||||
def test_remainder(self):
|
def test_remainder(self):
|
||||||
for dt in [mx.int32, mx.float32]:
|
for dt in [mx.int32, mx.float32]:
|
||||||
x = mx.array(2, dtype=dt)
|
x = mx.array(2, dtype=dt)
|
||||||
|
Loading…
Reference in New Issue
Block a user