mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 09:51:17 +08:00
added tri / tril / triu (#170)
* added tri / tril / triu * fixed tests * ctest tests * tri overload and simplified tests * changes from comment * more tests for m * ensure assert if not 2-D * remove broadcast_to * minor tweaks --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
2e02acdc83
commit
dc2edc762c
@ -1,6 +1,6 @@
|
|||||||
repos:
|
repos:
|
||||||
- repo: https://github.com/pre-commit/mirrors-clang-format
|
- repo: https://github.com/pre-commit/mirrors-clang-format
|
||||||
rev: v14.0.6
|
rev: v17.0.6
|
||||||
hooks:
|
hooks:
|
||||||
- id: clang-format
|
- id: clang-format
|
||||||
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
|
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
|
||||||
|
@ -95,6 +95,9 @@ Operations
|
|||||||
tan
|
tan
|
||||||
tanh
|
tanh
|
||||||
transpose
|
transpose
|
||||||
|
tri
|
||||||
|
tril
|
||||||
|
triu
|
||||||
var
|
var
|
||||||
where
|
where
|
||||||
zeros
|
zeros
|
||||||
|
22
mlx/ops.cpp
22
mlx/ops.cpp
@ -218,6 +218,28 @@ array identity(int n, Dtype dtype, StreamOrDevice s /* = {} */) {
|
|||||||
return eye(n, n, 0, dtype, s);
|
return eye(n, n, 0, dtype, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array tri(int n, int m, int k, Dtype type, StreamOrDevice s /* = {} */) {
|
||||||
|
auto l = expand_dims(arange(n, s), 1, s);
|
||||||
|
auto r = expand_dims(arange(-k, m - k, s), 0, s);
|
||||||
|
return astype(greater_equal(l, r, s), type, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
array tril(array x, int k, StreamOrDevice s /* = {} */) {
|
||||||
|
if (x.ndim() < 2) {
|
||||||
|
throw std::invalid_argument("[tril] array must be atleast 2-D");
|
||||||
|
}
|
||||||
|
auto mask = tri(x.shape(-2), x.shape(-1), k, x.dtype(), s);
|
||||||
|
return where(mask, x, zeros_like(x, s), s);
|
||||||
|
}
|
||||||
|
|
||||||
|
array triu(array x, int k, StreamOrDevice s /* = {} */) {
|
||||||
|
if (x.ndim() < 2) {
|
||||||
|
throw std::invalid_argument("[triu] array must be atleast 2-D");
|
||||||
|
}
|
||||||
|
auto mask = tri(x.shape(-2), x.shape(-1), k - 1, x.dtype(), s);
|
||||||
|
return where(mask, zeros_like(x, s), x, s);
|
||||||
|
}
|
||||||
|
|
||||||
array reshape(
|
array reshape(
|
||||||
const array& a,
|
const array& a,
|
||||||
std::vector<int> shape,
|
std::vector<int> shape,
|
||||||
|
@ -110,6 +110,14 @@ inline array identity(int n, StreamOrDevice s = {}) {
|
|||||||
return identity(n, float32, s);
|
return identity(n, float32, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array tri(int n, int m, int k, Dtype type, StreamOrDevice s = {});
|
||||||
|
inline array tri(int n, Dtype type, StreamOrDevice s = {}) {
|
||||||
|
return tri(n, n, 0, type, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
array tril(array x, int k, StreamOrDevice s = {});
|
||||||
|
array triu(array x, int k, StreamOrDevice s = {});
|
||||||
|
|
||||||
/** array manipulation */
|
/** array manipulation */
|
||||||
|
|
||||||
/** Reshape an array to the given shape. */
|
/** Reshape an array to the given shape. */
|
||||||
|
@ -1410,6 +1410,72 @@ void init_ops(py::module_& m) {
|
|||||||
Returns:
|
Returns:
|
||||||
array: An identity matrix of size n x n.
|
array: An identity matrix of size n x n.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"tri",
|
||||||
|
[](int n, std::optional<int> m, int k, Dtype dtype, StreamOrDevice s) {
|
||||||
|
return tri(n, m.value_or(n), k, float32, s);
|
||||||
|
},
|
||||||
|
"n"_a,
|
||||||
|
"m"_a = none,
|
||||||
|
"k"_a = 0,
|
||||||
|
"dtype"_a = float32,
|
||||||
|
py::kw_only(),
|
||||||
|
"stream"_a = none,
|
||||||
|
R"pbdoc(
|
||||||
|
tri(n: int, m: int, k: int, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
|
An array with ones at and below the given diagonal and zeros elsewhere.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n (int): The number of rows in the output.
|
||||||
|
m (int, optional): The number of cols in the output. Defaults to ``None``.
|
||||||
|
k (int, optional): The diagonal of the 2-D array. Defaults to ``0``.
|
||||||
|
dtype (Dtype, optional): Data type of the output array. Defaults to ``float32``.
|
||||||
|
stream (Stream, optional): Stream or device. Defaults to ``None``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: Array with its lower triangle filled with ones and zeros elsewhere
|
||||||
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"tril",
|
||||||
|
&tril,
|
||||||
|
"x"_a,
|
||||||
|
"k"_a = 0,
|
||||||
|
py::kw_only(),
|
||||||
|
"stream"_a = none,
|
||||||
|
R"pbdoc(
|
||||||
|
tril(x: array, k: int, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
|
Zeros the array above the given diagonal.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (array): input array.
|
||||||
|
k (int, optional): The diagonal of the 2-D array. Defaults to ``0``.
|
||||||
|
stream (Stream, optional): Stream or device. Defaults to ``None``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: Array zeroed above the given diagonal
|
||||||
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"triu",
|
||||||
|
&triu,
|
||||||
|
"x"_a,
|
||||||
|
"k"_a = 0,
|
||||||
|
py::kw_only(),
|
||||||
|
"stream"_a = none,
|
||||||
|
R"pbdoc(
|
||||||
|
triu(x: array, k: int, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
|
Zeros the array below the given diagonal.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (array): input array.
|
||||||
|
k (int, optional): The diagonal of the 2-D array. Defaults to ``0``.
|
||||||
|
stream (Stream, optional): Stream or device. Defaults to ``None``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: Array zeroed below the given diagonal
|
||||||
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"allclose",
|
"allclose",
|
||||||
&allclose,
|
&allclose,
|
||||||
|
@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
from typing import Callable, List, Tuple, Union
|
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -21,13 +20,16 @@ class MLXTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
def assertEqualArray(
|
def assertEqualArray(
|
||||||
self,
|
self,
|
||||||
args: List[Union[mx.array, float, int]],
|
mx_res: mx.array,
|
||||||
mlx_func: Callable[..., mx.array],
|
|
||||||
expected: mx.array,
|
expected: mx.array,
|
||||||
atol=1e-2,
|
atol=1e-2,
|
||||||
rtol=1e-2,
|
rtol=1e-2,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
mx_res = mlx_func(*args)
|
assert tuple(mx_res.shape) == tuple(
|
||||||
assert tuple(mx_res.shape) == tuple(expected.shape), "shape mismatch"
|
expected.shape
|
||||||
assert mx_res.dtype == expected.dtype, "dtype mismatch"
|
), f"shape mismatch expected={expected.shape} got={mx_res.shape}"
|
||||||
|
assert (
|
||||||
|
mx_res.dtype == expected.dtype
|
||||||
|
), f"dtype mismatch expected={expected.dtype} got={mx_res.dtype}"
|
||||||
np.testing.assert_allclose(mx_res, expected, rtol=rtol, atol=atol)
|
np.testing.assert_allclose(mx_res, expected, rtol=rtol, atol=atol)
|
||||||
|
@ -451,15 +451,13 @@ class TestNN(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
def test_prelu(self):
|
def test_prelu(self):
|
||||||
self.assertEqualArray(
|
self.assertEqualArray(
|
||||||
[mx.array([1.0, -1.0, 0.0, 0.5])],
|
nn.PReLU()(mx.array([1.0, -1.0, 0.0, 0.5])),
|
||||||
nn.PReLU(),
|
|
||||||
mx.array([1.0, -0.25, 0.0, 0.5]),
|
mx.array([1.0, -0.25, 0.0, 0.5]),
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_mish(self):
|
def test_mish(self):
|
||||||
self.assertEqualArray(
|
self.assertEqualArray(
|
||||||
[mx.array([1.0, -1.0, 0.0, 0.5])],
|
nn.Mish()(mx.array([1.0, -1.0, 0.0, 0.5])),
|
||||||
nn.Mish(),
|
|
||||||
mx.array([0.8651, -0.3034, 0.0000, 0.3752]),
|
mx.array([0.8651, -0.3034, 0.0000, 0.3752]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -320,6 +320,30 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
self.assertFalse(mx.array_equal(x, y))
|
self.assertFalse(mx.array_equal(x, y))
|
||||||
self.assertTrue(mx.array_equal(x, y, equal_nan=True))
|
self.assertTrue(mx.array_equal(x, y, equal_nan=True))
|
||||||
|
|
||||||
|
def test_tri(self):
|
||||||
|
for shape in [[4], [4, 4], [2, 10]]:
|
||||||
|
for diag in [-1, 0, 1, -2]:
|
||||||
|
self.assertEqualArray(
|
||||||
|
mx.tri(*shape, k=diag), mx.array(np.tri(*shape, k=diag))
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_tril(self):
|
||||||
|
mt = mx.random.normal((10, 10))
|
||||||
|
nt = np.array(mt)
|
||||||
|
for diag in [-1, 0, 1, -2]:
|
||||||
|
self.assertEqualArray(mx.tril(mt, diag), mx.array(np.tril(nt, diag)))
|
||||||
|
|
||||||
|
with self.assertRaises(Exception):
|
||||||
|
mx.tril(mx.zeros((1)))
|
||||||
|
|
||||||
|
def test_triu(self):
|
||||||
|
mt = mx.random.normal((10, 10))
|
||||||
|
nt = np.array(mt)
|
||||||
|
for diag in [-1, 0, 1, -2]:
|
||||||
|
self.assertEqualArray(mx.triu(mt, diag), mx.array(np.triu(nt, diag)))
|
||||||
|
with self.assertRaises(Exception):
|
||||||
|
mx.triu(mx.zeros((1)))
|
||||||
|
|
||||||
def test_minimum(self):
|
def test_minimum(self):
|
||||||
x = mx.array([0.0, -5, 10.0])
|
x = mx.array([0.0, -5, 10.0])
|
||||||
y = mx.array([1.0, -7.0, 3.0])
|
y = mx.array([1.0, -7.0, 3.0])
|
||||||
|
@ -2031,6 +2031,78 @@ TEST_CASE("test eye") {
|
|||||||
CHECK(array_equal(eye_3x2, expected_eye_3x2).item<bool>());
|
CHECK(array_equal(eye_3x2, expected_eye_3x2).item<bool>());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test tri") {
|
||||||
|
auto _tri = tri(4, 4, 0, float32);
|
||||||
|
CHECK_EQ(_tri.shape(), std::vector<int>{4, 4});
|
||||||
|
auto expected_tri = array(
|
||||||
|
{1.0f,
|
||||||
|
0.0f,
|
||||||
|
0.0f,
|
||||||
|
0.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
0.0f,
|
||||||
|
0.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
0.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f},
|
||||||
|
{4, 4});
|
||||||
|
CHECK(array_equal(_tri, expected_tri).item<bool>());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test tril") {
|
||||||
|
auto _tril = tril(full(std::vector<int>{4, 4}, 2.0f, float32), 0);
|
||||||
|
CHECK_EQ(_tril.shape(), std::vector<int>{4, 4});
|
||||||
|
auto expected_tri = array(
|
||||||
|
{2.0f,
|
||||||
|
0.0f,
|
||||||
|
0.0f,
|
||||||
|
0.0f,
|
||||||
|
2.0f,
|
||||||
|
2.0f,
|
||||||
|
0.0f,
|
||||||
|
0.0f,
|
||||||
|
2.0f,
|
||||||
|
2.0f,
|
||||||
|
2.0f,
|
||||||
|
0.0f,
|
||||||
|
2.0f,
|
||||||
|
2.0f,
|
||||||
|
2.0f,
|
||||||
|
2.0f},
|
||||||
|
{4, 4});
|
||||||
|
CHECK(array_equal(_tril, expected_tri).item<bool>());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test triu") {
|
||||||
|
auto _triu = triu(full(std::vector<int>{4, 4}, 2.0f, float32), 0);
|
||||||
|
CHECK_EQ(_triu.shape(), std::vector<int>{4, 4});
|
||||||
|
auto expected_tri = array(
|
||||||
|
{2.0f,
|
||||||
|
2.0f,
|
||||||
|
2.0f,
|
||||||
|
2.0f,
|
||||||
|
0.0f,
|
||||||
|
2.0f,
|
||||||
|
2.0f,
|
||||||
|
2.0f,
|
||||||
|
0.0f,
|
||||||
|
0.0f,
|
||||||
|
2.0f,
|
||||||
|
2.0f,
|
||||||
|
0.0f,
|
||||||
|
0.0f,
|
||||||
|
0.0f,
|
||||||
|
2.0f},
|
||||||
|
{4, 4});
|
||||||
|
CHECK(array_equal(_triu, expected_tri).item<bool>());
|
||||||
|
}
|
||||||
|
|
||||||
TEST_CASE("test identity") {
|
TEST_CASE("test identity") {
|
||||||
auto id_4 = identity(4);
|
auto id_4 = identity(4);
|
||||||
CHECK_EQ(id_4.shape(), std::vector<int>{4, 4});
|
CHECK_EQ(id_4.shape(), std::vector<int>{4, 4});
|
||||||
|
Loading…
Reference in New Issue
Block a user