mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Add tile op (#438)
This commit is contained in:
parent
1b71487e1f
commit
2e29d0815b
@ -10,7 +10,7 @@ MLX was developed with contributions from the following individuals:
|
|||||||
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops.
|
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops.
|
||||||
- Juarez Bochi: Fixed bug in cross attention.
|
- Juarez Bochi: Fixed bug in cross attention.
|
||||||
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
|
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
|
||||||
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer` and safetensor support
|
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile` and safetensor support
|
||||||
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer.
|
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer.
|
||||||
|
|
||||||
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
||||||
|
30
mlx/ops.cpp
30
mlx/ops.cpp
@ -753,6 +753,36 @@ array repeat(const array& arr, int repeats, StreamOrDevice s) {
|
|||||||
return repeat(flatten(arr, s), repeats, 0, s);
|
return repeat(flatten(arr, s), repeats, 0, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array tile(
|
||||||
|
const array& arr,
|
||||||
|
std::vector<int> reps,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
auto shape = arr.shape();
|
||||||
|
if (reps.size() < shape.size()) {
|
||||||
|
reps.insert(reps.begin(), shape.size() - reps.size(), 1);
|
||||||
|
}
|
||||||
|
if (reps.size() > shape.size()) {
|
||||||
|
shape.insert(shape.begin(), reps.size() - shape.size(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> expand_shape;
|
||||||
|
std::vector<int> broad_shape;
|
||||||
|
std::vector<int> final_shape;
|
||||||
|
for (int i = 0; i < shape.size(); i++) {
|
||||||
|
if (reps[i] != 1) {
|
||||||
|
expand_shape.push_back(1);
|
||||||
|
broad_shape.push_back(reps[i]);
|
||||||
|
}
|
||||||
|
expand_shape.push_back(shape[i]);
|
||||||
|
broad_shape.push_back(shape[i]);
|
||||||
|
final_shape.push_back(reps[i] * shape[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto x = reshape(arr, expand_shape, s);
|
||||||
|
x = broadcast_to(x, broad_shape, s);
|
||||||
|
return reshape(x, final_shape, s);
|
||||||
|
}
|
||||||
|
|
||||||
/** Pad an array with a constant value */
|
/** Pad an array with a constant value */
|
||||||
array pad(
|
array pad(
|
||||||
const array& a,
|
const array& a,
|
||||||
|
@ -218,6 +218,8 @@ array stack(const std::vector<array>& arrays, StreamOrDevice s = {});
|
|||||||
array repeat(const array& arr, int repeats, int axis, StreamOrDevice s = {});
|
array repeat(const array& arr, int repeats, int axis, StreamOrDevice s = {});
|
||||||
array repeat(const array& arr, int repeats, StreamOrDevice s = {});
|
array repeat(const array& arr, int repeats, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
array tile(const array& arr, std::vector<int> reps, StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Permutes the dimensions according to the given axes. */
|
/** Permutes the dimensions according to the given axes. */
|
||||||
array transpose(const array& a, std::vector<int> axes, StreamOrDevice s = {});
|
array transpose(const array& a, std::vector<int> axes, StreamOrDevice s = {});
|
||||||
inline array transpose(
|
inline array transpose(
|
||||||
|
@ -3394,4 +3394,30 @@ void init_ops(py::module_& m) {
|
|||||||
Returns:
|
Returns:
|
||||||
result (array): The outer product.
|
result (array): The outer product.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"tile",
|
||||||
|
[](const array& a, const IntOrVec& reps, StreamOrDevice s) {
|
||||||
|
if (auto pv = std::get_if<int>(&reps); pv) {
|
||||||
|
return tile(a, {*pv}, s);
|
||||||
|
} else {
|
||||||
|
return tile(a, std::get<std::vector<int>>(reps), s);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"a"_a,
|
||||||
|
"reps"_a,
|
||||||
|
py::pos_only(),
|
||||||
|
py::kw_only(),
|
||||||
|
"stream"_a = none,
|
||||||
|
R"pbdoc(
|
||||||
|
tile(a: array, reps: Union[int, List[int]], /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
|
Construct an array by repeating ``a`` the number of times given by ``reps``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
a (array): Input array
|
||||||
|
reps (int or list(int)): The number of times to repeat ``a`` along each axis.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
result (array): The tiled array.
|
||||||
|
)pbdoc");
|
||||||
}
|
}
|
||||||
|
@ -24,9 +24,10 @@ class MLXTestCase(unittest.TestCase):
|
|||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
mx.set_default_device(self.default)
|
mx.set_default_device(self.default)
|
||||||
|
|
||||||
|
# Note if a tuple is passed into args, it will be considered a shape request and convert to a mx.random.normal with the shape matching the tuple
|
||||||
def assertCmpNumpy(
|
def assertCmpNumpy(
|
||||||
self,
|
self,
|
||||||
shape: List[Union[Tuple[int], Any]],
|
args: List[Union[Tuple[int], Any]],
|
||||||
mx_fn: Callable[..., mx.array],
|
mx_fn: Callable[..., mx.array],
|
||||||
np_fn: Callable[..., np.array],
|
np_fn: Callable[..., np.array],
|
||||||
atol=1e-2,
|
atol=1e-2,
|
||||||
@ -37,7 +38,7 @@ class MLXTestCase(unittest.TestCase):
|
|||||||
assert dtype != mx.bfloat16, "numpy does not support bfloat16"
|
assert dtype != mx.bfloat16, "numpy does not support bfloat16"
|
||||||
args = [
|
args = [
|
||||||
mx.random.normal(s, dtype=dtype) if isinstance(s, Tuple) else s
|
mx.random.normal(s, dtype=dtype) if isinstance(s, Tuple) else s
|
||||||
for s in shape
|
for s in args
|
||||||
]
|
]
|
||||||
mx_res = mx_fn(*args, **kwargs)
|
mx_res = mx_fn(*args, **kwargs)
|
||||||
np_res = np_fn(
|
np_res = np_fn(
|
||||||
|
@ -1634,6 +1634,23 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
np.allclose(np_out[0], mx_out[0]), msg=f"Shapes {s1} {s2}, Type {t}"
|
np.allclose(np_out[0], mx_out[0]), msg=f"Shapes {s1} {s2}, Type {t}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_tile(self):
|
||||||
|
self.assertCmpNumpy([(2,), [2]], mx.tile, np.tile)
|
||||||
|
self.assertCmpNumpy([(2, 3, 4), [2]], mx.tile, np.tile)
|
||||||
|
self.assertCmpNumpy([(2, 3, 4), [2, 1]], mx.tile, np.tile)
|
||||||
|
self.assertCmpNumpy(
|
||||||
|
[
|
||||||
|
(2, 3, 4),
|
||||||
|
[
|
||||||
|
2,
|
||||||
|
2,
|
||||||
|
],
|
||||||
|
],
|
||||||
|
mx.tile,
|
||||||
|
np.tile,
|
||||||
|
)
|
||||||
|
self.assertCmpNumpy([(3,), [2, 2, 2]], mx.tile, np.tile)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -2343,6 +2343,32 @@ TEST_CASE("test repeat") {
|
|||||||
CHECK_THROWS_AS(repeat(data_3, -3, 0), std::invalid_argument);
|
CHECK_THROWS_AS(repeat(data_3, -3, 0), std::invalid_argument);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_CASE("tile") {
|
||||||
|
auto x = array({1, 2, 3}, {3});
|
||||||
|
auto y = tile(x, {2});
|
||||||
|
auto expected = array({1, 2, 3, 1, 2, 3}, {6});
|
||||||
|
CHECK(array_equal(y, expected).item<bool>());
|
||||||
|
x = array({1, 2, 3, 4}, {2, 2});
|
||||||
|
y = tile(x, {2});
|
||||||
|
expected = array({1, 2, 1, 2, 3, 4, 3, 4}, {2, 4});
|
||||||
|
CHECK(array_equal(y, expected).item<bool>());
|
||||||
|
x = array({1, 2, 3, 4}, {2, 2});
|
||||||
|
y = tile(x, {4, 1});
|
||||||
|
expected = array({1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4}, {8, 2});
|
||||||
|
CHECK(array_equal(y, expected).item<bool>());
|
||||||
|
|
||||||
|
x = array({1, 2, 3, 4}, {2, 2});
|
||||||
|
y = tile(x, {2, 2});
|
||||||
|
expected = array({1, 2, 1, 2, 3, 4, 3, 4, 1, 2, 1, 2, 3, 4, 3, 4}, {4, 4});
|
||||||
|
CHECK(array_equal(y, expected).item<bool>());
|
||||||
|
x = array({1, 2, 3}, {3});
|
||||||
|
y = tile(x, {2, 2, 2});
|
||||||
|
expected = array(
|
||||||
|
{1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3},
|
||||||
|
{2, 2, 6});
|
||||||
|
CHECK(array_equal(y, expected).item<bool>());
|
||||||
|
}
|
||||||
|
|
||||||
TEST_CASE("tensordot") {
|
TEST_CASE("tensordot") {
|
||||||
auto x = reshape(arange(60.), {3, 4, 5});
|
auto x = reshape(arange(60.), {3, 4, 5});
|
||||||
auto y = reshape(arange(24.), {4, 3, 2});
|
auto y = reshape(arange(24.), {4, 3, 2});
|
||||||
|
Loading…
Reference in New Issue
Block a user