mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	load eval gpu for cuda
This commit is contained in:
		@@ -32,6 +32,7 @@ target_sources(
 | 
			
		||||
          ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
 | 
			
		||||
          ${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
 | 
			
		||||
          ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
 | 
			
		||||
          ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
 | 
			
		||||
          ${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu
 | 
			
		||||
          ${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cu
 | 
			
		||||
          ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
 | 
			
		||||
 
 | 
			
		||||
@@ -90,18 +90,14 @@ CudaAllocator::CudaAllocator()
 | 
			
		||||
          page_size,
 | 
			
		||||
          [](CudaBuffer* buf) { return buf->size; },
 | 
			
		||||
          [this](CudaBuffer* buf) { cuda_free(buf); }) {
 | 
			
		||||
  // TODO: Set memory limit for multi-device.
 | 
			
		||||
  size_t free, total;
 | 
			
		||||
  CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
 | 
			
		||||
  memory_limit_ = total * 0.95;
 | 
			
		||||
  max_pool_size_ = memory_limit_;
 | 
			
		||||
  int loc = 0;
 | 
			
		||||
  cudaDeviceGetDefaultMemPool(&cuda_pool_, loc);
 | 
			
		||||
 | 
			
		||||
  // TODO need a strategy for that
 | 
			
		||||
  uint64_t threshold = UINT64_MAX;
 | 
			
		||||
  cudaMemPoolSetAttribute(
 | 
			
		||||
      cuda_pool_, cudaMemPoolAttrReleaseThreshold, &threshold);
 | 
			
		||||
  CHECK_CUDA_ERROR(cudaDeviceGetDefaultMemPool(&cuda_pool_, loc));
 | 
			
		||||
  CHECK_CUDA_ERROR(cudaMemPoolSetAttribute(
 | 
			
		||||
      cuda_pool_, cudaMemPoolAttrReleaseThreshold, &memory_limit_));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Buffer CudaAllocator::malloc_impl(size_t size, cudaStream_t stream) {
 | 
			
		||||
@@ -217,6 +213,9 @@ size_t CudaAllocator::get_memory_limit() {
 | 
			
		||||
size_t CudaAllocator::set_memory_limit(size_t limit) {
 | 
			
		||||
  std::lock_guard lock(mutex_);
 | 
			
		||||
  std::swap(limit, memory_limit_);
 | 
			
		||||
  CHECK_CUDA_ERROR(cudaMemPoolTrimTo(cuda_pool_, memory_limit_));
 | 
			
		||||
  CHECK_CUDA_ERROR(cudaMemPoolSetAttribute(
 | 
			
		||||
      cuda_pool_, cudaMemPoolAttrReleaseThreshold, &memory_limit_));
 | 
			
		||||
  return limit;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -267,6 +266,8 @@ void* Buffer::raw_ptr() {
 | 
			
		||||
  }
 | 
			
		||||
  auto& cbuf = *static_cast<cu::CudaBuffer*>(ptr_);
 | 
			
		||||
  if (!cbuf.managed) {
 | 
			
		||||
    // TODO maybe make this async on a i/o stream to avoid synchronizing the
 | 
			
		||||
    // device on malloc/and free
 | 
			
		||||
    void* new_data;
 | 
			
		||||
    CHECK_CUDA_ERROR(cudaMallocManaged(&new_data, cbuf.size));
 | 
			
		||||
    cbuf.managed = true;
 | 
			
		||||
 
 | 
			
		||||
@@ -293,8 +293,13 @@ void Compiled::eval_gpu(
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  auto& encoder = cu::get_command_encoder(s);
 | 
			
		||||
 | 
			
		||||
  // Put outputs.
 | 
			
		||||
  compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous);
 | 
			
		||||
  compiled_allocate_outputs(
 | 
			
		||||
      inputs, outputs, is_constant_, contiguous, [&](auto n) {
 | 
			
		||||
        return cu::malloc_async(n, encoder.stream());
 | 
			
		||||
      });
 | 
			
		||||
  for (auto& x : outputs) {
 | 
			
		||||
    args.append(x);
 | 
			
		||||
  }
 | 
			
		||||
@@ -324,7 +329,6 @@ void Compiled::eval_gpu(
 | 
			
		||||
    kernel_name += fmt::format(
 | 
			
		||||
        "_strided<{}, {}, {}>", shape.size(), index_type, work_per_thread);
 | 
			
		||||
  }
 | 
			
		||||
  auto& encoder = cu::get_command_encoder(s);
 | 
			
		||||
  for (const auto& in : inputs) {
 | 
			
		||||
    encoder.set_input_array(in);
 | 
			
		||||
  }
 | 
			
		||||
 
 | 
			
		||||
@@ -110,4 +110,24 @@ void fill_gpu(const array& in, array& out, const Stream& s) {
 | 
			
		||||
  copy_contiguous(encoder, CopyType::Scalar, in, out, 0, 0);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void reshape_gpu(const array& in, array& out, Stream s) {
 | 
			
		||||
  auto [copy_necessary, out_strides] = prepare_reshape(in, out);
 | 
			
		||||
  if (copy_necessary) {
 | 
			
		||||
    auto& encoder = cu::get_command_encoder(s);
 | 
			
		||||
    out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
 | 
			
		||||
    copy_gpu_inplace(
 | 
			
		||||
        in,
 | 
			
		||||
        out,
 | 
			
		||||
        in.shape(),
 | 
			
		||||
        in.strides(),
 | 
			
		||||
        make_contiguous_strides(in.shape()),
 | 
			
		||||
        0,
 | 
			
		||||
        0,
 | 
			
		||||
        CopyType::General,
 | 
			
		||||
        s);
 | 
			
		||||
  } else {
 | 
			
		||||
    shared_buffer_reshape(in, out_strides, out);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace mlx::core
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										60
									
								
								mlx/backend/cuda/load.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										60
									
								
								mlx/backend/cuda/load.cpp
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,60 @@
 | 
			
		||||
// Copyright © 2023 Apple Inc.
 | 
			
		||||
 | 
			
		||||
#include <algorithm>
 | 
			
		||||
#include <utility>
 | 
			
		||||
 | 
			
		||||
#include "mlx/backend/cuda/device.h"
 | 
			
		||||
#include "mlx/backend/cuda/utils.h"
 | 
			
		||||
#include "mlx/primitives.h"
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
template <const uint8_t scalar_size>
 | 
			
		||||
void swap_endianness(uint8_t* data_bytes, size_t N) {
 | 
			
		||||
  struct Elem {
 | 
			
		||||
    uint8_t bytes[scalar_size];
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  Elem* data = reinterpret_cast<Elem*>(data_bytes);
 | 
			
		||||
 | 
			
		||||
  for (size_t i = 0; i < N; i++) {
 | 
			
		||||
    for (size_t j = 0; j < (scalar_size / 2); j++) {
 | 
			
		||||
      std::swap(data[i].bytes[j], data[i].bytes[scalar_size - j - 1]);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace
 | 
			
		||||
 | 
			
		||||
namespace mlx::core {
 | 
			
		||||
 | 
			
		||||
void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
 | 
			
		||||
  auto& encoder = cu::get_command_encoder(stream());
 | 
			
		||||
  auto size = out.size();
 | 
			
		||||
  auto nbytes = size * out.itemsize();
 | 
			
		||||
  out.set_data(cu::malloc_async(nbytes, encoder.stream()));
 | 
			
		||||
  auto out_ptr = malloc(nbytes);
 | 
			
		||||
  reader_->read(static_cast<char*>(out_ptr), nbytes, offset_);
 | 
			
		||||
  if (swap_endianness_) {
 | 
			
		||||
    switch (out.itemsize()) {
 | 
			
		||||
      case 2:
 | 
			
		||||
        swap_endianness<2>(reinterpret_cast<uint8_t*>(out_ptr), size);
 | 
			
		||||
        break;
 | 
			
		||||
      case 4:
 | 
			
		||||
        swap_endianness<4>(reinterpret_cast<uint8_t*>(out_ptr), size);
 | 
			
		||||
        break;
 | 
			
		||||
      case 8:
 | 
			
		||||
        swap_endianness<8>(reinterpret_cast<uint8_t*>(out_ptr), size);
 | 
			
		||||
        break;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  CHECK_CUDA_ERROR(cudaMemcpyAsync(
 | 
			
		||||
      gpu_ptr<void>(out),
 | 
			
		||||
      out_ptr,
 | 
			
		||||
      nbytes,
 | 
			
		||||
      cudaMemcpyDefault,
 | 
			
		||||
      encoder.stream()));
 | 
			
		||||
  CHECK_CUDA_ERROR(cudaLaunchHostFunc(encoder.stream(), free, out_ptr));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace mlx::core
 | 
			
		||||
@@ -28,7 +28,6 @@ NO_GPU(FFT)
 | 
			
		||||
NO_GPU(GatherMM)
 | 
			
		||||
NO_GPU(GatherQMM)
 | 
			
		||||
NO_GPU(Hadamard)
 | 
			
		||||
NO_GPU(Load)
 | 
			
		||||
NO_GPU_MULTI(LUF)
 | 
			
		||||
NO_GPU_MULTI(QRF)
 | 
			
		||||
NO_GPU(QuantizedMatmul)
 | 
			
		||||
 
 | 
			
		||||
@@ -76,7 +76,7 @@ void fast::Quantize::eval_gpu(
 | 
			
		||||
    scales.set_data(cu::malloc_async(scales.nbytes(), enc.stream()));
 | 
			
		||||
    if (mode_ == QuantizationMode::Affine) {
 | 
			
		||||
      auto& biases = outputs[2];
 | 
			
		||||
      biases.set_data(allocator::malloc(biases.nbytes()));
 | 
			
		||||
      biases.set_data(cu::malloc_async(biases.nbytes(), enc.stream()));
 | 
			
		||||
      affine_quantize(w, wq, scales, biases, group_size_, bits_, enc, s);
 | 
			
		||||
    } else {
 | 
			
		||||
      fp_quantize(w, wq, scales, group_size_, bits_, enc, s);
 | 
			
		||||
 
 | 
			
		||||
@@ -41,25 +41,6 @@ array contiguous_copy_gpu(const array& arr, const Stream& s) {
 | 
			
		||||
  return arr_copy;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void reshape_gpu(const array& in, array& out, Stream s) {
 | 
			
		||||
  auto [copy_necessary, out_strides] = prepare_reshape(in, out);
 | 
			
		||||
  if (copy_necessary) {
 | 
			
		||||
    out.set_data(allocator::malloc(out.nbytes()));
 | 
			
		||||
    copy_gpu_inplace(
 | 
			
		||||
        in,
 | 
			
		||||
        out,
 | 
			
		||||
        in.shape(),
 | 
			
		||||
        in.strides(),
 | 
			
		||||
        make_contiguous_strides(in.shape()),
 | 
			
		||||
        0,
 | 
			
		||||
        0,
 | 
			
		||||
        CopyType::General,
 | 
			
		||||
        s);
 | 
			
		||||
  } else {
 | 
			
		||||
    shared_buffer_reshape(in, out_strides, out);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
array flatten_in_eval(const array& x, int start_axis, int end_axis, Stream s) {
 | 
			
		||||
  int ndim = x.ndim();
 | 
			
		||||
  if (start_axis < 0) {
 | 
			
		||||
 
 | 
			
		||||
@@ -214,4 +214,23 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
 | 
			
		||||
  compute_encoder.dispatch_threads(grid_dims, group_dims);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void reshape_gpu(const array& in, array& out, Stream s) {
 | 
			
		||||
  auto [copy_necessary, out_strides] = prepare_reshape(in, out);
 | 
			
		||||
  if (copy_necessary) {
 | 
			
		||||
    out.set_data(allocator::malloc(out.nbytes()));
 | 
			
		||||
    copy_gpu_inplace(
 | 
			
		||||
        in,
 | 
			
		||||
        out,
 | 
			
		||||
        in.shape(),
 | 
			
		||||
        in.strides(),
 | 
			
		||||
        make_contiguous_strides(in.shape()),
 | 
			
		||||
        0,
 | 
			
		||||
        0,
 | 
			
		||||
        CopyType::General,
 | 
			
		||||
        s);
 | 
			
		||||
  } else {
 | 
			
		||||
    shared_buffer_reshape(in, out_strides, out);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace mlx::core
 | 
			
		||||
 
 | 
			
		||||
@@ -13,6 +13,7 @@
 | 
			
		||||
#include <windows.h>
 | 
			
		||||
#endif // _WIN32
 | 
			
		||||
 | 
			
		||||
#include "mlx/backend/cuda/cuda.h"
 | 
			
		||||
#include "mlx/io/load.h"
 | 
			
		||||
#include "mlx/ops.h"
 | 
			
		||||
#include "mlx/primitives.h"
 | 
			
		||||
@@ -226,10 +227,7 @@ array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s) {
 | 
			
		||||
    throw std::runtime_error("[load] Failed to open " + in_stream->label());
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  auto stream = to_stream(s, Device::cpu);
 | 
			
		||||
  if (stream.device != Device::cpu) {
 | 
			
		||||
    throw std::runtime_error("[load] Must run on a CPU stream.");
 | 
			
		||||
  }
 | 
			
		||||
  auto stream = to_stream(s, cu::is_available() ? Device::gpu : Device::cpu);
 | 
			
		||||
 | 
			
		||||
  ////////////////////////////////////////////////////////
 | 
			
		||||
  // Read header and prepare array details
 | 
			
		||||
 
 | 
			
		||||
@@ -4,6 +4,7 @@
 | 
			
		||||
#include <memory>
 | 
			
		||||
#include <stack>
 | 
			
		||||
 | 
			
		||||
#include "mlx/backend/cuda/cuda.h"
 | 
			
		||||
#include "mlx/io.h"
 | 
			
		||||
#include "mlx/io/load.h"
 | 
			
		||||
#include "mlx/ops.h"
 | 
			
		||||
@@ -113,10 +114,7 @@ SafetensorsLoad load_safetensors(
 | 
			
		||||
        "[load_safetensors] Failed to open " + in_stream->label());
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  auto stream = to_stream(s, Device::cpu);
 | 
			
		||||
  if (stream.device != Device::cpu) {
 | 
			
		||||
    throw std::runtime_error("[load_safetensors] Must run on a CPU stream.");
 | 
			
		||||
  }
 | 
			
		||||
  auto stream = to_stream(s, cu::is_available() ? Device::gpu : Device::cpu);
 | 
			
		||||
 | 
			
		||||
  uint64_t jsonHeaderLength = 0;
 | 
			
		||||
  // This is the same limit as in the original Rust Safetensors code.
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user