Compare commits

...

5 Commits

Author SHA1 Message Date
Cheng
940f4c7818 Fix building with CUDA < 12.8 (#2782)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-11-18 12:55:19 +09:00
Cheng
35f81728f1 Remove unneeded tests in nightly build (#2786) 2025-11-18 08:09:58 +09:00
Cheng
4442ed86c1 Fix nightly build (#2785) 2025-11-18 08:07:51 +09:00
Cheng
698559c231 Test every commit in main branch (#2781) 2025-11-18 08:07:22 +09:00
Cheng
ecc4879b07 Do not run CPU tests in CUDA builds (#2784) 2025-11-18 07:27:09 +09:00
42 changed files with 96 additions and 156 deletions

View File

@@ -35,7 +35,7 @@ runs:
run: | run: |
python -m venv .venv python -m venv .venv
source .venv/bin/activate source .venv/bin/activate
pip install cmake nanobind==2.4.0 pip install setuptools cmake nanobind==2.4.0
echo PATH=$PATH >> $GITHUB_ENV echo PATH=$PATH >> $GITHUB_ENV
# Make cmake search .venv for nanobind # Make cmake search .venv for nanobind
echo PYTHONPATH=`python -c 'import sys; print(sys.path[-1])'` >> $GITHUB_ENV echo PYTHONPATH=`python -c 'import sys; print(sys.path[-1])'` >> $GITHUB_ENV
@@ -56,7 +56,6 @@ runs:
PACKAGES: | PACKAGES: |
{ {
"cuda-12.6": "libcudnn9-dev-cuda-12 cuda-toolkit-12-6", "cuda-12.6": "libcudnn9-dev-cuda-12 cuda-toolkit-12-6",
"cuda-12.8": "libcudnn9-dev-cuda-12 cuda-toolkit-12-8",
"cuda-12.9": "libcudnn9-dev-cuda-12 cuda-toolkit-12-9", "cuda-12.9": "libcudnn9-dev-cuda-12 cuda-toolkit-12-9",
"cuda-13.0": "libcudnn9-dev-cuda-13 cuda-toolkit-13-0" "cuda-13.0": "libcudnn9-dev-cuda-13 cuda-toolkit-13-0"
} }

View File

@@ -30,6 +30,7 @@ runs:
echo "::endgroup::" echo "::endgroup::"
- name: Run Python tests - CPU - name: Run Python tests - CPU
if: ${{ inputs.cpu-only == 'true' }}
shell: bash shell: bash
env: env:
DEVICE: cpu DEVICE: cpu

View File

@@ -40,7 +40,7 @@ jobs:
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"] python_version: ["3.11", "3.12", "3.13", "3.14"]
runner: runner:
- ubuntu-22.04 - ubuntu-22.04
- ubuntu-22.04-arm - ubuntu-22.04-arm
@@ -78,23 +78,6 @@ jobs:
macos-target: 14.0 macos-target: 14.0
build-backend: ${{ matrix.python-version == '3.10' }} build-backend: ${{ matrix.python-version == '3.10' }}
build_cuda_with_tests:
if: github.repository == 'ml-explore/mlx'
strategy:
fail-fast: false
matrix:
toolkit: ['cuda-12.8', 'cuda-12.9']
runs-on: gpu-t4-4-core
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
with:
toolkit: ${{ matrix.toolkit }}
- uses: ./.github/actions/build-cuda
with:
toolkit: ${{ matrix.toolkit }}
- uses: ./.github/actions/test-linux
build_cuda_release: build_cuda_release:
if: github.repository == 'ml-explore/mlx' if: github.repository == 'ml-explore/mlx'
runs-on: ubuntu-22-large runs-on: ubuntu-22-large
@@ -113,25 +96,3 @@ jobs:
name: mlx-cuda name: mlx-cuda
path: wheelhouse/mlx_cuda-*.whl path: wheelhouse/mlx_cuda-*.whl
retention-days: 7 retention-days: 7
linux_fedora_build_cpp:
name: Linux Fedora CPP Build (${{ matrix.arch }})
strategy:
fail-fast: false
matrix:
include:
- host: ubuntu-22.04
arch: x86_64
- host: ubuntu-22.04-arm
arch: aarch64
runs-on: ${{ matrix.host }}
container:
image: fedora:42
steps:
- name: Checkout code
uses: actions/checkout@v5
- name: CPP Build Test - No Release
run: |
bash ./.github/scripts/setup+build-cpp-linux-fedora-container.sh

View File

@@ -7,7 +7,7 @@ permissions:
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.ref }} group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true cancel-in-progress: ${{ github.ref != 'refs/head/main' }}
jobs: jobs:
check_lint: check_lint:
@@ -52,7 +52,7 @@ jobs:
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
toolkit: ['cuda-12.8', 'cuda-12.9'] toolkit: ['cuda-12.6', 'cuda-12.9']
runs-on: gpu-t4-4-core runs-on: gpu-t4-4-core
needs: check_lint needs: check_lint
steps: steps:

View File

@@ -142,6 +142,7 @@ FetchContent_Declare(
URL "https://github.com/NVIDIA/cccl/releases/download/v2.8.1/cccl-v2.8.1.zip") URL "https://github.com/NVIDIA/cccl/releases/download/v2.8.1/cccl-v2.8.1.zip")
FetchContent_MakeAvailable(cccl) FetchContent_MakeAvailable(cccl)
target_include_directories(mlx BEFORE PRIVATE "${cccl_SOURCE_DIR}/include") target_include_directories(mlx BEFORE PRIVATE "${cccl_SOURCE_DIR}/include")
set_target_properties(mlx PROPERTIES CCCL_DIR "${cccl_SOURCE_DIR}/include")
# Use fixed version of NVTX. # Use fixed version of NVTX.
FetchContent_Declare( FetchContent_Declare(

View File

@@ -119,7 +119,8 @@ void copy_to_managed(CudaBuffer& buf) {
buf.data = new_data; buf.data = new_data;
} }
Buffer CudaAllocator::malloc_impl(size_t size, cudaStream_t stream) { Buffer
CudaAllocator::malloc_async(size_t size, int device, cudaStream_t stream) {
if (size == 0) { if (size == 0) {
return Buffer{new CudaBuffer{nullptr, 0, -1}}; return Buffer{new CudaBuffer{nullptr, 0, -1}};
} }
@@ -134,9 +135,8 @@ Buffer CudaAllocator::malloc_impl(size_t size, cudaStream_t stream) {
size = page_size * ((size + page_size - 1) / page_size); size = page_size * ((size + page_size - 1) / page_size);
} }
int device = -1; if (size <= small_block_size || stream == nullptr) {
if (size > small_block_size && stream != nullptr) { device = -1;
CHECK_CUDA_ERROR(cudaStreamGetDevice(stream, &device));
} }
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size); CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
@@ -182,12 +182,8 @@ Buffer CudaAllocator::malloc_impl(size_t size, cudaStream_t stream) {
return Buffer{buf}; return Buffer{buf};
} }
Buffer CudaAllocator::malloc_async(size_t size, cudaStream_t stream) {
return malloc_impl(size, stream);
}
Buffer CudaAllocator::malloc(size_t size) { Buffer CudaAllocator::malloc(size_t size) {
return malloc_impl(size, nullptr); return malloc_async(size, -1, nullptr);
} }
void CudaAllocator::free(Buffer buffer) { void CudaAllocator::free(Buffer buffer) {
@@ -277,8 +273,9 @@ CudaAllocator& allocator() {
return *allocator_; return *allocator_;
} }
Buffer malloc_async(size_t size, cudaStream_t stream) { Buffer malloc_async(size_t size, CommandEncoder& encoder) {
auto buffer = allocator().malloc_async(size, stream); auto buffer = allocator().malloc_async(
size, encoder.device().cuda_device(), encoder.stream());
if (size && !buffer.ptr()) { if (size && !buffer.ptr()) {
std::ostringstream msg; std::ostringstream msg;
msg << "[malloc_async] Unable to allocate " << size << " bytes."; msg << "[malloc_async] Unable to allocate " << size << " bytes.";

View File

@@ -13,6 +13,8 @@
namespace mlx::core::cu { namespace mlx::core::cu {
class CommandEncoder;
using allocator::Buffer; using allocator::Buffer;
// Stores cuda-managed unified memory. // Stores cuda-managed unified memory.
@@ -48,7 +50,7 @@ class SmallSizePool {
class CudaAllocator : public allocator::Allocator { class CudaAllocator : public allocator::Allocator {
public: public:
Buffer malloc(size_t size) override; Buffer malloc(size_t size) override;
Buffer malloc_async(size_t size, cudaStream_t stream); Buffer malloc_async(size_t size, int device, cudaStream_t stream);
void free(Buffer buffer) override; void free(Buffer buffer) override;
size_t size(Buffer buffer) const override; size_t size(Buffer buffer) const override;
@@ -62,7 +64,6 @@ class CudaAllocator : public allocator::Allocator {
void clear_cache(); void clear_cache();
private: private:
Buffer malloc_impl(size_t size, cudaStream_t stream);
void cuda_free(CudaBuffer* buf); void cuda_free(CudaBuffer* buf);
CudaAllocator(); CudaAllocator();
@@ -80,6 +81,6 @@ class CudaAllocator : public allocator::Allocator {
CudaAllocator& allocator(); CudaAllocator& allocator();
Buffer malloc_async(size_t size, cudaStream_t stream); Buffer malloc_async(size_t size, CommandEncoder& encoder);
} // namespace mlx::core::cu } // namespace mlx::core::cu

View File

@@ -42,7 +42,7 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
return; return;
} }
auto& encoder = cu::get_command_encoder(stream()); auto& encoder = cu::get_command_encoder(stream());
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); out.set_data(cu::malloc_async(out.nbytes(), encoder));
encoder.set_output_array(out); encoder.set_output_array(out);
dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) { dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) {

View File

@@ -143,7 +143,7 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream(); auto& s = stream();
auto& encoder = cu::get_command_encoder(s); auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); out.set_data(cu::malloc_async(out.nbytes(), encoder));
// Prepare the shapes, strides and axis arguments. // Prepare the shapes, strides and axis arguments.
Shape shape = remove_index(in.shape(), axis_); Shape shape = remove_index(in.shape(), axis_);

View File

@@ -367,9 +367,8 @@ void binary_op_gpu(
auto bopt = get_binary_op_type(a, b); auto bopt = get_binary_op_type(a, b);
auto& encoder = cu::get_command_encoder(s); auto& encoder = cu::get_command_encoder(s);
set_binary_op_output_data(a, b, out, bopt, [&](auto n) { set_binary_op_output_data(
return cu::malloc_async(n, encoder.stream()); a, b, out, bopt, [&](auto n) { return cu::malloc_async(n, encoder); });
});
binary_op_gpu_inplace<Op>(inputs, out, op, s); binary_op_gpu_inplace<Op>(inputs, out, op, s);
} }

View File

@@ -246,12 +246,10 @@ void binary_two_op_gpu_inplace(
auto& out_b = outputs[1]; auto& out_b = outputs[1];
auto bopt = get_binary_op_type(a, b); auto bopt = get_binary_op_type(a, b);
auto& encoder = cu::get_command_encoder(s); auto& encoder = cu::get_command_encoder(s);
set_binary_op_output_data(a, b, out_a, bopt, [&](auto n) { set_binary_op_output_data(
return cu::malloc_async(n, encoder.stream()); a, b, out_a, bopt, [&](auto n) { return cu::malloc_async(n, encoder); });
}); set_binary_op_output_data(
set_binary_op_output_data(a, b, out_b, bopt, [&](auto n) { a, b, out_b, bopt, [&](auto n) { return cu::malloc_async(n, encoder); });
return cu::malloc_async(n, encoder.stream());
});
if (out_a.size() == 0) { if (out_a.size() == 0) {
return; return;

View File

@@ -298,7 +298,7 @@ void Compiled::eval_gpu(
// Put outputs. // Put outputs.
compiled_allocate_outputs( compiled_allocate_outputs(
inputs, outputs, is_constant_, contiguous, [&](auto n) { inputs, outputs, is_constant_, contiguous, [&](auto n) {
return cu::malloc_async(n, encoder.stream()); return cu::malloc_async(n, encoder);
}); });
for (auto& x : outputs) { for (auto& x : outputs) {
args.append(x); args.append(x);

View File

@@ -277,7 +277,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
array in = inputs[0]; array in = inputs[0];
array wt = inputs[1]; array wt = inputs[1];
array out = out_; array out = out_;
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); out.set_data(cu::malloc_async(out.nbytes(), encoder));
Dtype dtype = out.dtype(); Dtype dtype = out.dtype();
// Search cache. // Search cache.

View File

@@ -86,7 +86,7 @@ array unfold_inputs_nd(
int mat_N, int mat_N,
ConvParams<NDIM>& params) { ConvParams<NDIM>& params) {
array unfolded({mat_M, mat_K}, in.dtype(), nullptr, {}); array unfolded({mat_M, mat_K}, in.dtype(), nullptr, {});
unfolded.set_data(cu::malloc_async(unfolded.nbytes(), encoder.stream())); unfolded.set_data(cu::malloc_async(unfolded.nbytes(), encoder));
encoder.add_temporary(unfolded); encoder.add_temporary(unfolded);
int filter_size = params.C; int filter_size = params.C;

View File

@@ -89,7 +89,7 @@ array grouped_unfold_transpose_inputs_nd(
int mat_N, int mat_N,
ConvParams<NDIM>& params) { ConvParams<NDIM>& params) {
array unfolded({mat_M, mat_K * params.groups}, in.dtype(), nullptr, {}); array unfolded({mat_M, mat_K * params.groups}, in.dtype(), nullptr, {});
unfolded.set_data(cu::malloc_async(unfolded.nbytes(), encoder.stream())); unfolded.set_data(cu::malloc_async(unfolded.nbytes(), encoder));
encoder.add_temporary(unfolded); encoder.add_temporary(unfolded);
int filter_size = params.C; int filter_size = params.C;

View File

@@ -7,9 +7,8 @@ namespace mlx::core {
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) { void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
auto& encoder = cu::get_command_encoder(s); auto& encoder = cu::get_command_encoder(s);
bool donated = set_copy_output_data(in, out, ctype, [&](auto n) { bool donated = set_copy_output_data(
return cu::malloc_async(n, encoder.stream()); in, out, ctype, [&](auto n) { return cu::malloc_async(n, encoder); });
});
if (donated && in.dtype() == out.dtype()) { if (donated && in.dtype() == out.dtype()) {
// If the output has the same type as the input then there is nothing to // If the output has the same type as the input then there is nothing to
// copy, just use the buffer. // copy, just use the buffer.
@@ -104,7 +103,7 @@ void fill_gpu(const array& in, array& out, const Stream& s) {
return; return;
} }
auto& encoder = cu::get_command_encoder(s); auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); out.set_data(cu::malloc_async(out.nbytes(), encoder));
encoder.set_input_array(in); encoder.set_input_array(in);
encoder.set_output_array(out); encoder.set_output_array(out);
copy_contiguous(encoder, CopyType::Scalar, in, out, 0, 0); copy_contiguous(encoder, CopyType::Scalar, in, out, 0, 0);
@@ -114,7 +113,7 @@ void reshape_gpu(const array& in, array& out, Stream s) {
auto [copy_necessary, out_strides] = prepare_reshape(in, out); auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) { if (copy_necessary) {
auto& encoder = cu::get_command_encoder(s); auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); out.set_data(cu::malloc_async(out.nbytes(), encoder));
copy_gpu_inplace( copy_gpu_inplace(
in, in,
out, out,

View File

@@ -135,9 +135,7 @@ bool prepare_cudnn_plan(
void* workspace_ptr = nullptr; void* workspace_ptr = nullptr;
if (workspace_size > 0) { if (workspace_size > 0) {
array workspace( array workspace(
cu::malloc_async(workspace_size, encoder.stream()), cu::malloc_async(workspace_size, encoder), {workspace_size}, uint8);
{workspace_size},
uint8);
encoder.add_temporary(workspace); encoder.add_temporary(workspace);
workspace_ptr = gpu_ptr<void>(workspace); workspace_ptr = gpu_ptr<void>(workspace);
} }

View File

@@ -289,7 +289,7 @@ void CustomKernel::eval_gpu(
copies.emplace_back(init_value_.value(), out.dtype()); copies.emplace_back(init_value_.value(), out.dtype());
fill_gpu(copies.back(), out, s); fill_gpu(copies.back(), out, s);
} else { } else {
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); out.set_data(cu::malloc_async(out.nbytes(), encoder));
} }
} }

View File

@@ -26,7 +26,7 @@ void AllReduce::eval_gpu(
out.copy_shared_buffer(in); out.copy_shared_buffer(in);
return {in, out}; return {in, out};
} else { } else {
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); out.set_data(cu::malloc_async(out.nbytes(), encoder));
return {in, out}; return {in, out};
} }
}; };
@@ -74,7 +74,7 @@ void AllGather::eval_gpu(
}; };
auto input = ensure_contiguous(inputs[0]); auto input = ensure_contiguous(inputs[0]);
outputs[0].set_data(cu::malloc_async(outputs[0].nbytes(), encoder.stream())); outputs[0].set_data(cu::malloc_async(outputs[0].nbytes(), encoder));
encoder.set_input_array(input); encoder.set_input_array(input);
encoder.set_output_array(outputs[0]); encoder.set_output_array(outputs[0]);
@@ -103,7 +103,7 @@ void ReduceScatter::eval_gpu(
}; };
auto input = ensure_contiguous(inputs[0]); auto input = ensure_contiguous(inputs[0]);
outputs[0].set_data(cu::malloc_async(outputs[0].nbytes(), encoder.stream())); outputs[0].set_data(cu::malloc_async(outputs[0].nbytes(), encoder));
encoder.set_input_array(input); encoder.set_input_array(input);
encoder.set_output_array(outputs[0]); encoder.set_output_array(outputs[0]);

View File

@@ -370,7 +370,7 @@ void CublasGemm::execute(
// Ensure workspace is 256-byte aligned // Ensure workspace is 256-byte aligned
int nbytes = cuda::ceil_div(heuristic_.workspaceSize, 256) * 256; int nbytes = cuda::ceil_div(heuristic_.workspaceSize, 256) * 256;
array workspace( array workspace(
cu::malloc_async(nbytes, encoder.stream()), cu::malloc_async(nbytes, encoder),
{static_cast<int>(heuristic_.workspaceSize)}, {static_cast<int>(heuristic_.workspaceSize)},
int8); int8);
encoder.add_temporary(workspace); encoder.add_temporary(workspace);

View File

@@ -163,7 +163,7 @@ void CublasGemm::run_batched(
// Launch kernel to set device offsets // Launch kernel to set device offsets
auto pointers = array( auto pointers = array(
cu::malloc_async(batch_count * sizeof(void*) * 3, encoder.stream()), cu::malloc_async(batch_count * sizeof(void*) * 3, encoder),
{batch_count * 3}, {batch_count * 3},
uint64); uint64);
@@ -251,7 +251,7 @@ void CublasGemm::run_batched(
// Launch kernel to set device offsets // Launch kernel to set device offsets
auto pointers = array( auto pointers = array(
cu::malloc_async(batch_count * sizeof(uint64_t) * 4, encoder.stream()), cu::malloc_async(batch_count * sizeof(uint64_t) * 4, encoder),
{batch_count * 4}, {batch_count * 4},
uint64); uint64);

View File

@@ -61,7 +61,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream(); auto& s = stream();
auto& encoder = cu::get_command_encoder(s); auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); out.set_data(cu::malloc_async(out.nbytes(), encoder));
if (out.size() == 0) { if (out.size() == 0) {
return; return;
} }
@@ -241,7 +241,7 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream(); auto& s = stream();
auto& encoder = cu::get_command_encoder(s); auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); out.set_data(cu::malloc_async(out.nbytes(), encoder));
if (out.size() == 0) { if (out.size() == 0) {
return; return;
} }

View File

@@ -244,7 +244,7 @@ void LayerNorm::eval_gpu(
out.copy_shared_buffer(x); out.copy_shared_buffer(x);
} else { } else {
out.set_data( out.set_data(
cu::malloc_async(x.data_size() * x.itemsize(), encoder.stream()), cu::malloc_async(x.data_size() * x.itemsize(), encoder),
x.data_size(), x.data_size(),
x.strides(), x.strides(),
x.flags()); x.flags());
@@ -335,7 +335,7 @@ void LayerNormVJP::eval_gpu(
gx.copy_shared_buffer(g); gx.copy_shared_buffer(g);
g_in_gx = true; g_in_gx = true;
} else { } else {
gx.set_data(cu::malloc_async(gx.nbytes(), encoder.stream())); gx.set_data(cu::malloc_async(gx.nbytes(), encoder));
} }
if (g_copied && !g_in_gx) { if (g_copied && !g_in_gx) {
encoder.add_temporary(g); encoder.add_temporary(g);
@@ -355,7 +355,7 @@ void LayerNormVJP::eval_gpu(
g_in_gw = true; g_in_gw = true;
gw_temp.copy_shared_buffer(g); gw_temp.copy_shared_buffer(g);
} else { } else {
gw_temp.set_data(cu::malloc_async(gw_temp.nbytes(), encoder.stream())); gw_temp.set_data(cu::malloc_async(gw_temp.nbytes(), encoder));
encoder.add_temporary(gw_temp); encoder.add_temporary(gw_temp);
} }
} }

View File

@@ -32,7 +32,7 @@ void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& encoder = cu::get_command_encoder(stream()); auto& encoder = cu::get_command_encoder(stream());
auto size = out.size(); auto size = out.size();
auto nbytes = size * out.itemsize(); auto nbytes = size * out.itemsize();
out.set_data(cu::malloc_async(nbytes, encoder.stream())); out.set_data(cu::malloc_async(nbytes, encoder));
auto out_ptr = malloc(nbytes); auto out_ptr = malloc(nbytes);
reader_->read(static_cast<char*>(out_ptr), nbytes, offset_); reader_->read(static_cast<char*>(out_ptr), nbytes, offset_);
if (swap_endianness_) { if (swap_endianness_) {

View File

@@ -115,7 +115,7 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
auto in = ensure_contiguous(inputs[0]); auto in = ensure_contiguous(inputs[0]);
if (in.flags().row_contiguous) { if (in.flags().row_contiguous) {
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); out.set_data(cu::malloc_async(out.nbytes(), encoder));
} else { } else {
auto n = in.shape(-1); auto n = in.shape(-1);
auto flags = in.flags(); auto flags = in.flags();
@@ -130,7 +130,7 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
} }
flags.col_contiguous = col_contig; flags.col_contiguous = col_contig;
out.set_data( out.set_data(
cu::malloc_async(in.nbytes() / n, encoder.stream()), cu::malloc_async(in.nbytes() / n, encoder),
in.data_size() / n, in.data_size() / n,
std::move(strides), std::move(strides),
flags); flags);

View File

@@ -121,7 +121,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
return; return;
} }
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); out.set_data(cu::malloc_async(out.nbytes(), encoder));
int M = a_pre.shape(-2); int M = a_pre.shape(-2);
int N = b_pre.shape(-1); int N = b_pre.shape(-1);
@@ -163,7 +163,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
if (beta_ == 1 && a.dtype() != complex64 && c.strides(-1) == 1 && if (beta_ == 1 && a.dtype() != complex64 && c.strides(-1) == 1 &&
c.data_size() == out.shape(-1)) { c.data_size() == out.shape(-1)) {
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); out.set_data(cu::malloc_async(out.nbytes(), encoder));
gemm_and_bias( gemm_and_bias(
encoder, encoder,
M, M,
@@ -187,10 +187,10 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
auto sty = c.strides()[c.ndim() - 1]; auto sty = c.strides()[c.ndim() - 1];
if (sty == 1 && stx == c.shape(-1)) { if (sty == 1 && stx == c.shape(-1)) {
ldc = stx; ldc = stx;
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); out.set_data(cu::malloc_async(out.nbytes(), encoder));
} else if (sty == 1 && stx == 0) { } else if (sty == 1 && stx == 0) {
ldc = 0; ldc = 0;
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); out.set_data(cu::malloc_async(out.nbytes(), encoder));
} else { } else {
// Copy C into out and set C to out // Copy C into out and set C to out
ldc = c.shape(-1); ldc = c.shape(-1);

View File

@@ -59,7 +59,7 @@ void fast::Quantize::eval_gpu(
auto scales = ensure_row_contiguous(inputs[1], enc, s); auto scales = ensure_row_contiguous(inputs[1], enc, s);
auto& w = outputs[0]; auto& w = outputs[0];
w.set_data(cu::malloc_async(w.nbytes(), enc.stream())); w.set_data(cu::malloc_async(w.nbytes(), enc));
if (mode_ == QuantizationMode::Affine) { if (mode_ == QuantizationMode::Affine) {
auto biases = ensure_row_contiguous(inputs[2], enc, s); auto biases = ensure_row_contiguous(inputs[2], enc, s);
@@ -72,11 +72,11 @@ void fast::Quantize::eval_gpu(
auto& wq = outputs[0]; auto& wq = outputs[0];
auto& scales = outputs[1]; auto& scales = outputs[1];
wq.set_data(cu::malloc_async(wq.nbytes(), enc.stream())); wq.set_data(cu::malloc_async(wq.nbytes(), enc));
scales.set_data(cu::malloc_async(scales.nbytes(), enc.stream())); scales.set_data(cu::malloc_async(scales.nbytes(), enc));
if (mode_ == QuantizationMode::Affine) { if (mode_ == QuantizationMode::Affine) {
auto& biases = outputs[2]; auto& biases = outputs[2];
biases.set_data(cu::malloc_async(biases.nbytes(), enc.stream())); biases.set_data(cu::malloc_async(biases.nbytes(), enc));
affine_quantize(w, wq, scales, biases, group_size_, bits_, enc, s); affine_quantize(w, wq, scales, biases, group_size_, bits_, enc, s);
} else { } else {
fp_quantize(w, wq, scales, group_size_, bits_, enc, s); fp_quantize(w, wq, scales, group_size_, bits_, enc, s);

View File

@@ -145,7 +145,7 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
uint32_t bytes_per_key = out.itemsize() * elems_per_key; uint32_t bytes_per_key = out.itemsize() * elems_per_key;
auto& s = stream(); auto& s = stream();
auto& encoder = cu::get_command_encoder(s); auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); out.set_data(cu::malloc_async(out.nbytes(), encoder));
if (out.size() == 0) { if (out.size() == 0) {
return; return;
} }

View File

@@ -66,7 +66,7 @@ void all_reduce(
Reduce::ReduceType reduce_type) { Reduce::ReduceType reduce_type) {
constexpr int N_READS = 8; constexpr int N_READS = 8;
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); out.set_data(cu::malloc_async(out.nbytes(), encoder));
auto get_args = [](size_t size, int N) { auto get_args = [](size_t size, int N) {
int threads = std::min(512UL, (size + N - 1) / N); int threads = std::min(512UL, (size + N - 1) / N);
@@ -107,8 +107,7 @@ void all_reduce(
encoder.set_input_array(in); encoder.set_input_array(in);
if (blocks > 1) { if (blocks > 1) {
array intermediate({blocks}, out.dtype(), nullptr, {}); array intermediate({blocks}, out.dtype(), nullptr, {});
intermediate.set_data( intermediate.set_data(cu::malloc_async(intermediate.nbytes(), encoder));
cu::malloc_async(intermediate.nbytes(), encoder.stream()));
encoder.add_temporary(intermediate); encoder.add_temporary(intermediate);
encoder.set_output_array(intermediate); encoder.set_output_array(intermediate);
dispatch_all_types(dt, [&](auto type_tag) { dispatch_all_types(dt, [&](auto type_tag) {

View File

@@ -28,7 +28,7 @@ void init_reduce(
Reduce::ReduceType reduce_type) { Reduce::ReduceType reduce_type) {
// Allocate if needed // Allocate if needed
if (out.data_shared_ptr() == nullptr) { if (out.data_shared_ptr() == nullptr) {
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); out.set_data(cu::malloc_async(out.nbytes(), encoder));
} }
encoder.set_output_array(out); encoder.set_output_array(out);

View File

@@ -96,7 +96,7 @@ inline void allocate_same_layout(
const std::vector<int>& axes, const std::vector<int>& axes,
cu::CommandEncoder& encoder) { cu::CommandEncoder& encoder) {
if (in.flags().row_contiguous) { if (in.flags().row_contiguous) {
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); out.set_data(cu::malloc_async(out.nbytes(), encoder));
return; return;
} }
@@ -135,7 +135,7 @@ inline void allocate_same_layout(
fl.col_contiguous = cc; fl.col_contiguous = cc;
fl.contiguous = true; fl.contiguous = true;
out.set_data( out.set_data(
cu::malloc_async(out.nbytes(), encoder.stream()), cu::malloc_async(out.nbytes(), encoder),
data_size, data_size,
final_strides, final_strides,
fl, fl,

View File

@@ -190,7 +190,7 @@ void RMSNorm::eval_gpu(
out.copy_shared_buffer(x); out.copy_shared_buffer(x);
} else { } else {
out.set_data( out.set_data(
cu::malloc_async(x.data_size() * x.itemsize(), encoder.stream()), cu::malloc_async(x.data_size() * x.itemsize(), encoder),
x.data_size(), x.data_size(),
x.strides(), x.strides(),
x.flags()); x.flags());
@@ -274,7 +274,7 @@ void RMSNormVJP::eval_gpu(
gx.copy_shared_buffer(g); gx.copy_shared_buffer(g);
g_in_gx = true; g_in_gx = true;
} else { } else {
gx.set_data(cu::malloc_async(gx.nbytes(), encoder.stream())); gx.set_data(cu::malloc_async(gx.nbytes(), encoder));
} }
if (g_copied && !g_in_gx) { if (g_copied && !g_in_gx) {
encoder.add_temporary(g); encoder.add_temporary(g);
@@ -292,7 +292,7 @@ void RMSNormVJP::eval_gpu(
if (!g_in_gx && donate_g) { if (!g_in_gx && donate_g) {
gw_temp.copy_shared_buffer(g); gw_temp.copy_shared_buffer(g);
} else { } else {
gw_temp.set_data(cu::malloc_async(gw_temp.nbytes(), encoder.stream())); gw_temp.set_data(cu::malloc_async(gw_temp.nbytes(), encoder));
encoder.add_temporary(gw_temp); encoder.add_temporary(gw_temp);
} }
} }

View File

@@ -292,14 +292,14 @@ void RoPE::eval_gpu(
donated = true; donated = true;
out.copy_shared_buffer(in); out.copy_shared_buffer(in);
} else { } else {
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); out.set_data(cu::malloc_async(out.nbytes(), encoder));
} }
strides[0] = mat_size; strides[0] = mat_size;
strides[1] = in.strides()[ndim - 2]; strides[1] = in.strides()[ndim - 2];
strides[2] = in.strides()[ndim - 1]; strides[2] = in.strides()[ndim - 1];
} else if (dispatch_ndim == 3) { } else if (dispatch_ndim == 3) {
// Handle non-contiguous 3D inputs // Handle non-contiguous 3D inputs
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); out.set_data(cu::malloc_async(out.nbytes(), encoder));
strides[0] = in.strides()[ndim - 3]; strides[0] = in.strides()[ndim - 3];
strides[1] = in.strides()[ndim - 2]; strides[1] = in.strides()[ndim - 2];
strides[2] = in.strides()[ndim - 1]; strides[2] = in.strides()[ndim - 1];

View File

@@ -196,7 +196,7 @@ void sdpa_cudnn(
auto& encoder = cu::get_command_encoder(s); auto& encoder = cu::get_command_encoder(s);
// TODO: Handle donation. // TODO: Handle donation.
// TODO: Make O use same memory layout with Q. // TODO: Make O use same memory layout with Q.
o.set_data(cu::malloc_async(o.nbytes(), encoder.stream())); o.set_data(cu::malloc_async(o.nbytes(), encoder));
encoder.set_input_array(q); encoder.set_input_array(q);
encoder.set_input_array(k); encoder.set_input_array(k);
@@ -240,7 +240,7 @@ void sdpa_cudnn(
void* workspace_ptr = nullptr; void* workspace_ptr = nullptr;
if (workspace_size > 0) { if (workspace_size > 0) {
array workspace( array workspace(
cu::malloc_async(workspace_size, encoder.stream()), cu::malloc_async(workspace_size, encoder),
{static_cast<int>(workspace_size)}, {static_cast<int>(workspace_size)},
uint8); uint8);
encoder.add_temporary(workspace); encoder.add_temporary(workspace);

View File

@@ -561,10 +561,9 @@ void sdpa_vector_2pass_fallback(
array sums(intermediate_shape, float32, nullptr, {}); array sums(intermediate_shape, float32, nullptr, {});
array maxs(std::move(intermediate_shape), float32, nullptr, {}); array maxs(std::move(intermediate_shape), float32, nullptr, {});
intermediate.set_data( intermediate.set_data(cu::malloc_async(intermediate.nbytes(), encoder));
cu::malloc_async(intermediate.nbytes(), encoder.stream())); sums.set_data(cu::malloc_async(sums.nbytes(), encoder));
sums.set_data(cu::malloc_async(sums.nbytes(), encoder.stream())); maxs.set_data(cu::malloc_async(maxs.nbytes(), encoder));
maxs.set_data(cu::malloc_async(maxs.nbytes(), encoder.stream()));
encoder.add_temporary(intermediate); encoder.add_temporary(intermediate);
encoder.add_temporary(sums); encoder.add_temporary(sums);
@@ -769,7 +768,7 @@ void sdpa_vector(
}; };
o.set_data( o.set_data(
cu::malloc_async(o.nbytes(), encoder.stream()), cu::malloc_async(o.nbytes(), encoder),
o.size(), o.size(),
{str_oB, str_oH, str_oL, str_oD}, {str_oB, str_oH, str_oL, str_oD},
flags); flags);

View File

@@ -374,7 +374,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
out.copy_shared_buffer(in); out.copy_shared_buffer(in);
} else { } else {
out.set_data( out.set_data(
cu::malloc_async(in.data_size() * out.itemsize(), encoder.stream()), cu::malloc_async(in.data_size() * out.itemsize(), encoder),
in.data_size(), in.data_size(),
in.strides(), in.strides(),
in.flags()); in.flags());

View File

@@ -24,7 +24,7 @@ void concatenate_gpu(
std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin()); std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin());
auto& encoder = cu::get_command_encoder(s); auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); out.set_data(cu::malloc_async(out.nbytes(), encoder));
auto strides = out.strides(); auto strides = out.strides();
auto flags = out.flags(); auto flags = out.flags();
@@ -89,7 +89,7 @@ array compute_dynamic_offset(
if (donate) { if (donate) {
offset.copy_shared_buffer(indices); offset.copy_shared_buffer(indices);
} else { } else {
offset.set_data(cu::malloc_async(offset.itemsize(), encoder.stream())); offset.set_data(cu::malloc_async(offset.itemsize(), encoder));
} }
encoder.add_temporary(offset); encoder.add_temporary(offset);

View File

@@ -118,7 +118,7 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
out.copy_shared_buffer(x); out.copy_shared_buffer(x);
} else { } else {
out.set_data( out.set_data(
cu::malloc_async(x.data_size() * x.itemsize(), encoder.stream()), cu::malloc_async(x.data_size() * x.itemsize(), encoder),
x.data_size(), x.data_size(),
x.strides(), x.strides(),
x.flags()); x.flags());

View File

@@ -49,14 +49,12 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
array trans = swapaxes_in_eval(in, axis, last_dim); array trans = swapaxes_in_eval(in, axis, last_dim);
in = contiguous_copy_gpu(trans, s); in = contiguous_copy_gpu(trans, s);
encoder.add_temporary(in); encoder.add_temporary(in);
out = array( out =
cu::malloc_async(out.nbytes(), encoder.stream()), array(cu::malloc_async(out.nbytes(), encoder), in.shape(), out.dtype());
in.shape(),
out.dtype());
encoder.add_temporary(out); encoder.add_temporary(out);
} else { } else {
out.set_data( out.set_data(
cu::malloc_async(in.data_size() * out.itemsize(), encoder.stream()), cu::malloc_async(in.data_size() * out.itemsize(), encoder),
in.data_size(), in.data_size(),
in.strides(), in.strides(),
in.flags()); in.flags());
@@ -74,17 +72,13 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
if (argsort) { if (argsort) {
// Indices in the sorted dimension. // Indices in the sorted dimension.
array indices( array indices(
cu::malloc_async(out.nbytes(), encoder.stream()), cu::malloc_async(out.nbytes(), encoder), in.shape(), out.dtype());
in.shape(),
out.dtype());
encoder.add_temporary(indices); encoder.add_temporary(indices);
// In argsort though we don't need the result of sorted values, the // In argsort though we don't need the result of sorted values, the
// API requires us to provide an array to store it. // API requires us to provide an array to store it.
array discard( array discard(
cu::malloc_async(in.nbytes(), encoder.stream()), cu::malloc_async(in.nbytes(), encoder), in.shape(), in.dtype());
in.shape(),
in.dtype());
encoder.add_temporary(discard); encoder.add_temporary(discard);
size_t size; size_t size;
@@ -104,9 +98,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
stream)); stream));
array temp( array temp(
cu::malloc_async(size, encoder.stream()), cu::malloc_async(size, encoder), {static_cast<int>(size)}, uint8);
{static_cast<int>(size)},
uint8);
encoder.add_temporary(temp); encoder.add_temporary(temp);
// Start capturing after allocations // Start capturing after allocations
@@ -148,9 +140,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
stream)); stream));
array temp( array temp(
cu::malloc_async(size, encoder.stream()), cu::malloc_async(size, encoder), {static_cast<int>(size)}, uint8);
{static_cast<int>(size)},
uint8);
encoder.add_temporary(temp); encoder.add_temporary(temp);
// Start capturing after allocations // Start capturing after allocations

View File

@@ -257,9 +257,8 @@ void ternary_op_gpu(
auto& c = inputs[2]; auto& c = inputs[2];
auto topt = get_ternary_op_type(a, b, c); auto topt = get_ternary_op_type(a, b, c);
auto& encoder = cu::get_command_encoder(s); auto& encoder = cu::get_command_encoder(s);
set_ternary_op_output_data(a, b, c, out, topt, [&](auto n) { set_ternary_op_output_data(
return cu::malloc_async(n, encoder.stream()); a, b, c, out, topt, [&](auto n) { return cu::malloc_async(n, encoder); });
});
ternary_op_gpu_inplace<Op>(inputs, out, s); ternary_op_gpu_inplace<Op>(inputs, out, s);
} }

View File

@@ -208,9 +208,8 @@ void unary_op_gpu(
const char* op, const char* op,
const Stream& s) { const Stream& s) {
auto& encoder = cu::get_command_encoder(s); auto& encoder = cu::get_command_encoder(s);
set_unary_output_data(inputs[0], out, [&](auto n) { set_unary_output_data(
return cu::malloc_async(n, encoder.stream()); inputs[0], out, [&](auto n) { return cu::malloc_async(n, encoder); });
});
unary_op_gpu_inplace<Op>(inputs, out, op, s); unary_op_gpu_inplace<Op>(inputs, out, op, s);
} }

View File

@@ -37,11 +37,11 @@ target_sources(
${METAL_TEST_SOURCES}) ${METAL_TEST_SOURCES})
if(MLX_BUILD_CUDA) if(MLX_BUILD_CUDA)
# Find the CCCL headers in install dir. # C++ tests are always built from source, so we have to specify where to find
target_compile_definitions( # CCCL headers for JIT as they are not installed in system.
mlx get_target_property(MLX_CCCL_DIR mlx CCCL_DIR)
PRIVATE target_compile_definitions(mlx PRIVATE MLX_CCCL_DIR="${MLX_CCCL_DIR}")
MLX_CCCL_DIR="${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_INCLUDEDIR}/cccl") message(STATUS MLX_CCCL_DIR="${MLX_CCCL_DIR}")
endif() endif()
target_link_libraries(tests PRIVATE mlx doctest) target_link_libraries(tests PRIVATE mlx doctest)