mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-23 22:18:13 +08:00
Fix qvm splitk (#2415)
This commit is contained in:
@@ -265,9 +265,15 @@ void qvm_split_k(
|
||||
MTL::Size group_dims = MTL::Size(bk, 2, 1);
|
||||
MTL::Size grid_dims = MTL::Size(M, N / bn, B);
|
||||
|
||||
int x_batch_ndims = x.ndim() - 2;
|
||||
auto x_shape = x.shape();
|
||||
auto x_strides = x.strides();
|
||||
if (x_shape.size() == 1) {
|
||||
x_shape.insert(x_shape.begin(), 1);
|
||||
x_strides.insert(x_strides.begin(), 0);
|
||||
}
|
||||
|
||||
int x_ndim = x_shape.size();
|
||||
int x_batch_ndims = x_ndim - 2;
|
||||
int w_batch_ndims = w.ndim() - 2;
|
||||
auto w_shape = w.shape();
|
||||
auto w_strides = w.strides();
|
||||
@@ -278,7 +284,7 @@ void qvm_split_k(
|
||||
x_shape.insert(x_shape.end() - 2, split_k);
|
||||
x_shape.back() /= split_k;
|
||||
x_strides.insert(x_strides.end() - 2, split_D);
|
||||
x_strides[x.ndim() - 1] = split_D;
|
||||
x_strides[x_ndim - 1] = split_D;
|
||||
x_batch_ndims += 1;
|
||||
|
||||
w_shape.insert(w_shape.end() - 2, split_k);
|
||||
@@ -291,6 +297,9 @@ void qvm_split_k(
|
||||
int final_block_size = K - (split_k - 1) * split_D;
|
||||
|
||||
auto temp_shape = out.shape();
|
||||
if (temp_shape.size() == 1) {
|
||||
temp_shape.insert(temp_shape.begin(), 1);
|
||||
}
|
||||
temp_shape.insert(temp_shape.end() - 2, split_k);
|
||||
array intermediate(temp_shape, x.dtype(), nullptr, {});
|
||||
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
|
||||
|
Reference in New Issue
Block a user