Use int64 stride everywhere (#1671)

* use int64 stride everywhere

* fix ext

* fix ext

* more shape + cleanup

* one more

* few more
This commit is contained in:
Awni Hannun
2024-12-09 11:09:02 -08:00
committed by GitHub
parent 35b412c099
commit 40c62c1321
102 changed files with 1262 additions and 1705 deletions

View File

@@ -1014,7 +1014,7 @@ std::string write_signature(
}
if (shape_infos[i].strides) {
kernel_source +=
(" const constant size_t* " + name + "_strides [[buffer(" +
(" const constant int64_t* " + name + "_strides [[buffer(" +
std::to_string(index) + ")]],\n");
index++;
}
@@ -1144,7 +1144,7 @@ MetalKernelFunction metal_kernel(
shape_infos = std::move(shape_infos),
attributes = std::move(attributes)](
const std::vector<array>& inputs,
const std::vector<std::vector<int>>& output_shapes,
const std::vector<Shape>& output_shapes,
const std::vector<Dtype>& output_dtypes,
std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup,