Fix divide types + floor divide (//) (#138)

* divide types

* fix black + test
This commit is contained in:
Awni Hannun 2023-12-11 20:20:58 -08:00 committed by GitHub
parent 02de234ef0
commit 25f70d4ca4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 42 additions and 6 deletions

View File

@ -176,6 +176,8 @@ def selu(x):
See also :func:`elu`. See also :func:`elu`.
""" """
return elu(x, 1.67326) * 1.0507 return elu(x, 1.67326) * 1.0507
def prelu(x: mx.array, alpha: mx.array) -> mx.array: def prelu(x: mx.array, alpha: mx.array) -> mx.array:
r"""Applies the element-wise function: r"""Applies the element-wise function:

View File

@ -623,25 +623,41 @@ void init_array(py::module_& m) {
.def( .def(
"__truediv__", "__truediv__",
[](const array& a, const ScalarOrArray v) { [](const array& a, const ScalarOrArray v) {
return divide(a, to_array(v, float32)); return divide(a, to_array(v, a.dtype()));
}, },
"other"_a) "other"_a)
.def( .def(
"__div__", "__div__",
[](const array& a, const ScalarOrArray v) { [](const array& a, const ScalarOrArray v) {
return divide(a, to_array(v, float32)); return divide(a, to_array(v, a.dtype()));
},
"other"_a)
.def(
"__floordiv__",
[](const array& a, const ScalarOrArray v) {
auto b = to_array(v, a.dtype());
auto t = promote_types(a.dtype(), b.dtype());
return astype(divide(a, b), t);
}, },
"other"_a) "other"_a)
.def( .def(
"__rtruediv__", "__rtruediv__",
[](const array& a, const ScalarOrArray v) { [](const array& a, const ScalarOrArray v) {
return divide(to_array(v, float32), a); return divide(to_array(v, a.dtype()), a);
},
"other"_a)
.def(
"__rfloordiv__",
[](const array& a, const ScalarOrArray v) {
auto b = to_array(v, a.dtype());
auto t = promote_types(a.dtype(), b.dtype());
return astype(divide(b, a), t);
}, },
"other"_a) "other"_a)
.def( .def(
"__rdiv__", "__rdiv__",
[](const array& a, const ScalarOrArray v) { [](const array& a, const ScalarOrArray v) {
return divide(to_array(v, float32), a); return divide(to_array(v, a.dtype()), a);
}, },
"other"_a) "other"_a)
.def( .def(

View File

@ -2,7 +2,7 @@
import os import os
import unittest import unittest
from typing import Callable, List, Tuple from typing import Callable, List, Tuple, Union
import mlx.core as mx import mlx.core as mx
import numpy as np import numpy as np
@ -21,7 +21,7 @@ class MLXTestCase(unittest.TestCase):
def assertEqualArray( def assertEqualArray(
self, self,
args: List[mx.array | float | int], args: List[Union[mx.array, float, int]],
mlx_func: Callable[..., mx.array], mlx_func: Callable[..., mx.array],
expected: mx.array, expected: mx.array,
atol=1e-2, atol=1e-2,

View File

@ -236,6 +236,24 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertEqual(z.dtype, mx.float32) self.assertEqual(z.dtype, mx.float32)
self.assertEqual(z.item(), 0.5) self.assertEqual(z.item(), 0.5)
x = x.astype(mx.float16)
z = x / 4.0
self.assertEqual(z.dtype, mx.float16)
x = x.astype(mx.float16)
z = 4.0 / x
self.assertEqual(z.dtype, mx.float16)
x = mx.array(5)
y = mx.array(2)
z = x / y
self.assertEqual(z.dtype, mx.float32)
self.assertEqual(z.item(), 2.5)
z = x // y
self.assertEqual(z.dtype, mx.int32)
self.assertEqual(z.item(), 2)
def test_remainder(self): def test_remainder(self):
for dt in [mx.int32, mx.float32]: for dt in [mx.int32, mx.float32]:
x = mx.array(2, dtype=dt) x = mx.array(2, dtype=dt)