mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-29 22:01:17 +08:00
fix (#1566)
This commit is contained in:
parent
726dbd9267
commit
9a3842a2d9
@ -75,7 +75,7 @@ void CustomKernel::eval_gpu(
|
||||
MTL::Size group_dims = MTL::Size(tx, ty, tz);
|
||||
const auto [gx, gy, gz] = grid_;
|
||||
MTL::Size grid_dims = MTL::Size(gx, gy, gz);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
}
|
||||
|
@ -738,7 +738,7 @@ void fft_op(
|
||||
auto group_dims = MTL::Size(1, threadgroup_batch_size, threads_per_fft);
|
||||
auto grid_dims =
|
||||
MTL::Size(batch_size, threadgroup_batch_size, threads_per_fft);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
|
@ -144,7 +144,7 @@ void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
MTL::Size group_dims = MTL::Size(1, threads_per, 1);
|
||||
MTL::Size grid_dims = MTL::Size(batch_size, threads_per, 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
};
|
||||
|
||||
if (m > 1) {
|
||||
|
@ -139,7 +139,7 @@ void sdpa_full_self_attention_metal(
|
||||
MTL::Size grid_dims = MTL::Size(1, tm, batch_size_out);
|
||||
MTL::Size group_dims = MTL::Size(32, wm, wn);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void sdpa_vector(
|
||||
|
Loading…
Reference in New Issue
Block a user