Compare commits

..

14 Commits

Author SHA1 Message Date
Awni Hannun
7f39e9c299 nits 2025-07-17 06:26:43 -07:00
Gökdeniz Gülmez
baad6e392b Merge branch 'ml-explore:main' into adding-Muon-optimizer 2025-07-17 13:07:54 +02:00
Gökdeniz Gülmez
784e0716fe Merge branch 'ml-explore:main' into adding-Muon-optimizer 2025-07-16 21:58:17 +02:00
Goekdeniz-Guelmez
df6d9e972f nits and adding it to test 2025-07-16 19:13:40 +02:00
Gökdeniz Gülmez
650c956fe6 Merge branch 'ml-explore:main' into adding-Muon-optimizer 2025-07-16 16:29:10 +02:00
Gökdeniz Gülmez
d3d575cce7 Merge branch 'ml-explore:main' into adding-Muon-optimizer 2025-04-21 20:27:33 +02:00
Gökdeniz Gülmez
8f2744dcf3 Merge branch 'ml-explore:main' into adding-Muon-optimizer 2025-03-21 08:50:43 +01:00
Gökdeniz Gülmez
b12be4b7e0 Merge branch 'ml-explore:main' into adding-Muon-optimizer 2025-03-12 16:52:21 +01:00
Gökdeniz Gülmez
ebfcb4a14f Merge branch 'ml-explore:main' into adding-Muon-optimizer 2025-03-10 17:10:50 +01:00
Gökdeniz Gülmez
79175a1f35 Merge branch 'ml-explore:main' into adding-Muon-optimizer 2025-03-07 11:41:19 +01:00
Gökdeniz Gülmez
59d4e4f61d Merge branch 'ml-explore:main' into adding-Muon-optimizer 2025-03-05 23:09:44 +01:00
Gökdeniz Gülmez
44f776921c Merge branch 'ml-explore:main' into adding-Muon-optimizer 2025-03-05 10:05:10 +01:00
Goekdeniz-Guelmez
871ee2b9b0 update ACKNOWLEDGMENTS.md 2025-02-28 23:24:39 +01:00
Goekdeniz-Guelmez
6c048ab4da initial commit with workong optmimizer 2025-02-28 23:16:51 +01:00
28 changed files with 173 additions and 267 deletions

View File

@@ -272,7 +272,6 @@ jobs:
name: Build Python package
command: |
source env/bin/activate
python setup.py clean --all
<< parameters.build_env >> MLX_BUILD_STAGE=1 python -m build -w
- when:
condition:
@@ -334,7 +333,6 @@ jobs:
<< parameters.build_env >> pip install ".[dev]" -v
pip install typing_extensions
python setup.py generate_stubs
python setup.py clean --all
MLX_BUILD_STAGE=1 << parameters.build_env >> python -m build -w
bash python/scripts/repair_linux.sh
- when:

View File

@@ -53,10 +53,9 @@ struct FusedKernelBuilder {
// Build function signature.
if (contiguous) {
os += "template <typename IdxT = uint32_t, int work_per_thread = 1>\n";
os += "template <typename IdxT = uint32_t>\n";
} else {
os +=
"template <int NDIM, typename IdxT = uint32_t, int work_per_thread = 1>\n";
os += "template <int NDIM, typename IdxT = uint32_t>\n";
}
os += fmt::format("__global__ void {}(\n", kernel_name + name);
for (size_t i = 0; i < params.size(); ++i) {
@@ -68,46 +67,12 @@ struct FusedKernelBuilder {
}
os += ") {\n";
// Index. For non contiguous kernels we create a separate index
// variable per variable otherwise everyone uses `index`.
// Index.
os +=
" IdxT index = cg::this_grid().thread_rank() * work_per_thread;\n"
" IdxT index = cg::this_grid().thread_rank();\n"
" if (index >= size) {\n"
" return;\n"
" }\n";
if (!contiguous) {
for (size_t i = 0; i < inputs.size(); ++i) {
const auto& x = inputs[i];
const std::string& xname = namer.get_name(x);
if (is_scalar(x) || is_constant(i)) {
continue;
}
os += " IdxT " + xname + "_idx = 0;\n";
}
os += " {\n";
os += " IdxT loc = index;\n";
os +=
" #pragma unroll\n"
" for (int i = NDIM - 1; i >= 0; i--) {\n";
for (size_t i = 0; i < inputs.size(); ++i) {
const auto& x = inputs[i];
const std::string& xname = namer.get_name(x);
if (is_scalar(x) || is_constant(i)) {
continue;
}
os += " " + xname + "_idx += (loc \% shape[i]) * IdxT(" + xname +
"_strides[i]);\n";
}
os +=
" loc /= shape[i];\n"
" }\n"
" }\n";
}
// Work loop
os +=
"\n"
" for (int i = 0; i < work_per_thread && index < size; i++) {\n";
// Read inputs.
for (size_t i = 0; i < inputs.size(); ++i) {
@@ -124,9 +89,12 @@ struct FusedKernelBuilder {
} else if (contiguous) {
value = fmt::format("{}[index]", xname);
} else {
value = fmt::format("{}[{}_idx]", xname, xname);
std::string index = fmt::format(
"elem_to_loc_nd<NDIM>(index, shape.data(), {}_strides.data())",
xname);
value = fmt::format("{}[{}]", xname, index);
}
os += fmt::format(" {} tmp_{} = {};\n", type, xname, value);
os += fmt::format(" {} tmp_{} = {};\n", type, xname, value);
}
// Write tape.
@@ -145,30 +113,14 @@ struct FusedKernelBuilder {
}
value += fmt::format("tmp_{})", namer.get_name(x.inputs().back()));
}
os += fmt::format(" {} tmp_{} = {};\n", type, xname, value);
os += fmt::format(" {} tmp_{} = {};\n", type, xname, value);
}
// Write output.
for (const auto& x : outputs) {
os += fmt::format(" {0}[index] = tmp_{0};\n", namer.get_name(x));
os += fmt::format(" {0}[index] = tmp_{0};\n", namer.get_name(x));
}
// End of work loop
os +=
"\n"
" index++;\n";
if (!contiguous) {
for (size_t i = 0; i < inputs.size(); ++i) {
const auto& x = inputs[i];
const std::string& xname = namer.get_name(x);
if (is_scalar(x) || is_constant(i)) {
continue;
}
os += " " + xname + "_idx += " + xname + "_strides[NDIM - 1];\n";
}
}
os += " }\n";
os += "}\n";
}
};
@@ -204,28 +156,15 @@ void Compiled::eval_gpu(
builder.build("_strided", false);
builder.os += "\n} // namespace mlx::core::cu\n";
// Build kernel names.
std::vector<std::string> kernel_names;
for (auto work_per_thread : std::array<int, 2>{1, 4}) {
std::vector<std::string> kernel_names = {
fmt::format("mlx::core::cu::{}_contiguous<uint32_t>", lib_name()),
fmt::format("mlx::core::cu::{}_contiguous<int64_t>", lib_name()),
};
for (int i = 1; i <= MAX_NDIM; ++i) {
kernel_names.push_back(fmt::format(
"mlx::core::cu::{}_contiguous<uint32_t, {}>",
lib_name(),
work_per_thread));
kernel_names.push_back(fmt::format(
"mlx::core::cu::{}_contiguous<int64_t, {}>",
lib_name(),
work_per_thread));
for (int i = 1; i <= MAX_NDIM; ++i) {
kernel_names.push_back(fmt::format(
"mlx::core::cu::{}_strided<{}, uint32_t, {}>",
lib_name(),
i,
work_per_thread));
kernel_names.push_back(fmt::format(
"mlx::core::cu::{}_strided<{}, int64_t, {}>",
lib_name(),
i,
work_per_thread));
}
"mlx::core::cu::{}_strided<{}, uint32_t>", lib_name(), i));
kernel_names.push_back(
fmt::format("mlx::core::cu::{}_strided<{}, int64_t>", lib_name(), i));
}
return std::make_pair(std::move(builder.os), std::move(kernel_names));
});
@@ -268,21 +207,13 @@ void Compiled::eval_gpu(
args.append<uint32_t>(outputs[0].data_size());
}
// Choose work per thread
int work_per_thread = 4;
if (!contiguous && shape.back() % work_per_thread != 0) {
work_per_thread = 1;
}
// Launch kernel.
const char* index_type = large ? "int64_t" : "uint32_t";
std::string kernel_name = fmt::format("mlx::core::cu::{}", lib_name());
if (contiguous) {
kernel_name +=
fmt::format("_contiguous<{}, {}>", index_type, work_per_thread);
kernel_name += fmt::format("_contiguous<{}>", index_type);
} else {
kernel_name += fmt::format(
"_strided<{}, {}, {}>", shape.size(), index_type, work_per_thread);
kernel_name += fmt::format("_strided<{}, {}>", shape.size(), index_type);
}
auto& encoder = cu::get_command_encoder(s);
for (const auto& in : inputs) {
@@ -293,8 +224,7 @@ void Compiled::eval_gpu(
}
auto kernel = mod.get_kernel(kernel_name);
auto [num_blocks, block_dims] =
get_launch_args(kernel, outputs[0], large, work_per_thread);
auto [num_blocks, block_dims] = get_launch_args(kernel, outputs[0], large);
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
}

View File

@@ -66,6 +66,7 @@ CommandEncoder& Device::get_command_encoder(Stream s) {
}
CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
CHECK_CUDA_ERROR(cudaGraphCreate(&graph, 0));
CHECK_CUDA_ERROR(
cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal));
}

View File

@@ -52,29 +52,13 @@ const std::string& cuda_home() {
}
// Return the location of CCCL headers shipped with the distribution.
const std::string& cccl_dir() {
static std::string dir = []() {
std::filesystem::path path;
#if defined(MLX_CCCL_DIR)
// First search the install dir if defined.
path = MLX_CCCL_DIR;
if (std::filesystem::exists(path)) {
return path.string();
}
#endif
// Then search dynamically from the dir of libmlx.so file.
path = current_binary_dir().parent_path() / "include" / "cccl";
if (std::filesystem::exists(path)) {
return path.string();
}
// Finally check the environment variable.
path = std::getenv("MLX_CCCL_DIR");
if (!path.empty() && std::filesystem::exists(path)) {
return path.string();
}
return std::string();
}();
return dir;
bool get_cccl_include(std::string* out) {
auto cccl_headers = current_binary_dir().parent_path() / "include" / "cccl";
if (!std::filesystem::exists(cccl_headers)) {
return false;
}
*out = fmt::format("--include-path={}", cccl_headers.string());
return true;
}
// Get the cache directory for storing compiled results.
@@ -137,8 +121,7 @@ void write_cached_ptx(
const std::filesystem::path& cache_dir,
const std::string& module_name,
const std::vector<char>& ptx,
const std::vector<std::pair<std::string, std::string>>& ptx_kernels,
const std::string& source_code) {
const std::vector<std::pair<std::string, std::string>>& ptx_kernels) {
if (cache_dir.empty()) {
return;
}
@@ -151,9 +134,6 @@ void write_cached_ptx(
for (const auto& [name, mangled] : ptx_kernels) {
txt_file << name << "\t" << mangled << std::endl;
}
std::ofstream source_file(cache_dir / (module_name + ".cu"));
source_file << source_code;
}
// Return if |device|'s version is not newer than |major|.|minor| version.
@@ -254,9 +234,8 @@ JitModule::JitModule(
device.compute_capability_major(),
device.compute_capability_minor());
args.push_back(compute.c_str());
std::string cccl_include = cccl_dir();
if (!cccl_include.empty()) {
cccl_include = fmt::format("--include-path={}", cccl_include);
std::string cccl_include;
if (get_cccl_include(&cccl_include)) {
args.push_back(cccl_include.c_str());
}
std::string cuda_include =
@@ -293,8 +272,7 @@ JitModule::JitModule(
} else {
CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data()));
}
write_cached_ptx(
ptx_cache_dir(), module_name, ptx, ptx_kernels, source_code);
write_cached_ptx(ptx_cache_dir(), module_name, ptx, ptx_kernels);
}
// Load module.

View File

@@ -237,7 +237,8 @@ void LayerNorm::eval_gpu(
}
return x;
} else {
array x_copy = contiguous_copy_gpu(x, s);
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
out.copy_shared_buffer(x_copy);
return x_copy;
}
@@ -294,7 +295,9 @@ void LayerNormVJP::eval_gpu(
return x;
}
copied = true;
return contiguous_copy_gpu(x, s);
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
return x_copy;
};
bool donate_x = inputs[0].is_donatable();
bool donate_g = inputs[3].is_donatable();

View File

@@ -108,7 +108,8 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
return x;
} else {
array x_copy = contiguous_copy_gpu(x, s);
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
encoder.add_temporary(x_copy);
return x_copy;
}

View File

@@ -27,35 +27,6 @@ void check_cublas_error(const char* name, cublasStatus_t err) {
}
}
struct CublasPreference {
CublasPreference(Device& device) {
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB
// for Hopper+:
// https://docs.nvidia.com/cuda/cublas/#cublassetworkspace
uint64_t MiB = 1024 * 1024;
uint64_t workspace_size =
device.compute_capability_major() >= 9 ? 32 * MiB : 4 * MiB;
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceCreate(&pref_));
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceSetAttribute(
pref_,
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&workspace_size,
sizeof(uint64_t)));
}
~CublasPreference() {
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceDestroy(pref_));
}
cublasLtMatmulPreference_t pref_{nullptr};
};
cublasLtMatmulPreference_t cublas_preference(Device& device) {
static CublasPreference pref(device);
return pref.pref_;
}
class MatMul {
public:
MatMul(
@@ -72,7 +43,7 @@ class MatMul {
int32_t batch_count,
int64_t a_batch_stride,
int64_t b_batch_stride)
: handle_(device.lt_handle()), pref_(cublas_preference(device)) {
: handle_(device.lt_handle()) {
heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED;
auto scale_type = dtype_to_cuda_type(dtype);
@@ -106,6 +77,20 @@ class MatMul {
type, b_rows, b_cols, b_transposed, ldb, batch_count, b_batch_stride);
out_desc_ = create_matrix_layout(
type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols);
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB
// for Hopper+:
// https://docs.nvidia.com/cuda/cublas/#cublassetworkspace
uint64_t MiB = 1024 * 1024;
uint64_t workspace_size =
device.compute_capability_major() >= 9 ? 32 * MiB : 4 * MiB;
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceCreate(&pref_));
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceSetAttribute(
pref_,
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&workspace_size,
sizeof(uint64_t)));
}
MatMul(
@@ -119,6 +104,7 @@ class MatMul {
uint64_t b_rows,
uint64_t b_cols,
int64_t ldb,
bool c_transposed,
int64_t ldc,
int32_t batch_count,
int64_t a_batch_stride,
@@ -140,15 +126,15 @@ class MatMul {
b_batch_stride) {
auto type = dtype_to_cuda_type(dtype);
c_desc_ = create_matrix_layout(
type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride);
type, a_rows, b_cols, c_transposed, ldc, batch_count, c_batch_stride);
}
~MatMul() {
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(a_desc_));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(b_desc_));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(c_desc_));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_));
CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_));
cublasLtMatrixLayoutDestroy(a_desc_);
cublasLtMatrixLayoutDestroy(b_desc_);
cublasLtMatrixLayoutDestroy(c_desc_);
cublasLtMatrixLayoutDestroy(out_desc_);
cublasLtMatmulDescDestroy(matmul_desc_);
}
void run(
@@ -273,9 +259,9 @@ class MatMul {
return desc;
}
cublasLtMatmulPreference_t pref_{nullptr};
cublasLtHandle_t handle_{nullptr};
cublasLtMatmulDesc_t matmul_desc_{nullptr};
cublasLtMatmulPreference_t pref_{nullptr};
cublasLtMatrixLayout_t a_desc_{nullptr};
cublasLtMatrixLayout_t b_desc_{nullptr};
cublasLtMatrixLayout_t c_desc_{nullptr};
@@ -296,7 +282,8 @@ check_transpose(cu::CommandEncoder& enc, const Stream& s, const array& arr) {
} else if (stx == 1 && sty == arr.shape(-2)) {
return std::make_tuple(true, sty, arr);
} else {
array arr_copy = contiguous_copy_gpu(arr, s);
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
enc.add_temporary(arr_copy);
return std::make_tuple(false, arr.shape(-1), arr_copy);
}
@@ -402,7 +389,9 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 3);
auto& a_pre = inputs[0];
auto& b_pre = inputs[1];
auto c = inputs[2];
auto& c_pre = inputs[2];
out.set_data(allocator::malloc(out.nbytes()));
/////////////////////////////////////////////////////////////////////////////
// Init checks and prep
@@ -415,24 +404,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
// the arrays
auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);
auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);
int64_t ldc;
{
auto stx = c.strides()[c.ndim() - 2];
auto sty = c.strides()[c.ndim() - 1];
if (sty == 1 && stx == c.shape(-1)) {
ldc = stx;
out.set_data(allocator::malloc(out.nbytes()));
} else if (sty == 1 && stx == 0) {
ldc = 0;
out.set_data(allocator::malloc(out.nbytes()));
} else {
// Copy C into out and set C to out
ldc = c.shape(-1);
copy_gpu(c, out, CopyType::General, s);
c = out;
}
}
auto [c_transposed, ldc, c] = check_transpose(encoder, s, c_pre);
/////////////////////////////////////////////////////////////////////////////
// Check and collapse batch dimensions
@@ -470,6 +442,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
K,
N,
ldb,
c_transposed,
ldc,
batch_shape.back(),
a_batch_strides.back(),

View File

@@ -247,7 +247,8 @@ inline array ensure_row_contiguous(
cu::CommandEncoder& enc,
const Stream& s) {
if (!x.flags().row_contiguous) {
array x_copy = contiguous_copy_gpu(x, s);
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
enc.add_temporary(x_copy);
return x_copy;
} else {

View File

@@ -47,7 +47,8 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
}
}
if (plan.type == GeneralReduce || broadcasted || !in.flags().contiguous) {
array in_copy = contiguous_copy_gpu(in, s);
array in_copy(in.shape(), in.dtype(), nullptr, {});
copy_gpu(in, in_copy, CopyType::General, s);
encoder.add_temporary(in_copy);
in = in_copy;
plan = get_reduction_plan(in, axes_);

View File

@@ -206,7 +206,8 @@ void RMSNorm::eval_gpu(
}
return x;
} else {
array x_copy = contiguous_copy_gpu(x, s);
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
out.copy_shared_buffer(x_copy);
return x_copy;
}
@@ -258,7 +259,9 @@ void RMSNormVJP::eval_gpu(
return x;
}
copied = true;
return contiguous_copy_gpu(x, s);
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
return x_copy;
};
bool donate_x = inputs[0].is_donatable();
bool donate_g = inputs[2].is_donatable();

View File

@@ -379,7 +379,9 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
in.flags());
}
} else {
in = contiguous_copy_gpu(in, s);
array arr_copy(in.shape(), in.dtype(), nullptr, {});
copy_gpu(in, arr_copy, CopyType::General, s);
in = std::move(arr_copy);
out.copy_shared_buffer(in);
}

View File

@@ -125,7 +125,8 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
}
return x;
} else {
array x_copy = contiguous_copy_gpu(x, s);
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
out.copy_shared_buffer(x_copy);
return x_copy;
}

View File

@@ -72,7 +72,8 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
bool is_segmented_sort = in.flags().contiguous && in.strides()[axis] == 1;
if (!is_segmented_sort) {
array trans = swapaxes_in_eval(in, axis, last_dim);
in = contiguous_copy_gpu(trans, s);
in = array(trans.shape(), trans.dtype(), nullptr, {});
copy_gpu(trans, in, CopyType::General, s);
encoder.add_temporary(in);
out = array(allocator::malloc(out.nbytes()), in.shape(), out.dtype());
encoder.add_temporary(out);

View File

@@ -46,10 +46,4 @@ void copy_gpu_inplace(
in, out, in.shape(), i_strides, out.strides(), i_offset, 0, ctype, s);
}
array contiguous_copy_gpu(const array& arr, const Stream& s) {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
return arr_copy;
}
} // namespace mlx::core

View File

@@ -43,7 +43,4 @@ void copy_gpu_inplace(
// Fill the output with the scalar val
void fill_gpu(const array& val, array& out, const Stream& s);
// Return a contiguous array with same shape that copies the data of |arr|.
array contiguous_copy_gpu(const array& arr, const Stream& s);
} // namespace mlx::core

View File

@@ -149,7 +149,8 @@ void explicit_gemm_conv_group_ND_gpu(
wt, {wt.strides(0), 1, C_per_group}, wt.flags(), wt.size());
// Materialize
array wt_transpose = contiguous_copy_gpu(wt_view, s);
auto wt_transpose = array(wt_view.shape(), wt_view.dtype(), nullptr, {});
copy_gpu(wt_view, wt_transpose, CopyType::General, s);
// Perform gemm
std::vector<array> copies = {in_unfolded, wt_transpose};
@@ -960,12 +961,16 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
auto in = inputs[0];
auto wt = inputs[1];
if (!in.flags().row_contiguous) {
in = contiguous_copy_gpu(in, s);
copies.push_back(in);
array arr_copy(in.shape(), in.dtype(), nullptr, {});
copy_gpu(in, arr_copy, CopyType::General, s);
copies.push_back(arr_copy);
in = arr_copy;
}
if (!wt.flags().row_contiguous) {
wt = contiguous_copy_gpu(wt, s);
copies.push_back(wt);
array arr_copy(wt.shape(), wt.dtype(), nullptr, {});
copy_gpu(wt, arr_copy, CopyType::General, s);
copies.push_back(arr_copy);
wt = arr_copy;
}
// 3D conv

View File

@@ -25,7 +25,8 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
return x;
} else {
array x_copy = contiguous_copy_gpu(x, s);
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
d.add_temporary(x_copy, s.index);
return x_copy;
}

View File

@@ -33,7 +33,8 @@ std::tuple<bool, int64_t, array> check_transpose(
} else if (stx == 1 && (!is_vector || sty == arr.shape(-2))) {
return std::make_tuple(true, sty, arr);
} else {
array arr_copy = contiguous_copy_gpu(arr, s);
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
copies.push_back(arr_copy);
return std::make_tuple(false, arr.shape(-1), arr_copy);
}
@@ -42,7 +43,8 @@ std::tuple<bool, int64_t, array> check_transpose(
inline array
ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) {
if (!x.flags().row_contiguous) {
array x_copy = contiguous_copy_gpu(x, s);
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
d.add_temporary(x_copy, s.index);
return x_copy;
} else {
@@ -73,7 +75,8 @@ ensure_batch_contiguous(const array& x, metal::Device& d, const Stream& s) {
}
}
array x_copy = contiguous_copy_gpu(x, s);
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
d.add_temporary(x_copy, s.index);
return std::make_tuple(false, x_copy.strides()[x_copy.ndim() - 2], x_copy);
}
@@ -1891,7 +1894,8 @@ void segmented_mm(
return std::make_tuple(false, x);
}
array x_copy = contiguous_copy_gpu(x, s);
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
d.add_temporary(x_copy, s.index);
return std::make_tuple(true, x_copy);
};

View File

@@ -40,7 +40,8 @@ void RMSNorm::eval_gpu(
}
return x;
} else {
array x_copy = contiguous_copy_gpu(x, s);
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
out.copy_shared_buffer(x_copy);
return x_copy;
}
@@ -106,7 +107,9 @@ void RMSNormVJP::eval_gpu(
if (x.flags().row_contiguous) {
return {x, false};
}
array x_copy = contiguous_copy_gpu(x, s);
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
return {x_copy, true};
};
bool donate_x = inputs[0].is_donatable();
@@ -238,7 +241,8 @@ void LayerNorm::eval_gpu(
}
return x;
} else {
array x_copy = contiguous_copy_gpu(x, s);
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
out.copy_shared_buffer(x_copy);
return x_copy;
}
@@ -315,7 +319,8 @@ void LayerNormVJP::eval_gpu(
if (x.flags().row_contiguous) {
return {x, false};
}
array x_copy = contiguous_copy_gpu(x, s);
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
return {x_copy, true};
};
bool donate_x = inputs[0].is_donatable();

View File

@@ -20,7 +20,8 @@ namespace {
inline array
ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) {
if (!x.flags().row_contiguous) {
array x_copy = contiguous_copy_gpu(x, s);
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
d.add_temporary(x_copy, s.index);
return x_copy;
} else {
@@ -37,7 +38,8 @@ inline array ensure_row_contiguous_matrix(
if (stride_0 == x.shape(-1) && stride_1 == 1) {
return x;
} else {
array x_copy = contiguous_copy_gpu(x, s);
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
d.add_temporary(x_copy, s.index);
return x_copy;
}

View File

@@ -989,7 +989,8 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
// input for the axes with stride smaller than the minimum reduction
// stride.
if (plan.type == GeneralReduce) {
array in_copy = contiguous_copy_gpu(in, s);
array in_copy(in.shape(), in.dtype(), nullptr, {});
copy_gpu(in, in_copy, CopyType::General, s);
d.add_temporary(in_copy, s.index);
in = in_copy;
plan = get_reduction_plan(in, axes_);

View File

@@ -398,7 +398,8 @@ void ScaledDotProductAttention::eval_gpu(
auto copy_unless = [&copies, &s](
auto predicate, const array& arr) -> const array& {
if (!predicate(arr)) {
array arr_copy = contiguous_copy_gpu(arr, s);
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
copies.push_back(std::move(arr_copy));
return copies.back();
} else {

View File

@@ -30,7 +30,9 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
in.flags());
}
} else {
in = contiguous_copy_gpu(in, s);
array arr_copy(in.shape(), in.dtype(), nullptr, {});
copy_gpu(in, arr_copy, CopyType::General, s);
in = std::move(arr_copy);
out.copy_shared_buffer(in);
}

View File

@@ -35,7 +35,8 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
}
return x;
} else {
array x_copy = contiguous_copy_gpu(x, s);
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
out.copy_shared_buffer(x_copy);
return x_copy;
}

View File

@@ -4,7 +4,7 @@
#define MLX_VERSION_MAJOR 0
#define MLX_VERSION_MINOR 26
#define MLX_VERSION_PATCH 5
#define MLX_VERSION_PATCH 3
#define MLX_VERSION_NUMERIC \
(100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH)

View File

@@ -893,22 +893,24 @@ class Muon(Optimizer):
"""Initialize optimizer state"""
state["v"] = mx.zeros_like(parameter)
def _zeropower_via_newtonschulz5(self, X, steps: int):
assert (
X.ndim == 2
), f"Expected a 2D array for Newton-Schulz iteration, got shape {X.shape} instead."
def _zeropower_via_newtonschulz5(self, G, steps: int):
assert G.ndim >= 2
a, b, c = (3.4445, -4.7750, 2.0315)
transpose_needed = X.shape[-2] > X.shape[-1]
X = G.astype(mx.bfloat16)
transpose_needed = G.shape[-2] > G.shape[-1]
if transpose_needed:
X = X.T
X = X / (mx.linalg.norm(X, keepdims=True) + 1e-7)
# Ensure spectral norm is at most 1
norm = mx.sqrt(mx.sum(X * X, axis=(-2, -1), keepdims=True) + 1e-7)
X = X / norm
# Perform the NS iterations
for _ in range(steps):
A = X @ X.T
B = mx.addmm(b * A, A, A, beta=1.0, alpha=c)
X = mx.addmm(a * X, B, X, beta=1.0, alpha=1.0)
B = b * A + c * (A @ A)
X = a * X + B @ X
if transpose_needed:
X = X.T
@@ -917,35 +919,56 @@ class Muon(Optimizer):
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
"""Performs the Muon parameter update"""
# Apply weight decay
if self.weight_decay != 0:
gradient = gradient + self.weight_decay * parameter
# Update momentum buffer
v = self.momentum * state["v"]
v = v + (1 - self.momentum) * gradient
state["v"] = v
# Get effective gradient
if self.nesterov:
update = gradient * (1 - self.momentum) + v * self.momentum
effective_grad = gradient * (1 - self.momentum) + v * self.momentum
else:
update = v
effective_grad = v
lr = self.learning_rate.astype(gradient.dtype)
if update.ndim >= 2:
original_shape = update.shape
reshape_needed = update.ndim > 2
# For tensors with fewer than 2 dimensions, skip Newton-Schulz
if effective_grad.ndim < 2:
orthogonalized_grad = effective_grad
scale_factor = 1.0
else:
# Save original shape for 4D conv filters
original_shape = effective_grad.shape
reshape_needed = effective_grad.ndim > 2
if reshape_needed:
update = mx.reshape(update, (update.shape[0], -1))
effective_grad = mx.reshape(
effective_grad, (effective_grad.shape[0], -1)
)
update = self._zeropower_via_newtonschulz5(update, steps=self.ns_steps)
# Apply Newton-Schulz orthogonalization
orthogonalized_grad = self._zeropower_via_newtonschulz5(
effective_grad, steps=self.ns_steps
)
# Reshape back if needed
if reshape_needed:
update = mx.reshape(update, original_shape)
orthogonalized_grad = mx.reshape(orthogonalized_grad, original_shape)
lr *= max(1, update.shape[-2] / update.shape[-1]) ** 0.5
# Calculate scaling factor
# scale_factor = max(1, parameter.shape[-2] / parameter.shape[-1]) ** 0.5
scale_factor = (
max(1, effective_grad.shape[-2] / effective_grad.shape[-1]) ** 0.5
)
return parameter - lr * update
return (
parameter
- self.learning_rate.astype(gradient.dtype)
* orthogonalized_grad
* scale_factor
)
def clip_grad_norm(grads, max_norm):

View File

@@ -691,21 +691,6 @@ class TestBlas(mlx_tests.MLXTestCase):
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
# Transposed c
a = mx.ones((10, 5)).T
b = mx.ones((5, 5))
out = mx.addmm(a, b, a, beta=1.5, alpha=0.5)
expected = 1.5 * a + 0.5 * (b @ a)
self.assertTrue(mx.allclose(expected, out))
# Broadcast c
a = mx.ones((5, 5))
b = mx.ones((5, 5))
c = mx.ones((1, 5))
out = mx.addmm(c, a, b, beta=1.5, alpha=0.5)
expected = 1.5 * c + 0.5 * (a @ b)
self.assertTrue(mx.allclose(expected, out))
def test_addmm_grad(self):
def make_ref_addmm(alpha, beta):
return lambda c, a, b: alpha * (a @ b) + beta * c

View File

@@ -39,14 +39,6 @@ target_sources(
linalg_tests.cpp
${METAL_TEST_SOURCES})
if(MLX_BUILD_CUDA)
# Find the CCCL headers in install dir.
target_compile_definitions(
mlx
PRIVATE
MLX_CCCL_DIR="${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_INCLUDEDIR}/cccl")
endif()
target_link_libraries(tests PRIVATE mlx doctest)
doctest_discover_tests(tests)
add_test(NAME tests COMMAND tests)