mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
fix
This commit is contained in:
parent
55d6edcaa3
commit
01a29b51c8
@ -145,10 +145,10 @@ void binary_op_gpu_inplace(
|
|||||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||||
MTL::Size grid_dims;
|
MTL::Size grid_dims;
|
||||||
if (large) {
|
if (large) {
|
||||||
compute_encoder.set_bytes<int64_t>(a.size(), arg_idx++);
|
compute_encoder.set_bytes<int64_t>(out.data_size(), arg_idx++);
|
||||||
grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread);
|
grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread);
|
||||||
} else {
|
} else {
|
||||||
compute_encoder.set_bytes<int>(a.size(), arg_idx++);
|
compute_encoder.set_bytes<int>(out.data_size(), arg_idx++);
|
||||||
grid_dims = MTL::Size(nthreads, 1, 1);
|
grid_dims = MTL::Size(nthreads, 1, 1);
|
||||||
}
|
}
|
||||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||||
|
Loading…
Reference in New Issue
Block a user