mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
10 Commits
7f39e9c299
...
0a8bb904d7
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0a8bb904d7 | ||
|
|
c535d8c1b5 | ||
|
|
4b3d7634cd | ||
|
|
516d172ba5 | ||
|
|
698daee214 | ||
|
|
4c0f7c713b | ||
|
|
3889c805da | ||
|
|
060404d862 | ||
|
|
fbb3f65a1a | ||
|
|
6b1b8ea91b |
@@ -53,9 +53,10 @@ struct FusedKernelBuilder {
|
||||
|
||||
// Build function signature.
|
||||
if (contiguous) {
|
||||
os += "template <typename IdxT = uint32_t>\n";
|
||||
os += "template <typename IdxT = uint32_t, int work_per_thread = 1>\n";
|
||||
} else {
|
||||
os += "template <int NDIM, typename IdxT = uint32_t>\n";
|
||||
os +=
|
||||
"template <int NDIM, typename IdxT = uint32_t, int work_per_thread = 1>\n";
|
||||
}
|
||||
os += fmt::format("__global__ void {}(\n", kernel_name + name);
|
||||
for (size_t i = 0; i < params.size(); ++i) {
|
||||
@@ -67,12 +68,46 @@ struct FusedKernelBuilder {
|
||||
}
|
||||
os += ") {\n";
|
||||
|
||||
// Index.
|
||||
// Index. For non contiguous kernels we create a separate index
|
||||
// variable per variable otherwise everyone uses `index`.
|
||||
os +=
|
||||
" IdxT index = cg::this_grid().thread_rank();\n"
|
||||
" IdxT index = cg::this_grid().thread_rank() * work_per_thread;\n"
|
||||
" if (index >= size) {\n"
|
||||
" return;\n"
|
||||
" }\n";
|
||||
if (!contiguous) {
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
const auto& x = inputs[i];
|
||||
const std::string& xname = namer.get_name(x);
|
||||
if (is_scalar(x) || is_constant(i)) {
|
||||
continue;
|
||||
}
|
||||
os += " IdxT " + xname + "_idx = 0;\n";
|
||||
}
|
||||
os += " {\n";
|
||||
os += " IdxT loc = index;\n";
|
||||
os +=
|
||||
" #pragma unroll\n"
|
||||
" for (int i = NDIM - 1; i >= 0; i--) {\n";
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
const auto& x = inputs[i];
|
||||
const std::string& xname = namer.get_name(x);
|
||||
if (is_scalar(x) || is_constant(i)) {
|
||||
continue;
|
||||
}
|
||||
os += " " + xname + "_idx += (loc \% shape[i]) * IdxT(" + xname +
|
||||
"_strides[i]);\n";
|
||||
}
|
||||
os +=
|
||||
" loc /= shape[i];\n"
|
||||
" }\n"
|
||||
" }\n";
|
||||
}
|
||||
|
||||
// Work loop
|
||||
os +=
|
||||
"\n"
|
||||
" for (int i = 0; i < work_per_thread && index < size; i++) {\n";
|
||||
|
||||
// Read inputs.
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
@@ -89,12 +124,9 @@ struct FusedKernelBuilder {
|
||||
} else if (contiguous) {
|
||||
value = fmt::format("{}[index]", xname);
|
||||
} else {
|
||||
std::string index = fmt::format(
|
||||
"elem_to_loc_nd<NDIM>(index, shape.data(), {}_strides.data())",
|
||||
xname);
|
||||
value = fmt::format("{}[{}]", xname, index);
|
||||
value = fmt::format("{}[{}_idx]", xname, xname);
|
||||
}
|
||||
os += fmt::format(" {} tmp_{} = {};\n", type, xname, value);
|
||||
os += fmt::format(" {} tmp_{} = {};\n", type, xname, value);
|
||||
}
|
||||
|
||||
// Write tape.
|
||||
@@ -113,14 +145,30 @@ struct FusedKernelBuilder {
|
||||
}
|
||||
value += fmt::format("tmp_{})", namer.get_name(x.inputs().back()));
|
||||
}
|
||||
os += fmt::format(" {} tmp_{} = {};\n", type, xname, value);
|
||||
os += fmt::format(" {} tmp_{} = {};\n", type, xname, value);
|
||||
}
|
||||
|
||||
// Write output.
|
||||
for (const auto& x : outputs) {
|
||||
os += fmt::format(" {0}[index] = tmp_{0};\n", namer.get_name(x));
|
||||
os += fmt::format(" {0}[index] = tmp_{0};\n", namer.get_name(x));
|
||||
}
|
||||
|
||||
// End of work loop
|
||||
os +=
|
||||
"\n"
|
||||
" index++;\n";
|
||||
if (!contiguous) {
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
const auto& x = inputs[i];
|
||||
const std::string& xname = namer.get_name(x);
|
||||
if (is_scalar(x) || is_constant(i)) {
|
||||
continue;
|
||||
}
|
||||
os += " " + xname + "_idx += " + xname + "_strides[NDIM - 1];\n";
|
||||
}
|
||||
}
|
||||
os += " }\n";
|
||||
|
||||
os += "}\n";
|
||||
}
|
||||
};
|
||||
@@ -156,15 +204,28 @@ void Compiled::eval_gpu(
|
||||
builder.build("_strided", false);
|
||||
builder.os += "\n} // namespace mlx::core::cu\n";
|
||||
// Build kernel names.
|
||||
std::vector<std::string> kernel_names = {
|
||||
fmt::format("mlx::core::cu::{}_contiguous<uint32_t>", lib_name()),
|
||||
fmt::format("mlx::core::cu::{}_contiguous<int64_t>", lib_name()),
|
||||
};
|
||||
for (int i = 1; i <= MAX_NDIM; ++i) {
|
||||
std::vector<std::string> kernel_names;
|
||||
for (auto work_per_thread : std::array<int, 2>{1, 4}) {
|
||||
kernel_names.push_back(fmt::format(
|
||||
"mlx::core::cu::{}_strided<{}, uint32_t>", lib_name(), i));
|
||||
kernel_names.push_back(
|
||||
fmt::format("mlx::core::cu::{}_strided<{}, int64_t>", lib_name(), i));
|
||||
"mlx::core::cu::{}_contiguous<uint32_t, {}>",
|
||||
lib_name(),
|
||||
work_per_thread));
|
||||
kernel_names.push_back(fmt::format(
|
||||
"mlx::core::cu::{}_contiguous<int64_t, {}>",
|
||||
lib_name(),
|
||||
work_per_thread));
|
||||
for (int i = 1; i <= MAX_NDIM; ++i) {
|
||||
kernel_names.push_back(fmt::format(
|
||||
"mlx::core::cu::{}_strided<{}, uint32_t, {}>",
|
||||
lib_name(),
|
||||
i,
|
||||
work_per_thread));
|
||||
kernel_names.push_back(fmt::format(
|
||||
"mlx::core::cu::{}_strided<{}, int64_t, {}>",
|
||||
lib_name(),
|
||||
i,
|
||||
work_per_thread));
|
||||
}
|
||||
}
|
||||
return std::make_pair(std::move(builder.os), std::move(kernel_names));
|
||||
});
|
||||
@@ -207,13 +268,21 @@ void Compiled::eval_gpu(
|
||||
args.append<uint32_t>(outputs[0].data_size());
|
||||
}
|
||||
|
||||
// Choose work per thread
|
||||
int work_per_thread = 4;
|
||||
if (!contiguous && shape.back() % work_per_thread != 0) {
|
||||
work_per_thread = 1;
|
||||
}
|
||||
|
||||
// Launch kernel.
|
||||
const char* index_type = large ? "int64_t" : "uint32_t";
|
||||
std::string kernel_name = fmt::format("mlx::core::cu::{}", lib_name());
|
||||
if (contiguous) {
|
||||
kernel_name += fmt::format("_contiguous<{}>", index_type);
|
||||
kernel_name +=
|
||||
fmt::format("_contiguous<{}, {}>", index_type, work_per_thread);
|
||||
} else {
|
||||
kernel_name += fmt::format("_strided<{}, {}>", shape.size(), index_type);
|
||||
kernel_name += fmt::format(
|
||||
"_strided<{}, {}, {}>", shape.size(), index_type, work_per_thread);
|
||||
}
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
for (const auto& in : inputs) {
|
||||
@@ -224,7 +293,8 @@ void Compiled::eval_gpu(
|
||||
}
|
||||
|
||||
auto kernel = mod.get_kernel(kernel_name);
|
||||
auto [num_blocks, block_dims] = get_launch_args(kernel, outputs[0], large);
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, outputs[0], large, work_per_thread);
|
||||
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
|
||||
}
|
||||
|
||||
|
||||
@@ -66,7 +66,6 @@ CommandEncoder& Device::get_command_encoder(Stream s) {
|
||||
}
|
||||
|
||||
CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
|
||||
CHECK_CUDA_ERROR(cudaGraphCreate(&graph, 0));
|
||||
CHECK_CUDA_ERROR(
|
||||
cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal));
|
||||
}
|
||||
|
||||
@@ -121,7 +121,8 @@ void write_cached_ptx(
|
||||
const std::filesystem::path& cache_dir,
|
||||
const std::string& module_name,
|
||||
const std::vector<char>& ptx,
|
||||
const std::vector<std::pair<std::string, std::string>>& ptx_kernels) {
|
||||
const std::vector<std::pair<std::string, std::string>>& ptx_kernels,
|
||||
const std::string& source_code) {
|
||||
if (cache_dir.empty()) {
|
||||
return;
|
||||
}
|
||||
@@ -134,6 +135,9 @@ void write_cached_ptx(
|
||||
for (const auto& [name, mangled] : ptx_kernels) {
|
||||
txt_file << name << "\t" << mangled << std::endl;
|
||||
}
|
||||
|
||||
std::ofstream source_file(cache_dir / (module_name + ".cu"));
|
||||
source_file << source_code;
|
||||
}
|
||||
|
||||
// Return if |device|'s version is not newer than |major|.|minor| version.
|
||||
@@ -272,7 +276,8 @@ JitModule::JitModule(
|
||||
} else {
|
||||
CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data()));
|
||||
}
|
||||
write_cached_ptx(ptx_cache_dir(), module_name, ptx, ptx_kernels);
|
||||
write_cached_ptx(
|
||||
ptx_cache_dir(), module_name, ptx, ptx_kernels, source_code);
|
||||
}
|
||||
|
||||
// Load module.
|
||||
|
||||
@@ -27,6 +27,35 @@ void check_cublas_error(const char* name, cublasStatus_t err) {
|
||||
}
|
||||
}
|
||||
|
||||
struct CublasPreference {
|
||||
CublasPreference(Device& device) {
|
||||
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB
|
||||
// for Hopper+:
|
||||
// https://docs.nvidia.com/cuda/cublas/#cublassetworkspace
|
||||
uint64_t MiB = 1024 * 1024;
|
||||
uint64_t workspace_size =
|
||||
device.compute_capability_major() >= 9 ? 32 * MiB : 4 * MiB;
|
||||
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceCreate(&pref_));
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceSetAttribute(
|
||||
pref_,
|
||||
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
|
||||
&workspace_size,
|
||||
sizeof(uint64_t)));
|
||||
}
|
||||
|
||||
~CublasPreference() {
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceDestroy(pref_));
|
||||
}
|
||||
|
||||
cublasLtMatmulPreference_t pref_{nullptr};
|
||||
};
|
||||
|
||||
cublasLtMatmulPreference_t cublas_preference(Device& device) {
|
||||
static CublasPreference pref(device);
|
||||
return pref.pref_;
|
||||
}
|
||||
|
||||
class MatMul {
|
||||
public:
|
||||
MatMul(
|
||||
@@ -43,7 +72,7 @@ class MatMul {
|
||||
int32_t batch_count,
|
||||
int64_t a_batch_stride,
|
||||
int64_t b_batch_stride)
|
||||
: handle_(device.lt_handle()) {
|
||||
: handle_(device.lt_handle()), pref_(cublas_preference(device)) {
|
||||
heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED;
|
||||
|
||||
auto scale_type = dtype_to_cuda_type(dtype);
|
||||
@@ -77,20 +106,6 @@ class MatMul {
|
||||
type, b_rows, b_cols, b_transposed, ldb, batch_count, b_batch_stride);
|
||||
out_desc_ = create_matrix_layout(
|
||||
type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols);
|
||||
|
||||
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB
|
||||
// for Hopper+:
|
||||
// https://docs.nvidia.com/cuda/cublas/#cublassetworkspace
|
||||
uint64_t MiB = 1024 * 1024;
|
||||
uint64_t workspace_size =
|
||||
device.compute_capability_major() >= 9 ? 32 * MiB : 4 * MiB;
|
||||
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceCreate(&pref_));
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceSetAttribute(
|
||||
pref_,
|
||||
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
|
||||
&workspace_size,
|
||||
sizeof(uint64_t)));
|
||||
}
|
||||
|
||||
MatMul(
|
||||
@@ -130,11 +145,11 @@ class MatMul {
|
||||
}
|
||||
|
||||
~MatMul() {
|
||||
cublasLtMatrixLayoutDestroy(a_desc_);
|
||||
cublasLtMatrixLayoutDestroy(b_desc_);
|
||||
cublasLtMatrixLayoutDestroy(c_desc_);
|
||||
cublasLtMatrixLayoutDestroy(out_desc_);
|
||||
cublasLtMatmulDescDestroy(matmul_desc_);
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(a_desc_));
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(b_desc_));
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(c_desc_));
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_));
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_));
|
||||
}
|
||||
|
||||
void run(
|
||||
@@ -259,9 +274,9 @@ class MatMul {
|
||||
return desc;
|
||||
}
|
||||
|
||||
cublasLtMatmulPreference_t pref_{nullptr};
|
||||
cublasLtHandle_t handle_{nullptr};
|
||||
cublasLtMatmulDesc_t matmul_desc_{nullptr};
|
||||
cublasLtMatmulPreference_t pref_{nullptr};
|
||||
cublasLtMatrixLayout_t a_desc_{nullptr};
|
||||
cublasLtMatrixLayout_t b_desc_{nullptr};
|
||||
cublasLtMatrixLayout_t c_desc_{nullptr};
|
||||
|
||||
@@ -893,24 +893,23 @@ class Muon(Optimizer):
|
||||
"""Initialize optimizer state"""
|
||||
state["v"] = mx.zeros_like(parameter)
|
||||
|
||||
def _zeropower_via_newtonschulz5(self, G, steps: int):
|
||||
assert G.ndim >= 2
|
||||
def _zeropower_via_newtonschulz5(self, X, steps: int):
|
||||
assert (
|
||||
X.ndim == 2
|
||||
), f"Expected a 2D array for Newton-Schulz iteration, got shape {X.shape} instead."
|
||||
a, b, c = (3.4445, -4.7750, 2.0315)
|
||||
X = G.astype(mx.bfloat16)
|
||||
transpose_needed = G.shape[-2] > G.shape[-1]
|
||||
transpose_needed = X.shape[-2] > X.shape[-1]
|
||||
|
||||
if transpose_needed:
|
||||
X = X.T
|
||||
|
||||
# Ensure spectral norm is at most 1
|
||||
norm = mx.sqrt(mx.sum(X * X, axis=(-2, -1), keepdims=True) + 1e-7)
|
||||
norm = mx.sqrt(mx.sum(mx.square(X), keepdims=True) + 1e-7)
|
||||
X = X / norm
|
||||
|
||||
# Perform the NS iterations
|
||||
for _ in range(steps):
|
||||
A = X @ X.T
|
||||
B = b * A + c * (A @ A)
|
||||
X = a * X + B @ X
|
||||
B = mx.addmm(b * A, A, A, beta=1.0, alpha=c)
|
||||
X = mx.addmm(a * X, B, X, beta=1.0, alpha=1.0)
|
||||
|
||||
if transpose_needed:
|
||||
X = X.T
|
||||
@@ -919,56 +918,35 @@ class Muon(Optimizer):
|
||||
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
|
||||
"""Performs the Muon parameter update"""
|
||||
|
||||
# Apply weight decay
|
||||
if self.weight_decay != 0:
|
||||
gradient = gradient + self.weight_decay * parameter
|
||||
|
||||
# Update momentum buffer
|
||||
v = self.momentum * state["v"]
|
||||
v = v + (1 - self.momentum) * gradient
|
||||
state["v"] = v
|
||||
|
||||
# Get effective gradient
|
||||
if self.nesterov:
|
||||
effective_grad = gradient * (1 - self.momentum) + v * self.momentum
|
||||
update = gradient * (1 - self.momentum) + v * self.momentum
|
||||
else:
|
||||
effective_grad = v
|
||||
update = v
|
||||
|
||||
# For tensors with fewer than 2 dimensions, skip Newton-Schulz
|
||||
if effective_grad.ndim < 2:
|
||||
orthogonalized_grad = effective_grad
|
||||
scale_factor = 1.0
|
||||
else:
|
||||
# Save original shape for 4D conv filters
|
||||
original_shape = effective_grad.shape
|
||||
reshape_needed = effective_grad.ndim > 2
|
||||
lr = self.learning_rate.astype(gradient.dtype)
|
||||
|
||||
if update.ndim >= 2:
|
||||
original_shape = update.shape
|
||||
reshape_needed = update.ndim > 2
|
||||
|
||||
if reshape_needed:
|
||||
effective_grad = mx.reshape(
|
||||
effective_grad, (effective_grad.shape[0], -1)
|
||||
)
|
||||
update = mx.reshape(update, (update.shape[0], -1))
|
||||
|
||||
# Apply Newton-Schulz orthogonalization
|
||||
orthogonalized_grad = self._zeropower_via_newtonschulz5(
|
||||
effective_grad, steps=self.ns_steps
|
||||
)
|
||||
update = self._zeropower_via_newtonschulz5(update, steps=self.ns_steps)
|
||||
|
||||
# Reshape back if needed
|
||||
if reshape_needed:
|
||||
orthogonalized_grad = mx.reshape(orthogonalized_grad, original_shape)
|
||||
update = mx.reshape(update, original_shape)
|
||||
|
||||
# Calculate scaling factor
|
||||
# scale_factor = max(1, parameter.shape[-2] / parameter.shape[-1]) ** 0.5
|
||||
scale_factor = (
|
||||
max(1, effective_grad.shape[-2] / effective_grad.shape[-1]) ** 0.5
|
||||
)
|
||||
lr *= max(1, update.shape[-2] / update.shape[-1]) ** 0.5
|
||||
|
||||
return (
|
||||
parameter
|
||||
- self.learning_rate.astype(gradient.dtype)
|
||||
* orthogonalized_grad
|
||||
* scale_factor
|
||||
)
|
||||
return parameter - lr * update
|
||||
|
||||
|
||||
def clip_grad_norm(grads, max_norm):
|
||||
|
||||
Reference in New Issue
Block a user