mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Change the segments to be more general
This commit is contained in:
@@ -77,7 +77,10 @@ inline void segmented_mm(
|
||||
int32_t N = b_copy[ndim - 1];
|
||||
int32_t k_start = 0;
|
||||
for (int i = 0; i < num_segments; i++) {
|
||||
int32_t k_end = segments[elem_to_loc(i, segments_shape, segments_strides)];
|
||||
int32_t k_start =
|
||||
segments[elem_to_loc(2 * i, segments_shape, segments_strides)];
|
||||
int32_t k_end =
|
||||
segments[elem_to_loc(2 * i + 1, segments_shape, segments_strides)];
|
||||
a_copy[ndim - 1] = k_end - k_start;
|
||||
b_copy[ndim - 2] = k_end - k_start;
|
||||
matmul<T>(
|
||||
@@ -96,7 +99,6 @@ inline void segmented_mm(
|
||||
a_strides,
|
||||
b_copy,
|
||||
b_strides);
|
||||
k_start = k_end;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -537,7 +539,7 @@ void SegmentedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
a.strides(),
|
||||
b.shape(),
|
||||
b.strides(),
|
||||
segments.size(),
|
||||
segments.size() / 2,
|
||||
segments.shape(),
|
||||
segments.strides());
|
||||
break;
|
||||
@@ -555,7 +557,7 @@ void SegmentedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
a.strides(),
|
||||
b.shape(),
|
||||
b.strides(),
|
||||
segments.size(),
|
||||
segments.size() / 2,
|
||||
segments.shape(),
|
||||
segments.strides());
|
||||
break;
|
||||
@@ -573,7 +575,7 @@ void SegmentedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
a.strides(),
|
||||
b.shape(),
|
||||
b.strides(),
|
||||
segments.size(),
|
||||
segments.size() / 2,
|
||||
segments.shape(),
|
||||
segments.strides());
|
||||
break;
|
||||
@@ -591,7 +593,7 @@ void SegmentedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
a.strides(),
|
||||
b.shape(),
|
||||
b.strides(),
|
||||
segments.size(),
|
||||
segments.size() / 2,
|
||||
segments.shape(),
|
||||
segments.strides());
|
||||
break;
|
||||
|
||||
@@ -1864,8 +1864,138 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
gather_mm(a, b, lhs_indices, rhs_indices, out, M, N, K, d, s);
|
||||
}
|
||||
|
||||
void segmented_mm(
|
||||
const array& a_,
|
||||
const array& b_,
|
||||
const array& segments_,
|
||||
array& out,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
// Copy if needed
|
||||
std::vector<array> copies;
|
||||
auto [transpose_a, lda, a] = check_transpose(copies, s, a_, false);
|
||||
auto [transpose_b, ldb, b] = check_transpose(copies, s, b_, false);
|
||||
auto segments = ensure_row_contiguous(segments_, d, s);
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
|
||||
// Determine dispatch kernel
|
||||
int bm = 64, bn = 64, bk = 16;
|
||||
int wm = 2, wn = 2;
|
||||
size_t batch_size_out = out.size() / M / N;
|
||||
|
||||
char devc = d.get_architecture().back();
|
||||
GEMM_TPARAM_MACRO(devc)
|
||||
|
||||
const bool align_M = (M % bm) == 0;
|
||||
const bool align_N = (N % bn) == 0;
|
||||
|
||||
// Define the kernel name
|
||||
std::string base_name;
|
||||
base_name.reserve(128);
|
||||
concatenate(
|
||||
base_name,
|
||||
"steel_segmented_mm_",
|
||||
transpose_a ? 't' : 'n',
|
||||
transpose_b ? 't' : 'n',
|
||||
"_",
|
||||
type_to_name(a),
|
||||
"_",
|
||||
type_to_name(out),
|
||||
"_bm",
|
||||
bm,
|
||||
"_bn",
|
||||
bn,
|
||||
"_bk",
|
||||
bk,
|
||||
"_wm",
|
||||
wm,
|
||||
"_wn",
|
||||
wn);
|
||||
|
||||
metal::MTLFCList func_consts = {
|
||||
{&align_M, MTL::DataType::DataTypeBool, 200},
|
||||
{&align_N, MTL::DataType::DataTypeBool, 201},
|
||||
};
|
||||
|
||||
// And the kernel hash that includes the function constants
|
||||
std::string hash_name;
|
||||
hash_name.reserve(128);
|
||||
concatenate(
|
||||
hash_name,
|
||||
base_name,
|
||||
"_align_M_",
|
||||
align_M ? 't' : 'n',
|
||||
"_align_N_",
|
||||
align_N ? 't' : 'n');
|
||||
|
||||
// Get and set the kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = get_steel_gemm_gather_kernel(
|
||||
d,
|
||||
base_name,
|
||||
hash_name,
|
||||
func_consts,
|
||||
out,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
bm,
|
||||
bn,
|
||||
bk,
|
||||
wm,
|
||||
wn,
|
||||
false);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Prepare the matmul params
|
||||
steel::GEMMParams params{
|
||||
/* const int M = */ M,
|
||||
/* const int N = */ N,
|
||||
/* const int K = */ K,
|
||||
/* const int lda = */ static_cast<int>(lda),
|
||||
/* const int ldb = */ static_cast<int>(ldb),
|
||||
/* const int ldd = */ N,
|
||||
/* const int tiles_n = */ (N + bn - 1) / bn,
|
||||
/* const int tiles_m = */ (M + bm - 1) / bm,
|
||||
/* const int64_t batch_stride_a = */ 0,
|
||||
/* const int64_t batch_stride_b = */ 0,
|
||||
/* const int64_t batch_stride_d = */ M * N,
|
||||
/* const int swizzle_log = */ 0,
|
||||
/* const int gemm_k_iterations_aligned = */ 0,
|
||||
/* const int batch_ndim = */ 0};
|
||||
|
||||
// Prepare the grid
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims =
|
||||
MTL::Size(params.tiles_n, params.tiles_m, batch_size_out);
|
||||
|
||||
// Launch kernel
|
||||
compute_encoder.set_input_array(a, 0);
|
||||
compute_encoder.set_input_array(b, 1);
|
||||
compute_encoder.set_input_array(segments, 2);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
compute_encoder.set_bytes(params, 4);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void SegmentedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
throw std::invalid_argument("NYI");
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto& segments = inputs[2];
|
||||
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
// Extract shapes from inputs.
|
||||
int M = a.shape(-2);
|
||||
int N = b.shape(-1);
|
||||
int K = a.shape(-1);
|
||||
|
||||
segmented_mm(a, b, segments, out, M, N, K, d, s);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -4658,6 +4658,13 @@ array segmented_mm(
|
||||
throw std::invalid_argument("[segmented_mm] Batched matmul not supported");
|
||||
}
|
||||
|
||||
if (segments.ndim() < 1 || segments.shape().back() != 2) {
|
||||
std::ostringstream msg;
|
||||
msg << "[segmented_mm] The segments should have shape (..., 2) but "
|
||||
<< segments.shape() << " was provided.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
// Type promotion
|
||||
auto out_type = result_type(a, b);
|
||||
if (!issubdtype(out_type, floating)) {
|
||||
@@ -4673,6 +4680,7 @@ array segmented_mm(
|
||||
b = astype(b, out_type, s);
|
||||
|
||||
Shape out_shape = segments.shape();
|
||||
out_shape.pop_back();
|
||||
out_shape.push_back(a.shape(0));
|
||||
out_shape.push_back(b.shape(1));
|
||||
|
||||
|
||||
Reference in New Issue
Block a user