Fix offset bug for device buffers (#1151)

* fix bug with large offsets for buffers

* add a test

* remove test as its too big for small machine
This commit is contained in:
Awni Hannun 2024-05-22 15:50:05 -07:00 committed by GitHub
parent 226748b3e7
commit e110ca11e2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -63,7 +63,7 @@ struct CommandEncoder {
return enc; return enc;
} }
void set_input_array(const array& a, int idx, int offset = 0) { void set_input_array(const array& a, int idx, int64_t offset = 0) {
auto r_buf = auto r_buf =
static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr())); static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
if (auto it = outputs.find(r_buf); it != outputs.end()) { if (auto it = outputs.find(r_buf); it != outputs.end()) {
@ -80,7 +80,7 @@ struct CommandEncoder {
enc->setBuffer(a_buf, base_offset, idx); enc->setBuffer(a_buf, base_offset, idx);
} }
void set_output_array(array& a, int idx, int offset = 0) { void set_output_array(array& a, int idx, int64_t offset = 0) {
// Add barriers before adding the output to the output set // Add barriers before adding the output to the output set
set_input_array(a, idx, offset); set_input_array(a, idx, offset);
auto buf = static_cast<MTL::Resource*>(a.buffer().ptr()); auto buf = static_cast<MTL::Resource*>(a.buffer().ptr());