mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
5 Commits
0a8bb904d7
...
v0.26.5
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
84b4d96efa | ||
|
|
aec67f2fa6 | ||
|
|
deee214a95 | ||
|
|
45adec102c | ||
|
|
31fc530c76 |
@@ -272,6 +272,7 @@ jobs:
|
||||
name: Build Python package
|
||||
command: |
|
||||
source env/bin/activate
|
||||
python setup.py clean --all
|
||||
<< parameters.build_env >> MLX_BUILD_STAGE=1 python -m build -w
|
||||
- when:
|
||||
condition:
|
||||
@@ -333,6 +334,7 @@ jobs:
|
||||
<< parameters.build_env >> pip install ".[dev]" -v
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
python setup.py clean --all
|
||||
MLX_BUILD_STAGE=1 << parameters.build_env >> python -m build -w
|
||||
bash python/scripts/repair_linux.sh
|
||||
- when:
|
||||
|
||||
@@ -19,6 +19,7 @@ MLX was developed with contributions from the following individuals:
|
||||
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
|
||||
- Paul Paczuski: Improved stability of BCE loss calculation
|
||||
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
|
||||
- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer.
|
||||
|
||||
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
||||
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
||||
|
||||
@@ -19,3 +19,4 @@ Common Optimizers
|
||||
Adamax
|
||||
Lion
|
||||
MultiOptimizer
|
||||
Muon
|
||||
|
||||
@@ -52,13 +52,29 @@ const std::string& cuda_home() {
|
||||
}
|
||||
|
||||
// Return the location of CCCL headers shipped with the distribution.
|
||||
bool get_cccl_include(std::string* out) {
|
||||
auto cccl_headers = current_binary_dir().parent_path() / "include" / "cccl";
|
||||
if (!std::filesystem::exists(cccl_headers)) {
|
||||
return false;
|
||||
}
|
||||
*out = fmt::format("--include-path={}", cccl_headers.string());
|
||||
return true;
|
||||
const std::string& cccl_dir() {
|
||||
static std::string dir = []() {
|
||||
std::filesystem::path path;
|
||||
#if defined(MLX_CCCL_DIR)
|
||||
// First search the install dir if defined.
|
||||
path = MLX_CCCL_DIR;
|
||||
if (std::filesystem::exists(path)) {
|
||||
return path.string();
|
||||
}
|
||||
#endif
|
||||
// Then search dynamically from the dir of libmlx.so file.
|
||||
path = current_binary_dir().parent_path() / "include" / "cccl";
|
||||
if (std::filesystem::exists(path)) {
|
||||
return path.string();
|
||||
}
|
||||
// Finally check the environment variable.
|
||||
path = std::getenv("MLX_CCCL_DIR");
|
||||
if (!path.empty() && std::filesystem::exists(path)) {
|
||||
return path.string();
|
||||
}
|
||||
return std::string();
|
||||
}();
|
||||
return dir;
|
||||
}
|
||||
|
||||
// Get the cache directory for storing compiled results.
|
||||
@@ -238,8 +254,9 @@ JitModule::JitModule(
|
||||
device.compute_capability_major(),
|
||||
device.compute_capability_minor());
|
||||
args.push_back(compute.c_str());
|
||||
std::string cccl_include;
|
||||
if (get_cccl_include(&cccl_include)) {
|
||||
std::string cccl_include = cccl_dir();
|
||||
if (!cccl_include.empty()) {
|
||||
cccl_include = fmt::format("--include-path={}", cccl_include);
|
||||
args.push_back(cccl_include.c_str());
|
||||
}
|
||||
std::string cuda_include =
|
||||
|
||||
@@ -237,8 +237,7 @@ void LayerNorm::eval_gpu(
|
||||
}
|
||||
return x;
|
||||
} else {
|
||||
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
array x_copy = contiguous_copy_gpu(x, s);
|
||||
out.copy_shared_buffer(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
@@ -295,9 +294,7 @@ void LayerNormVJP::eval_gpu(
|
||||
return x;
|
||||
}
|
||||
copied = true;
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
return x_copy;
|
||||
return contiguous_copy_gpu(x, s);
|
||||
};
|
||||
bool donate_x = inputs[0].is_donatable();
|
||||
bool donate_g = inputs[3].is_donatable();
|
||||
|
||||
@@ -108,8 +108,7 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
|
||||
return x;
|
||||
} else {
|
||||
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
array x_copy = contiguous_copy_gpu(x, s);
|
||||
encoder.add_temporary(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
|
||||
@@ -119,7 +119,6 @@ class MatMul {
|
||||
uint64_t b_rows,
|
||||
uint64_t b_cols,
|
||||
int64_t ldb,
|
||||
bool c_transposed,
|
||||
int64_t ldc,
|
||||
int32_t batch_count,
|
||||
int64_t a_batch_stride,
|
||||
@@ -141,7 +140,7 @@ class MatMul {
|
||||
b_batch_stride) {
|
||||
auto type = dtype_to_cuda_type(dtype);
|
||||
c_desc_ = create_matrix_layout(
|
||||
type, a_rows, b_cols, c_transposed, ldc, batch_count, c_batch_stride);
|
||||
type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride);
|
||||
}
|
||||
|
||||
~MatMul() {
|
||||
@@ -297,8 +296,7 @@ check_transpose(cu::CommandEncoder& enc, const Stream& s, const array& arr) {
|
||||
} else if (stx == 1 && sty == arr.shape(-2)) {
|
||||
return std::make_tuple(true, sty, arr);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy_gpu(arr, arr_copy, CopyType::General, s);
|
||||
array arr_copy = contiguous_copy_gpu(arr, s);
|
||||
enc.add_temporary(arr_copy);
|
||||
return std::make_tuple(false, arr.shape(-1), arr_copy);
|
||||
}
|
||||
@@ -404,9 +402,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 3);
|
||||
auto& a_pre = inputs[0];
|
||||
auto& b_pre = inputs[1];
|
||||
auto& c_pre = inputs[2];
|
||||
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
auto c = inputs[2];
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Init checks and prep
|
||||
@@ -419,7 +415,24 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// the arrays
|
||||
auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);
|
||||
auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);
|
||||
auto [c_transposed, ldc, c] = check_transpose(encoder, s, c_pre);
|
||||
|
||||
int64_t ldc;
|
||||
{
|
||||
auto stx = c.strides()[c.ndim() - 2];
|
||||
auto sty = c.strides()[c.ndim() - 1];
|
||||
if (sty == 1 && stx == c.shape(-1)) {
|
||||
ldc = stx;
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
} else if (sty == 1 && stx == 0) {
|
||||
ldc = 0;
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
} else {
|
||||
// Copy C into out and set C to out
|
||||
ldc = c.shape(-1);
|
||||
copy_gpu(c, out, CopyType::General, s);
|
||||
c = out;
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Check and collapse batch dimensions
|
||||
@@ -457,7 +470,6 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
K,
|
||||
N,
|
||||
ldb,
|
||||
c_transposed,
|
||||
ldc,
|
||||
batch_shape.back(),
|
||||
a_batch_strides.back(),
|
||||
|
||||
@@ -247,8 +247,7 @@ inline array ensure_row_contiguous(
|
||||
cu::CommandEncoder& enc,
|
||||
const Stream& s) {
|
||||
if (!x.flags().row_contiguous) {
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
array x_copy = contiguous_copy_gpu(x, s);
|
||||
enc.add_temporary(x_copy);
|
||||
return x_copy;
|
||||
} else {
|
||||
|
||||
@@ -47,8 +47,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
if (plan.type == GeneralReduce || broadcasted || !in.flags().contiguous) {
|
||||
array in_copy(in.shape(), in.dtype(), nullptr, {});
|
||||
copy_gpu(in, in_copy, CopyType::General, s);
|
||||
array in_copy = contiguous_copy_gpu(in, s);
|
||||
encoder.add_temporary(in_copy);
|
||||
in = in_copy;
|
||||
plan = get_reduction_plan(in, axes_);
|
||||
|
||||
@@ -206,8 +206,7 @@ void RMSNorm::eval_gpu(
|
||||
}
|
||||
return x;
|
||||
} else {
|
||||
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
array x_copy = contiguous_copy_gpu(x, s);
|
||||
out.copy_shared_buffer(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
@@ -259,9 +258,7 @@ void RMSNormVJP::eval_gpu(
|
||||
return x;
|
||||
}
|
||||
copied = true;
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
return x_copy;
|
||||
return contiguous_copy_gpu(x, s);
|
||||
};
|
||||
bool donate_x = inputs[0].is_donatable();
|
||||
bool donate_g = inputs[2].is_donatable();
|
||||
|
||||
@@ -379,9 +379,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
in.flags());
|
||||
}
|
||||
} else {
|
||||
array arr_copy(in.shape(), in.dtype(), nullptr, {});
|
||||
copy_gpu(in, arr_copy, CopyType::General, s);
|
||||
in = std::move(arr_copy);
|
||||
in = contiguous_copy_gpu(in, s);
|
||||
out.copy_shared_buffer(in);
|
||||
}
|
||||
|
||||
|
||||
@@ -125,8 +125,7 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
return x;
|
||||
} else {
|
||||
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
array x_copy = contiguous_copy_gpu(x, s);
|
||||
out.copy_shared_buffer(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
|
||||
@@ -72,8 +72,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
||||
bool is_segmented_sort = in.flags().contiguous && in.strides()[axis] == 1;
|
||||
if (!is_segmented_sort) {
|
||||
array trans = swapaxes_in_eval(in, axis, last_dim);
|
||||
in = array(trans.shape(), trans.dtype(), nullptr, {});
|
||||
copy_gpu(trans, in, CopyType::General, s);
|
||||
in = contiguous_copy_gpu(trans, s);
|
||||
encoder.add_temporary(in);
|
||||
out = array(allocator::malloc(out.nbytes()), in.shape(), out.dtype());
|
||||
encoder.add_temporary(out);
|
||||
|
||||
@@ -46,4 +46,10 @@ void copy_gpu_inplace(
|
||||
in, out, in.shape(), i_strides, out.strides(), i_offset, 0, ctype, s);
|
||||
}
|
||||
|
||||
array contiguous_copy_gpu(const array& arr, const Stream& s) {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy_gpu(arr, arr_copy, CopyType::General, s);
|
||||
return arr_copy;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -43,4 +43,7 @@ void copy_gpu_inplace(
|
||||
// Fill the output with the scalar val
|
||||
void fill_gpu(const array& val, array& out, const Stream& s);
|
||||
|
||||
// Return a contiguous array with same shape that copies the data of |arr|.
|
||||
array contiguous_copy_gpu(const array& arr, const Stream& s);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -149,8 +149,7 @@ void explicit_gemm_conv_group_ND_gpu(
|
||||
wt, {wt.strides(0), 1, C_per_group}, wt.flags(), wt.size());
|
||||
|
||||
// Materialize
|
||||
auto wt_transpose = array(wt_view.shape(), wt_view.dtype(), nullptr, {});
|
||||
copy_gpu(wt_view, wt_transpose, CopyType::General, s);
|
||||
array wt_transpose = contiguous_copy_gpu(wt_view, s);
|
||||
|
||||
// Perform gemm
|
||||
std::vector<array> copies = {in_unfolded, wt_transpose};
|
||||
@@ -961,16 +960,12 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto in = inputs[0];
|
||||
auto wt = inputs[1];
|
||||
if (!in.flags().row_contiguous) {
|
||||
array arr_copy(in.shape(), in.dtype(), nullptr, {});
|
||||
copy_gpu(in, arr_copy, CopyType::General, s);
|
||||
copies.push_back(arr_copy);
|
||||
in = arr_copy;
|
||||
in = contiguous_copy_gpu(in, s);
|
||||
copies.push_back(in);
|
||||
}
|
||||
if (!wt.flags().row_contiguous) {
|
||||
array arr_copy(wt.shape(), wt.dtype(), nullptr, {});
|
||||
copy_gpu(wt, arr_copy, CopyType::General, s);
|
||||
copies.push_back(arr_copy);
|
||||
wt = arr_copy;
|
||||
wt = contiguous_copy_gpu(wt, s);
|
||||
copies.push_back(wt);
|
||||
}
|
||||
|
||||
// 3D conv
|
||||
|
||||
@@ -25,8 +25,7 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
|
||||
return x;
|
||||
} else {
|
||||
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
array x_copy = contiguous_copy_gpu(x, s);
|
||||
d.add_temporary(x_copy, s.index);
|
||||
return x_copy;
|
||||
}
|
||||
|
||||
@@ -33,8 +33,7 @@ std::tuple<bool, int64_t, array> check_transpose(
|
||||
} else if (stx == 1 && (!is_vector || sty == arr.shape(-2))) {
|
||||
return std::make_tuple(true, sty, arr);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy_gpu(arr, arr_copy, CopyType::General, s);
|
||||
array arr_copy = contiguous_copy_gpu(arr, s);
|
||||
copies.push_back(arr_copy);
|
||||
return std::make_tuple(false, arr.shape(-1), arr_copy);
|
||||
}
|
||||
@@ -43,8 +42,7 @@ std::tuple<bool, int64_t, array> check_transpose(
|
||||
inline array
|
||||
ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) {
|
||||
if (!x.flags().row_contiguous) {
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
array x_copy = contiguous_copy_gpu(x, s);
|
||||
d.add_temporary(x_copy, s.index);
|
||||
return x_copy;
|
||||
} else {
|
||||
@@ -75,8 +73,7 @@ ensure_batch_contiguous(const array& x, metal::Device& d, const Stream& s) {
|
||||
}
|
||||
}
|
||||
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
array x_copy = contiguous_copy_gpu(x, s);
|
||||
d.add_temporary(x_copy, s.index);
|
||||
return std::make_tuple(false, x_copy.strides()[x_copy.ndim() - 2], x_copy);
|
||||
}
|
||||
@@ -1894,8 +1891,7 @@ void segmented_mm(
|
||||
return std::make_tuple(false, x);
|
||||
}
|
||||
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
array x_copy = contiguous_copy_gpu(x, s);
|
||||
d.add_temporary(x_copy, s.index);
|
||||
return std::make_tuple(true, x_copy);
|
||||
};
|
||||
|
||||
@@ -40,8 +40,7 @@ void RMSNorm::eval_gpu(
|
||||
}
|
||||
return x;
|
||||
} else {
|
||||
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
array x_copy = contiguous_copy_gpu(x, s);
|
||||
out.copy_shared_buffer(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
@@ -107,9 +106,7 @@ void RMSNormVJP::eval_gpu(
|
||||
if (x.flags().row_contiguous) {
|
||||
return {x, false};
|
||||
}
|
||||
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
array x_copy = contiguous_copy_gpu(x, s);
|
||||
return {x_copy, true};
|
||||
};
|
||||
bool donate_x = inputs[0].is_donatable();
|
||||
@@ -241,8 +238,7 @@ void LayerNorm::eval_gpu(
|
||||
}
|
||||
return x;
|
||||
} else {
|
||||
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
array x_copy = contiguous_copy_gpu(x, s);
|
||||
out.copy_shared_buffer(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
@@ -319,8 +315,7 @@ void LayerNormVJP::eval_gpu(
|
||||
if (x.flags().row_contiguous) {
|
||||
return {x, false};
|
||||
}
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
array x_copy = contiguous_copy_gpu(x, s);
|
||||
return {x_copy, true};
|
||||
};
|
||||
bool donate_x = inputs[0].is_donatable();
|
||||
|
||||
@@ -20,8 +20,7 @@ namespace {
|
||||
inline array
|
||||
ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) {
|
||||
if (!x.flags().row_contiguous) {
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
array x_copy = contiguous_copy_gpu(x, s);
|
||||
d.add_temporary(x_copy, s.index);
|
||||
return x_copy;
|
||||
} else {
|
||||
@@ -38,8 +37,7 @@ inline array ensure_row_contiguous_matrix(
|
||||
if (stride_0 == x.shape(-1) && stride_1 == 1) {
|
||||
return x;
|
||||
} else {
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
array x_copy = contiguous_copy_gpu(x, s);
|
||||
d.add_temporary(x_copy, s.index);
|
||||
return x_copy;
|
||||
}
|
||||
|
||||
@@ -989,8 +989,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// input for the axes with stride smaller than the minimum reduction
|
||||
// stride.
|
||||
if (plan.type == GeneralReduce) {
|
||||
array in_copy(in.shape(), in.dtype(), nullptr, {});
|
||||
copy_gpu(in, in_copy, CopyType::General, s);
|
||||
array in_copy = contiguous_copy_gpu(in, s);
|
||||
d.add_temporary(in_copy, s.index);
|
||||
in = in_copy;
|
||||
plan = get_reduction_plan(in, axes_);
|
||||
|
||||
@@ -398,8 +398,7 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
auto copy_unless = [&copies, &s](
|
||||
auto predicate, const array& arr) -> const array& {
|
||||
if (!predicate(arr)) {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy_gpu(arr, arr_copy, CopyType::General, s);
|
||||
array arr_copy = contiguous_copy_gpu(arr, s);
|
||||
copies.push_back(std::move(arr_copy));
|
||||
return copies.back();
|
||||
} else {
|
||||
|
||||
@@ -30,9 +30,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
in.flags());
|
||||
}
|
||||
} else {
|
||||
array arr_copy(in.shape(), in.dtype(), nullptr, {});
|
||||
copy_gpu(in, arr_copy, CopyType::General, s);
|
||||
in = std::move(arr_copy);
|
||||
in = contiguous_copy_gpu(in, s);
|
||||
out.copy_shared_buffer(in);
|
||||
}
|
||||
|
||||
|
||||
@@ -35,8 +35,7 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
return x;
|
||||
} else {
|
||||
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
array x_copy = contiguous_copy_gpu(x, s);
|
||||
out.copy_shared_buffer(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
|
||||
#define MLX_VERSION_MAJOR 0
|
||||
#define MLX_VERSION_MINOR 26
|
||||
#define MLX_VERSION_PATCH 3
|
||||
#define MLX_VERSION_PATCH 5
|
||||
#define MLX_VERSION_NUMERIC \
|
||||
(100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH)
|
||||
|
||||
|
||||
@@ -848,6 +848,106 @@ class Adafactor(Optimizer):
|
||||
return parameter - update
|
||||
|
||||
|
||||
class Muon(Optimizer):
|
||||
r"""The Muon optimizer.
|
||||
|
||||
Our Muon (MomentUm Orthogonalized by Newton-schulz) optimizer follows the
|
||||
original implementation: `Muon: An optimizer for hidden layers in neural
|
||||
networks <https://kellerjordan.github.io/posts/muon/>`_
|
||||
|
||||
Note:
|
||||
- Muon may be sub-optimal for the embedding layer, the final fully
|
||||
connected layer, or any 0D/1D parameters. Those should be optimized
|
||||
by a different method (e.g., :class:`AdamW`).
|
||||
- For 4D convolutional filters, it works by flattening their last
|
||||
dimensions.
|
||||
|
||||
Args:
|
||||
learning_rate (float or callable): The learning rate.
|
||||
momentum (float, optional): The momentum strength. Default: ``0.95``
|
||||
weight_decay (float, optional): The weight decay (L2 penalty).
|
||||
Default: ``0.01``
|
||||
nesterov (bool, optional): Enables Nesterov momentum. Recommended for
|
||||
better performance. Default: ``True``
|
||||
ns_steps (int, optional): Number of Newton-Schulz iteration steps for
|
||||
orthogonalization. Default: ``5``
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
learning_rate: Union[float, Callable[[mx.array], mx.array]],
|
||||
momentum: float = 0.95,
|
||||
weight_decay: float = 0.01,
|
||||
nesterov: bool = True,
|
||||
ns_steps: int = 5,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self._maybe_schedule("learning_rate", learning_rate)
|
||||
self.momentum = momentum
|
||||
self.weight_decay = weight_decay
|
||||
self.nesterov = nesterov
|
||||
self.ns_steps = ns_steps
|
||||
|
||||
def init_single(self, parameter: mx.array, state: dict):
|
||||
"""Initialize optimizer state"""
|
||||
state["v"] = mx.zeros_like(parameter)
|
||||
|
||||
def _zeropower_via_newtonschulz5(self, X, steps: int):
|
||||
assert (
|
||||
X.ndim == 2
|
||||
), f"Expected a 2D array for Newton-Schulz iteration, got shape {X.shape} instead."
|
||||
a, b, c = (3.4445, -4.7750, 2.0315)
|
||||
transpose_needed = X.shape[-2] > X.shape[-1]
|
||||
|
||||
if transpose_needed:
|
||||
X = X.T
|
||||
|
||||
X = X / (mx.linalg.norm(X, keepdims=True) + 1e-7)
|
||||
|
||||
for _ in range(steps):
|
||||
A = X @ X.T
|
||||
B = mx.addmm(b * A, A, A, beta=1.0, alpha=c)
|
||||
X = mx.addmm(a * X, B, X, beta=1.0, alpha=1.0)
|
||||
|
||||
if transpose_needed:
|
||||
X = X.T
|
||||
return X
|
||||
|
||||
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
|
||||
"""Performs the Muon parameter update"""
|
||||
|
||||
if self.weight_decay != 0:
|
||||
gradient = gradient + self.weight_decay * parameter
|
||||
|
||||
v = self.momentum * state["v"]
|
||||
v = v + (1 - self.momentum) * gradient
|
||||
state["v"] = v
|
||||
|
||||
if self.nesterov:
|
||||
update = gradient * (1 - self.momentum) + v * self.momentum
|
||||
else:
|
||||
update = v
|
||||
|
||||
lr = self.learning_rate.astype(gradient.dtype)
|
||||
|
||||
if update.ndim >= 2:
|
||||
original_shape = update.shape
|
||||
reshape_needed = update.ndim > 2
|
||||
|
||||
if reshape_needed:
|
||||
update = mx.reshape(update, (update.shape[0], -1))
|
||||
|
||||
update = self._zeropower_via_newtonschulz5(update, steps=self.ns_steps)
|
||||
|
||||
if reshape_needed:
|
||||
update = mx.reshape(update, original_shape)
|
||||
|
||||
lr *= max(1, update.shape[-2] / update.shape[-1]) ** 0.5
|
||||
|
||||
return parameter - lr * update
|
||||
|
||||
|
||||
def clip_grad_norm(grads, max_norm):
|
||||
"""Clips the global norm of the gradients.
|
||||
|
||||
|
||||
@@ -691,6 +691,21 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
||||
|
||||
# Transposed c
|
||||
a = mx.ones((10, 5)).T
|
||||
b = mx.ones((5, 5))
|
||||
out = mx.addmm(a, b, a, beta=1.5, alpha=0.5)
|
||||
expected = 1.5 * a + 0.5 * (b @ a)
|
||||
self.assertTrue(mx.allclose(expected, out))
|
||||
|
||||
# Broadcast c
|
||||
a = mx.ones((5, 5))
|
||||
b = mx.ones((5, 5))
|
||||
c = mx.ones((1, 5))
|
||||
out = mx.addmm(c, a, b, beta=1.5, alpha=0.5)
|
||||
expected = 1.5 * c + 0.5 * (a @ b)
|
||||
self.assertTrue(mx.allclose(expected, out))
|
||||
|
||||
def test_addmm_grad(self):
|
||||
def make_ref_addmm(alpha, beta):
|
||||
return lambda c, a, b: alpha * (a @ b) + beta * c
|
||||
|
||||
@@ -286,6 +286,53 @@ class TestOptimizers(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(xp["x"].shape, x.shape)
|
||||
self.assertEqual(optimizer.state["step"], 2)
|
||||
|
||||
def test_muon(self):
|
||||
params = {
|
||||
"first": [mx.zeros((10, 5)), mx.zeros((1,))],
|
||||
"second": mx.zeros((3, 3)),
|
||||
"conv": mx.zeros((16, 8, 3, 3)),
|
||||
}
|
||||
grads = tree_map(lambda x: mx.ones_like(x), params)
|
||||
|
||||
# Explicit init
|
||||
optim = opt.Muon(learning_rate=1e-2, momentum=0.95, nesterov=True)
|
||||
optim.init(params)
|
||||
self.assertTrue(
|
||||
tree_equal(
|
||||
lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)),
|
||||
params,
|
||||
optim.state,
|
||||
)
|
||||
)
|
||||
|
||||
# Test update
|
||||
updated_params = optim.apply_gradients(grads, params)
|
||||
|
||||
# Check that shapes are preserved
|
||||
self.assertTrue(
|
||||
tree_equal(
|
||||
lambda p, u: p.shape == u.shape,
|
||||
params,
|
||||
updated_params,
|
||||
)
|
||||
)
|
||||
|
||||
# Check that parameters actually changed
|
||||
self.assertFalse(
|
||||
tree_equal(
|
||||
lambda p, u: mx.array_equal(p, u),
|
||||
params,
|
||||
updated_params,
|
||||
)
|
||||
)
|
||||
|
||||
# Test with different configurations
|
||||
optim_no_nesterov = opt.Muon(learning_rate=1e-2, momentum=0.95, nesterov=False)
|
||||
optim_no_nesterov.apply_gradients(grads, params)
|
||||
|
||||
optim_no_momentum = opt.Muon(learning_rate=1e-2, momentum=0.0)
|
||||
optim_no_momentum.apply_gradients(grads, params)
|
||||
|
||||
def test_compiled_optimizer(self):
|
||||
model = nn.Linear(10, 10)
|
||||
x = mx.random.uniform(shape=(2, 10))
|
||||
|
||||
@@ -39,6 +39,14 @@ target_sources(
|
||||
linalg_tests.cpp
|
||||
${METAL_TEST_SOURCES})
|
||||
|
||||
if(MLX_BUILD_CUDA)
|
||||
# Find the CCCL headers in install dir.
|
||||
target_compile_definitions(
|
||||
mlx
|
||||
PRIVATE
|
||||
MLX_CCCL_DIR="${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_INCLUDEDIR}/cccl")
|
||||
endif()
|
||||
|
||||
target_link_libraries(tests PRIVATE mlx doctest)
|
||||
doctest_discover_tests(tests)
|
||||
add_test(NAME tests COMMAND tests)
|
||||
|
||||
Reference in New Issue
Block a user