Use int64 stride everywhere (#1671)

* use int64 stride everywhere

* fix ext

* fix ext

* more shape + cleanup

* one more

* few more
This commit is contained in:
Awni Hannun
2024-12-09 11:09:02 -08:00
committed by GitHub
parent 35b412c099
commit 40c62c1321
102 changed files with 1262 additions and 1705 deletions

View File

@@ -1219,12 +1219,12 @@ METAL_FUNC void adjust_matrix_offsets(
int output_stride,
const constant int& x_batch_ndims,
const constant int* x_shape,
const constant size_t* x_strides,
const constant int64_t* x_strides,
const constant int& w_batch_ndims,
const constant int* w_shape,
const constant size_t* w_strides,
const constant size_t* s_strides,
const constant size_t* b_strides,
const constant int64_t* w_strides,
const constant int64_t* s_strides,
const constant int64_t* b_strides,
uint3 tid [[threadgroup_position_in_grid]]) {
// Set the input/output matrices
uint32_t x_idx = tid.z;
@@ -1260,16 +1260,16 @@ METAL_FUNC void adjust_matrix_offsets(
int output_stride,
const constant int& batch_ndims,
const constant int* batch_shape,
const constant size_t* lhs_strides,
const constant size_t* rhs_strides,
const constant int64_t* lhs_strides,
const constant int64_t* rhs_strides,
const constant int& x_batch_ndims,
const constant int* x_shape,
const constant size_t* x_strides,
const constant int64_t* x_strides,
const constant int& w_batch_ndims,
const constant int* w_shape,
const constant size_t* w_strides,
const constant size_t* s_strides,
const constant size_t* b_strides,
const constant int64_t* w_strides,
const constant int64_t* s_strides,
const constant int64_t* b_strides,
uint3 tid [[threadgroup_position_in_grid]]) {
// Set the input/output matrices
uint32_t x_idx;
@@ -1313,12 +1313,12 @@ template <typename T, int group_size, int bits, int D, bool batched>
const constant int& out_vec_size [[buffer(6)]],
const constant int& x_batch_ndims [[buffer(7)]],
const constant int* x_shape [[buffer(8)]],
const constant size_t* x_strides [[buffer(9)]],
const constant int64_t* x_strides [[buffer(9)]],
const constant int& w_batch_ndims [[buffer(10)]],
const constant int* w_shape [[buffer(11)]],
const constant size_t* w_strides [[buffer(12)]],
const constant size_t* s_strides [[buffer(13)]],
const constant size_t* b_strides [[buffer(14)]],
const constant int64_t* w_strides [[buffer(12)]],
const constant int64_t* s_strides [[buffer(13)]],
const constant int64_t* b_strides [[buffer(14)]],
uint3 tid [[threadgroup_position_in_grid]],
uint quad_gid [[quadgroup_index_in_threadgroup]],
uint quad_lid [[thread_index_in_quadgroup]]) {
@@ -1364,12 +1364,12 @@ template <typename T, int group_size, int bits, bool batched>
const constant int& out_vec_size [[buffer(6)]],
const constant int& x_batch_ndims [[buffer(7)]],
const constant int* x_shape [[buffer(8)]],
const constant size_t* x_strides [[buffer(9)]],
const constant int64_t* x_strides [[buffer(9)]],
const constant int& w_batch_ndims [[buffer(10)]],
const constant int* w_shape [[buffer(11)]],
const constant size_t* w_strides [[buffer(12)]],
const constant size_t* s_strides [[buffer(13)]],
const constant size_t* b_strides [[buffer(14)]],
const constant int64_t* w_strides [[buffer(12)]],
const constant int64_t* s_strides [[buffer(13)]],
const constant int64_t* b_strides [[buffer(14)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
@@ -1415,12 +1415,12 @@ template <typename T, const int group_size, const int bits, bool batched>
const constant int& out_vec_size [[buffer(6)]],
const constant int& x_batch_ndims [[buffer(7)]],
const constant int* x_shape [[buffer(8)]],
const constant size_t* x_strides [[buffer(9)]],
const constant int64_t* x_strides [[buffer(9)]],
const constant int& w_batch_ndims [[buffer(10)]],
const constant int* w_shape [[buffer(11)]],
const constant size_t* w_strides [[buffer(12)]],
const constant size_t* s_strides [[buffer(13)]],
const constant size_t* b_strides [[buffer(14)]],
const constant int64_t* w_strides [[buffer(12)]],
const constant int64_t* s_strides [[buffer(13)]],
const constant int64_t* b_strides [[buffer(14)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
@@ -1466,12 +1466,12 @@ template <typename T, const int group_size, const int bits, bool batched>
const constant int& out_vec_size [[buffer(6)]],
const constant int& x_batch_ndims [[buffer(7)]],
const constant int* x_shape [[buffer(8)]],
const constant size_t* x_strides [[buffer(9)]],
const constant int64_t* x_strides [[buffer(9)]],
const constant int& w_batch_ndims [[buffer(10)]],
const constant int* w_shape [[buffer(11)]],
const constant size_t* w_strides [[buffer(12)]],
const constant size_t* s_strides [[buffer(13)]],
const constant size_t* b_strides [[buffer(14)]],
const constant int64_t* w_strides [[buffer(12)]],
const constant int64_t* s_strides [[buffer(13)]],
const constant int64_t* b_strides [[buffer(14)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
@@ -1517,12 +1517,12 @@ template <typename T, const int group_size, const int bits, int split_k = 32>
const constant int& out_vec_size [[buffer(6)]],
const constant int& x_batch_ndims [[buffer(7)]],
const constant int* x_shape [[buffer(8)]],
const constant size_t* x_strides [[buffer(9)]],
const constant int64_t* x_strides [[buffer(9)]],
const constant int& w_batch_ndims [[buffer(10)]],
const constant int* w_shape [[buffer(11)]],
const constant size_t* w_strides [[buffer(12)]],
const constant size_t* s_strides [[buffer(13)]],
const constant size_t* b_strides [[buffer(14)]],
const constant int64_t* w_strides [[buffer(12)]],
const constant int64_t* s_strides [[buffer(13)]],
const constant int64_t* b_strides [[buffer(14)]],
const constant int& final_block_size [[buffer(15)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
@@ -1581,12 +1581,12 @@ template <
const constant int& M [[buffer(7)]],
const constant int& x_batch_ndims [[buffer(8)]],
const constant int* x_shape [[buffer(9)]],
const constant size_t* x_strides [[buffer(10)]],
const constant int64_t* x_strides [[buffer(10)]],
const constant int& w_batch_ndims [[buffer(11)]],
const constant int* w_shape [[buffer(12)]],
const constant size_t* w_strides [[buffer(13)]],
const constant size_t* s_strides [[buffer(14)]],
const constant size_t* b_strides [[buffer(15)]],
const constant int64_t* w_strides [[buffer(13)]],
const constant int64_t* s_strides [[buffer(14)]],
const constant int64_t* b_strides [[buffer(15)]],
uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
@@ -1639,12 +1639,12 @@ template <
const constant int& M [[buffer(7)]],
const constant int& x_batch_ndims [[buffer(8)]],
const constant int* x_shape [[buffer(9)]],
const constant size_t* x_strides [[buffer(10)]],
const constant int64_t* x_strides [[buffer(10)]],
const constant int& w_batch_ndims [[buffer(11)]],
const constant int* w_shape [[buffer(12)]],
const constant size_t* w_strides [[buffer(13)]],
const constant size_t* s_strides [[buffer(14)]],
const constant size_t* b_strides [[buffer(15)]],
const constant int64_t* w_strides [[buffer(13)]],
const constant int64_t* s_strides [[buffer(14)]],
const constant int64_t* b_strides [[buffer(15)]],
uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
@@ -1691,18 +1691,18 @@ template <typename T, int group_size, int bits>
const constant int& out_vec_size [[buffer(6)]],
const constant int& x_batch_ndims [[buffer(7)]],
const constant int* x_shape [[buffer(8)]],
const constant size_t* x_strides [[buffer(9)]],
const constant int64_t* x_strides [[buffer(9)]],
const constant int& w_batch_ndims [[buffer(10)]],
const constant int* w_shape [[buffer(11)]],
const constant size_t* w_strides [[buffer(12)]],
const constant size_t* s_strides [[buffer(13)]],
const constant size_t* b_strides [[buffer(14)]],
const constant int64_t* w_strides [[buffer(12)]],
const constant int64_t* s_strides [[buffer(13)]],
const constant int64_t* b_strides [[buffer(14)]],
const constant int& batch_ndims [[buffer(15)]],
const constant int* batch_shape [[buffer(16)]],
const device uint32_t* lhs_indices [[buffer(17)]],
const device uint32_t* rhs_indices [[buffer(18)]],
const constant size_t* lhs_strides [[buffer(19)]],
const constant size_t* rhs_strides [[buffer(20)]],
const constant int64_t* lhs_strides [[buffer(19)]],
const constant int64_t* rhs_strides [[buffer(20)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
@@ -1752,18 +1752,18 @@ template <typename T, int group_size, int bits>
const constant int& out_vec_size [[buffer(6)]],
const constant int& x_batch_ndims [[buffer(7)]],
const constant int* x_shape [[buffer(8)]],
const constant size_t* x_strides [[buffer(9)]],
const constant int64_t* x_strides [[buffer(9)]],
const constant int& w_batch_ndims [[buffer(10)]],
const constant int* w_shape [[buffer(11)]],
const constant size_t* w_strides [[buffer(12)]],
const constant size_t* s_strides [[buffer(13)]],
const constant size_t* b_strides [[buffer(14)]],
const constant int64_t* w_strides [[buffer(12)]],
const constant int64_t* s_strides [[buffer(13)]],
const constant int64_t* b_strides [[buffer(14)]],
const constant int& batch_ndims [[buffer(15)]],
const constant int* batch_shape [[buffer(16)]],
const device uint32_t* lhs_indices [[buffer(17)]],
const device uint32_t* rhs_indices [[buffer(18)]],
const constant size_t* lhs_strides [[buffer(19)]],
const constant size_t* rhs_strides [[buffer(20)]],
const constant int64_t* lhs_strides [[buffer(19)]],
const constant int64_t* rhs_strides [[buffer(20)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
@@ -1813,18 +1813,18 @@ template <typename T, int group_size, int bits>
const constant int& out_vec_size [[buffer(6)]],
const constant int& x_batch_ndims [[buffer(7)]],
const constant int* x_shape [[buffer(8)]],
const constant size_t* x_strides [[buffer(9)]],
const constant int64_t* x_strides [[buffer(9)]],
const constant int& w_batch_ndims [[buffer(10)]],
const constant int* w_shape [[buffer(11)]],
const constant size_t* w_strides [[buffer(12)]],
const constant size_t* s_strides [[buffer(13)]],
const constant size_t* b_strides [[buffer(14)]],
const constant int64_t* w_strides [[buffer(12)]],
const constant int64_t* s_strides [[buffer(13)]],
const constant int64_t* b_strides [[buffer(14)]],
const constant int& batch_ndims [[buffer(15)]],
const constant int* batch_shape [[buffer(16)]],
const device uint32_t* lhs_indices [[buffer(17)]],
const device uint32_t* rhs_indices [[buffer(18)]],
const constant size_t* lhs_strides [[buffer(19)]],
const constant size_t* rhs_strides [[buffer(20)]],
const constant int64_t* lhs_strides [[buffer(19)]],
const constant int64_t* rhs_strides [[buffer(20)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
@@ -1882,18 +1882,18 @@ template <
const constant int& M [[buffer(7)]],
const constant int& x_batch_ndims [[buffer(8)]],
const constant int* x_shape [[buffer(9)]],
const constant size_t* x_strides [[buffer(10)]],
const constant int64_t* x_strides [[buffer(10)]],
const constant int& w_batch_ndims [[buffer(11)]],
const constant int* w_shape [[buffer(12)]],
const constant size_t* w_strides [[buffer(13)]],
const constant size_t* s_strides [[buffer(14)]],
const constant size_t* b_strides [[buffer(15)]],
const constant int64_t* w_strides [[buffer(13)]],
const constant int64_t* s_strides [[buffer(14)]],
const constant int64_t* b_strides [[buffer(15)]],
const constant int& batch_ndims [[buffer(16)]],
const constant int* batch_shape [[buffer(17)]],
const device uint32_t* lhs_indices [[buffer(18)]],
const device uint32_t* rhs_indices [[buffer(19)]],
const constant size_t* lhs_strides [[buffer(20)]],
const constant size_t* rhs_strides [[buffer(21)]],
const constant int64_t* lhs_strides [[buffer(20)]],
const constant int64_t* rhs_strides [[buffer(21)]],
uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
@@ -1949,18 +1949,18 @@ template <
const constant int& M [[buffer(7)]],
const constant int& x_batch_ndims [[buffer(8)]],
const constant int* x_shape [[buffer(9)]],
const constant size_t* x_strides [[buffer(10)]],
const constant int64_t* x_strides [[buffer(10)]],
const constant int& w_batch_ndims [[buffer(11)]],
const constant int* w_shape [[buffer(12)]],
const constant size_t* w_strides [[buffer(13)]],
const constant size_t* s_strides [[buffer(14)]],
const constant size_t* b_strides [[buffer(15)]],
const constant int64_t* w_strides [[buffer(13)]],
const constant int64_t* s_strides [[buffer(14)]],
const constant int64_t* b_strides [[buffer(15)]],
const constant int& batch_ndims [[buffer(16)]],
const constant int* batch_shape [[buffer(17)]],
const device uint32_t* lhs_indices [[buffer(18)]],
const device uint32_t* rhs_indices [[buffer(19)]],
const constant size_t* lhs_strides [[buffer(20)]],
const constant size_t* rhs_strides [[buffer(21)]],
const constant int64_t* lhs_strides [[buffer(20)]],
const constant int64_t* rhs_strides [[buffer(21)]],
uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],