fix custom kernel test (#2510)

This commit is contained in:
Awni Hannun
2025-08-18 06:45:59 -07:00
committed by GitHub
parent 1df9887998
commit c5fcd5b61b
3 changed files with 4 additions and 4 deletions

View File

@@ -128,6 +128,7 @@ relying on a copy from ``ensure_row_contiguous``:
input_names=["inp"],
output_names=["out"],
source=source
ensure_row_contiguous=False,
)
def exp_elementwise(a: mx.array):
@@ -138,7 +139,6 @@ relying on a copy from ``ensure_row_contiguous``:
threadgroup=(256, 1, 1),
output_shapes=[a.shape],
output_dtypes=[a.dtype],
ensure_row_contiguous=False,
)
return outputs[0]