mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Fix concatenate vmap (#1600)
This commit is contained in:
parent
2af7e8a9a6
commit
5e89aace9b
@ -836,31 +836,43 @@ std::vector<array> Concatenate::jvp(
|
||||
std::pair<std::vector<array>, std::vector<int>> Concatenate::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
std::vector<array> t_inputs;
|
||||
int out_ax = -1;
|
||||
int first_vmap = -1;
|
||||
|
||||
// Find the first vmapped input
|
||||
int i = 0;
|
||||
for (; i < axes.size(); i++) {
|
||||
t_inputs.push_back(inputs[i]);
|
||||
for (int i = 0; i < axes.size(); i++) {
|
||||
if (axes[i] >= 0) {
|
||||
out_ax = axes[i];
|
||||
first_vmap = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (out_ax >= 0) {
|
||||
// Advance to the next input
|
||||
i++;
|
||||
|
||||
// Move vmap axes to the same spot.
|
||||
for (; i < axes.size(); ++i) {
|
||||
if (out_ax != axes[i] && axes[i] >= 0) {
|
||||
// No vmap, should we even be in here?
|
||||
if (out_ax < 0) {
|
||||
return {{concatenate(inputs, axis_, stream())}, {out_ax}};
|
||||
}
|
||||
|
||||
// Make sure vmapped arrays have all vmapped axes in the same location and
|
||||
// expand non-vmapped arrays to be compatible with the vmapped ones.
|
||||
std::vector<array> t_inputs;
|
||||
int N = inputs[first_vmap].shape(out_ax);
|
||||
int axis = axis_ + (axis_ >= out_ax);
|
||||
auto cat_shape = inputs[first_vmap].shape();
|
||||
for (int i = 0; i < axes.size(); i++) {
|
||||
if (axes[i] >= 0) {
|
||||
if (out_ax != axes[i]) {
|
||||
t_inputs.push_back(moveaxis(inputs[i], axes[i], out_ax, stream()));
|
||||
} else {
|
||||
t_inputs.push_back(inputs[i]);
|
||||
}
|
||||
} else {
|
||||
cat_shape[axis] = inputs[i].shape(axis_);
|
||||
t_inputs.push_back(broadcast_to(
|
||||
expand_dims(inputs[i], out_ax, stream()), cat_shape, stream()));
|
||||
}
|
||||
}
|
||||
auto axis = axis_ + (out_ax >= 0 && axis_ >= out_ax);
|
||||
|
||||
return {{concatenate(t_inputs, axis, stream())}, {out_ax}};
|
||||
}
|
||||
|
||||
|
@ -527,6 +527,26 @@ class TestVmap(mlx_tests.MLXTestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
out = mx.vmap(const_func, in_axes=(0, 0))(a, b)
|
||||
|
||||
def test_vmap_concatenate(self):
|
||||
x = mx.random.uniform(shape=(2, 2, 2))
|
||||
|
||||
def cat_fun(x, y):
|
||||
return mx.concatenate([x, y], axis=1)
|
||||
|
||||
def cat_constant(x):
|
||||
y = mx.ones((2, 1))
|
||||
return mx.concatenate([x, y], 1)
|
||||
|
||||
out = mx.vmap(cat_fun, in_axes=(0, 2))(x, x)
|
||||
target = mx.stack(
|
||||
[mx.concatenate([x[i], x[:, :, i]], axis=1) for i in range(2)]
|
||||
)
|
||||
self.assertTrue(mx.array_equal(out, target))
|
||||
|
||||
out = mx.vmap(cat_constant)(x)
|
||||
target = mx.concatenate([x, mx.ones((2, 2, 1))], axis=2)
|
||||
self.assertTrue(mx.array_equal(out, target))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user