Compare commits

...

2 Commits

Author SHA1 Message Date
Awni Hannun
529842fed9 fix 2025-11-03 16:43:19 -08:00
Awni Hannun
cc6df9fc8a fix 2025-11-03 15:07:01 -08:00
16 changed files with 49 additions and 46 deletions

View File

@@ -14,7 +14,7 @@ class Buffer {
void* ptr_; void* ptr_;
public: public:
Buffer(void* ptr) : ptr_(ptr) {}; explicit Buffer(void* ptr) : ptr_(ptr) {};
// Get the raw data pointer from the buffer // Get the raw data pointer from the buffer
void* raw_ptr(); void* raw_ptr();

View File

@@ -364,6 +364,10 @@ class array {
return const_cast<array&>(*this).data<T>(); return const_cast<array&>(*this).data<T>();
} }
int64_t offset() const {
return array_desc_->offset;
}
enum Status { enum Status {
// The output of a computation which has not been scheduled. // The output of a computation which has not been scheduled.
// For example, the status of `x` in `auto x = a + b`. // For example, the status of `x` in `auto x = a + b`.

View File

@@ -6,7 +6,7 @@ namespace mlx::core {
void broadcast(const array& in, array& out) { void broadcast(const array& in, array& out) {
if (out.size() == 0) { if (out.size() == 0) {
out.set_data(nullptr); out.set_data(allocator::malloc(0));
return; return;
} }
Strides strides(out.ndim(), 0); Strides strides(out.ndim(), 0);

View File

@@ -45,7 +45,7 @@ void slice(
const Shape& start_indices, const Shape& start_indices,
const Shape& strides) { const Shape& strides) {
if (out.size() == 0) { if (out.size() == 0) {
out.set_data(nullptr); out.set_data(allocator::malloc(0));
return; return;
} }

View File

@@ -333,7 +333,7 @@ void Reshape::eval_cpu(const std::vector<array>& inputs, array& out) {
void DynamicSlice::eval_cpu(const std::vector<array>& inputs, array& out) { void DynamicSlice::eval_cpu(const std::vector<array>& inputs, array& out) {
if (out.size() == 0) { if (out.size() == 0) {
out.set_data(nullptr); out.set_data(allocator::malloc(0));
return; return;
} }
auto& in = inputs[0]; auto& in = inputs[0];
@@ -361,7 +361,7 @@ void DynamicSliceUpdate::eval_cpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
array& out) { array& out) {
if (out.size() == 0) { if (out.size() == 0) {
out.set_data(nullptr); out.set_data(allocator::malloc(0));
return; return;
} }
@@ -396,7 +396,7 @@ void DynamicSliceUpdate::eval_cpu(
void SliceUpdate::eval_cpu(const std::vector<array>& inputs, array& out) { void SliceUpdate::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
if (out.size() == 0) { if (out.size() == 0) {
out.set_data(nullptr); out.set_data(allocator::malloc(0));
return; return;
} }

View File

@@ -97,11 +97,11 @@ CudaAllocator::CudaAllocator()
int device_count = 0; int device_count = 0;
CHECK_CUDA_ERROR(cudaGetDeviceCount(&device_count)); CHECK_CUDA_ERROR(cudaGetDeviceCount(&device_count));
int curr_device = 0;
CHECK_CUDA_ERROR(cudaGetDevice(&curr_device));
for (int i = 0; i < device_count; ++i) { for (int i = 0; i < device_count; ++i) {
free_streams_.emplace_back( CHECK_CUDA_ERROR(cudaSetDevice(i));
cu::device(mlx::core::Device{mlx::core::Device::gpu, i})); cudaStream_t s;
CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&s, cudaStreamNonBlocking));
free_streams_.push_back(s);
} }
} }

View File

@@ -22,11 +22,6 @@ struct CudaBuffer {
int device; // -1 for managed int device; // -1 for managed
}; };
template <typename T>
inline T* gpu_ptr(Buffer buf) {
return static_cast<T*>(static_cast<cu::CudaBuffer*>(buf.ptr())->data);
}
class SmallSizePool { class SmallSizePool {
private: private:
union Block { union Block {
@@ -79,7 +74,7 @@ class CudaAllocator : public allocator::Allocator {
BufferCache<CudaBuffer> buffer_cache_; BufferCache<CudaBuffer> buffer_cache_;
size_t active_memory_{0}; size_t active_memory_{0};
size_t peak_memory_{0}; size_t peak_memory_{0};
std::vector<CudaStream> free_streams_; std::vector<cudaStream_t> free_streams_;
SmallSizePool scalar_pool_; SmallSizePool scalar_pool_;
}; };

View File

@@ -132,14 +132,18 @@ bool prepare_cudnn_plan(
void** data_ptrs, void** data_ptrs,
F&& execute) { F&& execute) {
int workspace_size = plan.getWorkspaceSize(); int workspace_size = plan.getWorkspaceSize();
void* workspace_ptr = nullptr;
if (workspace_size > 0) {
array workspace( array workspace(
workspace_size > 0 ? cu::malloc_async(workspace_size, encoder.stream()) cu::malloc_async(workspace_size, encoder.stream()),
: allocator::Buffer(nullptr),
{workspace_size}, {workspace_size},
uint8); uint8);
encoder.add_temporary(workspace);
workspace_ptr = gpu_ptr<void>(workspace);
}
auto args = cudnn_frontend::VariantPackBuilder() auto args = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(gpu_ptr<void>(workspace)) .setWorkspacePointer(workspace_ptr)
.setDataPointers(num_args, data_ptrs) .setDataPointers(num_args, data_ptrs)
.setUids(num_args, uids) .setUids(num_args, uids)
.build(); .build();
@@ -151,7 +155,6 @@ bool prepare_cudnn_plan(
return false; return false;
} }
encoder.add_temporary(workspace);
return true; return true;
} }

View File

@@ -31,7 +31,7 @@ struct KernelArgs {
} }
void append(const array& a) { void append(const array& a) {
append(reinterpret_cast<CUdeviceptr>(gpu_ptr<void>(a.buffer()))); append(reinterpret_cast<CUdeviceptr>(gpu_ptr<void>(a)));
} }
template <typename T> template <typename T>

View File

@@ -25,12 +25,15 @@ inline uint max_occupancy_block_dim(T kernel) {
template <typename T> template <typename T>
inline T* gpu_ptr(array& arr) { inline T* gpu_ptr(array& arr) {
return cu::gpu_ptr<T>(arr.buffer()); return reinterpret_cast<T*>(
static_cast<char*>(
static_cast<cu::CudaBuffer*>(arr.buffer().ptr())->data) +
arr.offset());
} }
template <typename T> template <typename T>
inline const T* gpu_ptr(const array& arr) { inline const T* gpu_ptr(const array& arr) {
return cu::gpu_ptr<T>(arr.buffer()); return gpu_ptr<T>(const_cast<array&>(arr));
} }
struct Dtype; struct Dtype;

View File

@@ -83,7 +83,7 @@ void Depends::eval_gpu(
void DynamicSlice::eval_gpu(const std::vector<array>& inputs, array& out) { void DynamicSlice::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("DynamicSlice::eval_gpu"); MLX_PROFILER_RANGE("DynamicSlice::eval_gpu");
if (out.size() == 0) { if (out.size() == 0) {
out.set_data(nullptr); out.set_data(allocator::malloc(0));
return; return;
} }
@@ -112,7 +112,7 @@ void DynamicSliceUpdate::eval_gpu(
array& out) { array& out) {
MLX_PROFILER_RANGE("DynamicSliceUpdate::eval_gpu"); MLX_PROFILER_RANGE("DynamicSliceUpdate::eval_gpu");
if (out.size() == 0) { if (out.size() == 0) {
out.set_data(nullptr); out.set_data(allocator::malloc(0));
return; return;
} }
@@ -209,7 +209,7 @@ void Slice::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Slice::eval_gpu"); MLX_PROFILER_RANGE("Slice::eval_gpu");
assert(inputs.size() == 1); assert(inputs.size() == 1);
if (out.size() == 0) { if (out.size() == 0) {
out.set_data(nullptr); out.set_data(allocator::malloc(0));
return; return;
} }
@@ -220,7 +220,7 @@ void Slice::eval_gpu(const std::vector<array>& inputs, array& out) {
void SliceUpdate::eval_gpu(const std::vector<array>& inputs, array& out) { void SliceUpdate::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
if (out.size() == 0) { if (out.size() == 0) {
out.set_data(nullptr); out.set_data(allocator::malloc(0));
return; return;
} }

View File

@@ -259,10 +259,7 @@ void CommandEncoder::set_input_array(
needs_barrier_ = needs_barrier_ =
needs_barrier_ | (prev_outputs_.find(r_buf) != prev_outputs_.end()); needs_barrier_ | (prev_outputs_.find(r_buf) != prev_outputs_.end());
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr()); auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
auto base_offset = a.data<char>() - enc_->setBuffer(a_buf, a.offset() + offset, idx);
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
base_offset += offset;
enc_->setBuffer(a_buf, base_offset, idx);
} }
void CommandEncoder::set_output_array( void CommandEncoder::set_output_array(
@@ -446,11 +443,9 @@ void Device::end_encoding(int index) {
auto& enc = *stream.encoder; auto& enc = *stream.encoder;
// Remove temporaries from inputs and outputs // Remove temporaries from inputs and outputs
for (auto& t : stream.temporaries) { for (auto& t : stream.temporaries) {
if (t.data<void>() != nullptr) {
enc.outputs().erase(t.buffer().ptr()); enc.outputs().erase(t.buffer().ptr());
enc.inputs().erase(t.buffer().ptr()); enc.inputs().erase(t.buffer().ptr());
} }
}
// Keep references to the fences we waited on and put them // Keep references to the fences we waited on and put them
// in the completion handler so they are not prematurely released // in the completion handler so they are not prematurely released

View File

@@ -31,7 +31,7 @@ struct FenceImpl {
auto p = metal::new_scoped_memory_pool(); auto p = metal::new_scoped_memory_pool();
static_cast<MTL::SharedEvent*>(fence)->release(); static_cast<MTL::SharedEvent*>(fence)->release();
} else { } else {
allocator::free(static_cast<MTL::Buffer*>(fence)); allocator::free(allocator::Buffer{static_cast<MTL::Buffer*>(fence)});
} }
} }
bool use_fast{false}; bool use_fast{false};

View File

@@ -768,9 +768,12 @@ void nd_fft_op(
const Stream& s) { const Stream& s) {
// Perform ND FFT on GPU as a series of 1D FFTs // Perform ND FFT on GPU as a series of 1D FFTs
auto temp_shape = inverse ? in.shape() : out.shape(); auto temp_shape = inverse ? in.shape() : out.shape();
array temp1(temp_shape, complex64, nullptr, {}); std::vector<array> temp_arrs;
array temp2(temp_shape, complex64, nullptr, {}); temp_arrs.emplace_back(temp_shape, complex64, nullptr, std::vector<array>{});
std::vector<array> temp_arrs = {temp1, temp2}; if (axes.size() > 2) {
temp_arrs.emplace_back(
temp_shape, complex64, nullptr, std::vector<array>{});
}
for (int i = axes.size() - 1; i >= 0; i--) { for (int i = axes.size() - 1; i >= 0; i--) {
int reverse_index = axes.size() - i - 1; int reverse_index = axes.size() - i - 1;
// For 5D and above, we don't want to reallocate our two temporary arrays // For 5D and above, we don't want to reallocate our two temporary arrays
@@ -781,8 +784,8 @@ void nd_fft_op(
// Mirror np.fft.(i)rfftn and perform a real transform // Mirror np.fft.(i)rfftn and perform a real transform
// only on the final axis. // only on the final axis.
bool step_real = (real && index == axes.size() - 1); bool step_real = (real && index == axes.size() - 1);
const array& in_arr = i == axes.size() - 1 ? in : temp_arrs[1 - i % 2]; const array& in_arr = i == axes.size() - 1 ? in : temp_arrs[i % 2];
array& out_arr = i == 0 ? out : temp_arrs[i % 2]; array& out_arr = i == 0 ? out : temp_arrs[1 - i % 2];
fft_op(in_arr, out_arr, axis, inverse, step_real, inplace, s); fft_op(in_arr, out_arr, axis, inverse, step_real, inplace, s);
} }

View File

@@ -227,7 +227,7 @@ array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s) {
throw std::runtime_error("[load] Failed to open " + in_stream->label()); throw std::runtime_error("[load] Failed to open " + in_stream->label());
} }
auto stream = to_stream(s, cu::is_available() ? Device::gpu : Device::cpu); auto stream = cu::is_available() ? to_stream(s) : to_stream(s, Device::cpu);
//////////////////////////////////////////////////////// ////////////////////////////////////////////////////////
// Read header and prepare array details // Read header and prepare array details

View File

@@ -114,7 +114,7 @@ SafetensorsLoad load_safetensors(
"[load_safetensors] Failed to open " + in_stream->label()); "[load_safetensors] Failed to open " + in_stream->label());
} }
auto stream = to_stream(s, cu::is_available() ? Device::gpu : Device::cpu); auto stream = cu::is_available() ? to_stream(s) : to_stream(s, Device::cpu);
uint64_t jsonHeaderLength = 0; uint64_t jsonHeaderLength = 0;
// This is the same limit as in the original Rust Safetensors code. // This is the same limit as in the original Rust Safetensors code.