Merge branch 'ml-explore:main' into adding-Muon-optimizer

This commit is contained in:
Gökdeniz Gülmez
2025-07-17 20:10:02 +02:00
committed by GitHub
4 changed files with 135 additions and 46 deletions

View File

@@ -53,9 +53,10 @@ struct FusedKernelBuilder {
// Build function signature. // Build function signature.
if (contiguous) { if (contiguous) {
os += "template <typename IdxT = uint32_t>\n"; os += "template <typename IdxT = uint32_t, int work_per_thread = 1>\n";
} else { } else {
os += "template <int NDIM, typename IdxT = uint32_t>\n"; os +=
"template <int NDIM, typename IdxT = uint32_t, int work_per_thread = 1>\n";
} }
os += fmt::format("__global__ void {}(\n", kernel_name + name); os += fmt::format("__global__ void {}(\n", kernel_name + name);
for (size_t i = 0; i < params.size(); ++i) { for (size_t i = 0; i < params.size(); ++i) {
@@ -67,12 +68,46 @@ struct FusedKernelBuilder {
} }
os += ") {\n"; os += ") {\n";
// Index. // Index. For non contiguous kernels we create a separate index
// variable per variable otherwise everyone uses `index`.
os += os +=
" IdxT index = cg::this_grid().thread_rank();\n" " IdxT index = cg::this_grid().thread_rank() * work_per_thread;\n"
" if (index >= size) {\n" " if (index >= size) {\n"
" return;\n" " return;\n"
" }\n"; " }\n";
if (!contiguous) {
for (size_t i = 0; i < inputs.size(); ++i) {
const auto& x = inputs[i];
const std::string& xname = namer.get_name(x);
if (is_scalar(x) || is_constant(i)) {
continue;
}
os += " IdxT " + xname + "_idx = 0;\n";
}
os += " {\n";
os += " IdxT loc = index;\n";
os +=
" #pragma unroll\n"
" for (int i = NDIM - 1; i >= 0; i--) {\n";
for (size_t i = 0; i < inputs.size(); ++i) {
const auto& x = inputs[i];
const std::string& xname = namer.get_name(x);
if (is_scalar(x) || is_constant(i)) {
continue;
}
os += " " + xname + "_idx += (loc \% shape[i]) * IdxT(" + xname +
"_strides[i]);\n";
}
os +=
" loc /= shape[i];\n"
" }\n"
" }\n";
}
// Work loop
os +=
"\n"
" for (int i = 0; i < work_per_thread && index < size; i++) {\n";
// Read inputs. // Read inputs.
for (size_t i = 0; i < inputs.size(); ++i) { for (size_t i = 0; i < inputs.size(); ++i) {
@@ -89,10 +124,7 @@ struct FusedKernelBuilder {
} else if (contiguous) { } else if (contiguous) {
value = fmt::format("{}[index]", xname); value = fmt::format("{}[index]", xname);
} else { } else {
std::string index = fmt::format( value = fmt::format("{}[{}_idx]", xname, xname);
"elem_to_loc_nd<NDIM>(index, shape.data(), {}_strides.data())",
xname);
value = fmt::format("{}[{}]", xname, index);
} }
os += fmt::format(" {} tmp_{} = {};\n", type, xname, value); os += fmt::format(" {} tmp_{} = {};\n", type, xname, value);
} }
@@ -121,6 +153,22 @@ struct FusedKernelBuilder {
os += fmt::format(" {0}[index] = tmp_{0};\n", namer.get_name(x)); os += fmt::format(" {0}[index] = tmp_{0};\n", namer.get_name(x));
} }
// End of work loop
os +=
"\n"
" index++;\n";
if (!contiguous) {
for (size_t i = 0; i < inputs.size(); ++i) {
const auto& x = inputs[i];
const std::string& xname = namer.get_name(x);
if (is_scalar(x) || is_constant(i)) {
continue;
}
os += " " + xname + "_idx += " + xname + "_strides[NDIM - 1];\n";
}
}
os += " }\n";
os += "}\n"; os += "}\n";
} }
}; };
@@ -156,15 +204,28 @@ void Compiled::eval_gpu(
builder.build("_strided", false); builder.build("_strided", false);
builder.os += "\n} // namespace mlx::core::cu\n"; builder.os += "\n} // namespace mlx::core::cu\n";
// Build kernel names. // Build kernel names.
std::vector<std::string> kernel_names = { std::vector<std::string> kernel_names;
fmt::format("mlx::core::cu::{}_contiguous<uint32_t>", lib_name()), for (auto work_per_thread : std::array<int, 2>{1, 4}) {
fmt::format("mlx::core::cu::{}_contiguous<int64_t>", lib_name()), kernel_names.push_back(fmt::format(
}; "mlx::core::cu::{}_contiguous<uint32_t, {}>",
lib_name(),
work_per_thread));
kernel_names.push_back(fmt::format(
"mlx::core::cu::{}_contiguous<int64_t, {}>",
lib_name(),
work_per_thread));
for (int i = 1; i <= MAX_NDIM; ++i) { for (int i = 1; i <= MAX_NDIM; ++i) {
kernel_names.push_back(fmt::format( kernel_names.push_back(fmt::format(
"mlx::core::cu::{}_strided<{}, uint32_t>", lib_name(), i)); "mlx::core::cu::{}_strided<{}, uint32_t, {}>",
kernel_names.push_back( lib_name(),
fmt::format("mlx::core::cu::{}_strided<{}, int64_t>", lib_name(), i)); i,
work_per_thread));
kernel_names.push_back(fmt::format(
"mlx::core::cu::{}_strided<{}, int64_t, {}>",
lib_name(),
i,
work_per_thread));
}
} }
return std::make_pair(std::move(builder.os), std::move(kernel_names)); return std::make_pair(std::move(builder.os), std::move(kernel_names));
}); });
@@ -207,13 +268,21 @@ void Compiled::eval_gpu(
args.append<uint32_t>(outputs[0].data_size()); args.append<uint32_t>(outputs[0].data_size());
} }
// Choose work per thread
int work_per_thread = 4;
if (!contiguous && shape.back() % work_per_thread != 0) {
work_per_thread = 1;
}
// Launch kernel. // Launch kernel.
const char* index_type = large ? "int64_t" : "uint32_t"; const char* index_type = large ? "int64_t" : "uint32_t";
std::string kernel_name = fmt::format("mlx::core::cu::{}", lib_name()); std::string kernel_name = fmt::format("mlx::core::cu::{}", lib_name());
if (contiguous) { if (contiguous) {
kernel_name += fmt::format("_contiguous<{}>", index_type); kernel_name +=
fmt::format("_contiguous<{}, {}>", index_type, work_per_thread);
} else { } else {
kernel_name += fmt::format("_strided<{}, {}>", shape.size(), index_type); kernel_name += fmt::format(
"_strided<{}, {}, {}>", shape.size(), index_type, work_per_thread);
} }
auto& encoder = cu::get_command_encoder(s); auto& encoder = cu::get_command_encoder(s);
for (const auto& in : inputs) { for (const auto& in : inputs) {
@@ -224,7 +293,8 @@ void Compiled::eval_gpu(
} }
auto kernel = mod.get_kernel(kernel_name); auto kernel = mod.get_kernel(kernel_name);
auto [num_blocks, block_dims] = get_launch_args(kernel, outputs[0], large); auto [num_blocks, block_dims] =
get_launch_args(kernel, outputs[0], large, work_per_thread);
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args()); encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
} }

View File

@@ -66,7 +66,6 @@ CommandEncoder& Device::get_command_encoder(Stream s) {
} }
CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) { CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
CHECK_CUDA_ERROR(cudaGraphCreate(&graph, 0));
CHECK_CUDA_ERROR( CHECK_CUDA_ERROR(
cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal)); cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal));
} }

View File

@@ -121,7 +121,8 @@ void write_cached_ptx(
const std::filesystem::path& cache_dir, const std::filesystem::path& cache_dir,
const std::string& module_name, const std::string& module_name,
const std::vector<char>& ptx, const std::vector<char>& ptx,
const std::vector<std::pair<std::string, std::string>>& ptx_kernels) { const std::vector<std::pair<std::string, std::string>>& ptx_kernels,
const std::string& source_code) {
if (cache_dir.empty()) { if (cache_dir.empty()) {
return; return;
} }
@@ -134,6 +135,9 @@ void write_cached_ptx(
for (const auto& [name, mangled] : ptx_kernels) { for (const auto& [name, mangled] : ptx_kernels) {
txt_file << name << "\t" << mangled << std::endl; txt_file << name << "\t" << mangled << std::endl;
} }
std::ofstream source_file(cache_dir / (module_name + ".cu"));
source_file << source_code;
} }
// Return if |device|'s version is not newer than |major|.|minor| version. // Return if |device|'s version is not newer than |major|.|minor| version.
@@ -272,7 +276,8 @@ JitModule::JitModule(
} else { } else {
CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data())); CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data()));
} }
write_cached_ptx(ptx_cache_dir(), module_name, ptx, ptx_kernels); write_cached_ptx(
ptx_cache_dir(), module_name, ptx, ptx_kernels, source_code);
} }
// Load module. // Load module.

View File

@@ -27,6 +27,35 @@ void check_cublas_error(const char* name, cublasStatus_t err) {
} }
} }
struct CublasPreference {
CublasPreference(Device& device) {
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB
// for Hopper+:
// https://docs.nvidia.com/cuda/cublas/#cublassetworkspace
uint64_t MiB = 1024 * 1024;
uint64_t workspace_size =
device.compute_capability_major() >= 9 ? 32 * MiB : 4 * MiB;
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceCreate(&pref_));
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceSetAttribute(
pref_,
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&workspace_size,
sizeof(uint64_t)));
}
~CublasPreference() {
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceDestroy(pref_));
}
cublasLtMatmulPreference_t pref_{nullptr};
};
cublasLtMatmulPreference_t cublas_preference(Device& device) {
static CublasPreference pref(device);
return pref.pref_;
}
class MatMul { class MatMul {
public: public:
MatMul( MatMul(
@@ -43,7 +72,7 @@ class MatMul {
int32_t batch_count, int32_t batch_count,
int64_t a_batch_stride, int64_t a_batch_stride,
int64_t b_batch_stride) int64_t b_batch_stride)
: handle_(device.lt_handle()) { : handle_(device.lt_handle()), pref_(cublas_preference(device)) {
heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED; heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED;
auto scale_type = dtype_to_cuda_type(dtype); auto scale_type = dtype_to_cuda_type(dtype);
@@ -77,20 +106,6 @@ class MatMul {
type, b_rows, b_cols, b_transposed, ldb, batch_count, b_batch_stride); type, b_rows, b_cols, b_transposed, ldb, batch_count, b_batch_stride);
out_desc_ = create_matrix_layout( out_desc_ = create_matrix_layout(
type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols); type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols);
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB
// for Hopper+:
// https://docs.nvidia.com/cuda/cublas/#cublassetworkspace
uint64_t MiB = 1024 * 1024;
uint64_t workspace_size =
device.compute_capability_major() >= 9 ? 32 * MiB : 4 * MiB;
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceCreate(&pref_));
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceSetAttribute(
pref_,
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&workspace_size,
sizeof(uint64_t)));
} }
MatMul( MatMul(
@@ -130,11 +145,11 @@ class MatMul {
} }
~MatMul() { ~MatMul() {
cublasLtMatrixLayoutDestroy(a_desc_); CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(a_desc_));
cublasLtMatrixLayoutDestroy(b_desc_); CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(b_desc_));
cublasLtMatrixLayoutDestroy(c_desc_); CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(c_desc_));
cublasLtMatrixLayoutDestroy(out_desc_); CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_));
cublasLtMatmulDescDestroy(matmul_desc_); CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_));
} }
void run( void run(
@@ -259,9 +274,9 @@ class MatMul {
return desc; return desc;
} }
cublasLtMatmulPreference_t pref_{nullptr};
cublasLtHandle_t handle_{nullptr}; cublasLtHandle_t handle_{nullptr};
cublasLtMatmulDesc_t matmul_desc_{nullptr}; cublasLtMatmulDesc_t matmul_desc_{nullptr};
cublasLtMatmulPreference_t pref_{nullptr};
cublasLtMatrixLayout_t a_desc_{nullptr}; cublasLtMatrixLayout_t a_desc_{nullptr};
cublasLtMatrixLayout_t b_desc_{nullptr}; cublasLtMatrixLayout_t b_desc_{nullptr};
cublasLtMatrixLayout_t c_desc_{nullptr}; cublasLtMatrixLayout_t c_desc_{nullptr};