Use int64 stride everywhere (#1671)

* use int64 stride everywhere

* fix ext

* fix ext

* more shape + cleanup

* one more

* few more
This commit is contained in:
Awni Hannun
2024-12-09 11:09:02 -08:00
committed by GitHub
parent 35b412c099
commit 40c62c1321
102 changed files with 1262 additions and 1705 deletions

View File

@@ -50,17 +50,17 @@ void sdpa_full_self_attention_metal(
std::ostringstream kname;
// clang-format off
kname << "steel_attention_"
<< type_to_name(q)
<< "_bq" << bq
kname << "steel_attention_"
<< type_to_name(q)
<< "_bq" << bq
<< "_bk" << bk
<< "_bd" << bd
<< "_bd" << bd
<< "_wm" << wm << "_wn" << wn; // clang-format on
std::string base_name = kname.str();
// clang-format off
kname << "_align_Q_" << (align_Q ? 't' : 'n')
kname << "_align_Q_" << (align_Q ? 't' : 'n')
<< "_align_K_" << (align_K ? 't' : 'n'); // clang-format on
std::string hash_name = kname.str();
@@ -92,10 +92,10 @@ void sdpa_full_self_attention_metal(
/* int NQ_aligned = */ NQ_aligned,
/* int NK_aligned = */ NK_aligned,
/* size_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)},
/* size_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)},
/* size_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)},
/* size_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)}};
/* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)},
/* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)},
/* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)},
/* int64_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)}};
compute_encoder.set_input_array(q, 0);
compute_encoder.set_input_array(k, 1);
@@ -175,13 +175,13 @@ void sdpa_vector_2pass(
int N = k.shape(2);
int blocks = 32;
int B = q.shape(0) * q.shape(1);
size_t k_stride = k.strides()[1];
size_t v_stride = v.strides()[1];
auto k_stride = k.strides()[1];
auto v_stride = v.strides()[1];
MTL::Size group_dims(8 * 32, 1, 1);
MTL::Size grid_dims(1, B, blocks);
// Allocate the intermediates
std::vector<int> intermediate_shape;
Shape intermediate_shape;
intermediate_shape.reserve(out.ndim() + 1);
intermediate_shape.insert(
intermediate_shape.end(), out.shape().begin(), out.shape().end() - 1);
@@ -324,10 +324,10 @@ void ScaledDotProductAttention::eval_gpu(
const auto& k = copy_unless(is_matrix_contiguous, k_pre);
const auto& v = copy_unless(is_matrix_contiguous, v_pre);
size_t str_oD = 1;
size_t str_oH = o.shape(3);
size_t str_oL = o.shape(1) * str_oH;
size_t str_oB = o.shape(2) * str_oL;
int64_t str_oD = 1;
int64_t str_oH = o.shape(3);
int64_t str_oL = o.shape(1) * str_oH;
int64_t str_oB = o.shape(2) * str_oL;
size_t data_size = o.shape(0) * str_oB;
array::Flags flags{