diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index fa87e4283..047b2735a 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -63,7 +63,7 @@ struct CommandEncoder { return enc; } - void set_input_array(const array& a, int idx, int offset = 0) { + void set_input_array(const array& a, int idx, int64_t offset = 0) { auto r_buf = static_cast(const_cast(a.buffer().ptr())); if (auto it = outputs.find(r_buf); it != outputs.end()) { @@ -80,7 +80,7 @@ struct CommandEncoder { enc->setBuffer(a_buf, base_offset, idx); } - void set_output_array(array& a, int idx, int offset = 0) { + void set_output_array(array& a, int idx, int64_t offset = 0) { // Add barriers before adding the output to the output set set_input_array(a, idx, offset); auto buf = static_cast(a.buffer().ptr());