mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Change the segments to be more general
This commit is contained in:
@@ -77,7 +77,10 @@ inline void segmented_mm(
|
||||
int32_t N = b_copy[ndim - 1];
|
||||
int32_t k_start = 0;
|
||||
for (int i = 0; i < num_segments; i++) {
|
||||
int32_t k_end = segments[elem_to_loc(i, segments_shape, segments_strides)];
|
||||
int32_t k_start =
|
||||
segments[elem_to_loc(2 * i, segments_shape, segments_strides)];
|
||||
int32_t k_end =
|
||||
segments[elem_to_loc(2 * i + 1, segments_shape, segments_strides)];
|
||||
a_copy[ndim - 1] = k_end - k_start;
|
||||
b_copy[ndim - 2] = k_end - k_start;
|
||||
matmul<T>(
|
||||
@@ -96,7 +99,6 @@ inline void segmented_mm(
|
||||
a_strides,
|
||||
b_copy,
|
||||
b_strides);
|
||||
k_start = k_end;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -537,7 +539,7 @@ void SegmentedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
a.strides(),
|
||||
b.shape(),
|
||||
b.strides(),
|
||||
segments.size(),
|
||||
segments.size() / 2,
|
||||
segments.shape(),
|
||||
segments.strides());
|
||||
break;
|
||||
@@ -555,7 +557,7 @@ void SegmentedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
a.strides(),
|
||||
b.shape(),
|
||||
b.strides(),
|
||||
segments.size(),
|
||||
segments.size() / 2,
|
||||
segments.shape(),
|
||||
segments.strides());
|
||||
break;
|
||||
@@ -573,7 +575,7 @@ void SegmentedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
a.strides(),
|
||||
b.shape(),
|
||||
b.strides(),
|
||||
segments.size(),
|
||||
segments.size() / 2,
|
||||
segments.shape(),
|
||||
segments.strides());
|
||||
break;
|
||||
@@ -591,7 +593,7 @@ void SegmentedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
a.strides(),
|
||||
b.shape(),
|
||||
b.strides(),
|
||||
segments.size(),
|
||||
segments.size() / 2,
|
||||
segments.shape(),
|
||||
segments.strides());
|
||||
break;
|
||||
|
||||
Reference in New Issue
Block a user