mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Merge branch 'ml-explore:main' into adding-Muon-optimizer
This commit is contained in:
@@ -53,9 +53,10 @@ struct FusedKernelBuilder {
|
|||||||
|
|
||||||
// Build function signature.
|
// Build function signature.
|
||||||
if (contiguous) {
|
if (contiguous) {
|
||||||
os += "template <typename IdxT = uint32_t>\n";
|
os += "template <typename IdxT = uint32_t, int work_per_thread = 1>\n";
|
||||||
} else {
|
} else {
|
||||||
os += "template <int NDIM, typename IdxT = uint32_t>\n";
|
os +=
|
||||||
|
"template <int NDIM, typename IdxT = uint32_t, int work_per_thread = 1>\n";
|
||||||
}
|
}
|
||||||
os += fmt::format("__global__ void {}(\n", kernel_name + name);
|
os += fmt::format("__global__ void {}(\n", kernel_name + name);
|
||||||
for (size_t i = 0; i < params.size(); ++i) {
|
for (size_t i = 0; i < params.size(); ++i) {
|
||||||
@@ -67,12 +68,46 @@ struct FusedKernelBuilder {
|
|||||||
}
|
}
|
||||||
os += ") {\n";
|
os += ") {\n";
|
||||||
|
|
||||||
// Index.
|
// Index. For non contiguous kernels we create a separate index
|
||||||
|
// variable per variable otherwise everyone uses `index`.
|
||||||
os +=
|
os +=
|
||||||
" IdxT index = cg::this_grid().thread_rank();\n"
|
" IdxT index = cg::this_grid().thread_rank() * work_per_thread;\n"
|
||||||
" if (index >= size) {\n"
|
" if (index >= size) {\n"
|
||||||
" return;\n"
|
" return;\n"
|
||||||
" }\n";
|
" }\n";
|
||||||
|
if (!contiguous) {
|
||||||
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
|
const auto& x = inputs[i];
|
||||||
|
const std::string& xname = namer.get_name(x);
|
||||||
|
if (is_scalar(x) || is_constant(i)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
os += " IdxT " + xname + "_idx = 0;\n";
|
||||||
|
}
|
||||||
|
os += " {\n";
|
||||||
|
os += " IdxT loc = index;\n";
|
||||||
|
os +=
|
||||||
|
" #pragma unroll\n"
|
||||||
|
" for (int i = NDIM - 1; i >= 0; i--) {\n";
|
||||||
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
|
const auto& x = inputs[i];
|
||||||
|
const std::string& xname = namer.get_name(x);
|
||||||
|
if (is_scalar(x) || is_constant(i)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
os += " " + xname + "_idx += (loc \% shape[i]) * IdxT(" + xname +
|
||||||
|
"_strides[i]);\n";
|
||||||
|
}
|
||||||
|
os +=
|
||||||
|
" loc /= shape[i];\n"
|
||||||
|
" }\n"
|
||||||
|
" }\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
// Work loop
|
||||||
|
os +=
|
||||||
|
"\n"
|
||||||
|
" for (int i = 0; i < work_per_thread && index < size; i++) {\n";
|
||||||
|
|
||||||
// Read inputs.
|
// Read inputs.
|
||||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
@@ -89,10 +124,7 @@ struct FusedKernelBuilder {
|
|||||||
} else if (contiguous) {
|
} else if (contiguous) {
|
||||||
value = fmt::format("{}[index]", xname);
|
value = fmt::format("{}[index]", xname);
|
||||||
} else {
|
} else {
|
||||||
std::string index = fmt::format(
|
value = fmt::format("{}[{}_idx]", xname, xname);
|
||||||
"elem_to_loc_nd<NDIM>(index, shape.data(), {}_strides.data())",
|
|
||||||
xname);
|
|
||||||
value = fmt::format("{}[{}]", xname, index);
|
|
||||||
}
|
}
|
||||||
os += fmt::format(" {} tmp_{} = {};\n", type, xname, value);
|
os += fmt::format(" {} tmp_{} = {};\n", type, xname, value);
|
||||||
}
|
}
|
||||||
@@ -121,6 +153,22 @@ struct FusedKernelBuilder {
|
|||||||
os += fmt::format(" {0}[index] = tmp_{0};\n", namer.get_name(x));
|
os += fmt::format(" {0}[index] = tmp_{0};\n", namer.get_name(x));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// End of work loop
|
||||||
|
os +=
|
||||||
|
"\n"
|
||||||
|
" index++;\n";
|
||||||
|
if (!contiguous) {
|
||||||
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
|
const auto& x = inputs[i];
|
||||||
|
const std::string& xname = namer.get_name(x);
|
||||||
|
if (is_scalar(x) || is_constant(i)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
os += " " + xname + "_idx += " + xname + "_strides[NDIM - 1];\n";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
os += " }\n";
|
||||||
|
|
||||||
os += "}\n";
|
os += "}\n";
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -156,15 +204,28 @@ void Compiled::eval_gpu(
|
|||||||
builder.build("_strided", false);
|
builder.build("_strided", false);
|
||||||
builder.os += "\n} // namespace mlx::core::cu\n";
|
builder.os += "\n} // namespace mlx::core::cu\n";
|
||||||
// Build kernel names.
|
// Build kernel names.
|
||||||
std::vector<std::string> kernel_names = {
|
std::vector<std::string> kernel_names;
|
||||||
fmt::format("mlx::core::cu::{}_contiguous<uint32_t>", lib_name()),
|
for (auto work_per_thread : std::array<int, 2>{1, 4}) {
|
||||||
fmt::format("mlx::core::cu::{}_contiguous<int64_t>", lib_name()),
|
kernel_names.push_back(fmt::format(
|
||||||
};
|
"mlx::core::cu::{}_contiguous<uint32_t, {}>",
|
||||||
|
lib_name(),
|
||||||
|
work_per_thread));
|
||||||
|
kernel_names.push_back(fmt::format(
|
||||||
|
"mlx::core::cu::{}_contiguous<int64_t, {}>",
|
||||||
|
lib_name(),
|
||||||
|
work_per_thread));
|
||||||
for (int i = 1; i <= MAX_NDIM; ++i) {
|
for (int i = 1; i <= MAX_NDIM; ++i) {
|
||||||
kernel_names.push_back(fmt::format(
|
kernel_names.push_back(fmt::format(
|
||||||
"mlx::core::cu::{}_strided<{}, uint32_t>", lib_name(), i));
|
"mlx::core::cu::{}_strided<{}, uint32_t, {}>",
|
||||||
kernel_names.push_back(
|
lib_name(),
|
||||||
fmt::format("mlx::core::cu::{}_strided<{}, int64_t>", lib_name(), i));
|
i,
|
||||||
|
work_per_thread));
|
||||||
|
kernel_names.push_back(fmt::format(
|
||||||
|
"mlx::core::cu::{}_strided<{}, int64_t, {}>",
|
||||||
|
lib_name(),
|
||||||
|
i,
|
||||||
|
work_per_thread));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return std::make_pair(std::move(builder.os), std::move(kernel_names));
|
return std::make_pair(std::move(builder.os), std::move(kernel_names));
|
||||||
});
|
});
|
||||||
@@ -207,13 +268,21 @@ void Compiled::eval_gpu(
|
|||||||
args.append<uint32_t>(outputs[0].data_size());
|
args.append<uint32_t>(outputs[0].data_size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Choose work per thread
|
||||||
|
int work_per_thread = 4;
|
||||||
|
if (!contiguous && shape.back() % work_per_thread != 0) {
|
||||||
|
work_per_thread = 1;
|
||||||
|
}
|
||||||
|
|
||||||
// Launch kernel.
|
// Launch kernel.
|
||||||
const char* index_type = large ? "int64_t" : "uint32_t";
|
const char* index_type = large ? "int64_t" : "uint32_t";
|
||||||
std::string kernel_name = fmt::format("mlx::core::cu::{}", lib_name());
|
std::string kernel_name = fmt::format("mlx::core::cu::{}", lib_name());
|
||||||
if (contiguous) {
|
if (contiguous) {
|
||||||
kernel_name += fmt::format("_contiguous<{}>", index_type);
|
kernel_name +=
|
||||||
|
fmt::format("_contiguous<{}, {}>", index_type, work_per_thread);
|
||||||
} else {
|
} else {
|
||||||
kernel_name += fmt::format("_strided<{}, {}>", shape.size(), index_type);
|
kernel_name += fmt::format(
|
||||||
|
"_strided<{}, {}, {}>", shape.size(), index_type, work_per_thread);
|
||||||
}
|
}
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
for (const auto& in : inputs) {
|
for (const auto& in : inputs) {
|
||||||
@@ -224,7 +293,8 @@ void Compiled::eval_gpu(
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto kernel = mod.get_kernel(kernel_name);
|
auto kernel = mod.get_kernel(kernel_name);
|
||||||
auto [num_blocks, block_dims] = get_launch_args(kernel, outputs[0], large);
|
auto [num_blocks, block_dims] =
|
||||||
|
get_launch_args(kernel, outputs[0], large, work_per_thread);
|
||||||
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
|
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -66,7 +66,6 @@ CommandEncoder& Device::get_command_encoder(Stream s) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
|
CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
|
||||||
CHECK_CUDA_ERROR(cudaGraphCreate(&graph, 0));
|
|
||||||
CHECK_CUDA_ERROR(
|
CHECK_CUDA_ERROR(
|
||||||
cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal));
|
cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal));
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -121,7 +121,8 @@ void write_cached_ptx(
|
|||||||
const std::filesystem::path& cache_dir,
|
const std::filesystem::path& cache_dir,
|
||||||
const std::string& module_name,
|
const std::string& module_name,
|
||||||
const std::vector<char>& ptx,
|
const std::vector<char>& ptx,
|
||||||
const std::vector<std::pair<std::string, std::string>>& ptx_kernels) {
|
const std::vector<std::pair<std::string, std::string>>& ptx_kernels,
|
||||||
|
const std::string& source_code) {
|
||||||
if (cache_dir.empty()) {
|
if (cache_dir.empty()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -134,6 +135,9 @@ void write_cached_ptx(
|
|||||||
for (const auto& [name, mangled] : ptx_kernels) {
|
for (const auto& [name, mangled] : ptx_kernels) {
|
||||||
txt_file << name << "\t" << mangled << std::endl;
|
txt_file << name << "\t" << mangled << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::ofstream source_file(cache_dir / (module_name + ".cu"));
|
||||||
|
source_file << source_code;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return if |device|'s version is not newer than |major|.|minor| version.
|
// Return if |device|'s version is not newer than |major|.|minor| version.
|
||||||
@@ -272,7 +276,8 @@ JitModule::JitModule(
|
|||||||
} else {
|
} else {
|
||||||
CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data()));
|
CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data()));
|
||||||
}
|
}
|
||||||
write_cached_ptx(ptx_cache_dir(), module_name, ptx, ptx_kernels);
|
write_cached_ptx(
|
||||||
|
ptx_cache_dir(), module_name, ptx, ptx_kernels, source_code);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load module.
|
// Load module.
|
||||||
|
|||||||
@@ -27,6 +27,35 @@ void check_cublas_error(const char* name, cublasStatus_t err) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct CublasPreference {
|
||||||
|
CublasPreference(Device& device) {
|
||||||
|
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB
|
||||||
|
// for Hopper+:
|
||||||
|
// https://docs.nvidia.com/cuda/cublas/#cublassetworkspace
|
||||||
|
uint64_t MiB = 1024 * 1024;
|
||||||
|
uint64_t workspace_size =
|
||||||
|
device.compute_capability_major() >= 9 ? 32 * MiB : 4 * MiB;
|
||||||
|
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceCreate(&pref_));
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceSetAttribute(
|
||||||
|
pref_,
|
||||||
|
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
|
||||||
|
&workspace_size,
|
||||||
|
sizeof(uint64_t)));
|
||||||
|
}
|
||||||
|
|
||||||
|
~CublasPreference() {
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceDestroy(pref_));
|
||||||
|
}
|
||||||
|
|
||||||
|
cublasLtMatmulPreference_t pref_{nullptr};
|
||||||
|
};
|
||||||
|
|
||||||
|
cublasLtMatmulPreference_t cublas_preference(Device& device) {
|
||||||
|
static CublasPreference pref(device);
|
||||||
|
return pref.pref_;
|
||||||
|
}
|
||||||
|
|
||||||
class MatMul {
|
class MatMul {
|
||||||
public:
|
public:
|
||||||
MatMul(
|
MatMul(
|
||||||
@@ -43,7 +72,7 @@ class MatMul {
|
|||||||
int32_t batch_count,
|
int32_t batch_count,
|
||||||
int64_t a_batch_stride,
|
int64_t a_batch_stride,
|
||||||
int64_t b_batch_stride)
|
int64_t b_batch_stride)
|
||||||
: handle_(device.lt_handle()) {
|
: handle_(device.lt_handle()), pref_(cublas_preference(device)) {
|
||||||
heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED;
|
heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED;
|
||||||
|
|
||||||
auto scale_type = dtype_to_cuda_type(dtype);
|
auto scale_type = dtype_to_cuda_type(dtype);
|
||||||
@@ -77,20 +106,6 @@ class MatMul {
|
|||||||
type, b_rows, b_cols, b_transposed, ldb, batch_count, b_batch_stride);
|
type, b_rows, b_cols, b_transposed, ldb, batch_count, b_batch_stride);
|
||||||
out_desc_ = create_matrix_layout(
|
out_desc_ = create_matrix_layout(
|
||||||
type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols);
|
type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols);
|
||||||
|
|
||||||
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB
|
|
||||||
// for Hopper+:
|
|
||||||
// https://docs.nvidia.com/cuda/cublas/#cublassetworkspace
|
|
||||||
uint64_t MiB = 1024 * 1024;
|
|
||||||
uint64_t workspace_size =
|
|
||||||
device.compute_capability_major() >= 9 ? 32 * MiB : 4 * MiB;
|
|
||||||
|
|
||||||
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceCreate(&pref_));
|
|
||||||
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceSetAttribute(
|
|
||||||
pref_,
|
|
||||||
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
|
|
||||||
&workspace_size,
|
|
||||||
sizeof(uint64_t)));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
MatMul(
|
MatMul(
|
||||||
@@ -130,11 +145,11 @@ class MatMul {
|
|||||||
}
|
}
|
||||||
|
|
||||||
~MatMul() {
|
~MatMul() {
|
||||||
cublasLtMatrixLayoutDestroy(a_desc_);
|
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(a_desc_));
|
||||||
cublasLtMatrixLayoutDestroy(b_desc_);
|
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(b_desc_));
|
||||||
cublasLtMatrixLayoutDestroy(c_desc_);
|
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(c_desc_));
|
||||||
cublasLtMatrixLayoutDestroy(out_desc_);
|
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_));
|
||||||
cublasLtMatmulDescDestroy(matmul_desc_);
|
CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_));
|
||||||
}
|
}
|
||||||
|
|
||||||
void run(
|
void run(
|
||||||
@@ -259,9 +274,9 @@ class MatMul {
|
|||||||
return desc;
|
return desc;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cublasLtMatmulPreference_t pref_{nullptr};
|
||||||
cublasLtHandle_t handle_{nullptr};
|
cublasLtHandle_t handle_{nullptr};
|
||||||
cublasLtMatmulDesc_t matmul_desc_{nullptr};
|
cublasLtMatmulDesc_t matmul_desc_{nullptr};
|
||||||
cublasLtMatmulPreference_t pref_{nullptr};
|
|
||||||
cublasLtMatrixLayout_t a_desc_{nullptr};
|
cublasLtMatrixLayout_t a_desc_{nullptr};
|
||||||
cublasLtMatrixLayout_t b_desc_{nullptr};
|
cublasLtMatrixLayout_t b_desc_{nullptr};
|
||||||
cublasLtMatrixLayout_t c_desc_{nullptr};
|
cublasLtMatrixLayout_t c_desc_{nullptr};
|
||||||
|
|||||||
Reference in New Issue
Block a user