added tri / tril / triu (#170)

* added tri / tril / triu

* fixed tests

* ctest tests

* tri overload and simplified tests

* changes from comment

* more tests for m

* ensure assert if not 2-D

* remove broadcast_to

* minor tweaks

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Diogo 2023-12-15 20:30:34 -05:00 committed by GitHub
parent 2e02acdc83
commit dc2edc762c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 207 additions and 12 deletions

View File

@ -1,6 +1,6 @@
repos: repos:
- repo: https://github.com/pre-commit/mirrors-clang-format - repo: https://github.com/pre-commit/mirrors-clang-format
rev: v14.0.6 rev: v17.0.6
hooks: hooks:
- id: clang-format - id: clang-format
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster # Using this mirror lets us use mypyc-compiled black, which is about 2x faster

View File

@ -95,6 +95,9 @@ Operations
tan tan
tanh tanh
transpose transpose
tri
tril
triu
var var
where where
zeros zeros

View File

@ -218,6 +218,28 @@ array identity(int n, Dtype dtype, StreamOrDevice s /* = {} */) {
return eye(n, n, 0, dtype, s); return eye(n, n, 0, dtype, s);
} }
array tri(int n, int m, int k, Dtype type, StreamOrDevice s /* = {} */) {
auto l = expand_dims(arange(n, s), 1, s);
auto r = expand_dims(arange(-k, m - k, s), 0, s);
return astype(greater_equal(l, r, s), type, s);
}
array tril(array x, int k, StreamOrDevice s /* = {} */) {
if (x.ndim() < 2) {
throw std::invalid_argument("[tril] array must be atleast 2-D");
}
auto mask = tri(x.shape(-2), x.shape(-1), k, x.dtype(), s);
return where(mask, x, zeros_like(x, s), s);
}
array triu(array x, int k, StreamOrDevice s /* = {} */) {
if (x.ndim() < 2) {
throw std::invalid_argument("[triu] array must be atleast 2-D");
}
auto mask = tri(x.shape(-2), x.shape(-1), k - 1, x.dtype(), s);
return where(mask, zeros_like(x, s), x, s);
}
array reshape( array reshape(
const array& a, const array& a,
std::vector<int> shape, std::vector<int> shape,

View File

@ -110,6 +110,14 @@ inline array identity(int n, StreamOrDevice s = {}) {
return identity(n, float32, s); return identity(n, float32, s);
} }
array tri(int n, int m, int k, Dtype type, StreamOrDevice s = {});
inline array tri(int n, Dtype type, StreamOrDevice s = {}) {
return tri(n, n, 0, type, s);
}
array tril(array x, int k, StreamOrDevice s = {});
array triu(array x, int k, StreamOrDevice s = {});
/** array manipulation */ /** array manipulation */
/** Reshape an array to the given shape. */ /** Reshape an array to the given shape. */

View File

@ -1410,6 +1410,72 @@ void init_ops(py::module_& m) {
Returns: Returns:
array: An identity matrix of size n x n. array: An identity matrix of size n x n.
)pbdoc"); )pbdoc");
m.def(
"tri",
[](int n, std::optional<int> m, int k, Dtype dtype, StreamOrDevice s) {
return tri(n, m.value_or(n), k, float32, s);
},
"n"_a,
"m"_a = none,
"k"_a = 0,
"dtype"_a = float32,
py::kw_only(),
"stream"_a = none,
R"pbdoc(
tri(n: int, m: int, k: int, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array
An array with ones at and below the given diagonal and zeros elsewhere.
Args:
n (int): The number of rows in the output.
m (int, optional): The number of cols in the output. Defaults to ``None``.
k (int, optional): The diagonal of the 2-D array. Defaults to ``0``.
dtype (Dtype, optional): Data type of the output array. Defaults to ``float32``.
stream (Stream, optional): Stream or device. Defaults to ``None``.
Returns:
array: Array with its lower triangle filled with ones and zeros elsewhere
)pbdoc");
m.def(
"tril",
&tril,
"x"_a,
"k"_a = 0,
py::kw_only(),
"stream"_a = none,
R"pbdoc(
tril(x: array, k: int, *, stream: Union[None, Stream, Device] = None) -> array
Zeros the array above the given diagonal.
Args:
x (array): input array.
k (int, optional): The diagonal of the 2-D array. Defaults to ``0``.
stream (Stream, optional): Stream or device. Defaults to ``None``.
Returns:
array: Array zeroed above the given diagonal
)pbdoc");
m.def(
"triu",
&triu,
"x"_a,
"k"_a = 0,
py::kw_only(),
"stream"_a = none,
R"pbdoc(
triu(x: array, k: int, *, stream: Union[None, Stream, Device] = None) -> array
Zeros the array below the given diagonal.
Args:
x (array): input array.
k (int, optional): The diagonal of the 2-D array. Defaults to ``0``.
stream (Stream, optional): Stream or device. Defaults to ``None``.
Returns:
array: Array zeroed below the given diagonal
)pbdoc");
m.def( m.def(
"allclose", "allclose",
&allclose, &allclose,

View File

@ -2,7 +2,6 @@
import os import os
import unittest import unittest
from typing import Callable, List, Tuple, Union
import mlx.core as mx import mlx.core as mx
import numpy as np import numpy as np
@ -21,13 +20,16 @@ class MLXTestCase(unittest.TestCase):
def assertEqualArray( def assertEqualArray(
self, self,
args: List[Union[mx.array, float, int]], mx_res: mx.array,
mlx_func: Callable[..., mx.array],
expected: mx.array, expected: mx.array,
atol=1e-2, atol=1e-2,
rtol=1e-2, rtol=1e-2,
**kwargs,
): ):
mx_res = mlx_func(*args) assert tuple(mx_res.shape) == tuple(
assert tuple(mx_res.shape) == tuple(expected.shape), "shape mismatch" expected.shape
assert mx_res.dtype == expected.dtype, "dtype mismatch" ), f"shape mismatch expected={expected.shape} got={mx_res.shape}"
assert (
mx_res.dtype == expected.dtype
), f"dtype mismatch expected={expected.dtype} got={mx_res.dtype}"
np.testing.assert_allclose(mx_res, expected, rtol=rtol, atol=atol) np.testing.assert_allclose(mx_res, expected, rtol=rtol, atol=atol)

View File

@ -451,15 +451,13 @@ class TestNN(mlx_tests.MLXTestCase):
def test_prelu(self): def test_prelu(self):
self.assertEqualArray( self.assertEqualArray(
[mx.array([1.0, -1.0, 0.0, 0.5])], nn.PReLU()(mx.array([1.0, -1.0, 0.0, 0.5])),
nn.PReLU(),
mx.array([1.0, -0.25, 0.0, 0.5]), mx.array([1.0, -0.25, 0.0, 0.5]),
) )
def test_mish(self): def test_mish(self):
self.assertEqualArray( self.assertEqualArray(
[mx.array([1.0, -1.0, 0.0, 0.5])], nn.Mish()(mx.array([1.0, -1.0, 0.0, 0.5])),
nn.Mish(),
mx.array([0.8651, -0.3034, 0.0000, 0.3752]), mx.array([0.8651, -0.3034, 0.0000, 0.3752]),
) )

View File

@ -320,6 +320,30 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertFalse(mx.array_equal(x, y)) self.assertFalse(mx.array_equal(x, y))
self.assertTrue(mx.array_equal(x, y, equal_nan=True)) self.assertTrue(mx.array_equal(x, y, equal_nan=True))
def test_tri(self):
for shape in [[4], [4, 4], [2, 10]]:
for diag in [-1, 0, 1, -2]:
self.assertEqualArray(
mx.tri(*shape, k=diag), mx.array(np.tri(*shape, k=diag))
)
def test_tril(self):
mt = mx.random.normal((10, 10))
nt = np.array(mt)
for diag in [-1, 0, 1, -2]:
self.assertEqualArray(mx.tril(mt, diag), mx.array(np.tril(nt, diag)))
with self.assertRaises(Exception):
mx.tril(mx.zeros((1)))
def test_triu(self):
mt = mx.random.normal((10, 10))
nt = np.array(mt)
for diag in [-1, 0, 1, -2]:
self.assertEqualArray(mx.triu(mt, diag), mx.array(np.triu(nt, diag)))
with self.assertRaises(Exception):
mx.triu(mx.zeros((1)))
def test_minimum(self): def test_minimum(self):
x = mx.array([0.0, -5, 10.0]) x = mx.array([0.0, -5, 10.0])
y = mx.array([1.0, -7.0, 3.0]) y = mx.array([1.0, -7.0, 3.0])

View File

@ -2031,6 +2031,78 @@ TEST_CASE("test eye") {
CHECK(array_equal(eye_3x2, expected_eye_3x2).item<bool>()); CHECK(array_equal(eye_3x2, expected_eye_3x2).item<bool>());
} }
TEST_CASE("test tri") {
auto _tri = tri(4, 4, 0, float32);
CHECK_EQ(_tri.shape(), std::vector<int>{4, 4});
auto expected_tri = array(
{1.0f,
0.0f,
0.0f,
0.0f,
1.0f,
1.0f,
0.0f,
0.0f,
1.0f,
1.0f,
1.0f,
0.0f,
1.0f,
1.0f,
1.0f,
1.0f},
{4, 4});
CHECK(array_equal(_tri, expected_tri).item<bool>());
}
TEST_CASE("test tril") {
auto _tril = tril(full(std::vector<int>{4, 4}, 2.0f, float32), 0);
CHECK_EQ(_tril.shape(), std::vector<int>{4, 4});
auto expected_tri = array(
{2.0f,
0.0f,
0.0f,
0.0f,
2.0f,
2.0f,
0.0f,
0.0f,
2.0f,
2.0f,
2.0f,
0.0f,
2.0f,
2.0f,
2.0f,
2.0f},
{4, 4});
CHECK(array_equal(_tril, expected_tri).item<bool>());
}
TEST_CASE("test triu") {
auto _triu = triu(full(std::vector<int>{4, 4}, 2.0f, float32), 0);
CHECK_EQ(_triu.shape(), std::vector<int>{4, 4});
auto expected_tri = array(
{2.0f,
2.0f,
2.0f,
2.0f,
0.0f,
2.0f,
2.0f,
2.0f,
0.0f,
0.0f,
2.0f,
2.0f,
0.0f,
0.0f,
0.0f,
2.0f},
{4, 4});
CHECK(array_equal(_triu, expected_tri).item<bool>());
}
TEST_CASE("test identity") { TEST_CASE("test identity") {
auto id_4 = identity(4); auto id_4 = identity(4);
CHECK_EQ(id_4.shape(), std::vector<int>{4, 4}); CHECK_EQ(id_4.shape(), std::vector<int>{4, 4});