mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-22 11:14:32 +08:00
Compile stride bug (#812)
* fix compile stride bug * revert sdpa fix * fix cpu * fix bug with simplifying outputs
This commit is contained in:
@@ -329,7 +329,9 @@ void Compiled::eval_gpu(
|
||||
if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() &&
|
||||
in.is_donatable() &&
|
||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
||||
outputs[o++].move_shared_buffer(in);
|
||||
outputs[o].move_shared_buffer(
|
||||
in, outputs[o].strides(), in.flags(), in.data_size());
|
||||
o++;
|
||||
}
|
||||
}
|
||||
for (; o < outputs.size(); ++o) {
|
||||
|
@@ -13,12 +13,10 @@ template<typename T, typename T2, typename T4, uint16_t TILE_SIZE_CONST, uint16_
|
||||
device float* O_partials [[buffer(5)]],
|
||||
device float* p_lse [[buffer(6)]],
|
||||
device float* p_maxes [[buffer(7)]],
|
||||
threadgroup T* threadgroup_block [[threadgroup(0)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]]) {
|
||||
|
||||
threadgroup T threadgroup_block[32768 / sizeof(T)];
|
||||
|
||||
constexpr const size_t DK = 128;
|
||||
constexpr const ulong SIMDGROUP_MATRIX_LOAD_FACTOR = 8;
|
||||
constexpr const size_t THREADS_PER_SIMDGROUP = 32;
|
||||
@@ -358,6 +356,7 @@ template [[host_name("fast_inference_sdpa_compute_partials_" #itype "_" #tile_si
|
||||
device float* O_partials [[buffer(5)]], \
|
||||
device float* p_lse [[buffer(6)]], \
|
||||
device float* p_maxes [[buffer(7)]], \
|
||||
threadgroup itype *threadgroup_block [[threadgroup(0)]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]]);
|
||||
|
@@ -97,6 +97,8 @@ void sdpa_metal(
|
||||
set_array_buffer(compute_encoder, p_lse, 6);
|
||||
set_array_buffer(compute_encoder, p_rowmaxes, 7);
|
||||
|
||||
constexpr const uint tgroupMemorySize = 32768;
|
||||
compute_encoder->setThreadgroupMemoryLength(tgroupMemorySize, 0);
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
{
|
||||
|
Reference in New Issue
Block a user