load eval gpu for cuda

This commit is contained in:
Awni Hannun
2025-11-01 06:08:58 -07:00
parent d378567cc6
commit c27a0647a3
11 changed files with 119 additions and 38 deletions

View File

@@ -32,6 +32,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cu
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp

View File

@@ -90,18 +90,14 @@ CudaAllocator::CudaAllocator()
page_size,
[](CudaBuffer* buf) { return buf->size; },
[this](CudaBuffer* buf) { cuda_free(buf); }) {
// TODO: Set memory limit for multi-device.
size_t free, total;
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
memory_limit_ = total * 0.95;
max_pool_size_ = memory_limit_;
int loc = 0;
cudaDeviceGetDefaultMemPool(&cuda_pool_, loc);
// TODO need a strategy for that
uint64_t threshold = UINT64_MAX;
cudaMemPoolSetAttribute(
cuda_pool_, cudaMemPoolAttrReleaseThreshold, &threshold);
CHECK_CUDA_ERROR(cudaDeviceGetDefaultMemPool(&cuda_pool_, loc));
CHECK_CUDA_ERROR(cudaMemPoolSetAttribute(
cuda_pool_, cudaMemPoolAttrReleaseThreshold, &memory_limit_));
}
Buffer CudaAllocator::malloc_impl(size_t size, cudaStream_t stream) {
@@ -217,6 +213,9 @@ size_t CudaAllocator::get_memory_limit() {
size_t CudaAllocator::set_memory_limit(size_t limit) {
std::lock_guard lock(mutex_);
std::swap(limit, memory_limit_);
CHECK_CUDA_ERROR(cudaMemPoolTrimTo(cuda_pool_, memory_limit_));
CHECK_CUDA_ERROR(cudaMemPoolSetAttribute(
cuda_pool_, cudaMemPoolAttrReleaseThreshold, &memory_limit_));
return limit;
}
@@ -267,6 +266,8 @@ void* Buffer::raw_ptr() {
}
auto& cbuf = *static_cast<cu::CudaBuffer*>(ptr_);
if (!cbuf.managed) {
// TODO maybe make this async on a i/o stream to avoid synchronizing the
// device on malloc/and free
void* new_data;
CHECK_CUDA_ERROR(cudaMallocManaged(&new_data, cbuf.size));
cbuf.managed = true;

View File

@@ -293,8 +293,13 @@ void Compiled::eval_gpu(
}
}
auto& encoder = cu::get_command_encoder(s);
// Put outputs.
compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous);
compiled_allocate_outputs(
inputs, outputs, is_constant_, contiguous, [&](auto n) {
return cu::malloc_async(n, encoder.stream());
});
for (auto& x : outputs) {
args.append(x);
}
@@ -324,7 +329,6 @@ void Compiled::eval_gpu(
kernel_name += fmt::format(
"_strided<{}, {}, {}>", shape.size(), index_type, work_per_thread);
}
auto& encoder = cu::get_command_encoder(s);
for (const auto& in : inputs) {
encoder.set_input_array(in);
}

View File

@@ -110,4 +110,24 @@ void fill_gpu(const array& in, array& out, const Stream& s) {
copy_contiguous(encoder, CopyType::Scalar, in, out, 0, 0);
}
void reshape_gpu(const array& in, array& out, Stream s) {
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) {
auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
copy_gpu_inplace(
in,
out,
in.shape(),
in.strides(),
make_contiguous_strides(in.shape()),
0,
0,
CopyType::General,
s);
} else {
shared_buffer_reshape(in, out_strides, out);
}
}
} // namespace mlx::core

60
mlx/backend/cuda/load.cpp Normal file
View File

@@ -0,0 +1,60 @@
// Copyright © 2023 Apple Inc.
#include <algorithm>
#include <utility>
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/utils.h"
#include "mlx/primitives.h"
namespace {
template <const uint8_t scalar_size>
void swap_endianness(uint8_t* data_bytes, size_t N) {
struct Elem {
uint8_t bytes[scalar_size];
};
Elem* data = reinterpret_cast<Elem*>(data_bytes);
for (size_t i = 0; i < N; i++) {
for (size_t j = 0; j < (scalar_size / 2); j++) {
std::swap(data[i].bytes[j], data[i].bytes[scalar_size - j - 1]);
}
}
}
} // namespace
namespace mlx::core {
void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& encoder = cu::get_command_encoder(stream());
auto size = out.size();
auto nbytes = size * out.itemsize();
out.set_data(cu::malloc_async(nbytes, encoder.stream()));
auto out_ptr = malloc(nbytes);
reader_->read(static_cast<char*>(out_ptr), nbytes, offset_);
if (swap_endianness_) {
switch (out.itemsize()) {
case 2:
swap_endianness<2>(reinterpret_cast<uint8_t*>(out_ptr), size);
break;
case 4:
swap_endianness<4>(reinterpret_cast<uint8_t*>(out_ptr), size);
break;
case 8:
swap_endianness<8>(reinterpret_cast<uint8_t*>(out_ptr), size);
break;
}
}
CHECK_CUDA_ERROR(cudaMemcpyAsync(
gpu_ptr<void>(out),
out_ptr,
nbytes,
cudaMemcpyDefault,
encoder.stream()));
CHECK_CUDA_ERROR(cudaLaunchHostFunc(encoder.stream(), free, out_ptr));
}
} // namespace mlx::core

View File

@@ -28,7 +28,6 @@ NO_GPU(FFT)
NO_GPU(GatherMM)
NO_GPU(GatherQMM)
NO_GPU(Hadamard)
NO_GPU(Load)
NO_GPU_MULTI(LUF)
NO_GPU_MULTI(QRF)
NO_GPU(QuantizedMatmul)

View File

@@ -76,7 +76,7 @@ void fast::Quantize::eval_gpu(
scales.set_data(cu::malloc_async(scales.nbytes(), enc.stream()));
if (mode_ == QuantizationMode::Affine) {
auto& biases = outputs[2];
biases.set_data(allocator::malloc(biases.nbytes()));
biases.set_data(cu::malloc_async(biases.nbytes(), enc.stream()));
affine_quantize(w, wq, scales, biases, group_size_, bits_, enc, s);
} else {
fp_quantize(w, wq, scales, group_size_, bits_, enc, s);

View File

@@ -41,25 +41,6 @@ array contiguous_copy_gpu(const array& arr, const Stream& s) {
return arr_copy;
}
void reshape_gpu(const array& in, array& out, Stream s) {
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) {
out.set_data(allocator::malloc(out.nbytes()));
copy_gpu_inplace(
in,
out,
in.shape(),
in.strides(),
make_contiguous_strides(in.shape()),
0,
0,
CopyType::General,
s);
} else {
shared_buffer_reshape(in, out_strides, out);
}
}
array flatten_in_eval(const array& x, int start_axis, int end_axis, Stream s) {
int ndim = x.ndim();
if (start_axis < 0) {

View File

@@ -214,4 +214,23 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
void reshape_gpu(const array& in, array& out, Stream s) {
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) {
out.set_data(allocator::malloc(out.nbytes()));
copy_gpu_inplace(
in,
out,
in.shape(),
in.strides(),
make_contiguous_strides(in.shape()),
0,
0,
CopyType::General,
s);
} else {
shared_buffer_reshape(in, out_strides, out);
}
}
} // namespace mlx::core

View File

@@ -13,6 +13,7 @@
#include <windows.h>
#endif // _WIN32
#include "mlx/backend/cuda/cuda.h"
#include "mlx/io/load.h"
#include "mlx/ops.h"
#include "mlx/primitives.h"
@@ -226,10 +227,7 @@ array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s) {
throw std::runtime_error("[load] Failed to open " + in_stream->label());
}
auto stream = to_stream(s, Device::cpu);
if (stream.device != Device::cpu) {
throw std::runtime_error("[load] Must run on a CPU stream.");
}
auto stream = to_stream(s, cu::is_available() ? Device::gpu : Device::cpu);
////////////////////////////////////////////////////////
// Read header and prepare array details

View File

@@ -4,6 +4,7 @@
#include <memory>
#include <stack>
#include "mlx/backend/cuda/cuda.h"
#include "mlx/io.h"
#include "mlx/io/load.h"
#include "mlx/ops.h"
@@ -113,10 +114,7 @@ SafetensorsLoad load_safetensors(
"[load_safetensors] Failed to open " + in_stream->label());
}
auto stream = to_stream(s, Device::cpu);
if (stream.device != Device::cpu) {
throw std::runtime_error("[load_safetensors] Must run on a CPU stream.");
}
auto stream = to_stream(s, cu::is_available() ? Device::gpu : Device::cpu);
uint64_t jsonHeaderLength = 0;
// This is the same limit as in the original Rust Safetensors code.