mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Added support for atleast_1d, atleast_2d, atleast_3d (#694)
This commit is contained in:
parent
e1bdf6a8d9
commit
f883fcede0
@ -10,8 +10,9 @@ MLX was developed with contributions from the following individuals:
|
|||||||
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops.
|
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops.
|
||||||
- Juarez Bochi: Fixed bug in cross attention.
|
- Juarez Bochi: Fixed bug in cross attention.
|
||||||
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
|
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
|
||||||
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream` and safetensor support
|
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream` and safetensor support.
|
||||||
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented ``MaxPool1d``, ``MaxPool2d``, ``AvgPool1d``, ``AvgPool2d``.
|
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented ``MaxPool1d``, ``MaxPool2d``, ``AvgPool1d``, ``AvgPool2d``.
|
||||||
|
- Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops.
|
||||||
|
|
||||||
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
||||||
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
||||||
|
@ -25,6 +25,9 @@ Operations
|
|||||||
argpartition
|
argpartition
|
||||||
argsort
|
argsort
|
||||||
array_equal
|
array_equal
|
||||||
|
atleast_1d
|
||||||
|
atleast_2d
|
||||||
|
atleast_3d
|
||||||
broadcast_to
|
broadcast_to
|
||||||
ceil
|
ceil
|
||||||
clip
|
clip
|
||||||
|
30
mlx/ops.cpp
30
mlx/ops.cpp
@ -3381,4 +3381,34 @@ std::vector<array> depends(
|
|||||||
shapes, dtypes, std::make_shared<Depends>(to_stream(s)), all_inputs);
|
shapes, dtypes, std::make_shared<Depends>(to_stream(s)), all_inputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array atleast_1d(const array& a, StreamOrDevice s /* = {} */) {
|
||||||
|
if (a.ndim() == 0) {
|
||||||
|
return reshape(a, {1}, s);
|
||||||
|
}
|
||||||
|
return a;
|
||||||
|
}
|
||||||
|
|
||||||
|
array atleast_2d(const array& a, StreamOrDevice s /* = {} */) {
|
||||||
|
switch (a.ndim()) {
|
||||||
|
case 0:
|
||||||
|
return reshape(a, {1, 1}, s);
|
||||||
|
case 1:
|
||||||
|
return reshape(a, {1, static_cast<int>(a.size())}, s);
|
||||||
|
default:
|
||||||
|
return a;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
array atleast_3d(const array& a, StreamOrDevice s /* = {} */) {
|
||||||
|
switch (a.ndim()) {
|
||||||
|
case 0:
|
||||||
|
return reshape(a, {1, 1, 1}, s);
|
||||||
|
case 1:
|
||||||
|
return reshape(a, {1, static_cast<int>(a.size()), 1}, s);
|
||||||
|
case 2:
|
||||||
|
return reshape(a, {a.shape(0), a.shape(1), 1}, s);
|
||||||
|
default:
|
||||||
|
return a;
|
||||||
|
}
|
||||||
|
}
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -1121,4 +1121,9 @@ std::vector<array> depends(
|
|||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<array>& dependencies);
|
const std::vector<array>& dependencies);
|
||||||
|
|
||||||
|
/** convert an array to an atleast ndim array */
|
||||||
|
array atleast_1d(const array& a, StreamOrDevice s = {});
|
||||||
|
array atleast_2d(const array& a, StreamOrDevice s = {});
|
||||||
|
array atleast_3d(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -3636,4 +3636,64 @@ void init_ops(py::module_& m) {
|
|||||||
Returns:
|
Returns:
|
||||||
array: The extracted diagonal or the constructed diagonal matrix.
|
array: The extracted diagonal or the constructed diagonal matrix.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"atleast_1d",
|
||||||
|
&atleast_1d,
|
||||||
|
"a"_a,
|
||||||
|
py::pos_only(),
|
||||||
|
py::kw_only(),
|
||||||
|
"stream"_a = none,
|
||||||
|
R"pbdoc(
|
||||||
|
atleast_1d(a: array, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
|
Convert array to have at least one dimension.
|
||||||
|
|
||||||
|
args:
|
||||||
|
a (array): Input array
|
||||||
|
stream (Union[None, Stream, Device], optional): The stream to execute the operation on.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: An array with at least one dimension.
|
||||||
|
|
||||||
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"atleast_2d",
|
||||||
|
&atleast_2d,
|
||||||
|
"a"_a,
|
||||||
|
py::pos_only(),
|
||||||
|
py::kw_only(),
|
||||||
|
"stream"_a = none,
|
||||||
|
R"pbdoc(
|
||||||
|
atleast_2d(a: array, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
|
Convert array to have at least two dimensions.
|
||||||
|
|
||||||
|
args:
|
||||||
|
a (array): Input array
|
||||||
|
stream (Union[None, Stream, Device], optional): The stream to execute the operation on.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: An array with at least two dimensions.
|
||||||
|
|
||||||
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"atleast_3d",
|
||||||
|
&atleast_3d,
|
||||||
|
"a"_a,
|
||||||
|
py::pos_only(),
|
||||||
|
py::kw_only(),
|
||||||
|
"stream"_a = none,
|
||||||
|
R"pbdoc(
|
||||||
|
atleast_3d(a: array, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
|
Convert array to have at least three dimensions.
|
||||||
|
|
||||||
|
args:
|
||||||
|
a (array): Input array
|
||||||
|
stream (Union[None, Stream, Device], optional): The stream to execute the operation on.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: An array with at least three dimensions.
|
||||||
|
|
||||||
|
)pbdoc");
|
||||||
}
|
}
|
||||||
|
@ -1883,6 +1883,96 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
expected = mx.array(np.diag(x, k=-1))
|
expected = mx.array(np.diag(x, k=-1))
|
||||||
self.assertTrue(mx.array_equal(result, expected))
|
self.assertTrue(mx.array_equal(result, expected))
|
||||||
|
|
||||||
|
def test_atleast_1d(self):
|
||||||
|
def compare_nested_lists(x, y):
|
||||||
|
if isinstance(x, list) and isinstance(y, list):
|
||||||
|
if len(x) != len(y):
|
||||||
|
return False
|
||||||
|
for i in range(len(x)):
|
||||||
|
if not compare_nested_lists(x[i], y[i]):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return x == y
|
||||||
|
|
||||||
|
# Test 1D input
|
||||||
|
arrays = [
|
||||||
|
[1],
|
||||||
|
[1, 2, 3],
|
||||||
|
[1, 2, 3, 4],
|
||||||
|
[[1], [2], [3]],
|
||||||
|
[[1, 2], [3, 4]],
|
||||||
|
[[1, 2, 3], [4, 5, 6]],
|
||||||
|
[[[[1]], [[2]], [[3]]]],
|
||||||
|
]
|
||||||
|
|
||||||
|
for array in arrays:
|
||||||
|
mx_res = mx.atleast_1d(mx.array(array))
|
||||||
|
np_res = np.atleast_1d(np.array(array))
|
||||||
|
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
|
||||||
|
self.assertEqual(mx_res.shape, np_res.shape)
|
||||||
|
self.assertEqual(mx_res.ndim, np_res.ndim)
|
||||||
|
|
||||||
|
def test_atleast_2d(self):
|
||||||
|
def compare_nested_lists(x, y):
|
||||||
|
if isinstance(x, list) and isinstance(y, list):
|
||||||
|
if len(x) != len(y):
|
||||||
|
return False
|
||||||
|
for i in range(len(x)):
|
||||||
|
if not compare_nested_lists(x[i], y[i]):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return x == y
|
||||||
|
|
||||||
|
# Test 1D input
|
||||||
|
arrays = [
|
||||||
|
[1],
|
||||||
|
[1, 2, 3],
|
||||||
|
[1, 2, 3, 4],
|
||||||
|
[[1], [2], [3]],
|
||||||
|
[[1, 2], [3, 4]],
|
||||||
|
[[1, 2, 3], [4, 5, 6]],
|
||||||
|
[[[[1]], [[2]], [[3]]]],
|
||||||
|
]
|
||||||
|
|
||||||
|
for array in arrays:
|
||||||
|
mx_res = mx.atleast_2d(mx.array(array))
|
||||||
|
np_res = np.atleast_2d(np.array(array))
|
||||||
|
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
|
||||||
|
self.assertEqual(mx_res.shape, np_res.shape)
|
||||||
|
self.assertEqual(mx_res.ndim, np_res.ndim)
|
||||||
|
|
||||||
|
def test_atleast_3d(self):
|
||||||
|
def compare_nested_lists(x, y):
|
||||||
|
if isinstance(x, list) and isinstance(y, list):
|
||||||
|
if len(x) != len(y):
|
||||||
|
return False
|
||||||
|
for i in range(len(x)):
|
||||||
|
if not compare_nested_lists(x[i], y[i]):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return x == y
|
||||||
|
|
||||||
|
# Test 1D input
|
||||||
|
arrays = [
|
||||||
|
[1],
|
||||||
|
[1, 2, 3],
|
||||||
|
[1, 2, 3, 4],
|
||||||
|
[[1], [2], [3]],
|
||||||
|
[[1, 2], [3, 4]],
|
||||||
|
[[1, 2, 3], [4, 5, 6]],
|
||||||
|
[[[[1]], [[2]], [[3]]]],
|
||||||
|
]
|
||||||
|
|
||||||
|
for array in arrays:
|
||||||
|
mx_res = mx.atleast_3d(mx.array(array))
|
||||||
|
np_res = np.atleast_3d(np.array(array))
|
||||||
|
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
|
||||||
|
self.assertEqual(mx_res.shape, np_res.shape)
|
||||||
|
self.assertEqual(mx_res.ndim, np_res.ndim)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -2716,3 +2716,54 @@ TEST_CASE("test diag") {
|
|||||||
out = diag(x, -1);
|
out = diag(x, -1);
|
||||||
CHECK(array_equal(out, array({3, 7}, {2})).item<bool>());
|
CHECK(array_equal(out, array({3, 7}, {2})).item<bool>());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test atleast_1d") {
|
||||||
|
auto x = array(1);
|
||||||
|
auto out = atleast_1d(x);
|
||||||
|
CHECK_EQ(out.ndim(), 1);
|
||||||
|
CHECK_EQ(out.shape(), std::vector<int>{1});
|
||||||
|
|
||||||
|
x = array({1, 2, 3}, {3});
|
||||||
|
out = atleast_1d(x);
|
||||||
|
CHECK_EQ(out.ndim(), 1);
|
||||||
|
CHECK_EQ(out.shape(), std::vector<int>{3});
|
||||||
|
|
||||||
|
x = array({1, 2, 3}, {3, 1});
|
||||||
|
out = atleast_1d(x);
|
||||||
|
CHECK_EQ(out.ndim(), 2);
|
||||||
|
CHECK_EQ(out.shape(), std::vector<int>{3, 1});
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test atleast_2d") {
|
||||||
|
auto x = array(1);
|
||||||
|
auto out = atleast_2d(x);
|
||||||
|
CHECK_EQ(out.ndim(), 2);
|
||||||
|
CHECK_EQ(out.shape(), std::vector<int>{1, 1});
|
||||||
|
|
||||||
|
x = array({1, 2, 3}, {3});
|
||||||
|
out = atleast_2d(x);
|
||||||
|
CHECK_EQ(out.ndim(), 2);
|
||||||
|
CHECK_EQ(out.shape(), std::vector<int>{1, 3});
|
||||||
|
|
||||||
|
x = array({1, 2, 3}, {3, 1});
|
||||||
|
out = atleast_2d(x);
|
||||||
|
CHECK_EQ(out.ndim(), 2);
|
||||||
|
CHECK_EQ(out.shape(), std::vector<int>{3, 1});
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test atleast_3d") {
|
||||||
|
auto x = array(1);
|
||||||
|
auto out = atleast_3d(x);
|
||||||
|
CHECK_EQ(out.ndim(), 3);
|
||||||
|
CHECK_EQ(out.shape(), std::vector<int>{1, 1, 1});
|
||||||
|
|
||||||
|
x = array({1, 2, 3}, {3});
|
||||||
|
out = atleast_3d(x);
|
||||||
|
CHECK_EQ(out.ndim(), 3);
|
||||||
|
CHECK_EQ(out.shape(), std::vector<int>{1, 3, 1});
|
||||||
|
|
||||||
|
x = array({1, 2, 3}, {3, 1});
|
||||||
|
out = atleast_3d(x);
|
||||||
|
CHECK_EQ(out.ndim(), 3);
|
||||||
|
CHECK_EQ(out.shape(), std::vector<int>{3, 1, 1});
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user