Address floating point exception on linux blas

This commit is contained in:
Angelos Katharopoulos
2025-07-04 13:16:54 -07:00
parent 22f9b8a6fc
commit bd0622c4d9

View File

@@ -80,6 +80,9 @@ inline void segmented_mm(
segments[elem_to_loc(2 * i, segments_shape, segments_strides)]; segments[elem_to_loc(2 * i, segments_shape, segments_strides)];
uint32_t k_end = uint32_t k_end =
segments[elem_to_loc(2 * i + 1, segments_shape, segments_strides)]; segments[elem_to_loc(2 * i + 1, segments_shape, segments_strides)];
if (k_end <= k_start) {
continue;
}
a_copy[ndim - 1] = k_end - k_start; a_copy[ndim - 1] = k_end - k_start;
b_copy[ndim - 2] = k_end - k_start; b_copy[ndim - 2] = k_end - k_start;
matmul<T>( matmul<T>(