Compile stride bug (#812)

* fix compile stride bug

* revert sdpa fix

* fix cpu

* fix bug with simplifying outputs
This commit is contained in:
Awni Hannun
2024-03-11 06:31:31 -07:00
committed by GitHub
parent a4d290adb9
commit 7c441600fe
9 changed files with 58 additions and 12 deletions

View File

@@ -97,6 +97,8 @@ void sdpa_metal(
set_array_buffer(compute_encoder, p_lse, 6);
set_array_buffer(compute_encoder, p_rowmaxes, 7);
constexpr const uint tgroupMemorySize = 32768;
compute_encoder->setThreadgroupMemoryLength(tgroupMemorySize, 0);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
{