mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Address floating point exception on linux blas
This commit is contained in:
@@ -80,6 +80,9 @@ inline void segmented_mm(
|
|||||||
segments[elem_to_loc(2 * i, segments_shape, segments_strides)];
|
segments[elem_to_loc(2 * i, segments_shape, segments_strides)];
|
||||||
uint32_t k_end =
|
uint32_t k_end =
|
||||||
segments[elem_to_loc(2 * i + 1, segments_shape, segments_strides)];
|
segments[elem_to_loc(2 * i + 1, segments_shape, segments_strides)];
|
||||||
|
if (k_end <= k_start) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
a_copy[ndim - 1] = k_end - k_start;
|
a_copy[ndim - 1] = k_end - k_start;
|
||||||
b_copy[ndim - 2] = k_end - k_start;
|
b_copy[ndim - 2] = k_end - k_start;
|
||||||
matmul<T>(
|
matmul<T>(
|
||||||
|
|||||||
Reference in New Issue
Block a user