diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8da07ad46..4f9dfac79 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/pre-commit/mirrors-clang-format - rev: v14.0.6 + rev: v17.0.6 hooks: - id: clang-format # Using this mirror lets us use mypyc-compiled black, which is about 2x faster diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index c235d3b64..75d54be5a 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -95,6 +95,9 @@ Operations tan tanh transpose + tri + tril + triu var where zeros diff --git a/mlx/ops.cpp b/mlx/ops.cpp index dc852370b..1e61745eb 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -218,6 +218,28 @@ array identity(int n, Dtype dtype, StreamOrDevice s /* = {} */) { return eye(n, n, 0, dtype, s); } +array tri(int n, int m, int k, Dtype type, StreamOrDevice s /* = {} */) { + auto l = expand_dims(arange(n, s), 1, s); + auto r = expand_dims(arange(-k, m - k, s), 0, s); + return astype(greater_equal(l, r, s), type, s); +} + +array tril(array x, int k, StreamOrDevice s /* = {} */) { + if (x.ndim() < 2) { + throw std::invalid_argument("[tril] array must be atleast 2-D"); + } + auto mask = tri(x.shape(-2), x.shape(-1), k, x.dtype(), s); + return where(mask, x, zeros_like(x, s), s); +} + +array triu(array x, int k, StreamOrDevice s /* = {} */) { + if (x.ndim() < 2) { + throw std::invalid_argument("[triu] array must be atleast 2-D"); + } + auto mask = tri(x.shape(-2), x.shape(-1), k - 1, x.dtype(), s); + return where(mask, zeros_like(x, s), x, s); +} + array reshape( const array& a, std::vector shape, diff --git a/mlx/ops.h b/mlx/ops.h index 6b081ad9f..e5bdcf358 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -110,6 +110,14 @@ inline array identity(int n, StreamOrDevice s = {}) { return identity(n, float32, s); } +array tri(int n, int m, int k, Dtype type, StreamOrDevice s = {}); +inline array tri(int n, Dtype type, StreamOrDevice s = {}) { + return tri(n, n, 0, type, s); +} + +array tril(array x, int k, StreamOrDevice s = {}); +array triu(array x, int k, StreamOrDevice s = {}); + /** array manipulation */ /** Reshape an array to the given shape. */ diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 6d4d80b97..4891052b6 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -1410,6 +1410,72 @@ void init_ops(py::module_& m) { Returns: array: An identity matrix of size n x n. )pbdoc"); + m.def( + "tri", + [](int n, std::optional m, int k, Dtype dtype, StreamOrDevice s) { + return tri(n, m.value_or(n), k, float32, s); + }, + "n"_a, + "m"_a = none, + "k"_a = 0, + "dtype"_a = float32, + py::kw_only(), + "stream"_a = none, + R"pbdoc( + tri(n: int, m: int, k: int, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array + + An array with ones at and below the given diagonal and zeros elsewhere. + + Args: + n (int): The number of rows in the output. + m (int, optional): The number of cols in the output. Defaults to ``None``. + k (int, optional): The diagonal of the 2-D array. Defaults to ``0``. + dtype (Dtype, optional): Data type of the output array. Defaults to ``float32``. + stream (Stream, optional): Stream or device. Defaults to ``None``. + + Returns: + array: Array with its lower triangle filled with ones and zeros elsewhere + )pbdoc"); + m.def( + "tril", + &tril, + "x"_a, + "k"_a = 0, + py::kw_only(), + "stream"_a = none, + R"pbdoc( + tril(x: array, k: int, *, stream: Union[None, Stream, Device] = None) -> array + + Zeros the array above the given diagonal. + + Args: + x (array): input array. + k (int, optional): The diagonal of the 2-D array. Defaults to ``0``. + stream (Stream, optional): Stream or device. Defaults to ``None``. + + Returns: + array: Array zeroed above the given diagonal + )pbdoc"); + m.def( + "triu", + &triu, + "x"_a, + "k"_a = 0, + py::kw_only(), + "stream"_a = none, + R"pbdoc( + triu(x: array, k: int, *, stream: Union[None, Stream, Device] = None) -> array + + Zeros the array below the given diagonal. + + Args: + x (array): input array. + k (int, optional): The diagonal of the 2-D array. Defaults to ``0``. + stream (Stream, optional): Stream or device. Defaults to ``None``. + + Returns: + array: Array zeroed below the given diagonal + )pbdoc"); m.def( "allclose", &allclose, @@ -2254,7 +2320,7 @@ void init_ops(py::module_& m) { Args: arrays (list(array)): A list of arrays to stack. axis (int, optional): The axis in the result array along which the - input arrays are stacked. Defaults to ``0``. + input arrays are stacked. Defaults to ``0``. stream (Stream, optional): Stream or device. Defaults to ``None``. Returns: diff --git a/python/tests/mlx_tests.py b/python/tests/mlx_tests.py index 0950c0bfc..d9a485885 100644 --- a/python/tests/mlx_tests.py +++ b/python/tests/mlx_tests.py @@ -2,7 +2,6 @@ import os import unittest -from typing import Callable, List, Tuple, Union import mlx.core as mx import numpy as np @@ -21,13 +20,16 @@ class MLXTestCase(unittest.TestCase): def assertEqualArray( self, - args: List[Union[mx.array, float, int]], - mlx_func: Callable[..., mx.array], + mx_res: mx.array, expected: mx.array, atol=1e-2, rtol=1e-2, + **kwargs, ): - mx_res = mlx_func(*args) - assert tuple(mx_res.shape) == tuple(expected.shape), "shape mismatch" - assert mx_res.dtype == expected.dtype, "dtype mismatch" + assert tuple(mx_res.shape) == tuple( + expected.shape + ), f"shape mismatch expected={expected.shape} got={mx_res.shape}" + assert ( + mx_res.dtype == expected.dtype + ), f"dtype mismatch expected={expected.dtype} got={mx_res.dtype}" np.testing.assert_allclose(mx_res, expected, rtol=rtol, atol=atol) diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 259795654..852816c20 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -451,15 +451,13 @@ class TestNN(mlx_tests.MLXTestCase): def test_prelu(self): self.assertEqualArray( - [mx.array([1.0, -1.0, 0.0, 0.5])], - nn.PReLU(), + nn.PReLU()(mx.array([1.0, -1.0, 0.0, 0.5])), mx.array([1.0, -0.25, 0.0, 0.5]), ) def test_mish(self): self.assertEqualArray( - [mx.array([1.0, -1.0, 0.0, 0.5])], - nn.Mish(), + nn.Mish()(mx.array([1.0, -1.0, 0.0, 0.5])), mx.array([0.8651, -0.3034, 0.0000, 0.3752]), ) diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index db1830e16..9ea0e80db 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -320,6 +320,30 @@ class TestOps(mlx_tests.MLXTestCase): self.assertFalse(mx.array_equal(x, y)) self.assertTrue(mx.array_equal(x, y, equal_nan=True)) + def test_tri(self): + for shape in [[4], [4, 4], [2, 10]]: + for diag in [-1, 0, 1, -2]: + self.assertEqualArray( + mx.tri(*shape, k=diag), mx.array(np.tri(*shape, k=diag)) + ) + + def test_tril(self): + mt = mx.random.normal((10, 10)) + nt = np.array(mt) + for diag in [-1, 0, 1, -2]: + self.assertEqualArray(mx.tril(mt, diag), mx.array(np.tril(nt, diag))) + + with self.assertRaises(Exception): + mx.tril(mx.zeros((1))) + + def test_triu(self): + mt = mx.random.normal((10, 10)) + nt = np.array(mt) + for diag in [-1, 0, 1, -2]: + self.assertEqualArray(mx.triu(mt, diag), mx.array(np.triu(nt, diag))) + with self.assertRaises(Exception): + mx.triu(mx.zeros((1))) + def test_minimum(self): x = mx.array([0.0, -5, 10.0]) y = mx.array([1.0, -7.0, 3.0]) diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 0916eeafe..921afc31d 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -2031,6 +2031,78 @@ TEST_CASE("test eye") { CHECK(array_equal(eye_3x2, expected_eye_3x2).item()); } +TEST_CASE("test tri") { + auto _tri = tri(4, 4, 0, float32); + CHECK_EQ(_tri.shape(), std::vector{4, 4}); + auto expected_tri = array( + {1.0f, + 0.0f, + 0.0f, + 0.0f, + 1.0f, + 1.0f, + 0.0f, + 0.0f, + 1.0f, + 1.0f, + 1.0f, + 0.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f}, + {4, 4}); + CHECK(array_equal(_tri, expected_tri).item()); +} + +TEST_CASE("test tril") { + auto _tril = tril(full(std::vector{4, 4}, 2.0f, float32), 0); + CHECK_EQ(_tril.shape(), std::vector{4, 4}); + auto expected_tri = array( + {2.0f, + 0.0f, + 0.0f, + 0.0f, + 2.0f, + 2.0f, + 0.0f, + 0.0f, + 2.0f, + 2.0f, + 2.0f, + 0.0f, + 2.0f, + 2.0f, + 2.0f, + 2.0f}, + {4, 4}); + CHECK(array_equal(_tril, expected_tri).item()); +} + +TEST_CASE("test triu") { + auto _triu = triu(full(std::vector{4, 4}, 2.0f, float32), 0); + CHECK_EQ(_triu.shape(), std::vector{4, 4}); + auto expected_tri = array( + {2.0f, + 2.0f, + 2.0f, + 2.0f, + 0.0f, + 2.0f, + 2.0f, + 2.0f, + 0.0f, + 0.0f, + 2.0f, + 2.0f, + 0.0f, + 0.0f, + 0.0f, + 2.0f}, + {4, 4}); + CHECK(array_equal(_triu, expected_tri).item()); +} + TEST_CASE("test identity") { auto id_4 = identity(4); CHECK_EQ(id_4.shape(), std::vector{4, 4});