Fix concatenate vmap (#1600)

This commit is contained in:
Angelos Katharopoulos 2024-11-19 10:44:04 -08:00 committed by GitHub
parent 2af7e8a9a6
commit 5e89aace9b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 43 additions and 11 deletions

View File

@ -836,31 +836,43 @@ std::vector<array> Concatenate::jvp(
std::pair<std::vector<array>, std::vector<int>> Concatenate::vmap( std::pair<std::vector<array>, std::vector<int>> Concatenate::vmap(
const std::vector<array>& inputs, const std::vector<array>& inputs,
const std::vector<int>& axes) { const std::vector<int>& axes) {
std::vector<array> t_inputs;
int out_ax = -1; int out_ax = -1;
int first_vmap = -1;
// Find the first vmapped input // Find the first vmapped input
int i = 0; for (int i = 0; i < axes.size(); i++) {
for (; i < axes.size(); i++) {
t_inputs.push_back(inputs[i]);
if (axes[i] >= 0) { if (axes[i] >= 0) {
out_ax = axes[i]; out_ax = axes[i];
first_vmap = i;
break; break;
} }
} }
if (out_ax >= 0) {
// Advance to the next input
i++;
// Move vmap axes to the same spot. // No vmap, should we even be in here?
for (; i < axes.size(); ++i) { if (out_ax < 0) {
if (out_ax != axes[i] && axes[i] >= 0) { return {{concatenate(inputs, axis_, stream())}, {out_ax}};
}
// Make sure vmapped arrays have all vmapped axes in the same location and
// expand non-vmapped arrays to be compatible with the vmapped ones.
std::vector<array> t_inputs;
int N = inputs[first_vmap].shape(out_ax);
int axis = axis_ + (axis_ >= out_ax);
auto cat_shape = inputs[first_vmap].shape();
for (int i = 0; i < axes.size(); i++) {
if (axes[i] >= 0) {
if (out_ax != axes[i]) {
t_inputs.push_back(moveaxis(inputs[i], axes[i], out_ax, stream())); t_inputs.push_back(moveaxis(inputs[i], axes[i], out_ax, stream()));
} else { } else {
t_inputs.push_back(inputs[i]); t_inputs.push_back(inputs[i]);
} }
} else {
cat_shape[axis] = inputs[i].shape(axis_);
t_inputs.push_back(broadcast_to(
expand_dims(inputs[i], out_ax, stream()), cat_shape, stream()));
} }
} }
auto axis = axis_ + (out_ax >= 0 && axis_ >= out_ax);
return {{concatenate(t_inputs, axis, stream())}, {out_ax}}; return {{concatenate(t_inputs, axis, stream())}, {out_ax}};
} }

View File

@ -527,6 +527,26 @@ class TestVmap(mlx_tests.MLXTestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
out = mx.vmap(const_func, in_axes=(0, 0))(a, b) out = mx.vmap(const_func, in_axes=(0, 0))(a, b)
def test_vmap_concatenate(self):
x = mx.random.uniform(shape=(2, 2, 2))
def cat_fun(x, y):
return mx.concatenate([x, y], axis=1)
def cat_constant(x):
y = mx.ones((2, 1))
return mx.concatenate([x, y], 1)
out = mx.vmap(cat_fun, in_axes=(0, 2))(x, x)
target = mx.stack(
[mx.concatenate([x[i], x[:, :, i]], axis=1) for i in range(2)]
)
self.assertTrue(mx.array_equal(out, target))
out = mx.vmap(cat_constant)(x)
target = mx.concatenate([x, mx.ones((2, 2, 1))], axis=2)
self.assertTrue(mx.array_equal(out, target))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()