This commit is contained in:
Awni Hannun 2025-05-02 20:40:22 -07:00
parent 01a29b51c8
commit 9eb5fa764c
2 changed files with 3 additions and 3 deletions

View File

@ -5,8 +5,8 @@ template <typename T, typename Op, int N = WorkPerThread<T>::n>
device const bool* a,
device const T* b,
device const T* c,
constant uint& size,
device T* d,
constant uint& size,
uint index [[thread_position_in_grid]]) {
index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) {

View File

@ -113,10 +113,10 @@ void ternary_op_gpu_inplace(
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
MTL::Size grid_dims;
if (large) {
compute_encoder.set_bytes<int64_t>(out.data_size(), 2);
compute_encoder.set_bytes<int64_t>(out.data_size(), 4);
grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread);
} else {
compute_encoder.set_bytes<int>(out.data_size(), 2);
compute_encoder.set_bytes<int>(out.data_size(), 4);
grid_dims = MTL::Size(nthreads, 1, 1);
}
compute_encoder.dispatch_threads(grid_dims, group_dims);