Change the segments to be more general

This commit is contained in:
Angelos Katharopoulos
2025-07-02 14:23:27 -07:00
parent 6020ad6363
commit 3104c3eb14
3 changed files with 147 additions and 7 deletions

View File

@@ -77,7 +77,10 @@ inline void segmented_mm(
int32_t N = b_copy[ndim - 1];
int32_t k_start = 0;
for (int i = 0; i < num_segments; i++) {
int32_t k_end = segments[elem_to_loc(i, segments_shape, segments_strides)];
int32_t k_start =
segments[elem_to_loc(2 * i, segments_shape, segments_strides)];
int32_t k_end =
segments[elem_to_loc(2 * i + 1, segments_shape, segments_strides)];
a_copy[ndim - 1] = k_end - k_start;
b_copy[ndim - 2] = k_end - k_start;
matmul<T>(
@@ -96,7 +99,6 @@ inline void segmented_mm(
a_strides,
b_copy,
b_strides);
k_start = k_end;
}
}
@@ -537,7 +539,7 @@ void SegmentedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
a.strides(),
b.shape(),
b.strides(),
segments.size(),
segments.size() / 2,
segments.shape(),
segments.strides());
break;
@@ -555,7 +557,7 @@ void SegmentedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
a.strides(),
b.shape(),
b.strides(),
segments.size(),
segments.size() / 2,
segments.shape(),
segments.strides());
break;
@@ -573,7 +575,7 @@ void SegmentedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
a.strides(),
b.shape(),
b.strides(),
segments.size(),
segments.size() / 2,
segments.shape(),
segments.strides());
break;
@@ -591,7 +593,7 @@ void SegmentedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
a.strides(),
b.shape(),
b.strides(),
segments.size(),
segments.size() / 2,
segments.shape(),
segments.strides());
break;

View File

@@ -1864,8 +1864,138 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
gather_mm(a, b, lhs_indices, rhs_indices, out, M, N, K, d, s);
}
void segmented_mm(
const array& a_,
const array& b_,
const array& segments_,
array& out,
int M,
int N,
int K,
metal::Device& d,
const Stream& s) {
// Copy if needed
std::vector<array> copies;
auto [transpose_a, lda, a] = check_transpose(copies, s, a_, false);
auto [transpose_b, ldb, b] = check_transpose(copies, s, b_, false);
auto segments = ensure_row_contiguous(segments_, d, s);
d.add_temporaries(std::move(copies), s.index);
// Determine dispatch kernel
int bm = 64, bn = 64, bk = 16;
int wm = 2, wn = 2;
size_t batch_size_out = out.size() / M / N;
char devc = d.get_architecture().back();
GEMM_TPARAM_MACRO(devc)
const bool align_M = (M % bm) == 0;
const bool align_N = (N % bn) == 0;
// Define the kernel name
std::string base_name;
base_name.reserve(128);
concatenate(
base_name,
"steel_segmented_mm_",
transpose_a ? 't' : 'n',
transpose_b ? 't' : 'n',
"_",
type_to_name(a),
"_",
type_to_name(out),
"_bm",
bm,
"_bn",
bn,
"_bk",
bk,
"_wm",
wm,
"_wn",
wn);
metal::MTLFCList func_consts = {
{&align_M, MTL::DataType::DataTypeBool, 200},
{&align_N, MTL::DataType::DataTypeBool, 201},
};
// And the kernel hash that includes the function constants
std::string hash_name;
hash_name.reserve(128);
concatenate(
hash_name,
base_name,
"_align_M_",
align_M ? 't' : 'n',
"_align_N_",
align_N ? 't' : 'n');
// Get and set the kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = get_steel_gemm_gather_kernel(
d,
base_name,
hash_name,
func_consts,
out,
transpose_a,
transpose_b,
bm,
bn,
bk,
wm,
wn,
false);
compute_encoder.set_compute_pipeline_state(kernel);
// Prepare the matmul params
steel::GEMMParams params{
/* const int M = */ M,
/* const int N = */ N,
/* const int K = */ K,
/* const int lda = */ static_cast<int>(lda),
/* const int ldb = */ static_cast<int>(ldb),
/* const int ldd = */ N,
/* const int tiles_n = */ (N + bn - 1) / bn,
/* const int tiles_m = */ (M + bm - 1) / bm,
/* const int64_t batch_stride_a = */ 0,
/* const int64_t batch_stride_b = */ 0,
/* const int64_t batch_stride_d = */ M * N,
/* const int swizzle_log = */ 0,
/* const int gemm_k_iterations_aligned = */ 0,
/* const int batch_ndim = */ 0};
// Prepare the grid
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims =
MTL::Size(params.tiles_n, params.tiles_m, batch_size_out);
// Launch kernel
compute_encoder.set_input_array(a, 0);
compute_encoder.set_input_array(b, 1);
compute_encoder.set_input_array(segments, 2);
compute_encoder.set_output_array(out, 3);
compute_encoder.set_bytes(params, 4);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
void SegmentedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
throw std::invalid_argument("NYI");
auto& s = stream();
auto& d = metal::device(s.device);
auto& a = inputs[0];
auto& b = inputs[1];
auto& segments = inputs[2];
out.set_data(allocator::malloc(out.nbytes()));
// Extract shapes from inputs.
int M = a.shape(-2);
int N = b.shape(-1);
int K = a.shape(-1);
segmented_mm(a, b, segments, out, M, N, K, d, s);
}
} // namespace mlx::core

View File

@@ -4658,6 +4658,13 @@ array segmented_mm(
throw std::invalid_argument("[segmented_mm] Batched matmul not supported");
}
if (segments.ndim() < 1 || segments.shape().back() != 2) {
std::ostringstream msg;
msg << "[segmented_mm] The segments should have shape (..., 2) but "
<< segments.shape() << " was provided.";
throw std::invalid_argument(msg.str());
}
// Type promotion
auto out_type = result_type(a, b);
if (!issubdtype(out_type, floating)) {
@@ -4673,6 +4680,7 @@ array segmented_mm(
b = astype(b, out_type, s);
Shape out_shape = segments.shape();
out_shape.pop_back();
out_shape.push_back(a.shape(0));
out_shape.push_back(b.shape(1));