mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
fix dispatch threads for a few kernels (#1594)
This commit is contained in:
@@ -52,13 +52,14 @@ void explicit_gemm_conv_ND_gpu(
|
||||
compute_encoder.set_bytes(conv_params, 2);
|
||||
|
||||
// Launch unfolding kernel
|
||||
int tgp_x = std::min(conv_params.C, 64);
|
||||
size_t tgp_x = std::min(conv_params.C, 64);
|
||||
tgp_x = 32 * ((tgp_x + 32 - 1) / 32);
|
||||
int tgp_y = 256 / tgp_x;
|
||||
size_t tgp_y = 256 / tgp_x;
|
||||
|
||||
MTL::Size group_dims = MTL::Size(tgp_x, tgp_y, 1);
|
||||
MTL::Size grid_dims = MTL::Size(
|
||||
conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]);
|
||||
MTL::Size group_dims = MTL::Size(
|
||||
std::min(tgp_x, grid_dims.width), std::min(tgp_y, grid_dims.height), 1);
|
||||
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
|
||||
@@ -130,13 +131,14 @@ void explicit_gemm_conv_group_ND_gpu(
|
||||
compute_encoder.set_bytes(conv_params, 2);
|
||||
|
||||
// Launch unfolding kernel
|
||||
int tgp_x = std::min(conv_params.C, 64);
|
||||
size_t tgp_x = std::min(conv_params.C, 64);
|
||||
tgp_x = 32 * ((tgp_x + 32 - 1) / 32);
|
||||
int tgp_y = 256 / tgp_x;
|
||||
size_t tgp_y = 256 / tgp_x;
|
||||
|
||||
MTL::Size group_dims = MTL::Size(tgp_x, tgp_y, 1);
|
||||
MTL::Size grid_dims = MTL::Size(
|
||||
conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]);
|
||||
MTL::Size group_dims = MTL::Size(
|
||||
std::min(tgp_x, grid_dims.width), std::min(tgp_y, grid_dims.height), 1);
|
||||
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user