added tri / tril / triu (#170)

* added tri / tril / triu

* fixed tests

* ctest tests

* tri overload and simplified tests

* changes from comment

* more tests for m

* ensure assert if not 2-D

* remove broadcast_to

* minor tweaks

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Diogo
2023-12-15 20:30:34 -05:00
committed by GitHub
parent 2e02acdc83
commit dc2edc762c
9 changed files with 207 additions and 12 deletions

View File

@@ -1410,6 +1410,72 @@ void init_ops(py::module_& m) {
Returns:
array: An identity matrix of size n x n.
)pbdoc");
m.def(
"tri",
[](int n, std::optional<int> m, int k, Dtype dtype, StreamOrDevice s) {
return tri(n, m.value_or(n), k, float32, s);
},
"n"_a,
"m"_a = none,
"k"_a = 0,
"dtype"_a = float32,
py::kw_only(),
"stream"_a = none,
R"pbdoc(
tri(n: int, m: int, k: int, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array
An array with ones at and below the given diagonal and zeros elsewhere.
Args:
n (int): The number of rows in the output.
m (int, optional): The number of cols in the output. Defaults to ``None``.
k (int, optional): The diagonal of the 2-D array. Defaults to ``0``.
dtype (Dtype, optional): Data type of the output array. Defaults to ``float32``.
stream (Stream, optional): Stream or device. Defaults to ``None``.
Returns:
array: Array with its lower triangle filled with ones and zeros elsewhere
)pbdoc");
m.def(
"tril",
&tril,
"x"_a,
"k"_a = 0,
py::kw_only(),
"stream"_a = none,
R"pbdoc(
tril(x: array, k: int, *, stream: Union[None, Stream, Device] = None) -> array
Zeros the array above the given diagonal.
Args:
x (array): input array.
k (int, optional): The diagonal of the 2-D array. Defaults to ``0``.
stream (Stream, optional): Stream or device. Defaults to ``None``.
Returns:
array: Array zeroed above the given diagonal
)pbdoc");
m.def(
"triu",
&triu,
"x"_a,
"k"_a = 0,
py::kw_only(),
"stream"_a = none,
R"pbdoc(
triu(x: array, k: int, *, stream: Union[None, Stream, Device] = None) -> array
Zeros the array below the given diagonal.
Args:
x (array): input array.
k (int, optional): The diagonal of the 2-D array. Defaults to ``0``.
stream (Stream, optional): Stream or device. Defaults to ``None``.
Returns:
array: Array zeroed below the given diagonal
)pbdoc");
m.def(
"allclose",
&allclose,
@@ -2254,7 +2320,7 @@ void init_ops(py::module_& m) {
Args:
arrays (list(array)): A list of arrays to stack.
axis (int, optional): The axis in the result array along which the
input arrays are stacked. Defaults to ``0``.
input arrays are stacked. Defaults to ``0``.
stream (Stream, optional): Stream or device. Defaults to ``None``.
Returns:

View File

@@ -2,7 +2,6 @@
import os
import unittest
from typing import Callable, List, Tuple, Union
import mlx.core as mx
import numpy as np
@@ -21,13 +20,16 @@ class MLXTestCase(unittest.TestCase):
def assertEqualArray(
self,
args: List[Union[mx.array, float, int]],
mlx_func: Callable[..., mx.array],
mx_res: mx.array,
expected: mx.array,
atol=1e-2,
rtol=1e-2,
**kwargs,
):
mx_res = mlx_func(*args)
assert tuple(mx_res.shape) == tuple(expected.shape), "shape mismatch"
assert mx_res.dtype == expected.dtype, "dtype mismatch"
assert tuple(mx_res.shape) == tuple(
expected.shape
), f"shape mismatch expected={expected.shape} got={mx_res.shape}"
assert (
mx_res.dtype == expected.dtype
), f"dtype mismatch expected={expected.dtype} got={mx_res.dtype}"
np.testing.assert_allclose(mx_res, expected, rtol=rtol, atol=atol)

View File

@@ -451,15 +451,13 @@ class TestNN(mlx_tests.MLXTestCase):
def test_prelu(self):
self.assertEqualArray(
[mx.array([1.0, -1.0, 0.0, 0.5])],
nn.PReLU(),
nn.PReLU()(mx.array([1.0, -1.0, 0.0, 0.5])),
mx.array([1.0, -0.25, 0.0, 0.5]),
)
def test_mish(self):
self.assertEqualArray(
[mx.array([1.0, -1.0, 0.0, 0.5])],
nn.Mish(),
nn.Mish()(mx.array([1.0, -1.0, 0.0, 0.5])),
mx.array([0.8651, -0.3034, 0.0000, 0.3752]),
)

View File

@@ -320,6 +320,30 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertFalse(mx.array_equal(x, y))
self.assertTrue(mx.array_equal(x, y, equal_nan=True))
def test_tri(self):
for shape in [[4], [4, 4], [2, 10]]:
for diag in [-1, 0, 1, -2]:
self.assertEqualArray(
mx.tri(*shape, k=diag), mx.array(np.tri(*shape, k=diag))
)
def test_tril(self):
mt = mx.random.normal((10, 10))
nt = np.array(mt)
for diag in [-1, 0, 1, -2]:
self.assertEqualArray(mx.tril(mt, diag), mx.array(np.tril(nt, diag)))
with self.assertRaises(Exception):
mx.tril(mx.zeros((1)))
def test_triu(self):
mt = mx.random.normal((10, 10))
nt = np.array(mt)
for diag in [-1, 0, 1, -2]:
self.assertEqualArray(mx.triu(mt, diag), mx.array(np.triu(nt, diag)))
with self.assertRaises(Exception):
mx.triu(mx.zeros((1)))
def test_minimum(self):
x = mx.array([0.0, -5, 10.0])
y = mx.array([1.0, -7.0, 3.0])