diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 7f8f1aade..543e9fd58 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -32,6 +32,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp ${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu ${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cu ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp diff --git a/mlx/backend/cuda/allocator.cpp b/mlx/backend/cuda/allocator.cpp index dcc82a7f2..918d85741 100644 --- a/mlx/backend/cuda/allocator.cpp +++ b/mlx/backend/cuda/allocator.cpp @@ -90,18 +90,14 @@ CudaAllocator::CudaAllocator() page_size, [](CudaBuffer* buf) { return buf->size; }, [this](CudaBuffer* buf) { cuda_free(buf); }) { - // TODO: Set memory limit for multi-device. size_t free, total; CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total)); memory_limit_ = total * 0.95; max_pool_size_ = memory_limit_; int loc = 0; - cudaDeviceGetDefaultMemPool(&cuda_pool_, loc); - - // TODO need a strategy for that - uint64_t threshold = UINT64_MAX; - cudaMemPoolSetAttribute( - cuda_pool_, cudaMemPoolAttrReleaseThreshold, &threshold); + CHECK_CUDA_ERROR(cudaDeviceGetDefaultMemPool(&cuda_pool_, loc)); + CHECK_CUDA_ERROR(cudaMemPoolSetAttribute( + cuda_pool_, cudaMemPoolAttrReleaseThreshold, &memory_limit_)); } Buffer CudaAllocator::malloc_impl(size_t size, cudaStream_t stream) { @@ -217,6 +213,9 @@ size_t CudaAllocator::get_memory_limit() { size_t CudaAllocator::set_memory_limit(size_t limit) { std::lock_guard lock(mutex_); std::swap(limit, memory_limit_); + CHECK_CUDA_ERROR(cudaMemPoolTrimTo(cuda_pool_, memory_limit_)); + CHECK_CUDA_ERROR(cudaMemPoolSetAttribute( + cuda_pool_, cudaMemPoolAttrReleaseThreshold, &memory_limit_)); return limit; } @@ -267,6 +266,8 @@ void* Buffer::raw_ptr() { } auto& cbuf = *static_cast(ptr_); if (!cbuf.managed) { + // TODO maybe make this async on a i/o stream to avoid synchronizing the + // device on malloc/and free void* new_data; CHECK_CUDA_ERROR(cudaMallocManaged(&new_data, cbuf.size)); cbuf.managed = true; diff --git a/mlx/backend/cuda/compiled.cpp b/mlx/backend/cuda/compiled.cpp index 302ca2f99..3c5681019 100644 --- a/mlx/backend/cuda/compiled.cpp +++ b/mlx/backend/cuda/compiled.cpp @@ -293,8 +293,13 @@ void Compiled::eval_gpu( } } + auto& encoder = cu::get_command_encoder(s); + // Put outputs. - compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous); + compiled_allocate_outputs( + inputs, outputs, is_constant_, contiguous, [&](auto n) { + return cu::malloc_async(n, encoder.stream()); + }); for (auto& x : outputs) { args.append(x); } @@ -324,7 +329,6 @@ void Compiled::eval_gpu( kernel_name += fmt::format( "_strided<{}, {}, {}>", shape.size(), index_type, work_per_thread); } - auto& encoder = cu::get_command_encoder(s); for (const auto& in : inputs) { encoder.set_input_array(in); } diff --git a/mlx/backend/cuda/copy.cu b/mlx/backend/cuda/copy.cu index e36bcbdf6..f559076d2 100644 --- a/mlx/backend/cuda/copy.cu +++ b/mlx/backend/cuda/copy.cu @@ -110,4 +110,24 @@ void fill_gpu(const array& in, array& out, const Stream& s) { copy_contiguous(encoder, CopyType::Scalar, in, out, 0, 0); } +void reshape_gpu(const array& in, array& out, Stream s) { + auto [copy_necessary, out_strides] = prepare_reshape(in, out); + if (copy_necessary) { + auto& encoder = cu::get_command_encoder(s); + out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); + copy_gpu_inplace( + in, + out, + in.shape(), + in.strides(), + make_contiguous_strides(in.shape()), + 0, + 0, + CopyType::General, + s); + } else { + shared_buffer_reshape(in, out_strides, out); + } +} + } // namespace mlx::core diff --git a/mlx/backend/cuda/load.cpp b/mlx/backend/cuda/load.cpp new file mode 100644 index 000000000..a5687addb --- /dev/null +++ b/mlx/backend/cuda/load.cpp @@ -0,0 +1,60 @@ +// Copyright © 2023 Apple Inc. + +#include +#include + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/utils.h" +#include "mlx/primitives.h" + +namespace { + +template +void swap_endianness(uint8_t* data_bytes, size_t N) { + struct Elem { + uint8_t bytes[scalar_size]; + }; + + Elem* data = reinterpret_cast(data_bytes); + + for (size_t i = 0; i < N; i++) { + for (size_t j = 0; j < (scalar_size / 2); j++) { + std::swap(data[i].bytes[j], data[i].bytes[scalar_size - j - 1]); + } + } +} + +} // namespace + +namespace mlx::core { + +void Load::eval_gpu(const std::vector& inputs, array& out) { + auto& encoder = cu::get_command_encoder(stream()); + auto size = out.size(); + auto nbytes = size * out.itemsize(); + out.set_data(cu::malloc_async(nbytes, encoder.stream())); + auto out_ptr = malloc(nbytes); + reader_->read(static_cast(out_ptr), nbytes, offset_); + if (swap_endianness_) { + switch (out.itemsize()) { + case 2: + swap_endianness<2>(reinterpret_cast(out_ptr), size); + break; + case 4: + swap_endianness<4>(reinterpret_cast(out_ptr), size); + break; + case 8: + swap_endianness<8>(reinterpret_cast(out_ptr), size); + break; + } + } + CHECK_CUDA_ERROR(cudaMemcpyAsync( + gpu_ptr(out), + out_ptr, + nbytes, + cudaMemcpyDefault, + encoder.stream())); + CHECK_CUDA_ERROR(cudaLaunchHostFunc(encoder.stream(), free, out_ptr)); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/primitives.cpp b/mlx/backend/cuda/primitives.cpp index 77c295665..b2f727e0e 100644 --- a/mlx/backend/cuda/primitives.cpp +++ b/mlx/backend/cuda/primitives.cpp @@ -28,7 +28,6 @@ NO_GPU(FFT) NO_GPU(GatherMM) NO_GPU(GatherQMM) NO_GPU(Hadamard) -NO_GPU(Load) NO_GPU_MULTI(LUF) NO_GPU_MULTI(QRF) NO_GPU(QuantizedMatmul) diff --git a/mlx/backend/cuda/quantized/quantized.cpp b/mlx/backend/cuda/quantized/quantized.cpp index f221abc50..f75064d4e 100644 --- a/mlx/backend/cuda/quantized/quantized.cpp +++ b/mlx/backend/cuda/quantized/quantized.cpp @@ -76,7 +76,7 @@ void fast::Quantize::eval_gpu( scales.set_data(cu::malloc_async(scales.nbytes(), enc.stream())); if (mode_ == QuantizationMode::Affine) { auto& biases = outputs[2]; - biases.set_data(allocator::malloc(biases.nbytes())); + biases.set_data(cu::malloc_async(biases.nbytes(), enc.stream())); affine_quantize(w, wq, scales, biases, group_size_, bits_, enc, s); } else { fp_quantize(w, wq, scales, group_size_, bits_, enc, s); diff --git a/mlx/backend/gpu/copy.cpp b/mlx/backend/gpu/copy.cpp index f8caf09cd..1ed6e2345 100644 --- a/mlx/backend/gpu/copy.cpp +++ b/mlx/backend/gpu/copy.cpp @@ -41,25 +41,6 @@ array contiguous_copy_gpu(const array& arr, const Stream& s) { return arr_copy; } -void reshape_gpu(const array& in, array& out, Stream s) { - auto [copy_necessary, out_strides] = prepare_reshape(in, out); - if (copy_necessary) { - out.set_data(allocator::malloc(out.nbytes())); - copy_gpu_inplace( - in, - out, - in.shape(), - in.strides(), - make_contiguous_strides(in.shape()), - 0, - 0, - CopyType::General, - s); - } else { - shared_buffer_reshape(in, out_strides, out); - } -} - array flatten_in_eval(const array& x, int start_axis, int end_axis, Stream s) { int ndim = x.ndim(); if (start_axis < 0) { diff --git a/mlx/backend/metal/copy.cpp b/mlx/backend/metal/copy.cpp index c58e7d3c2..6b791289c 100644 --- a/mlx/backend/metal/copy.cpp +++ b/mlx/backend/metal/copy.cpp @@ -214,4 +214,23 @@ void fill_gpu(const array& val, array& out, const Stream& s) { compute_encoder.dispatch_threads(grid_dims, group_dims); } +void reshape_gpu(const array& in, array& out, Stream s) { + auto [copy_necessary, out_strides] = prepare_reshape(in, out); + if (copy_necessary) { + out.set_data(allocator::malloc(out.nbytes())); + copy_gpu_inplace( + in, + out, + in.shape(), + in.strides(), + make_contiguous_strides(in.shape()), + 0, + 0, + CopyType::General, + s); + } else { + shared_buffer_reshape(in, out_strides, out); + } +} + } // namespace mlx::core diff --git a/mlx/io/load.cpp b/mlx/io/load.cpp index f3d71e860..8a6ca0852 100644 --- a/mlx/io/load.cpp +++ b/mlx/io/load.cpp @@ -13,6 +13,7 @@ #include #endif // _WIN32 +#include "mlx/backend/cuda/cuda.h" #include "mlx/io/load.h" #include "mlx/ops.h" #include "mlx/primitives.h" @@ -226,10 +227,7 @@ array load(std::shared_ptr in_stream, StreamOrDevice s) { throw std::runtime_error("[load] Failed to open " + in_stream->label()); } - auto stream = to_stream(s, Device::cpu); - if (stream.device != Device::cpu) { - throw std::runtime_error("[load] Must run on a CPU stream."); - } + auto stream = to_stream(s, cu::is_available() ? Device::gpu : Device::cpu); //////////////////////////////////////////////////////// // Read header and prepare array details diff --git a/mlx/io/safetensors.cpp b/mlx/io/safetensors.cpp index d9b9e9e40..e726fc359 100644 --- a/mlx/io/safetensors.cpp +++ b/mlx/io/safetensors.cpp @@ -4,6 +4,7 @@ #include #include +#include "mlx/backend/cuda/cuda.h" #include "mlx/io.h" #include "mlx/io/load.h" #include "mlx/ops.h" @@ -113,10 +114,7 @@ SafetensorsLoad load_safetensors( "[load_safetensors] Failed to open " + in_stream->label()); } - auto stream = to_stream(s, Device::cpu); - if (stream.device != Device::cpu) { - throw std::runtime_error("[load_safetensors] Must run on a CPU stream."); - } + auto stream = to_stream(s, cu::is_available() ? Device::gpu : Device::cpu); uint64_t jsonHeaderLength = 0; // This is the same limit as in the original Rust Safetensors code.