mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-30 05:31:15 +08:00
fix
This commit is contained in:
parent
01a29b51c8
commit
9eb5fa764c
@ -5,8 +5,8 @@ template <typename T, typename Op, int N = WorkPerThread<T>::n>
|
|||||||
device const bool* a,
|
device const bool* a,
|
||||||
device const T* b,
|
device const T* b,
|
||||||
device const T* c,
|
device const T* c,
|
||||||
constant uint& size,
|
|
||||||
device T* d,
|
device T* d,
|
||||||
|
constant uint& size,
|
||||||
uint index [[thread_position_in_grid]]) {
|
uint index [[thread_position_in_grid]]) {
|
||||||
index *= N;
|
index *= N;
|
||||||
for (int i = 0; i < N && (index + i) < size; ++i) {
|
for (int i = 0; i < N && (index + i) < size; ++i) {
|
||||||
|
@ -113,10 +113,10 @@ void ternary_op_gpu_inplace(
|
|||||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||||
MTL::Size grid_dims;
|
MTL::Size grid_dims;
|
||||||
if (large) {
|
if (large) {
|
||||||
compute_encoder.set_bytes<int64_t>(out.data_size(), 2);
|
compute_encoder.set_bytes<int64_t>(out.data_size(), 4);
|
||||||
grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread);
|
grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread);
|
||||||
} else {
|
} else {
|
||||||
compute_encoder.set_bytes<int>(out.data_size(), 2);
|
compute_encoder.set_bytes<int>(out.data_size(), 4);
|
||||||
grid_dims = MTL::Size(nthreads, 1, 1);
|
grid_dims = MTL::Size(nthreads, 1, 1);
|
||||||
}
|
}
|
||||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||||
|
Loading…
Reference in New Issue
Block a user