mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 12:49:44 +08:00
Merge branch 'ml-explore:main' into adding-Muon-optimizer
This commit is contained in:
@@ -53,9 +53,10 @@ struct FusedKernelBuilder {
|
||||
|
||||
// Build function signature.
|
||||
if (contiguous) {
|
||||
os += "template <typename IdxT = uint32_t>\n";
|
||||
os += "template <typename IdxT = uint32_t, int work_per_thread = 1>\n";
|
||||
} else {
|
||||
os += "template <int NDIM, typename IdxT = uint32_t>\n";
|
||||
os +=
|
||||
"template <int NDIM, typename IdxT = uint32_t, int work_per_thread = 1>\n";
|
||||
}
|
||||
os += fmt::format("__global__ void {}(\n", kernel_name + name);
|
||||
for (size_t i = 0; i < params.size(); ++i) {
|
||||
@@ -67,12 +68,46 @@ struct FusedKernelBuilder {
|
||||
}
|
||||
os += ") {\n";
|
||||
|
||||
// Index.
|
||||
// Index. For non contiguous kernels we create a separate index
|
||||
// variable per variable otherwise everyone uses `index`.
|
||||
os +=
|
||||
" IdxT index = cg::this_grid().thread_rank();\n"
|
||||
" IdxT index = cg::this_grid().thread_rank() * work_per_thread;\n"
|
||||
" if (index >= size) {\n"
|
||||
" return;\n"
|
||||
" }\n";
|
||||
if (!contiguous) {
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
const auto& x = inputs[i];
|
||||
const std::string& xname = namer.get_name(x);
|
||||
if (is_scalar(x) || is_constant(i)) {
|
||||
continue;
|
||||
}
|
||||
os += " IdxT " + xname + "_idx = 0;\n";
|
||||
}
|
||||
os += " {\n";
|
||||
os += " IdxT loc = index;\n";
|
||||
os +=
|
||||
" #pragma unroll\n"
|
||||
" for (int i = NDIM - 1; i >= 0; i--) {\n";
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
const auto& x = inputs[i];
|
||||
const std::string& xname = namer.get_name(x);
|
||||
if (is_scalar(x) || is_constant(i)) {
|
||||
continue;
|
||||
}
|
||||
os += " " + xname + "_idx += (loc \% shape[i]) * IdxT(" + xname +
|
||||
"_strides[i]);\n";
|
||||
}
|
||||
os +=
|
||||
" loc /= shape[i];\n"
|
||||
" }\n"
|
||||
" }\n";
|
||||
}
|
||||
|
||||
// Work loop
|
||||
os +=
|
||||
"\n"
|
||||
" for (int i = 0; i < work_per_thread && index < size; i++) {\n";
|
||||
|
||||
// Read inputs.
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
@@ -89,12 +124,9 @@ struct FusedKernelBuilder {
|
||||
} else if (contiguous) {
|
||||
value = fmt::format("{}[index]", xname);
|
||||
} else {
|
||||
std::string index = fmt::format(
|
||||
"elem_to_loc_nd<NDIM>(index, shape.data(), {}_strides.data())",
|
||||
xname);
|
||||
value = fmt::format("{}[{}]", xname, index);
|
||||
value = fmt::format("{}[{}_idx]", xname, xname);
|
||||
}
|
||||
os += fmt::format(" {} tmp_{} = {};\n", type, xname, value);
|
||||
os += fmt::format(" {} tmp_{} = {};\n", type, xname, value);
|
||||
}
|
||||
|
||||
// Write tape.
|
||||
@@ -113,14 +145,30 @@ struct FusedKernelBuilder {
|
||||
}
|
||||
value += fmt::format("tmp_{})", namer.get_name(x.inputs().back()));
|
||||
}
|
||||
os += fmt::format(" {} tmp_{} = {};\n", type, xname, value);
|
||||
os += fmt::format(" {} tmp_{} = {};\n", type, xname, value);
|
||||
}
|
||||
|
||||
// Write output.
|
||||
for (const auto& x : outputs) {
|
||||
os += fmt::format(" {0}[index] = tmp_{0};\n", namer.get_name(x));
|
||||
os += fmt::format(" {0}[index] = tmp_{0};\n", namer.get_name(x));
|
||||
}
|
||||
|
||||
// End of work loop
|
||||
os +=
|
||||
"\n"
|
||||
" index++;\n";
|
||||
if (!contiguous) {
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
const auto& x = inputs[i];
|
||||
const std::string& xname = namer.get_name(x);
|
||||
if (is_scalar(x) || is_constant(i)) {
|
||||
continue;
|
||||
}
|
||||
os += " " + xname + "_idx += " + xname + "_strides[NDIM - 1];\n";
|
||||
}
|
||||
}
|
||||
os += " }\n";
|
||||
|
||||
os += "}\n";
|
||||
}
|
||||
};
|
||||
@@ -156,15 +204,28 @@ void Compiled::eval_gpu(
|
||||
builder.build("_strided", false);
|
||||
builder.os += "\n} // namespace mlx::core::cu\n";
|
||||
// Build kernel names.
|
||||
std::vector<std::string> kernel_names = {
|
||||
fmt::format("mlx::core::cu::{}_contiguous<uint32_t>", lib_name()),
|
||||
fmt::format("mlx::core::cu::{}_contiguous<int64_t>", lib_name()),
|
||||
};
|
||||
for (int i = 1; i <= MAX_NDIM; ++i) {
|
||||
std::vector<std::string> kernel_names;
|
||||
for (auto work_per_thread : std::array<int, 2>{1, 4}) {
|
||||
kernel_names.push_back(fmt::format(
|
||||
"mlx::core::cu::{}_strided<{}, uint32_t>", lib_name(), i));
|
||||
kernel_names.push_back(
|
||||
fmt::format("mlx::core::cu::{}_strided<{}, int64_t>", lib_name(), i));
|
||||
"mlx::core::cu::{}_contiguous<uint32_t, {}>",
|
||||
lib_name(),
|
||||
work_per_thread));
|
||||
kernel_names.push_back(fmt::format(
|
||||
"mlx::core::cu::{}_contiguous<int64_t, {}>",
|
||||
lib_name(),
|
||||
work_per_thread));
|
||||
for (int i = 1; i <= MAX_NDIM; ++i) {
|
||||
kernel_names.push_back(fmt::format(
|
||||
"mlx::core::cu::{}_strided<{}, uint32_t, {}>",
|
||||
lib_name(),
|
||||
i,
|
||||
work_per_thread));
|
||||
kernel_names.push_back(fmt::format(
|
||||
"mlx::core::cu::{}_strided<{}, int64_t, {}>",
|
||||
lib_name(),
|
||||
i,
|
||||
work_per_thread));
|
||||
}
|
||||
}
|
||||
return std::make_pair(std::move(builder.os), std::move(kernel_names));
|
||||
});
|
||||
@@ -207,13 +268,21 @@ void Compiled::eval_gpu(
|
||||
args.append<uint32_t>(outputs[0].data_size());
|
||||
}
|
||||
|
||||
// Choose work per thread
|
||||
int work_per_thread = 4;
|
||||
if (!contiguous && shape.back() % work_per_thread != 0) {
|
||||
work_per_thread = 1;
|
||||
}
|
||||
|
||||
// Launch kernel.
|
||||
const char* index_type = large ? "int64_t" : "uint32_t";
|
||||
std::string kernel_name = fmt::format("mlx::core::cu::{}", lib_name());
|
||||
if (contiguous) {
|
||||
kernel_name += fmt::format("_contiguous<{}>", index_type);
|
||||
kernel_name +=
|
||||
fmt::format("_contiguous<{}, {}>", index_type, work_per_thread);
|
||||
} else {
|
||||
kernel_name += fmt::format("_strided<{}, {}>", shape.size(), index_type);
|
||||
kernel_name += fmt::format(
|
||||
"_strided<{}, {}, {}>", shape.size(), index_type, work_per_thread);
|
||||
}
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
for (const auto& in : inputs) {
|
||||
@@ -224,7 +293,8 @@ void Compiled::eval_gpu(
|
||||
}
|
||||
|
||||
auto kernel = mod.get_kernel(kernel_name);
|
||||
auto [num_blocks, block_dims] = get_launch_args(kernel, outputs[0], large);
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, outputs[0], large, work_per_thread);
|
||||
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
|
||||
}
|
||||
|
||||
|
@@ -66,7 +66,6 @@ CommandEncoder& Device::get_command_encoder(Stream s) {
|
||||
}
|
||||
|
||||
CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
|
||||
CHECK_CUDA_ERROR(cudaGraphCreate(&graph, 0));
|
||||
CHECK_CUDA_ERROR(
|
||||
cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal));
|
||||
}
|
||||
|
@@ -121,7 +121,8 @@ void write_cached_ptx(
|
||||
const std::filesystem::path& cache_dir,
|
||||
const std::string& module_name,
|
||||
const std::vector<char>& ptx,
|
||||
const std::vector<std::pair<std::string, std::string>>& ptx_kernels) {
|
||||
const std::vector<std::pair<std::string, std::string>>& ptx_kernels,
|
||||
const std::string& source_code) {
|
||||
if (cache_dir.empty()) {
|
||||
return;
|
||||
}
|
||||
@@ -134,6 +135,9 @@ void write_cached_ptx(
|
||||
for (const auto& [name, mangled] : ptx_kernels) {
|
||||
txt_file << name << "\t" << mangled << std::endl;
|
||||
}
|
||||
|
||||
std::ofstream source_file(cache_dir / (module_name + ".cu"));
|
||||
source_file << source_code;
|
||||
}
|
||||
|
||||
// Return if |device|'s version is not newer than |major|.|minor| version.
|
||||
@@ -272,7 +276,8 @@ JitModule::JitModule(
|
||||
} else {
|
||||
CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data()));
|
||||
}
|
||||
write_cached_ptx(ptx_cache_dir(), module_name, ptx, ptx_kernels);
|
||||
write_cached_ptx(
|
||||
ptx_cache_dir(), module_name, ptx, ptx_kernels, source_code);
|
||||
}
|
||||
|
||||
// Load module.
|
||||
|
@@ -27,6 +27,35 @@ void check_cublas_error(const char* name, cublasStatus_t err) {
|
||||
}
|
||||
}
|
||||
|
||||
struct CublasPreference {
|
||||
CublasPreference(Device& device) {
|
||||
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB
|
||||
// for Hopper+:
|
||||
// https://docs.nvidia.com/cuda/cublas/#cublassetworkspace
|
||||
uint64_t MiB = 1024 * 1024;
|
||||
uint64_t workspace_size =
|
||||
device.compute_capability_major() >= 9 ? 32 * MiB : 4 * MiB;
|
||||
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceCreate(&pref_));
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceSetAttribute(
|
||||
pref_,
|
||||
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
|
||||
&workspace_size,
|
||||
sizeof(uint64_t)));
|
||||
}
|
||||
|
||||
~CublasPreference() {
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceDestroy(pref_));
|
||||
}
|
||||
|
||||
cublasLtMatmulPreference_t pref_{nullptr};
|
||||
};
|
||||
|
||||
cublasLtMatmulPreference_t cublas_preference(Device& device) {
|
||||
static CublasPreference pref(device);
|
||||
return pref.pref_;
|
||||
}
|
||||
|
||||
class MatMul {
|
||||
public:
|
||||
MatMul(
|
||||
@@ -43,7 +72,7 @@ class MatMul {
|
||||
int32_t batch_count,
|
||||
int64_t a_batch_stride,
|
||||
int64_t b_batch_stride)
|
||||
: handle_(device.lt_handle()) {
|
||||
: handle_(device.lt_handle()), pref_(cublas_preference(device)) {
|
||||
heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED;
|
||||
|
||||
auto scale_type = dtype_to_cuda_type(dtype);
|
||||
@@ -77,20 +106,6 @@ class MatMul {
|
||||
type, b_rows, b_cols, b_transposed, ldb, batch_count, b_batch_stride);
|
||||
out_desc_ = create_matrix_layout(
|
||||
type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols);
|
||||
|
||||
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB
|
||||
// for Hopper+:
|
||||
// https://docs.nvidia.com/cuda/cublas/#cublassetworkspace
|
||||
uint64_t MiB = 1024 * 1024;
|
||||
uint64_t workspace_size =
|
||||
device.compute_capability_major() >= 9 ? 32 * MiB : 4 * MiB;
|
||||
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceCreate(&pref_));
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceSetAttribute(
|
||||
pref_,
|
||||
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
|
||||
&workspace_size,
|
||||
sizeof(uint64_t)));
|
||||
}
|
||||
|
||||
MatMul(
|
||||
@@ -130,11 +145,11 @@ class MatMul {
|
||||
}
|
||||
|
||||
~MatMul() {
|
||||
cublasLtMatrixLayoutDestroy(a_desc_);
|
||||
cublasLtMatrixLayoutDestroy(b_desc_);
|
||||
cublasLtMatrixLayoutDestroy(c_desc_);
|
||||
cublasLtMatrixLayoutDestroy(out_desc_);
|
||||
cublasLtMatmulDescDestroy(matmul_desc_);
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(a_desc_));
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(b_desc_));
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(c_desc_));
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_));
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_));
|
||||
}
|
||||
|
||||
void run(
|
||||
@@ -259,9 +274,9 @@ class MatMul {
|
||||
return desc;
|
||||
}
|
||||
|
||||
cublasLtMatmulPreference_t pref_{nullptr};
|
||||
cublasLtHandle_t handle_{nullptr};
|
||||
cublasLtMatmulDesc_t matmul_desc_{nullptr};
|
||||
cublasLtMatmulPreference_t pref_{nullptr};
|
||||
cublasLtMatrixLayout_t a_desc_{nullptr};
|
||||
cublasLtMatrixLayout_t b_desc_{nullptr};
|
||||
cublasLtMatrixLayout_t c_desc_{nullptr};
|
||||
|
Reference in New Issue
Block a user