Fix divide types + floor divide (//) (#138)

* divide types

* fix black + test
This commit is contained in:
Awni Hannun 2023-12-11 20:20:58 -08:00 committed by GitHub
parent 02de234ef0
commit 25f70d4ca4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 42 additions and 6 deletions

View File

@ -176,6 +176,8 @@ def selu(x):
See also :func:`elu`.
"""
return elu(x, 1.67326) * 1.0507
def prelu(x: mx.array, alpha: mx.array) -> mx.array:
r"""Applies the element-wise function:

View File

@ -623,25 +623,41 @@ void init_array(py::module_& m) {
.def(
"__truediv__",
[](const array& a, const ScalarOrArray v) {
return divide(a, to_array(v, float32));
return divide(a, to_array(v, a.dtype()));
},
"other"_a)
.def(
"__div__",
[](const array& a, const ScalarOrArray v) {
return divide(a, to_array(v, float32));
return divide(a, to_array(v, a.dtype()));
},
"other"_a)
.def(
"__floordiv__",
[](const array& a, const ScalarOrArray v) {
auto b = to_array(v, a.dtype());
auto t = promote_types(a.dtype(), b.dtype());
return astype(divide(a, b), t);
},
"other"_a)
.def(
"__rtruediv__",
[](const array& a, const ScalarOrArray v) {
return divide(to_array(v, float32), a);
return divide(to_array(v, a.dtype()), a);
},
"other"_a)
.def(
"__rfloordiv__",
[](const array& a, const ScalarOrArray v) {
auto b = to_array(v, a.dtype());
auto t = promote_types(a.dtype(), b.dtype());
return astype(divide(b, a), t);
},
"other"_a)
.def(
"__rdiv__",
[](const array& a, const ScalarOrArray v) {
return divide(to_array(v, float32), a);
return divide(to_array(v, a.dtype()), a);
},
"other"_a)
.def(

View File

@ -2,7 +2,7 @@
import os
import unittest
from typing import Callable, List, Tuple
from typing import Callable, List, Tuple, Union
import mlx.core as mx
import numpy as np
@ -21,7 +21,7 @@ class MLXTestCase(unittest.TestCase):
def assertEqualArray(
self,
args: List[mx.array | float | int],
args: List[Union[mx.array, float, int]],
mlx_func: Callable[..., mx.array],
expected: mx.array,
atol=1e-2,

View File

@ -236,6 +236,24 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertEqual(z.dtype, mx.float32)
self.assertEqual(z.item(), 0.5)
x = x.astype(mx.float16)
z = x / 4.0
self.assertEqual(z.dtype, mx.float16)
x = x.astype(mx.float16)
z = 4.0 / x
self.assertEqual(z.dtype, mx.float16)
x = mx.array(5)
y = mx.array(2)
z = x / y
self.assertEqual(z.dtype, mx.float32)
self.assertEqual(z.item(), 2.5)
z = x // y
self.assertEqual(z.dtype, mx.int32)
self.assertEqual(z.item(), 2)
def test_remainder(self):
for dt in [mx.int32, mx.float32]:
x = mx.array(2, dtype=dt)