mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-03 06:14:43 +08:00
Compare commits
3 Commits
v0.25.1
...
gemm-tuner
Author | SHA1 | Date | |
---|---|---|---|
![]() |
e21143961c | ||
![]() |
2ed2e0e3da | ||
![]() |
84e7c49f08 |
@@ -16,6 +16,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/internal/tuner/ops.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h)
|
||||
|
||||
if(MLX_BUILD_CPU)
|
||||
|
@@ -15,6 +15,8 @@
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
#include "mlx/internal/tuner/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
@@ -1848,4 +1850,195 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
}
|
||||
|
||||
namespace internal {
|
||||
|
||||
void TunableMatmul::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
auto& a_pre = inputs[0];
|
||||
auto& b_pre = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
// Return 0s if either input is empty
|
||||
if (a_pre.size() == 0 || b_pre.size() == 0) {
|
||||
array zero = array(0, a_pre.dtype());
|
||||
fill_gpu(zero, out, s);
|
||||
d.add_temporary(std::move(zero), s.index);
|
||||
return;
|
||||
}
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Init checks and prep
|
||||
|
||||
int M = a_pre.shape(-2);
|
||||
int N = b_pre.shape(-1);
|
||||
int K = a_pre.shape(-1);
|
||||
|
||||
// Keep a vector with copies to be cleared in the completed buffer to release
|
||||
// the arrays
|
||||
std::vector<array> copies;
|
||||
auto check_transpose = [&copies, &s](const array& arr, bool is_vector) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
auto sty = arr.strides()[arr.ndim() - 1];
|
||||
if (sty == 1 && (!is_vector || stx == arr.shape(-1))) {
|
||||
return std::make_tuple(false, stx, arr);
|
||||
} else if (stx == 1 && (!is_vector || sty == arr.shape(-2))) {
|
||||
return std::make_tuple(true, sty, arr);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy_gpu(arr, arr_copy, CopyType::General, s);
|
||||
copies.push_back(arr_copy);
|
||||
size_t stx = arr.shape(-1);
|
||||
return std::make_tuple(false, stx, arr_copy);
|
||||
}
|
||||
};
|
||||
|
||||
auto [transpose_a, a_cols, a] = check_transpose(a_pre, M == 1);
|
||||
auto [transpose_b, b_cols, b] = check_transpose(b_pre, N == 1);
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Check and collapse batch dimensions
|
||||
|
||||
auto [batch_shape, A_batch_stride, B_batch_stride] = collapse_batches(a, b);
|
||||
|
||||
auto batch_size_out = out.size() / (size_t(M) * size_t(N));
|
||||
|
||||
// Collapse batches into M if needed
|
||||
if (batch_size_out > 1 && !transpose_a && batch_shape.size() == 1 &&
|
||||
a.strides()[a.ndim() - 2] == K && A_batch_stride.back() == M * K &&
|
||||
B_batch_stride.back() == 0) {
|
||||
M *= batch_shape.back();
|
||||
batch_size_out = 1;
|
||||
|
||||
A_batch_stride = {0};
|
||||
B_batch_stride = {0};
|
||||
batch_shape = {1};
|
||||
}
|
||||
|
||||
std::vector<size_t> batch_strides = A_batch_stride;
|
||||
batch_strides.insert(
|
||||
batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end());
|
||||
|
||||
using namespace mlx::steel;
|
||||
|
||||
// Determine dispatch kernel
|
||||
int bm = tparams_["bm"];
|
||||
int bn = tparams_["bn"];
|
||||
int bk = tparams_["bk"];
|
||||
int wm = tparams_["wm"];
|
||||
int wn = tparams_["wn"];
|
||||
|
||||
// Prepare kernel name
|
||||
std::ostringstream kname;
|
||||
kname << "steel_gemm_fused_" << (transpose_a ? 't' : 'n')
|
||||
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
|
||||
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
|
||||
<< "_wm" << wm << "_wn" << wn;
|
||||
|
||||
std::string base_name = kname.str();
|
||||
|
||||
const bool has_batch = (batch_shape.size() > 1);
|
||||
const bool use_out_source = false;
|
||||
const bool do_axpby = false;
|
||||
const bool align_M = (M % bm) == 0;
|
||||
const bool align_N = (N % bn) == 0;
|
||||
const bool align_K = (K % bk) == 0;
|
||||
const bool do_gather = false;
|
||||
|
||||
metal::MTLFCList func_consts = {
|
||||
{&has_batch, MTL::DataType::DataTypeBool, 10},
|
||||
{&use_out_source, MTL::DataType::DataTypeBool, 100},
|
||||
{&do_axpby, MTL::DataType::DataTypeBool, 110},
|
||||
{&align_M, MTL::DataType::DataTypeBool, 200},
|
||||
{&align_N, MTL::DataType::DataTypeBool, 201},
|
||||
{&align_K, MTL::DataType::DataTypeBool, 202},
|
||||
{&do_gather, MTL::DataType::DataTypeBool, 300},
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
kname << "_has_batch_" << (has_batch ? 't' : 'n')
|
||||
<< "_use_out_source_" << (use_out_source ? 't' : 'n')
|
||||
<< "_do_axpby_" << (do_axpby ? 't' : 'n')
|
||||
<< "_align_M_" << (align_M ? 't' : 'n')
|
||||
<< "_align_N_" << (align_N ? 't' : 'n')
|
||||
<< "_align_K_" << (align_K ? 't' : 'n')
|
||||
<< "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on
|
||||
|
||||
std::string hash_name = kname.str();
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = get_steel_gemm_fused_kernel(
|
||||
d,
|
||||
base_name,
|
||||
hash_name,
|
||||
func_consts,
|
||||
out,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
bm,
|
||||
bn,
|
||||
bk,
|
||||
wm,
|
||||
wn);
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Use problem size to determine threadblock swizzle
|
||||
int tn = (N + bn - 1) / bn;
|
||||
int tm = (M + bm - 1) / bm;
|
||||
|
||||
// TODO: Explore device-based tuning for swizzle
|
||||
int swizzle_log = 0; // tm >= 6 ? 3 : (tm <= 3 ? 0 : 2);
|
||||
|
||||
size_t matrix_stride_out = size_t(M) * N;
|
||||
|
||||
// Prepare steel matmul params
|
||||
GEMMParams params{
|
||||
/* const int M = */ M,
|
||||
/* const int N = */ N,
|
||||
/* const int K = */ K,
|
||||
/* const int lda = */ int(a_cols),
|
||||
/* const int ldb = */ int(b_cols),
|
||||
/* const int ldd = */ N,
|
||||
/* const int tiles_n = */ tn,
|
||||
/* const int tiles_m = */ tm,
|
||||
/* const size_t batch_stride_a = */ A_batch_stride.back(),
|
||||
/* const size_t batch_stride_b = */ B_batch_stride.back(),
|
||||
/* const size_t batch_stride_d = */ matrix_stride_out,
|
||||
/* const int swizzle_log = */ swizzle_log,
|
||||
/* const int gemm_k_iterations_aligned = */ (K / bk),
|
||||
/* const int batch_ndim = */ int(batch_shape.size())};
|
||||
|
||||
// Prepare launch grid params
|
||||
int tile = 1 << swizzle_log;
|
||||
tm = (tm + tile - 1) / tile;
|
||||
tn = tn * tile;
|
||||
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out);
|
||||
|
||||
// Launch kernel
|
||||
compute_encoder.set_input_array(a, 0);
|
||||
compute_encoder.set_input_array(b, 1);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
|
||||
compute_encoder->setBytes(¶ms, sizeof(GEMMParams), 4);
|
||||
|
||||
set_vector_bytes(compute_encoder, batch_shape, 6);
|
||||
set_vector_bytes(compute_encoder, batch_strides, 7);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
// Record copies
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
}
|
||||
|
||||
} // namespace internal
|
||||
|
||||
} // namespace mlx::core
|
||||
|
88
mlx/internal/tuner/ops.cpp
Normal file
88
mlx/internal/tuner/ops.cpp
Normal file
@@ -0,0 +1,88 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/internal/tuner/ops.h"
|
||||
#include "mlx/internal/tuner/primitives.h"
|
||||
#include "mlx/ops.h"
|
||||
|
||||
namespace mlx::core::internal {
|
||||
|
||||
array tunable_matmul(
|
||||
const array& in_a,
|
||||
const array& in_b,
|
||||
const std::unordered_map<std::string, int>& tparams,
|
||||
StreamOrDevice s_ /*= {} */) {
|
||||
auto s = to_stream(s_);
|
||||
auto fallback = [s](const std::vector<array>& inputs) {
|
||||
return std::vector<array>{matmul(inputs[0], inputs[1], s)};
|
||||
};
|
||||
|
||||
if (s.device == Device::cpu || in_a.ndim() < 2 || in_b.ndim() < 2) {
|
||||
return matmul(in_a, in_b, s);
|
||||
}
|
||||
|
||||
auto a = in_a;
|
||||
auto b = in_b;
|
||||
|
||||
if (a.shape(-1) != b.shape(-2)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[matmul] Last dimension of first input with shape " << a.shape()
|
||||
<< " must match second to last dimension of"
|
||||
<< " second input with shape " << b.shape() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto out_type = promote_types(a.dtype(), b.dtype());
|
||||
if (!issubdtype(out_type, floating)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[matmul] Only real floating point types are supported but "
|
||||
<< a.dtype() << " and " << b.dtype() << " were provided which results"
|
||||
<< " in " << out_type << ", which is not a real floating point type.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
a = astype(a, out_type, s);
|
||||
b = astype(b, out_type, s);
|
||||
|
||||
// We can batch the multiplication by reshaping a
|
||||
if (a.ndim() > 2 && b.ndim() == 2) {
|
||||
std::vector<int> out_shape = a.shape();
|
||||
a = reshape(a, {-1, out_shape.back()}, s);
|
||||
out_shape.back() = b.shape(-1);
|
||||
if (in_b.ndim() == 1) {
|
||||
out_shape.pop_back();
|
||||
}
|
||||
auto out = array(
|
||||
{a.shape(0), b.shape(1)},
|
||||
out_type,
|
||||
std::make_shared<TunableMatmul>(to_stream(s), fallback, tparams),
|
||||
{a, b});
|
||||
return reshape(out, out_shape, s);
|
||||
}
|
||||
|
||||
if (a.ndim() > 2 || b.ndim() > 2) {
|
||||
std::vector<int> bsx_a(a.shape().begin(), a.shape().end() - 2);
|
||||
std::vector<int> bsx_b(b.shape().begin(), b.shape().end() - 2);
|
||||
auto inner_shape = broadcast_shapes(bsx_a, bsx_b);
|
||||
|
||||
// Broadcast a
|
||||
inner_shape.push_back(a.shape(-2));
|
||||
inner_shape.push_back(a.shape(-1));
|
||||
a = broadcast_to(a, inner_shape, s);
|
||||
|
||||
// Broadcast b
|
||||
*(inner_shape.end() - 2) = b.shape(-2);
|
||||
*(inner_shape.end() - 1) = b.shape(-1);
|
||||
b = broadcast_to(b, inner_shape, s);
|
||||
}
|
||||
|
||||
auto out_shape = a.shape();
|
||||
out_shape.back() = b.shape(-1);
|
||||
|
||||
return array(
|
||||
std::move(out_shape),
|
||||
out_type,
|
||||
std::make_shared<TunableMatmul>(to_stream(s), fallback, tparams),
|
||||
{a, b});
|
||||
}
|
||||
|
||||
} // namespace mlx::core::internal
|
17
mlx/internal/tuner/ops.h
Normal file
17
mlx/internal/tuner/ops.h
Normal file
@@ -0,0 +1,17 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core::internal {
|
||||
|
||||
array tunable_matmul(
|
||||
const array& a,
|
||||
const array& b,
|
||||
const std::unordered_map<std::string, int>& tparams,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
} // namespace mlx::core::internal
|
34
mlx/internal/tuner/primitives.h
Normal file
34
mlx/internal/tuner/primitives.h
Normal file
@@ -0,0 +1,34 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core::internal {
|
||||
|
||||
class TunableMatmul : public mlx::core::fast::Custom {
|
||||
public:
|
||||
TunableMatmul(
|
||||
Stream stream,
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback,
|
||||
std::unordered_map<std::string, int> tparams)
|
||||
: mlx::core::fast::Custom(stream, fallback), tparams_(tparams) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override {
|
||||
throw std::runtime_error("NYI");
|
||||
}
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
DEFINE_PRINT(TunableMatmul)
|
||||
|
||||
private:
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
||||
std::unordered_map<std::string, int> tparams_;
|
||||
};
|
||||
|
||||
} // namespace mlx::core::internal
|
@@ -23,6 +23,7 @@ nanobind_add_module(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/constants.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/trees.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/internal.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp)
|
||||
|
||||
if(NOT MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY)
|
||||
|
59
python/src/internal.cpp
Normal file
59
python/src/internal.cpp
Normal file
@@ -0,0 +1,59 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/optional.h>
|
||||
#include <nanobind/stl/pair.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
#include <nanobind/stl/tuple.h>
|
||||
#include <nanobind/stl/unordered_map.h>
|
||||
#include <nanobind/stl/variant.h>
|
||||
#include <nanobind/stl/vector.h>
|
||||
|
||||
#include "python/src/utils.h"
|
||||
|
||||
#include "mlx/internal/tuner/ops.h"
|
||||
#include "mlx/ops.h"
|
||||
|
||||
namespace nb = nanobind;
|
||||
using namespace nb::literals;
|
||||
using namespace mlx::core;
|
||||
|
||||
void init_internal(nb::module_& parent_module) {
|
||||
auto m = parent_module.def_submodule(
|
||||
"internal", "mlx.core.internal: internal operations");
|
||||
|
||||
m.def(
|
||||
"tunable_matmul",
|
||||
&internal::tunable_matmul,
|
||||
nb::arg(),
|
||||
nb::arg(),
|
||||
nb::arg(),
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def tunable_matmul(a: array, b: array, tparams: dict[str, int], /, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Matrix multiplication.
|
||||
|
||||
Perform the (possibly batched) matrix multiplication of two arrays. This function supports
|
||||
broadcasting for arrays with more than two dimensions.
|
||||
|
||||
- If the first array is 1-D then a 1 is prepended to its shape to make it
|
||||
a matrix. Similarly if the second array is 1-D then a 1 is appended to its
|
||||
shape to make it a matrix. In either case the singleton dimension is removed
|
||||
from the result.
|
||||
- A batched matrix multiplication is performed if the arrays have more than
|
||||
2 dimensions. The matrix dimensions for the matrix product are the last
|
||||
two dimensions of each input.
|
||||
- All but the last two dimensions of each input are broadcast with one another using
|
||||
standard numpy-style broadcasting semantics.
|
||||
|
||||
Args:
|
||||
a (array): Input array or scalar.
|
||||
b (array): Input array or scalar.
|
||||
tparams (dict[str, int]): Matmul tunable parameters
|
||||
|
||||
Returns:
|
||||
array: The matrix product of ``a`` and ``b``.
|
||||
)pbdoc");
|
||||
}
|
@@ -19,6 +19,7 @@ void init_linalg(nb::module_&);
|
||||
void init_constants(nb::module_&);
|
||||
void init_fast(nb::module_&);
|
||||
void init_distributed(nb::module_&);
|
||||
void init_internal(nb::module_&);
|
||||
|
||||
NB_MODULE(core, m) {
|
||||
m.doc() = "mlx: A framework for machine learning on Apple silicon.";
|
||||
@@ -39,6 +40,7 @@ NB_MODULE(core, m) {
|
||||
init_constants(m);
|
||||
init_fast(m);
|
||||
init_distributed(m);
|
||||
init_internal(m);
|
||||
|
||||
m.attr("__version__") = TOSTRING(_VERSION_);
|
||||
}
|
||||
|
Reference in New Issue
Block a user