Compare commits

...

10 Commits

Author SHA1 Message Date
Awni Hannun
84b4d96efa fix release build + patch bump (#2387) 2025-07-18 14:47:37 -07:00
Awni Hannun
aec67f2fa6 patch bump (#2386) 2025-07-18 12:25:48 -07:00
Gökdeniz Gülmez
deee214a95 Adding support for the Muon Optimizer (#1914)
* initial commit with workong optmimizer

* update ACKNOWLEDGMENTS.md

* nits and adding it to test

* nits

* G.astype(mx.bfloat16) to G.astype(G.dtype)

* G.ndim >= 2 to assert G.ndim == 2

* remove coments

* replace with  mx.addmm

* remove comments

* format

* nits

* match muon

* fix addmm

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-07-18 12:25:28 -07:00
Cheng
45adec102c Add contiguous_copy_gpu util for copying array (#2379) 2025-07-18 06:44:25 -07:00
Cheng
31fc530c76 [CUDA] Add more ways finding CCCL headers in JIT (#2382) 2025-07-17 15:25:34 -07:00
Awni Hannun
fbb3f65a1a fix resource leaks in matmul and graph (#2383) 2025-07-17 06:50:15 -07:00
Angelos Katharopoulos
6b1b8ea91b [CUDA] Add work per thread to compile (#2368) 2025-07-17 06:47:52 -07:00
Awni Hannun
b2273733ea Test with CUDA 12.2 (#2375)
* Test with CUDA 12.0

* try older image

* fix cpu sort
2025-07-16 13:00:37 -07:00
Awni Hannun
f409b229a4 fix ring distributed test (#2380) 2025-07-16 11:25:24 -07:00
Cheng
30571e2326 Rename the copy util in cpu/copy.h to copy_cpu (#2378) 2025-07-16 07:34:24 -07:00
53 changed files with 492 additions and 207 deletions

View File

@@ -97,7 +97,8 @@ jobs:
command: |
python -m unittest discover python/tests -v
mpirun --bind-to none -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
- run:
name: Build CPP only
command: |
@@ -156,7 +157,8 @@ jobs:
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
- run:
name: Build example extension
command: |
@@ -199,7 +201,7 @@ jobs:
cuda_build_and_test:
machine:
image: linux-cuda-12:default
image: linux-cuda-12:2023.11.1
resource_class: gpu.nvidia.small.gen2
steps:
- checkout
@@ -208,7 +210,7 @@ jobs:
command: |
sudo apt-get update
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
python -m venv env
python3 -m venv env
source env/bin/activate
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
pip install -e ".[dev]"
@@ -270,6 +272,7 @@ jobs:
name: Build Python package
command: |
source env/bin/activate
python setup.py clean --all
<< parameters.build_env >> MLX_BUILD_STAGE=1 python -m build -w
- when:
condition:
@@ -331,6 +334,7 @@ jobs:
<< parameters.build_env >> pip install ".[dev]" -v
pip install typing_extensions
python setup.py generate_stubs
python setup.py clean --all
MLX_BUILD_STAGE=1 << parameters.build_env >> python -m build -w
bash python/scripts/repair_linux.sh
- when:

View File

@@ -19,6 +19,7 @@ MLX was developed with contributions from the following individuals:
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
- Paul Paczuski: Improved stability of BCE loss calculation
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer.
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />

View File

@@ -19,3 +19,4 @@ Common Optimizers
Adamax
Lion
MultiOptimizer
Muon

View File

@@ -20,7 +20,7 @@ void cholesky_impl(const array& a, array& factor, bool upper, Stream stream) {
// The decomposition is computed in place, so just copy the input to the
// output.
copy(
copy_cpu(
a,
factor,
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,

View File

@@ -883,7 +883,7 @@ void explicit_gemm_conv_1D_cpu(
// Fill with zeros
std::vector<array> temps;
temps.push_back(array(0, conv_dtype));
copy(temps.back(), in_padded, CopyType::Scalar, stream);
copy_cpu(temps.back(), in_padded, CopyType::Scalar, stream);
// Pick input slice from padded
size_t data_offset = padding_lo[0] * in_padded.strides()[1];
@@ -895,7 +895,7 @@ void explicit_gemm_conv_1D_cpu(
in_padded_slice.size(),
data_offset);
// Copy input values into the slice
copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream);
copy_cpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream);
temps.push_back(in_padded_slice);
// Make strided view
@@ -920,7 +920,7 @@ void explicit_gemm_conv_1D_cpu(
// Materialize strided view
Shape strided_reshape = {N * oH, wH * C};
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
copy(in_strided_view, in_strided, CopyType::General, stream);
copy_cpu(in_strided_view, in_strided, CopyType::General, stream);
temps.push_back(in_strided);
// Check wt dtype and prepare
@@ -938,13 +938,13 @@ void explicit_gemm_conv_1D_cpu(
wt.size(),
0);
gemm_wt = array(wt_transpose.shape(), float32, nullptr, {});
copy(wt_transpose, gemm_wt, CopyType::General, stream);
copy_cpu(wt_transpose, gemm_wt, CopyType::General, stream);
temps.push_back(gemm_wt);
} else if (wt.dtype() != float32 || !wt.flags().row_contiguous) {
auto ctype =
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
gemm_wt = array(wt.shape(), float32, nullptr, {});
copy(wt, gemm_wt, ctype, stream);
copy_cpu(wt, gemm_wt, ctype, stream);
temps.push_back(gemm_wt);
}
@@ -991,7 +991,7 @@ void explicit_gemm_conv_1D_cpu(
// Copy results if needed
if (out.dtype() != float32) {
copy_inplace(gemm_out, out, CopyType::Vector, stream);
copy_cpu_inplace(gemm_out, out, CopyType::Vector, stream);
}
encoder.add_temporaries(std::move(temps));
}
@@ -1029,7 +1029,7 @@ void explicit_gemm_conv_2D_cpu(
// Fill with zeros
std::vector<array> temps;
temps.push_back(array(0, conv_dtype));
copy(temps.back(), in_padded, CopyType::Scalar, stream);
copy_cpu(temps.back(), in_padded, CopyType::Scalar, stream);
// Pick input slice from padded
size_t data_offset = padding_lo[0] * in_padded.strides()[1] +
@@ -1044,7 +1044,7 @@ void explicit_gemm_conv_2D_cpu(
temps.push_back(in_padded_slice);
// Copy input values into the slice
copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream);
copy_cpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream);
// Make strided view
Shape strided_shape = {N, oH, oW, wH, wW, C};
@@ -1065,7 +1065,7 @@ void explicit_gemm_conv_2D_cpu(
// Materialize strided view
Shape strided_reshape = {N * oH * oW, wH * wW * C};
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
copy(in_strided_view, in_strided, CopyType::General, stream);
copy_cpu(in_strided_view, in_strided, CopyType::General, stream);
temps.push_back(in_strided);
// Check wt dtype and prepare
@@ -1076,7 +1076,7 @@ void explicit_gemm_conv_2D_cpu(
auto ctype =
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
gemm_wt = array(wt.shape(), float32, nullptr, {});
copy(wt, gemm_wt, ctype, stream);
copy_cpu(wt, gemm_wt, ctype, stream);
temps.push_back(gemm_wt);
}
@@ -1116,7 +1116,7 @@ void explicit_gemm_conv_2D_cpu(
// Copy results if needed
if (out.dtype() != float32) {
copy_inplace(gemm_out, out, CopyType::Vector, stream);
copy_cpu_inplace(gemm_out, out, CopyType::Vector, stream);
}
encoder.add_temporaries(std::move(temps));
}
@@ -1156,7 +1156,7 @@ void explicit_gemm_conv_ND_cpu(
// Fill with zeros
std::vector<array> temps = {array(0, conv_dtype)};
copy(temps.back(), in_padded, CopyType::Scalar, stream);
copy_cpu(temps.back(), in_padded, CopyType::Scalar, stream);
// Pick input slice from padded
size_t data_offset = 0;
@@ -1173,7 +1173,7 @@ void explicit_gemm_conv_ND_cpu(
data_offset);
// Copy input values into the slice
copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream);
copy_cpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream);
temps.push_back(in_padded_slice);
// Make strided view
@@ -1212,7 +1212,7 @@ void explicit_gemm_conv_ND_cpu(
}
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
copy(in_strided_view, in_strided, CopyType::General, stream);
copy_cpu(in_strided_view, in_strided, CopyType::General, stream);
temps.push_back(in_strided);
// Check wt dtype and prepare
@@ -1223,13 +1223,13 @@ void explicit_gemm_conv_ND_cpu(
auto ctype =
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
gemm_wt = array(wt.shape(), float32, nullptr, {});
copy(wt, gemm_wt, ctype, stream);
copy_cpu(wt, gemm_wt, ctype, stream);
temps.push_back(gemm_wt);
}
if (flip) {
auto gemm_wt_ = array(gemm_wt.shape(), float32, nullptr, {});
copy(gemm_wt, gemm_wt_, CopyType::Vector, stream);
copy_cpu(gemm_wt, gemm_wt_, CopyType::Vector, stream);
temps.push_back(gemm_wt_);
// Calculate the total size of the spatial dimensions
@@ -1284,7 +1284,7 @@ void explicit_gemm_conv_ND_cpu(
// Copy results if needed
if (out.dtype() != float32) {
copy_inplace(gemm_out, out, CopyType::Vector, stream);
copy_cpu_inplace(gemm_out, out, CopyType::Vector, stream);
}
encoder.add_temporaries(std::move(temps));
}

View File

@@ -295,7 +295,11 @@ inline void copy_inplace_dispatch(
} // namespace
void copy_inplace(const array& src, array& dst, CopyType ctype, Stream stream) {
void copy_cpu_inplace(
const array& src,
array& dst,
CopyType ctype,
Stream stream) {
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(src);
encoder.set_output_array(dst);
@@ -305,7 +309,7 @@ void copy_inplace(const array& src, array& dst, CopyType ctype, Stream stream) {
ctype]() mutable { copy_inplace_dispatch(src, dst, ctype); });
}
void copy(const array& src, array& dst, CopyType ctype, Stream stream) {
void copy_cpu(const array& src, array& dst, CopyType ctype, Stream stream) {
bool donated = set_copy_output_data(src, dst, ctype);
if (donated && src.dtype() == dst.dtype()) {
// If the output has the same type as the input then there is nothing to
@@ -315,10 +319,10 @@ void copy(const array& src, array& dst, CopyType ctype, Stream stream) {
if (ctype == CopyType::GeneralGeneral) {
ctype = CopyType::General;
}
copy_inplace(src, dst, ctype, stream);
copy_cpu_inplace(src, dst, ctype, stream);
}
void copy_inplace(
void copy_cpu_inplace(
const array& src,
array& dst,
const Shape& data_shape,

View File

@@ -10,10 +10,14 @@
namespace mlx::core {
void copy(const array& src, array& dst, CopyType ctype, Stream stream);
void copy_inplace(const array& src, array& dst, CopyType ctype, Stream stream);
void copy_cpu(const array& src, array& dst, CopyType ctype, Stream stream);
void copy_cpu_inplace(
const array& src,
array& dst,
CopyType ctype,
Stream stream);
void copy_inplace(
void copy_cpu_inplace(
const array& src,
array& dst,
const Shape& data_shape,

View File

@@ -14,7 +14,7 @@ std::pair<array, bool> ensure_row_contiguous(const array& arr, Stream stream) {
return {arr, false};
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General, stream);
copy_cpu(arr, arr_copy, CopyType::General, stream);
return {arr_copy, true};
}
};
@@ -35,7 +35,7 @@ void AllReduce::eval_cpu(
return in;
} else {
array arr_copy(in.shape(), in.dtype(), nullptr, {});
copy(in, arr_copy, CopyType::General, s);
copy_cpu(in, arr_copy, CopyType::General, s);
out.copy_shared_buffer(arr_copy);
return arr_copy;
}

View File

@@ -135,7 +135,7 @@ void Eig::eval_cpu(
: array(a.shape(), complex64, nullptr, {});
auto a_copy = array(a.shape(), a.dtype(), nullptr, {});
copy(
copy_cpu(
a,
a_copy,
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,

View File

@@ -196,7 +196,7 @@ void Eigh::eval_cpu(
values.set_data(allocator::malloc(values.nbytes()));
copy(
copy_cpu(
a,
vectors,
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,

View File

@@ -96,7 +96,7 @@ void Hadamard::eval_cpu(const std::vector<array>& inputs, array& out) {
if (in.flags().row_contiguous && in.is_donatable()) {
out.copy_shared_buffer(in);
} else {
copy(
copy_cpu(
in,
out,
in.flags().row_contiguous ? CopyType::Vector : CopyType::General,

View File

@@ -517,7 +517,7 @@ void Scatter::eval_cpu(const std::vector<array>& inputs, array& out) {
// Copy src into out (copy allocates memory for out)
auto ctype =
src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
copy(src, out, ctype, stream());
copy_cpu(src, out, ctype, stream());
auto& encoder = cpu::get_command_encoder(stream());
std::vector<array> inds;
@@ -686,7 +686,7 @@ void ScatterAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
// Copy src into out (copy allocates memory for out)
auto ctype =
src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
copy(src, out, ctype, stream());
copy_cpu(src, out, ctype, stream());
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(idx);

View File

@@ -115,7 +115,7 @@ void inverse_impl(
// (A⁻¹)ᵀ = (Aᵀ)⁻¹
// The inverse is computed in place, so just copy the input to the output.
copy(
copy_cpu(
a,
inv,
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,

View File

@@ -88,7 +88,7 @@ void LogSumExp::eval_cpu(const std::vector<array>& inputs, array& out) {
return x;
} else {
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy(x, x_copy, CopyType::General, s);
copy_cpu(x, x_copy, CopyType::General, s);
encoder.add_temporary(x_copy);
return x_copy;
}

View File

@@ -31,7 +31,7 @@ void luf_impl(
strides[ndim - 1] = M;
strides[ndim - 2] = 1;
lu.set_data(allocator::malloc(lu.nbytes()), lu.nbytes(), strides, flags);
copy_inplace(
copy_cpu_inplace(
a,
lu,
a.shape(),

View File

@@ -124,20 +124,20 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
if (!expand_all && stx == arr.shape(-1) && sty == 1) {
if (do_copy) {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::Vector, s);
copy_cpu(arr, arr_copy, CopyType::Vector, s);
return std::make_tuple(false, stx, arr_copy, true);
}
return std::make_tuple(false, stx, arr, false);
} else if (!expand_all && stx == 1 && sty == arr.shape(-2)) {
if (do_copy) {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::Vector, s);
copy_cpu(arr, arr_copy, CopyType::Vector, s);
return std::make_tuple(true, sty, arr_copy, true);
}
return std::make_tuple(true, sty, arr, false);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General, s);
copy_cpu(arr, arr_copy, CopyType::General, s);
int64_t stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy, true);
}
@@ -386,7 +386,7 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
return std::make_tuple(true, sty, arr);
} else {
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
copy(arr, temps.back(), CopyType::General, s);
copy_cpu(arr, temps.back(), CopyType::General, s);
int64_t stx = arr.shape(-1);
return std::make_tuple(false, stx, temps.back());
}
@@ -504,7 +504,7 @@ void SegmentedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
return std::make_tuple(true, sty, x);
} else {
array xc(x.shape(), x.dtype(), nullptr, {});
copy(x, xc, CopyType::General, s);
copy_cpu(x, xc, CopyType::General, s);
encoder.add_temporary(xc);
int64_t stx = x.shape(-1);
return std::make_tuple(false, stx, xc);

View File

@@ -81,7 +81,7 @@ void matmul_general(
return std::make_tuple(true, sty, arr);
} else {
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
copy(arr, temps.back(), CopyType::General, stream);
copy_cpu(arr, temps.back(), CopyType::General, stream);
stx = arr.shape(-1);
return std::make_tuple(false, stx, temps.back());
}
@@ -142,7 +142,7 @@ void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
CopyType ctype = c.data_size() == 1
? CopyType::Scalar
: (c.flags().row_contiguous ? CopyType::Vector : CopyType::General);
copy(c, out, ctype, stream());
copy_cpu(c, out, ctype, stream());
if (inputs[0].shape(-1) == 0) {
return;
}

View File

@@ -22,7 +22,7 @@ void reshape(const array& in, array& out) {
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) {
out.set_data(allocator::malloc(out.nbytes()));
copy_inplace(in, out, CopyType::General, out.primitive().stream());
copy_cpu_inplace(in, out, CopyType::General, out.primitive().stream());
} else {
shared_buffer_reshape(in, out_strides, out);
}
@@ -175,7 +175,7 @@ void AsType::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
copy(in, out, ctype, stream());
copy_cpu(in, out, ctype, stream());
}
void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -198,7 +198,7 @@ void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
size_t data_offset = strides[axis_] * sizes[i];
out_slice.copy_shared_buffer(
out, strides, flags, out_slice.size(), data_offset);
copy_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, stream());
copy_cpu_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, stream());
}
}
@@ -211,7 +211,7 @@ void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) {
(allow_col_major_ && in.flags().col_contiguous))) {
out.copy_shared_buffer(in);
} else {
copy(in, out, CopyType::General, stream());
copy_cpu(in, out, CopyType::General, stream());
}
}
@@ -235,7 +235,7 @@ void Full::eval_cpu(const std::vector<array>& inputs, array& out) {
} else {
ctype = CopyType::General;
}
copy(in, out, ctype, stream());
copy_cpu(in, out, ctype, stream());
}
void Pad::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -251,7 +251,7 @@ void Pad::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(val.dtype() == in.dtype() && in.dtype() == out.dtype());
// Fill output with val
copy(val, out, CopyType::Scalar, stream());
copy_cpu(val, out, CopyType::Scalar, stream());
// Find offset for start of input values
size_t data_offset = 0;
@@ -266,7 +266,7 @@ void Pad::eval_cpu(const std::vector<array>& inputs, array& out) {
out, out.strides(), out.flags(), out_slice.size(), data_offset);
// Copy input values into the slice
copy_inplace(in, out_slice, CopyType::GeneralGeneral, stream());
copy_cpu_inplace(in, out_slice, CopyType::GeneralGeneral, stream());
}
void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -340,7 +340,7 @@ void DynamicSlice::eval_cpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc(out.nbytes()));
auto [in_offset, donated] =
compute_dynamic_offset(inputs[1], in.strides(), axes_, stream());
copy_inplace(
copy_cpu_inplace(
/* const array& src = */ in,
/* array& dst = */ out,
/* const Shape& data_shape = */ out.shape(),
@@ -372,11 +372,11 @@ void DynamicSliceUpdate::eval_cpu(
auto ctype = in.flags().contiguous && in.size() == in.data_size()
? CopyType::Vector
: CopyType::General;
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
copy_cpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
auto [out_offset, donated] =
compute_dynamic_offset(inputs[2], out.strides(), axes_, stream());
copy_inplace(
copy_cpu_inplace(
/* const array& src = */ upd,
/* array& dst = */ out,
/* const std::vector<int>& data_shape = */ upd.shape(),
@@ -412,14 +412,14 @@ void SliceUpdate::eval_cpu(const std::vector<array>& inputs, array& out) {
auto ctype = in.flags().contiguous && in.size() == in.data_size()
? CopyType::Vector
: CopyType::General;
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
copy_cpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
// Calculate out strides, initial offset and if copy needs to be made
auto [data_offset, out_strides] =
prepare_slice(out, start_indices_, strides_);
// Do copy
copy_inplace(
copy_cpu_inplace(
/* const array& src = */ upd,
/* array& dst = */ out,
/* const std::vector<int>& data_shape = */ upd.shape(),
@@ -456,9 +456,9 @@ void View::eval_cpu(const std::vector<array>& inputs, array& out) {
if (in.dtype() == bool_) {
auto in_tmp = array(in.shape(), uint8, nullptr, {});
in_tmp.copy_shared_buffer(in);
copy_inplace(in_tmp, tmp, CopyType::General, stream());
copy_cpu_inplace(in_tmp, tmp, CopyType::General, stream());
} else {
copy_inplace(in, tmp, CopyType::General, stream());
copy_cpu_inplace(in, tmp, CopyType::General, stream());
}
auto flags = out.flags();

View File

@@ -26,7 +26,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
strides[in.ndim() - 2] = 1;
strides[in.ndim() - 1] = M;
in.set_data(allocator::malloc(in.nbytes()), in.nbytes(), strides, flags);
copy_inplace(a, in, CopyType::GeneralGeneral, stream);
copy_cpu_inplace(a, in, CopyType::GeneralGeneral, stream);
auto& encoder = cpu::get_command_encoder(stream);
q.set_data(allocator::malloc(q.nbytes()));
r.set_data(allocator::malloc(r.nbytes()));

View File

@@ -529,7 +529,7 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
return arr;
} else {
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
copy(arr, temps.back(), CopyType::General, s);
copy_cpu(arr, temps.back(), CopyType::General, s);
return temps.back();
}
};
@@ -579,7 +579,7 @@ void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
return arr;
} else {
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
copy(arr, temps.back(), CopyType::General, s);
copy_cpu(arr, temps.back(), CopyType::General, s);
return temps.back();
}
};
@@ -713,7 +713,7 @@ void fast::AffineQuantize::eval_cpu(
return std::make_pair(arr, false);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General, s);
copy_cpu(arr, arr_copy, CopyType::General, s);
return std::make_pair(arr_copy, true);
}
};

View File

@@ -251,7 +251,7 @@ void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
auto in = inputs[0];
if (!in.flags().row_contiguous) {
array arr_copy(in.shape(), in.dtype(), nullptr, {});
copy(in, arr_copy, CopyType::General, stream());
copy_cpu(in, arr_copy, CopyType::General, stream());
in = arr_copy;
encoder.add_temporary(arr_copy);
}

View File

@@ -132,7 +132,7 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
return x;
} else {
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy(x, x_copy, CopyType::General, s);
copy_cpu(x, x_copy, CopyType::General, s);
out.copy_shared_buffer(x_copy);
return x_copy;
}

View File

@@ -334,8 +334,10 @@ void Sort::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0];
// Copy input to output
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
copy(in, out, ctype, stream());
CopyType ctype = (in.flags().contiguous && in.strides()[axis_] != 0)
? CopyType::Vector
: CopyType::General;
copy_cpu(in, out, ctype, stream());
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_output_array(out);
@@ -426,8 +428,10 @@ void Partition::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0];
// Copy input to output
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
copy(in, out, ctype, stream());
CopyType ctype = (in.flags().contiguous && in.strides()[axis_] != 0)
? CopyType::Vector
: CopyType::General;
copy_cpu(in, out, ctype, stream());
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_output_array(out);

View File

@@ -31,7 +31,7 @@ void svd_impl(
// lapack clobbers the input, so we have to make a copy.
array in(a.shape(), a.dtype(), nullptr, {});
copy(
copy_cpu(
a,
in,
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,

View File

@@ -53,9 +53,10 @@ struct FusedKernelBuilder {
// Build function signature.
if (contiguous) {
os += "template <typename IdxT = uint32_t>\n";
os += "template <typename IdxT = uint32_t, int work_per_thread = 1>\n";
} else {
os += "template <int NDIM, typename IdxT = uint32_t>\n";
os +=
"template <int NDIM, typename IdxT = uint32_t, int work_per_thread = 1>\n";
}
os += fmt::format("__global__ void {}(\n", kernel_name + name);
for (size_t i = 0; i < params.size(); ++i) {
@@ -67,12 +68,46 @@ struct FusedKernelBuilder {
}
os += ") {\n";
// Index.
// Index. For non contiguous kernels we create a separate index
// variable per variable otherwise everyone uses `index`.
os +=
" IdxT index = cg::this_grid().thread_rank();\n"
" IdxT index = cg::this_grid().thread_rank() * work_per_thread;\n"
" if (index >= size) {\n"
" return;\n"
" }\n";
if (!contiguous) {
for (size_t i = 0; i < inputs.size(); ++i) {
const auto& x = inputs[i];
const std::string& xname = namer.get_name(x);
if (is_scalar(x) || is_constant(i)) {
continue;
}
os += " IdxT " + xname + "_idx = 0;\n";
}
os += " {\n";
os += " IdxT loc = index;\n";
os +=
" #pragma unroll\n"
" for (int i = NDIM - 1; i >= 0; i--) {\n";
for (size_t i = 0; i < inputs.size(); ++i) {
const auto& x = inputs[i];
const std::string& xname = namer.get_name(x);
if (is_scalar(x) || is_constant(i)) {
continue;
}
os += " " + xname + "_idx += (loc \% shape[i]) * IdxT(" + xname +
"_strides[i]);\n";
}
os +=
" loc /= shape[i];\n"
" }\n"
" }\n";
}
// Work loop
os +=
"\n"
" for (int i = 0; i < work_per_thread && index < size; i++) {\n";
// Read inputs.
for (size_t i = 0; i < inputs.size(); ++i) {
@@ -89,12 +124,9 @@ struct FusedKernelBuilder {
} else if (contiguous) {
value = fmt::format("{}[index]", xname);
} else {
std::string index = fmt::format(
"elem_to_loc_nd<NDIM>(index, shape.data(), {}_strides.data())",
xname);
value = fmt::format("{}[{}]", xname, index);
value = fmt::format("{}[{}_idx]", xname, xname);
}
os += fmt::format(" {} tmp_{} = {};\n", type, xname, value);
os += fmt::format(" {} tmp_{} = {};\n", type, xname, value);
}
// Write tape.
@@ -113,14 +145,30 @@ struct FusedKernelBuilder {
}
value += fmt::format("tmp_{})", namer.get_name(x.inputs().back()));
}
os += fmt::format(" {} tmp_{} = {};\n", type, xname, value);
os += fmt::format(" {} tmp_{} = {};\n", type, xname, value);
}
// Write output.
for (const auto& x : outputs) {
os += fmt::format(" {0}[index] = tmp_{0};\n", namer.get_name(x));
os += fmt::format(" {0}[index] = tmp_{0};\n", namer.get_name(x));
}
// End of work loop
os +=
"\n"
" index++;\n";
if (!contiguous) {
for (size_t i = 0; i < inputs.size(); ++i) {
const auto& x = inputs[i];
const std::string& xname = namer.get_name(x);
if (is_scalar(x) || is_constant(i)) {
continue;
}
os += " " + xname + "_idx += " + xname + "_strides[NDIM - 1];\n";
}
}
os += " }\n";
os += "}\n";
}
};
@@ -156,15 +204,28 @@ void Compiled::eval_gpu(
builder.build("_strided", false);
builder.os += "\n} // namespace mlx::core::cu\n";
// Build kernel names.
std::vector<std::string> kernel_names = {
fmt::format("mlx::core::cu::{}_contiguous<uint32_t>", lib_name()),
fmt::format("mlx::core::cu::{}_contiguous<int64_t>", lib_name()),
};
for (int i = 1; i <= MAX_NDIM; ++i) {
std::vector<std::string> kernel_names;
for (auto work_per_thread : std::array<int, 2>{1, 4}) {
kernel_names.push_back(fmt::format(
"mlx::core::cu::{}_strided<{}, uint32_t>", lib_name(), i));
kernel_names.push_back(
fmt::format("mlx::core::cu::{}_strided<{}, int64_t>", lib_name(), i));
"mlx::core::cu::{}_contiguous<uint32_t, {}>",
lib_name(),
work_per_thread));
kernel_names.push_back(fmt::format(
"mlx::core::cu::{}_contiguous<int64_t, {}>",
lib_name(),
work_per_thread));
for (int i = 1; i <= MAX_NDIM; ++i) {
kernel_names.push_back(fmt::format(
"mlx::core::cu::{}_strided<{}, uint32_t, {}>",
lib_name(),
i,
work_per_thread));
kernel_names.push_back(fmt::format(
"mlx::core::cu::{}_strided<{}, int64_t, {}>",
lib_name(),
i,
work_per_thread));
}
}
return std::make_pair(std::move(builder.os), std::move(kernel_names));
});
@@ -207,13 +268,21 @@ void Compiled::eval_gpu(
args.append<uint32_t>(outputs[0].data_size());
}
// Choose work per thread
int work_per_thread = 4;
if (!contiguous && shape.back() % work_per_thread != 0) {
work_per_thread = 1;
}
// Launch kernel.
const char* index_type = large ? "int64_t" : "uint32_t";
std::string kernel_name = fmt::format("mlx::core::cu::{}", lib_name());
if (contiguous) {
kernel_name += fmt::format("_contiguous<{}>", index_type);
kernel_name +=
fmt::format("_contiguous<{}, {}>", index_type, work_per_thread);
} else {
kernel_name += fmt::format("_strided<{}, {}>", shape.size(), index_type);
kernel_name += fmt::format(
"_strided<{}, {}, {}>", shape.size(), index_type, work_per_thread);
}
auto& encoder = cu::get_command_encoder(s);
for (const auto& in : inputs) {
@@ -224,7 +293,8 @@ void Compiled::eval_gpu(
}
auto kernel = mod.get_kernel(kernel_name);
auto [num_blocks, block_dims] = get_launch_args(kernel, outputs[0], large);
auto [num_blocks, block_dims] =
get_launch_args(kernel, outputs[0], large, work_per_thread);
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
}

View File

@@ -66,7 +66,6 @@ CommandEncoder& Device::get_command_encoder(Stream s) {
}
CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
CHECK_CUDA_ERROR(cudaGraphCreate(&graph, 0));
CHECK_CUDA_ERROR(
cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal));
}

View File

@@ -52,13 +52,29 @@ const std::string& cuda_home() {
}
// Return the location of CCCL headers shipped with the distribution.
bool get_cccl_include(std::string* out) {
auto cccl_headers = current_binary_dir().parent_path() / "include" / "cccl";
if (!std::filesystem::exists(cccl_headers)) {
return false;
}
*out = fmt::format("--include-path={}", cccl_headers.string());
return true;
const std::string& cccl_dir() {
static std::string dir = []() {
std::filesystem::path path;
#if defined(MLX_CCCL_DIR)
// First search the install dir if defined.
path = MLX_CCCL_DIR;
if (std::filesystem::exists(path)) {
return path.string();
}
#endif
// Then search dynamically from the dir of libmlx.so file.
path = current_binary_dir().parent_path() / "include" / "cccl";
if (std::filesystem::exists(path)) {
return path.string();
}
// Finally check the environment variable.
path = std::getenv("MLX_CCCL_DIR");
if (!path.empty() && std::filesystem::exists(path)) {
return path.string();
}
return std::string();
}();
return dir;
}
// Get the cache directory for storing compiled results.
@@ -121,7 +137,8 @@ void write_cached_ptx(
const std::filesystem::path& cache_dir,
const std::string& module_name,
const std::vector<char>& ptx,
const std::vector<std::pair<std::string, std::string>>& ptx_kernels) {
const std::vector<std::pair<std::string, std::string>>& ptx_kernels,
const std::string& source_code) {
if (cache_dir.empty()) {
return;
}
@@ -134,6 +151,9 @@ void write_cached_ptx(
for (const auto& [name, mangled] : ptx_kernels) {
txt_file << name << "\t" << mangled << std::endl;
}
std::ofstream source_file(cache_dir / (module_name + ".cu"));
source_file << source_code;
}
// Return if |device|'s version is not newer than |major|.|minor| version.
@@ -234,8 +254,9 @@ JitModule::JitModule(
device.compute_capability_major(),
device.compute_capability_minor());
args.push_back(compute.c_str());
std::string cccl_include;
if (get_cccl_include(&cccl_include)) {
std::string cccl_include = cccl_dir();
if (!cccl_include.empty()) {
cccl_include = fmt::format("--include-path={}", cccl_include);
args.push_back(cccl_include.c_str());
}
std::string cuda_include =
@@ -272,7 +293,8 @@ JitModule::JitModule(
} else {
CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data()));
}
write_cached_ptx(ptx_cache_dir(), module_name, ptx, ptx_kernels);
write_cached_ptx(
ptx_cache_dir(), module_name, ptx, ptx_kernels, source_code);
}
// Load module.

View File

@@ -237,8 +237,7 @@ void LayerNorm::eval_gpu(
}
return x;
} else {
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
array x_copy = contiguous_copy_gpu(x, s);
out.copy_shared_buffer(x_copy);
return x_copy;
}
@@ -295,9 +294,7 @@ void LayerNormVJP::eval_gpu(
return x;
}
copied = true;
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
return x_copy;
return contiguous_copy_gpu(x, s);
};
bool donate_x = inputs[0].is_donatable();
bool donate_g = inputs[3].is_donatable();

View File

@@ -108,8 +108,7 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
return x;
} else {
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
array x_copy = contiguous_copy_gpu(x, s);
encoder.add_temporary(x_copy);
return x_copy;
}

View File

@@ -27,6 +27,35 @@ void check_cublas_error(const char* name, cublasStatus_t err) {
}
}
struct CublasPreference {
CublasPreference(Device& device) {
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB
// for Hopper+:
// https://docs.nvidia.com/cuda/cublas/#cublassetworkspace
uint64_t MiB = 1024 * 1024;
uint64_t workspace_size =
device.compute_capability_major() >= 9 ? 32 * MiB : 4 * MiB;
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceCreate(&pref_));
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceSetAttribute(
pref_,
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&workspace_size,
sizeof(uint64_t)));
}
~CublasPreference() {
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceDestroy(pref_));
}
cublasLtMatmulPreference_t pref_{nullptr};
};
cublasLtMatmulPreference_t cublas_preference(Device& device) {
static CublasPreference pref(device);
return pref.pref_;
}
class MatMul {
public:
MatMul(
@@ -43,7 +72,7 @@ class MatMul {
int32_t batch_count,
int64_t a_batch_stride,
int64_t b_batch_stride)
: handle_(device.lt_handle()) {
: handle_(device.lt_handle()), pref_(cublas_preference(device)) {
heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED;
auto scale_type = dtype_to_cuda_type(dtype);
@@ -77,20 +106,6 @@ class MatMul {
type, b_rows, b_cols, b_transposed, ldb, batch_count, b_batch_stride);
out_desc_ = create_matrix_layout(
type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols);
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB
// for Hopper+:
// https://docs.nvidia.com/cuda/cublas/#cublassetworkspace
uint64_t MiB = 1024 * 1024;
uint64_t workspace_size =
device.compute_capability_major() >= 9 ? 32 * MiB : 4 * MiB;
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceCreate(&pref_));
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceSetAttribute(
pref_,
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&workspace_size,
sizeof(uint64_t)));
}
MatMul(
@@ -104,7 +119,6 @@ class MatMul {
uint64_t b_rows,
uint64_t b_cols,
int64_t ldb,
bool c_transposed,
int64_t ldc,
int32_t batch_count,
int64_t a_batch_stride,
@@ -126,15 +140,15 @@ class MatMul {
b_batch_stride) {
auto type = dtype_to_cuda_type(dtype);
c_desc_ = create_matrix_layout(
type, a_rows, b_cols, c_transposed, ldc, batch_count, c_batch_stride);
type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride);
}
~MatMul() {
cublasLtMatrixLayoutDestroy(a_desc_);
cublasLtMatrixLayoutDestroy(b_desc_);
cublasLtMatrixLayoutDestroy(c_desc_);
cublasLtMatrixLayoutDestroy(out_desc_);
cublasLtMatmulDescDestroy(matmul_desc_);
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(a_desc_));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(b_desc_));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(c_desc_));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_));
CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_));
}
void run(
@@ -259,9 +273,9 @@ class MatMul {
return desc;
}
cublasLtMatmulPreference_t pref_{nullptr};
cublasLtHandle_t handle_{nullptr};
cublasLtMatmulDesc_t matmul_desc_{nullptr};
cublasLtMatmulPreference_t pref_{nullptr};
cublasLtMatrixLayout_t a_desc_{nullptr};
cublasLtMatrixLayout_t b_desc_{nullptr};
cublasLtMatrixLayout_t c_desc_{nullptr};
@@ -282,8 +296,7 @@ check_transpose(cu::CommandEncoder& enc, const Stream& s, const array& arr) {
} else if (stx == 1 && sty == arr.shape(-2)) {
return std::make_tuple(true, sty, arr);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
array arr_copy = contiguous_copy_gpu(arr, s);
enc.add_temporary(arr_copy);
return std::make_tuple(false, arr.shape(-1), arr_copy);
}
@@ -389,9 +402,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 3);
auto& a_pre = inputs[0];
auto& b_pre = inputs[1];
auto& c_pre = inputs[2];
out.set_data(allocator::malloc(out.nbytes()));
auto c = inputs[2];
/////////////////////////////////////////////////////////////////////////////
// Init checks and prep
@@ -404,7 +415,24 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
// the arrays
auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);
auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);
auto [c_transposed, ldc, c] = check_transpose(encoder, s, c_pre);
int64_t ldc;
{
auto stx = c.strides()[c.ndim() - 2];
auto sty = c.strides()[c.ndim() - 1];
if (sty == 1 && stx == c.shape(-1)) {
ldc = stx;
out.set_data(allocator::malloc(out.nbytes()));
} else if (sty == 1 && stx == 0) {
ldc = 0;
out.set_data(allocator::malloc(out.nbytes()));
} else {
// Copy C into out and set C to out
ldc = c.shape(-1);
copy_gpu(c, out, CopyType::General, s);
c = out;
}
}
/////////////////////////////////////////////////////////////////////////////
// Check and collapse batch dimensions
@@ -442,7 +470,6 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
K,
N,
ldb,
c_transposed,
ldc,
batch_shape.back(),
a_batch_strides.back(),

View File

@@ -36,7 +36,8 @@ affine_quantize(const T* w, uint8_t* out, T* scales, T* biases, size_t size) {
auto tidx = block_idx.x * block_size.x + idx_in_block.x;
auto tidy = block_idx.y * block_size.y + idx_in_block.y;
auto grid_dim = cg::this_grid().dim_threads();
auto grid_dim_x =
cg::this_grid().dim_blocks().x * cg::this_grid().block_index().x;
constexpr float eps = 1e-7;
constexpr int simd_size = WARP_SIZE;
constexpr float n_bins = (1 << bits) - 1;
@@ -48,7 +49,7 @@ affine_quantize(const T* w, uint8_t* out, T* scales, T* biases, size_t size) {
writes_per_reduce > 1 ? 1 : values_per_reduce / pack_factor;
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
size_t offset = tidx + grid_dim.x * size_t(tidy);
size_t offset = tidx + grid_dim_x * size_t(tidy);
size_t in_index = offset * values_per_reduce;
if (in_index >= size) {
return;
@@ -153,12 +154,13 @@ __global__ void affine_dequantize(
auto tidx = block_idx.x * block_size.x + idx_in_block.x;
auto tidy = block_idx.y * block_size.y + idx_in_block.y;
auto grid_dim = cg::this_grid().dim_threads();
auto grid_dim_x =
cg::this_grid().dim_blocks().x * cg::this_grid().block_index().x;
constexpr int pack_factor = get_pack_factor<bits, 8>();
constexpr int bytes_per_pack = get_bytes_per_pack<bits>();
size_t offset = tidx + grid_dim.x * size_t(tidy);
size_t offset = tidx + grid_dim_x * size_t(tidy);
size_t oindex = offset * pack_factor;
if (oindex >= size) {
@@ -245,8 +247,7 @@ inline array ensure_row_contiguous(
cu::CommandEncoder& enc,
const Stream& s) {
if (!x.flags().row_contiguous) {
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
array x_copy = contiguous_copy_gpu(x, s);
enc.add_temporary(x_copy);
return x_copy;
} else {
@@ -349,7 +350,8 @@ void fast::AffineQuantize::eval_gpu(
dispatch_bits(bits_, [&](auto bits) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
if (dequantize_) {
auto kernel = cu::affine_dequantize<DataType, group_size(), bits()>;
auto kernel =
cu::affine_dequantize<DataType, group_size.value, bits.value>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, size, grid_shape, w.strides(), large);
enc.add_kernel_node(
@@ -362,7 +364,8 @@ void fast::AffineQuantize::eval_gpu(
out.data<DataType>(),
out.size());
} else {
auto kernel = cu::affine_quantize<DataType, group_size(), bits()>;
auto kernel =
cu::affine_quantize<DataType, group_size.value, bits.value>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, size, grid_shape, w.strides(), large);
enc.add_kernel_node(

View File

@@ -47,8 +47,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
}
}
if (plan.type == GeneralReduce || broadcasted || !in.flags().contiguous) {
array in_copy(in.shape(), in.dtype(), nullptr, {});
copy_gpu(in, in_copy, CopyType::General, s);
array in_copy = contiguous_copy_gpu(in, s);
encoder.add_temporary(in_copy);
in = in_copy;
plan = get_reduction_plan(in, axes_);

View File

@@ -206,8 +206,7 @@ void RMSNorm::eval_gpu(
}
return x;
} else {
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
array x_copy = contiguous_copy_gpu(x, s);
out.copy_shared_buffer(x_copy);
return x_copy;
}
@@ -259,9 +258,7 @@ void RMSNormVJP::eval_gpu(
return x;
}
copied = true;
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
return x_copy;
return contiguous_copy_gpu(x, s);
};
bool donate_x = inputs[0].is_donatable();
bool donate_g = inputs[2].is_donatable();

View File

@@ -379,9 +379,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
in.flags());
}
} else {
array arr_copy(in.shape(), in.dtype(), nullptr, {});
copy_gpu(in, arr_copy, CopyType::General, s);
in = std::move(arr_copy);
in = contiguous_copy_gpu(in, s);
out.copy_shared_buffer(in);
}

View File

@@ -125,8 +125,7 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
}
return x;
} else {
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
array x_copy = contiguous_copy_gpu(x, s);
out.copy_shared_buffer(x_copy);
return x_copy;
}

View File

@@ -72,8 +72,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
bool is_segmented_sort = in.flags().contiguous && in.strides()[axis] == 1;
if (!is_segmented_sort) {
array trans = swapaxes_in_eval(in, axis, last_dim);
in = array(trans.shape(), trans.dtype(), nullptr, {});
copy_gpu(trans, in, CopyType::General, s);
in = contiguous_copy_gpu(trans, s);
encoder.add_temporary(in);
out = array(allocator::malloc(out.nbytes()), in.shape(), out.dtype());
encoder.add_temporary(out);

View File

@@ -46,4 +46,10 @@ void copy_gpu_inplace(
in, out, in.shape(), i_strides, out.strides(), i_offset, 0, ctype, s);
}
array contiguous_copy_gpu(const array& arr, const Stream& s) {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
return arr_copy;
}
} // namespace mlx::core

View File

@@ -43,4 +43,7 @@ void copy_gpu_inplace(
// Fill the output with the scalar val
void fill_gpu(const array& val, array& out, const Stream& s);
// Return a contiguous array with same shape that copies the data of |arr|.
array contiguous_copy_gpu(const array& arr, const Stream& s);
} // namespace mlx::core

View File

@@ -149,8 +149,7 @@ void explicit_gemm_conv_group_ND_gpu(
wt, {wt.strides(0), 1, C_per_group}, wt.flags(), wt.size());
// Materialize
auto wt_transpose = array(wt_view.shape(), wt_view.dtype(), nullptr, {});
copy_gpu(wt_view, wt_transpose, CopyType::General, s);
array wt_transpose = contiguous_copy_gpu(wt_view, s);
// Perform gemm
std::vector<array> copies = {in_unfolded, wt_transpose};
@@ -961,16 +960,12 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
auto in = inputs[0];
auto wt = inputs[1];
if (!in.flags().row_contiguous) {
array arr_copy(in.shape(), in.dtype(), nullptr, {});
copy_gpu(in, arr_copy, CopyType::General, s);
copies.push_back(arr_copy);
in = arr_copy;
in = contiguous_copy_gpu(in, s);
copies.push_back(in);
}
if (!wt.flags().row_contiguous) {
array arr_copy(wt.shape(), wt.dtype(), nullptr, {});
copy_gpu(wt, arr_copy, CopyType::General, s);
copies.push_back(arr_copy);
wt = arr_copy;
wt = contiguous_copy_gpu(wt, s);
copies.push_back(wt);
}
// 3D conv

View File

@@ -25,8 +25,7 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
return x;
} else {
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
array x_copy = contiguous_copy_gpu(x, s);
d.add_temporary(x_copy, s.index);
return x_copy;
}

View File

@@ -33,8 +33,7 @@ std::tuple<bool, int64_t, array> check_transpose(
} else if (stx == 1 && (!is_vector || sty == arr.shape(-2))) {
return std::make_tuple(true, sty, arr);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
array arr_copy = contiguous_copy_gpu(arr, s);
copies.push_back(arr_copy);
return std::make_tuple(false, arr.shape(-1), arr_copy);
}
@@ -43,8 +42,7 @@ std::tuple<bool, int64_t, array> check_transpose(
inline array
ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) {
if (!x.flags().row_contiguous) {
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
array x_copy = contiguous_copy_gpu(x, s);
d.add_temporary(x_copy, s.index);
return x_copy;
} else {
@@ -75,8 +73,7 @@ ensure_batch_contiguous(const array& x, metal::Device& d, const Stream& s) {
}
}
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
array x_copy = contiguous_copy_gpu(x, s);
d.add_temporary(x_copy, s.index);
return std::make_tuple(false, x_copy.strides()[x_copy.ndim() - 2], x_copy);
}
@@ -1894,8 +1891,7 @@ void segmented_mm(
return std::make_tuple(false, x);
}
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
array x_copy = contiguous_copy_gpu(x, s);
d.add_temporary(x_copy, s.index);
return std::make_tuple(true, x_copy);
};

View File

@@ -40,8 +40,7 @@ void RMSNorm::eval_gpu(
}
return x;
} else {
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
array x_copy = contiguous_copy_gpu(x, s);
out.copy_shared_buffer(x_copy);
return x_copy;
}
@@ -107,9 +106,7 @@ void RMSNormVJP::eval_gpu(
if (x.flags().row_contiguous) {
return {x, false};
}
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
array x_copy = contiguous_copy_gpu(x, s);
return {x_copy, true};
};
bool donate_x = inputs[0].is_donatable();
@@ -241,8 +238,7 @@ void LayerNorm::eval_gpu(
}
return x;
} else {
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
array x_copy = contiguous_copy_gpu(x, s);
out.copy_shared_buffer(x_copy);
return x_copy;
}
@@ -319,8 +315,7 @@ void LayerNormVJP::eval_gpu(
if (x.flags().row_contiguous) {
return {x, false};
}
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
array x_copy = contiguous_copy_gpu(x, s);
return {x_copy, true};
};
bool donate_x = inputs[0].is_donatable();

View File

@@ -20,8 +20,7 @@ namespace {
inline array
ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) {
if (!x.flags().row_contiguous) {
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
array x_copy = contiguous_copy_gpu(x, s);
d.add_temporary(x_copy, s.index);
return x_copy;
} else {
@@ -38,8 +37,7 @@ inline array ensure_row_contiguous_matrix(
if (stride_0 == x.shape(-1) && stride_1 == 1) {
return x;
} else {
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
array x_copy = contiguous_copy_gpu(x, s);
d.add_temporary(x_copy, s.index);
return x_copy;
}

View File

@@ -989,8 +989,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
// input for the axes with stride smaller than the minimum reduction
// stride.
if (plan.type == GeneralReduce) {
array in_copy(in.shape(), in.dtype(), nullptr, {});
copy_gpu(in, in_copy, CopyType::General, s);
array in_copy = contiguous_copy_gpu(in, s);
d.add_temporary(in_copy, s.index);
in = in_copy;
plan = get_reduction_plan(in, axes_);

View File

@@ -398,8 +398,7 @@ void ScaledDotProductAttention::eval_gpu(
auto copy_unless = [&copies, &s](
auto predicate, const array& arr) -> const array& {
if (!predicate(arr)) {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
array arr_copy = contiguous_copy_gpu(arr, s);
copies.push_back(std::move(arr_copy));
return copies.back();
} else {

View File

@@ -30,9 +30,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
in.flags());
}
} else {
array arr_copy(in.shape(), in.dtype(), nullptr, {});
copy_gpu(in, arr_copy, CopyType::General, s);
in = std::move(arr_copy);
in = contiguous_copy_gpu(in, s);
out.copy_shared_buffer(in);
}

View File

@@ -35,8 +35,7 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
}
return x;
} else {
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
array x_copy = contiguous_copy_gpu(x, s);
out.copy_shared_buffer(x_copy);
return x_copy;
}

View File

@@ -4,7 +4,7 @@
#define MLX_VERSION_MAJOR 0
#define MLX_VERSION_MINOR 26
#define MLX_VERSION_PATCH 3
#define MLX_VERSION_PATCH 5
#define MLX_VERSION_NUMERIC \
(100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH)

View File

@@ -848,6 +848,106 @@ class Adafactor(Optimizer):
return parameter - update
class Muon(Optimizer):
r"""The Muon optimizer.
Our Muon (MomentUm Orthogonalized by Newton-schulz) optimizer follows the
original implementation: `Muon: An optimizer for hidden layers in neural
networks <https://kellerjordan.github.io/posts/muon/>`_
Note:
- Muon may be sub-optimal for the embedding layer, the final fully
connected layer, or any 0D/1D parameters. Those should be optimized
by a different method (e.g., :class:`AdamW`).
- For 4D convolutional filters, it works by flattening their last
dimensions.
Args:
learning_rate (float or callable): The learning rate.
momentum (float, optional): The momentum strength. Default: ``0.95``
weight_decay (float, optional): The weight decay (L2 penalty).
Default: ``0.01``
nesterov (bool, optional): Enables Nesterov momentum. Recommended for
better performance. Default: ``True``
ns_steps (int, optional): Number of Newton-Schulz iteration steps for
orthogonalization. Default: ``5``
"""
def __init__(
self,
learning_rate: Union[float, Callable[[mx.array], mx.array]],
momentum: float = 0.95,
weight_decay: float = 0.01,
nesterov: bool = True,
ns_steps: int = 5,
):
super().__init__()
self._maybe_schedule("learning_rate", learning_rate)
self.momentum = momentum
self.weight_decay = weight_decay
self.nesterov = nesterov
self.ns_steps = ns_steps
def init_single(self, parameter: mx.array, state: dict):
"""Initialize optimizer state"""
state["v"] = mx.zeros_like(parameter)
def _zeropower_via_newtonschulz5(self, X, steps: int):
assert (
X.ndim == 2
), f"Expected a 2D array for Newton-Schulz iteration, got shape {X.shape} instead."
a, b, c = (3.4445, -4.7750, 2.0315)
transpose_needed = X.shape[-2] > X.shape[-1]
if transpose_needed:
X = X.T
X = X / (mx.linalg.norm(X, keepdims=True) + 1e-7)
for _ in range(steps):
A = X @ X.T
B = mx.addmm(b * A, A, A, beta=1.0, alpha=c)
X = mx.addmm(a * X, B, X, beta=1.0, alpha=1.0)
if transpose_needed:
X = X.T
return X
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
"""Performs the Muon parameter update"""
if self.weight_decay != 0:
gradient = gradient + self.weight_decay * parameter
v = self.momentum * state["v"]
v = v + (1 - self.momentum) * gradient
state["v"] = v
if self.nesterov:
update = gradient * (1 - self.momentum) + v * self.momentum
else:
update = v
lr = self.learning_rate.astype(gradient.dtype)
if update.ndim >= 2:
original_shape = update.shape
reshape_needed = update.ndim > 2
if reshape_needed:
update = mx.reshape(update, (update.shape[0], -1))
update = self._zeropower_via_newtonschulz5(update, steps=self.ns_steps)
if reshape_needed:
update = mx.reshape(update, original_shape)
lr *= max(1, update.shape[-2] / update.shape[-1]) ** 0.5
return parameter - lr * update
def clip_grad_norm(grads, max_norm):
"""Clips the global norm of the gradients.

View File

@@ -4,6 +4,7 @@ import unittest
import mlx.core as mx
import mlx_distributed_tests
import mlx_tests
class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):

View File

@@ -691,6 +691,21 @@ class TestBlas(mlx_tests.MLXTestCase):
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
# Transposed c
a = mx.ones((10, 5)).T
b = mx.ones((5, 5))
out = mx.addmm(a, b, a, beta=1.5, alpha=0.5)
expected = 1.5 * a + 0.5 * (b @ a)
self.assertTrue(mx.allclose(expected, out))
# Broadcast c
a = mx.ones((5, 5))
b = mx.ones((5, 5))
c = mx.ones((1, 5))
out = mx.addmm(c, a, b, beta=1.5, alpha=0.5)
expected = 1.5 * c + 0.5 * (a @ b)
self.assertTrue(mx.allclose(expected, out))
def test_addmm_grad(self):
def make_ref_addmm(alpha, beta):
return lambda c, a, b: alpha * (a @ b) + beta * c

View File

@@ -286,6 +286,53 @@ class TestOptimizers(mlx_tests.MLXTestCase):
self.assertEqual(xp["x"].shape, x.shape)
self.assertEqual(optimizer.state["step"], 2)
def test_muon(self):
params = {
"first": [mx.zeros((10, 5)), mx.zeros((1,))],
"second": mx.zeros((3, 3)),
"conv": mx.zeros((16, 8, 3, 3)),
}
grads = tree_map(lambda x: mx.ones_like(x), params)
# Explicit init
optim = opt.Muon(learning_rate=1e-2, momentum=0.95, nesterov=True)
optim.init(params)
self.assertTrue(
tree_equal(
lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)),
params,
optim.state,
)
)
# Test update
updated_params = optim.apply_gradients(grads, params)
# Check that shapes are preserved
self.assertTrue(
tree_equal(
lambda p, u: p.shape == u.shape,
params,
updated_params,
)
)
# Check that parameters actually changed
self.assertFalse(
tree_equal(
lambda p, u: mx.array_equal(p, u),
params,
updated_params,
)
)
# Test with different configurations
optim_no_nesterov = opt.Muon(learning_rate=1e-2, momentum=0.95, nesterov=False)
optim_no_nesterov.apply_gradients(grads, params)
optim_no_momentum = opt.Muon(learning_rate=1e-2, momentum=0.0)
optim_no_momentum.apply_gradients(grads, params)
def test_compiled_optimizer(self):
model = nn.Linear(10, 10)
x = mx.random.uniform(shape=(2, 10))

View File

@@ -39,6 +39,14 @@ target_sources(
linalg_tests.cpp
${METAL_TEST_SOURCES})
if(MLX_BUILD_CUDA)
# Find the CCCL headers in install dir.
target_compile_definitions(
mlx
PRIVATE
MLX_CCCL_DIR="${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_INCLUDEDIR}/cccl")
endif()
target_link_libraries(tests PRIVATE mlx doctest)
doctest_discover_tests(tests)
add_test(NAME tests COMMAND tests)