mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Use int64 stride everywhere (#1671)
* use int64 stride everywhere * fix ext * fix ext * more shape + cleanup * one more * few more
This commit is contained in:
@@ -50,17 +50,17 @@ void sdpa_full_self_attention_metal(
|
||||
|
||||
std::ostringstream kname;
|
||||
// clang-format off
|
||||
kname << "steel_attention_"
|
||||
<< type_to_name(q)
|
||||
<< "_bq" << bq
|
||||
kname << "steel_attention_"
|
||||
<< type_to_name(q)
|
||||
<< "_bq" << bq
|
||||
<< "_bk" << bk
|
||||
<< "_bd" << bd
|
||||
<< "_bd" << bd
|
||||
<< "_wm" << wm << "_wn" << wn; // clang-format on
|
||||
|
||||
std::string base_name = kname.str();
|
||||
|
||||
// clang-format off
|
||||
kname << "_align_Q_" << (align_Q ? 't' : 'n')
|
||||
kname << "_align_Q_" << (align_Q ? 't' : 'n')
|
||||
<< "_align_K_" << (align_K ? 't' : 'n'); // clang-format on
|
||||
|
||||
std::string hash_name = kname.str();
|
||||
@@ -92,10 +92,10 @@ void sdpa_full_self_attention_metal(
|
||||
/* int NQ_aligned = */ NQ_aligned,
|
||||
/* int NK_aligned = */ NK_aligned,
|
||||
|
||||
/* size_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)},
|
||||
/* size_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)},
|
||||
/* size_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)},
|
||||
/* size_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)}};
|
||||
/* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)},
|
||||
/* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)},
|
||||
/* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)},
|
||||
/* int64_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)}};
|
||||
|
||||
compute_encoder.set_input_array(q, 0);
|
||||
compute_encoder.set_input_array(k, 1);
|
||||
@@ -175,13 +175,13 @@ void sdpa_vector_2pass(
|
||||
int N = k.shape(2);
|
||||
int blocks = 32;
|
||||
int B = q.shape(0) * q.shape(1);
|
||||
size_t k_stride = k.strides()[1];
|
||||
size_t v_stride = v.strides()[1];
|
||||
auto k_stride = k.strides()[1];
|
||||
auto v_stride = v.strides()[1];
|
||||
MTL::Size group_dims(8 * 32, 1, 1);
|
||||
MTL::Size grid_dims(1, B, blocks);
|
||||
|
||||
// Allocate the intermediates
|
||||
std::vector<int> intermediate_shape;
|
||||
Shape intermediate_shape;
|
||||
intermediate_shape.reserve(out.ndim() + 1);
|
||||
intermediate_shape.insert(
|
||||
intermediate_shape.end(), out.shape().begin(), out.shape().end() - 1);
|
||||
@@ -324,10 +324,10 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
const auto& k = copy_unless(is_matrix_contiguous, k_pre);
|
||||
const auto& v = copy_unless(is_matrix_contiguous, v_pre);
|
||||
|
||||
size_t str_oD = 1;
|
||||
size_t str_oH = o.shape(3);
|
||||
size_t str_oL = o.shape(1) * str_oH;
|
||||
size_t str_oB = o.shape(2) * str_oL;
|
||||
int64_t str_oD = 1;
|
||||
int64_t str_oH = o.shape(3);
|
||||
int64_t str_oL = o.shape(1) * str_oH;
|
||||
int64_t str_oB = o.shape(2) * str_oL;
|
||||
size_t data_size = o.shape(0) * str_oB;
|
||||
|
||||
array::Flags flags{
|
||||
|
||||
Reference in New Issue
Block a user