mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Fix concatenate vmap (#1600)
This commit is contained in:
parent
2af7e8a9a6
commit
5e89aace9b
@ -836,31 +836,43 @@ std::vector<array> Concatenate::jvp(
|
|||||||
std::pair<std::vector<array>, std::vector<int>> Concatenate::vmap(
|
std::pair<std::vector<array>, std::vector<int>> Concatenate::vmap(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<int>& axes) {
|
const std::vector<int>& axes) {
|
||||||
std::vector<array> t_inputs;
|
|
||||||
int out_ax = -1;
|
int out_ax = -1;
|
||||||
|
int first_vmap = -1;
|
||||||
|
|
||||||
// Find the first vmapped input
|
// Find the first vmapped input
|
||||||
int i = 0;
|
for (int i = 0; i < axes.size(); i++) {
|
||||||
for (; i < axes.size(); i++) {
|
|
||||||
t_inputs.push_back(inputs[i]);
|
|
||||||
if (axes[i] >= 0) {
|
if (axes[i] >= 0) {
|
||||||
out_ax = axes[i];
|
out_ax = axes[i];
|
||||||
|
first_vmap = i;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (out_ax >= 0) {
|
|
||||||
// Advance to the next input
|
|
||||||
i++;
|
|
||||||
|
|
||||||
// Move vmap axes to the same spot.
|
// No vmap, should we even be in here?
|
||||||
for (; i < axes.size(); ++i) {
|
if (out_ax < 0) {
|
||||||
if (out_ax != axes[i] && axes[i] >= 0) {
|
return {{concatenate(inputs, axis_, stream())}, {out_ax}};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make sure vmapped arrays have all vmapped axes in the same location and
|
||||||
|
// expand non-vmapped arrays to be compatible with the vmapped ones.
|
||||||
|
std::vector<array> t_inputs;
|
||||||
|
int N = inputs[first_vmap].shape(out_ax);
|
||||||
|
int axis = axis_ + (axis_ >= out_ax);
|
||||||
|
auto cat_shape = inputs[first_vmap].shape();
|
||||||
|
for (int i = 0; i < axes.size(); i++) {
|
||||||
|
if (axes[i] >= 0) {
|
||||||
|
if (out_ax != axes[i]) {
|
||||||
t_inputs.push_back(moveaxis(inputs[i], axes[i], out_ax, stream()));
|
t_inputs.push_back(moveaxis(inputs[i], axes[i], out_ax, stream()));
|
||||||
} else {
|
} else {
|
||||||
t_inputs.push_back(inputs[i]);
|
t_inputs.push_back(inputs[i]);
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
cat_shape[axis] = inputs[i].shape(axis_);
|
||||||
|
t_inputs.push_back(broadcast_to(
|
||||||
|
expand_dims(inputs[i], out_ax, stream()), cat_shape, stream()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
auto axis = axis_ + (out_ax >= 0 && axis_ >= out_ax);
|
|
||||||
return {{concatenate(t_inputs, axis, stream())}, {out_ax}};
|
return {{concatenate(t_inputs, axis, stream())}, {out_ax}};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -527,6 +527,26 @@ class TestVmap(mlx_tests.MLXTestCase):
|
|||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
out = mx.vmap(const_func, in_axes=(0, 0))(a, b)
|
out = mx.vmap(const_func, in_axes=(0, 0))(a, b)
|
||||||
|
|
||||||
|
def test_vmap_concatenate(self):
|
||||||
|
x = mx.random.uniform(shape=(2, 2, 2))
|
||||||
|
|
||||||
|
def cat_fun(x, y):
|
||||||
|
return mx.concatenate([x, y], axis=1)
|
||||||
|
|
||||||
|
def cat_constant(x):
|
||||||
|
y = mx.ones((2, 1))
|
||||||
|
return mx.concatenate([x, y], 1)
|
||||||
|
|
||||||
|
out = mx.vmap(cat_fun, in_axes=(0, 2))(x, x)
|
||||||
|
target = mx.stack(
|
||||||
|
[mx.concatenate([x[i], x[:, :, i]], axis=1) for i in range(2)]
|
||||||
|
)
|
||||||
|
self.assertTrue(mx.array_equal(out, target))
|
||||||
|
|
||||||
|
out = mx.vmap(cat_constant)(x)
|
||||||
|
target = mx.concatenate([x, mx.ones((2, 2, 1))], axis=2)
|
||||||
|
self.assertTrue(mx.array_equal(out, target))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user