Change the segments to be more general

This commit is contained in:
Angelos Katharopoulos
2025-07-02 14:23:27 -07:00
parent 6020ad6363
commit 3104c3eb14
3 changed files with 147 additions and 7 deletions

View File

@@ -77,7 +77,10 @@ inline void segmented_mm(
int32_t N = b_copy[ndim - 1];
int32_t k_start = 0;
for (int i = 0; i < num_segments; i++) {
int32_t k_end = segments[elem_to_loc(i, segments_shape, segments_strides)];
int32_t k_start =
segments[elem_to_loc(2 * i, segments_shape, segments_strides)];
int32_t k_end =
segments[elem_to_loc(2 * i + 1, segments_shape, segments_strides)];
a_copy[ndim - 1] = k_end - k_start;
b_copy[ndim - 2] = k_end - k_start;
matmul<T>(
@@ -96,7 +99,6 @@ inline void segmented_mm(
a_strides,
b_copy,
b_strides);
k_start = k_end;
}
}
@@ -537,7 +539,7 @@ void SegmentedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
a.strides(),
b.shape(),
b.strides(),
segments.size(),
segments.size() / 2,
segments.shape(),
segments.strides());
break;
@@ -555,7 +557,7 @@ void SegmentedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
a.strides(),
b.shape(),
b.strides(),
segments.size(),
segments.size() / 2,
segments.shape(),
segments.strides());
break;
@@ -573,7 +575,7 @@ void SegmentedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
a.strides(),
b.shape(),
b.strides(),
segments.size(),
segments.size() / 2,
segments.shape(),
segments.strides());
break;
@@ -591,7 +593,7 @@ void SegmentedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
a.strides(),
b.shape(),
b.strides(),
segments.size(),
segments.size() / 2,
segments.shape(),
segments.strides());
break;