diff --git a/python/mlx/nn/layers/activations.py b/python/mlx/nn/layers/activations.py index 88b3cc2b5..25d1c5268 100644 --- a/python/mlx/nn/layers/activations.py +++ b/python/mlx/nn/layers/activations.py @@ -176,6 +176,8 @@ def selu(x): See also :func:`elu`. """ return elu(x, 1.67326) * 1.0507 + + def prelu(x: mx.array, alpha: mx.array) -> mx.array: r"""Applies the element-wise function: diff --git a/python/src/array.cpp b/python/src/array.cpp index 6a3ffbf96..af0afef4d 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -623,25 +623,41 @@ void init_array(py::module_& m) { .def( "__truediv__", [](const array& a, const ScalarOrArray v) { - return divide(a, to_array(v, float32)); + return divide(a, to_array(v, a.dtype())); }, "other"_a) .def( "__div__", [](const array& a, const ScalarOrArray v) { - return divide(a, to_array(v, float32)); + return divide(a, to_array(v, a.dtype())); + }, + "other"_a) + .def( + "__floordiv__", + [](const array& a, const ScalarOrArray v) { + auto b = to_array(v, a.dtype()); + auto t = promote_types(a.dtype(), b.dtype()); + return astype(divide(a, b), t); }, "other"_a) .def( "__rtruediv__", [](const array& a, const ScalarOrArray v) { - return divide(to_array(v, float32), a); + return divide(to_array(v, a.dtype()), a); + }, + "other"_a) + .def( + "__rfloordiv__", + [](const array& a, const ScalarOrArray v) { + auto b = to_array(v, a.dtype()); + auto t = promote_types(a.dtype(), b.dtype()); + return astype(divide(b, a), t); }, "other"_a) .def( "__rdiv__", [](const array& a, const ScalarOrArray v) { - return divide(to_array(v, float32), a); + return divide(to_array(v, a.dtype()), a); }, "other"_a) .def( diff --git a/python/tests/mlx_tests.py b/python/tests/mlx_tests.py index 9400950a1..0950c0bfc 100644 --- a/python/tests/mlx_tests.py +++ b/python/tests/mlx_tests.py @@ -2,7 +2,7 @@ import os import unittest -from typing import Callable, List, Tuple +from typing import Callable, List, Tuple, Union import mlx.core as mx import numpy as np @@ -21,7 +21,7 @@ class MLXTestCase(unittest.TestCase): def assertEqualArray( self, - args: List[mx.array | float | int], + args: List[Union[mx.array, float, int]], mlx_func: Callable[..., mx.array], expected: mx.array, atol=1e-2, diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 0ee234cf4..d0ca54365 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -236,6 +236,24 @@ class TestOps(mlx_tests.MLXTestCase): self.assertEqual(z.dtype, mx.float32) self.assertEqual(z.item(), 0.5) + x = x.astype(mx.float16) + z = x / 4.0 + self.assertEqual(z.dtype, mx.float16) + + x = x.astype(mx.float16) + z = 4.0 / x + self.assertEqual(z.dtype, mx.float16) + + x = mx.array(5) + y = mx.array(2) + z = x / y + self.assertEqual(z.dtype, mx.float32) + self.assertEqual(z.item(), 2.5) + + z = x // y + self.assertEqual(z.dtype, mx.int32) + self.assertEqual(z.item(), 2) + def test_remainder(self): for dt in [mx.int32, mx.float32]: x = mx.array(2, dtype=dt)