diff --git a/mlx/allocator.h b/mlx/allocator.h index 362f4f08a..67a9245c4 100644 --- a/mlx/allocator.h +++ b/mlx/allocator.h @@ -14,7 +14,7 @@ class Buffer { void* ptr_; public: - Buffer(void* ptr) : ptr_(ptr) {}; + explicit Buffer(void* ptr) : ptr_(ptr) {}; // Get the raw data pointer from the buffer void* raw_ptr(); diff --git a/mlx/array.h b/mlx/array.h index 40a7aa87c..fda63b926 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -364,6 +364,10 @@ class array { return const_cast(*this).data(); } + int64_t offset() const { + return array_desc_->offset; + } + enum Status { // The output of a computation which has not been scheduled. // For example, the status of `x` in `auto x = a + b`. diff --git a/mlx/backend/common/broadcasting.cpp b/mlx/backend/common/broadcasting.cpp index 49bc75b8f..0bb52e096 100644 --- a/mlx/backend/common/broadcasting.cpp +++ b/mlx/backend/common/broadcasting.cpp @@ -6,7 +6,7 @@ namespace mlx::core { void broadcast(const array& in, array& out) { if (out.size() == 0) { - out.set_data(nullptr); + out.set_data(allocator::malloc(0)); return; } Strides strides(out.ndim(), 0); diff --git a/mlx/backend/common/slicing.cpp b/mlx/backend/common/slicing.cpp index 6f5736d63..38f3c1ba0 100644 --- a/mlx/backend/common/slicing.cpp +++ b/mlx/backend/common/slicing.cpp @@ -45,7 +45,7 @@ void slice( const Shape& start_indices, const Shape& strides) { if (out.size() == 0) { - out.set_data(nullptr); + out.set_data(allocator::malloc(0)); return; } diff --git a/mlx/backend/cpu/primitives.cpp b/mlx/backend/cpu/primitives.cpp index f2cb12fdd..d5b917b84 100644 --- a/mlx/backend/cpu/primitives.cpp +++ b/mlx/backend/cpu/primitives.cpp @@ -333,7 +333,7 @@ void Reshape::eval_cpu(const std::vector& inputs, array& out) { void DynamicSlice::eval_cpu(const std::vector& inputs, array& out) { if (out.size() == 0) { - out.set_data(nullptr); + out.set_data(allocator::malloc(0)); return; } auto& in = inputs[0]; @@ -361,7 +361,7 @@ void DynamicSliceUpdate::eval_cpu( const std::vector& inputs, array& out) { if (out.size() == 0) { - out.set_data(nullptr); + out.set_data(allocator::malloc(0)); return; } @@ -396,7 +396,7 @@ void DynamicSliceUpdate::eval_cpu( void SliceUpdate::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); if (out.size() == 0) { - out.set_data(nullptr); + out.set_data(allocator::malloc(0)); return; } diff --git a/mlx/backend/cuda/allocator.cpp b/mlx/backend/cuda/allocator.cpp index 67ff40e1e..7d1207cfa 100644 --- a/mlx/backend/cuda/allocator.cpp +++ b/mlx/backend/cuda/allocator.cpp @@ -97,11 +97,11 @@ CudaAllocator::CudaAllocator() int device_count = 0; CHECK_CUDA_ERROR(cudaGetDeviceCount(&device_count)); - int curr_device = 0; - CHECK_CUDA_ERROR(cudaGetDevice(&curr_device)); for (int i = 0; i < device_count; ++i) { - free_streams_.emplace_back( - cu::device(mlx::core::Device{mlx::core::Device::gpu, i})); + CHECK_CUDA_ERROR(cudaSetDevice(i)); + cudaStream_t s; + CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&s, cudaStreamNonBlocking)); + free_streams_.push_back(s); } } diff --git a/mlx/backend/cuda/allocator.h b/mlx/backend/cuda/allocator.h index df89d0982..0b62d8529 100644 --- a/mlx/backend/cuda/allocator.h +++ b/mlx/backend/cuda/allocator.h @@ -22,11 +22,6 @@ struct CudaBuffer { int device; // -1 for managed }; -template -inline T* gpu_ptr(Buffer buf) { - return static_cast(static_cast(buf.ptr())->data); -} - class SmallSizePool { private: union Block { @@ -79,7 +74,7 @@ class CudaAllocator : public allocator::Allocator { BufferCache buffer_cache_; size_t active_memory_{0}; size_t peak_memory_{0}; - std::vector free_streams_; + std::vector free_streams_; SmallSizePool scalar_pool_; }; diff --git a/mlx/backend/cuda/jit_module.h b/mlx/backend/cuda/jit_module.h index 801a699d5..ac2886306 100644 --- a/mlx/backend/cuda/jit_module.h +++ b/mlx/backend/cuda/jit_module.h @@ -31,7 +31,7 @@ struct KernelArgs { } void append(const array& a) { - append(reinterpret_cast(gpu_ptr(a.buffer()))); + append(reinterpret_cast(gpu_ptr(a))); } template diff --git a/mlx/backend/cuda/utils.h b/mlx/backend/cuda/utils.h index d52299fcf..417f7c8aa 100644 --- a/mlx/backend/cuda/utils.h +++ b/mlx/backend/cuda/utils.h @@ -25,12 +25,15 @@ inline uint max_occupancy_block_dim(T kernel) { template inline T* gpu_ptr(array& arr) { - return cu::gpu_ptr(arr.buffer()); + return reinterpret_cast( + static_cast( + static_cast(arr.buffer().ptr())->data) + + arr.offset()); } template inline const T* gpu_ptr(const array& arr) { - return cu::gpu_ptr(arr.buffer()); + return gpu_ptr(const_cast(arr)); } struct Dtype; diff --git a/mlx/backend/gpu/primitives.cpp b/mlx/backend/gpu/primitives.cpp index ee40799df..0138928c0 100644 --- a/mlx/backend/gpu/primitives.cpp +++ b/mlx/backend/gpu/primitives.cpp @@ -83,7 +83,7 @@ void Depends::eval_gpu( void DynamicSlice::eval_gpu(const std::vector& inputs, array& out) { MLX_PROFILER_RANGE("DynamicSlice::eval_gpu"); if (out.size() == 0) { - out.set_data(nullptr); + out.set_data(allocator::malloc(0)); return; } @@ -112,7 +112,7 @@ void DynamicSliceUpdate::eval_gpu( array& out) { MLX_PROFILER_RANGE("DynamicSliceUpdate::eval_gpu"); if (out.size() == 0) { - out.set_data(nullptr); + out.set_data(allocator::malloc(0)); return; } @@ -209,7 +209,7 @@ void Slice::eval_gpu(const std::vector& inputs, array& out) { MLX_PROFILER_RANGE("Slice::eval_gpu"); assert(inputs.size() == 1); if (out.size() == 0) { - out.set_data(nullptr); + out.set_data(allocator::malloc(0)); return; } @@ -220,7 +220,7 @@ void Slice::eval_gpu(const std::vector& inputs, array& out) { void SliceUpdate::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); if (out.size() == 0) { - out.set_data(nullptr); + out.set_data(allocator::malloc(0)); return; } diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index e82d734a2..e4d33c190 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -259,10 +259,7 @@ void CommandEncoder::set_input_array( needs_barrier_ = needs_barrier_ | (prev_outputs_.find(r_buf) != prev_outputs_.end()); auto a_buf = static_cast(a.buffer().ptr()); - auto base_offset = a.data() - - static_cast(const_cast(a_buf)->contents()); - base_offset += offset; - enc_->setBuffer(a_buf, base_offset, idx); + enc_->setBuffer(a_buf, a.offset() + offset, idx); } void CommandEncoder::set_output_array( @@ -446,10 +443,8 @@ void Device::end_encoding(int index) { auto& enc = *stream.encoder; // Remove temporaries from inputs and outputs for (auto& t : stream.temporaries) { - if (t.data() != nullptr) { - enc.outputs().erase(t.buffer().ptr()); - enc.inputs().erase(t.buffer().ptr()); - } + enc.outputs().erase(t.buffer().ptr()); + enc.inputs().erase(t.buffer().ptr()); } // Keep references to the fences we waited on and put them diff --git a/mlx/backend/metal/fence.cpp b/mlx/backend/metal/fence.cpp index 5abdf7309..1ae3f9c22 100644 --- a/mlx/backend/metal/fence.cpp +++ b/mlx/backend/metal/fence.cpp @@ -31,7 +31,7 @@ struct FenceImpl { auto p = metal::new_scoped_memory_pool(); static_cast(fence)->release(); } else { - allocator::free(static_cast(fence)); + allocator::free(allocator::Buffer{static_cast(fence)}); } } bool use_fast{false}; diff --git a/mlx/backend/metal/fft.cpp b/mlx/backend/metal/fft.cpp index d329a4685..d99e1badb 100644 --- a/mlx/backend/metal/fft.cpp +++ b/mlx/backend/metal/fft.cpp @@ -768,9 +768,12 @@ void nd_fft_op( const Stream& s) { // Perform ND FFT on GPU as a series of 1D FFTs auto temp_shape = inverse ? in.shape() : out.shape(); - array temp1(temp_shape, complex64, nullptr, {}); - array temp2(temp_shape, complex64, nullptr, {}); - std::vector temp_arrs = {temp1, temp2}; + std::vector temp_arrs; + temp_arrs.emplace_back(temp_shape, complex64, nullptr, std::vector{}); + if (axes.size() > 2) { + temp_arrs.emplace_back( + temp_shape, complex64, nullptr, std::vector{}); + } for (int i = axes.size() - 1; i >= 0; i--) { int reverse_index = axes.size() - i - 1; // For 5D and above, we don't want to reallocate our two temporary arrays @@ -781,8 +784,8 @@ void nd_fft_op( // Mirror np.fft.(i)rfftn and perform a real transform // only on the final axis. bool step_real = (real && index == axes.size() - 1); - const array& in_arr = i == axes.size() - 1 ? in : temp_arrs[1 - i % 2]; - array& out_arr = i == 0 ? out : temp_arrs[i % 2]; + const array& in_arr = i == axes.size() - 1 ? in : temp_arrs[i % 2]; + array& out_arr = i == 0 ? out : temp_arrs[1 - i % 2]; fft_op(in_arr, out_arr, axis, inverse, step_real, inplace, s); } diff --git a/mlx/io/load.cpp b/mlx/io/load.cpp index 8a6ca0852..92366be78 100644 --- a/mlx/io/load.cpp +++ b/mlx/io/load.cpp @@ -227,7 +227,7 @@ array load(std::shared_ptr in_stream, StreamOrDevice s) { throw std::runtime_error("[load] Failed to open " + in_stream->label()); } - auto stream = to_stream(s, cu::is_available() ? Device::gpu : Device::cpu); + auto stream = cu::is_available() ? to_stream(s) : to_stream(s, Device::cpu); //////////////////////////////////////////////////////// // Read header and prepare array details diff --git a/mlx/io/safetensors.cpp b/mlx/io/safetensors.cpp index e726fc359..738370f35 100644 --- a/mlx/io/safetensors.cpp +++ b/mlx/io/safetensors.cpp @@ -114,7 +114,7 @@ SafetensorsLoad load_safetensors( "[load_safetensors] Failed to open " + in_stream->label()); } - auto stream = to_stream(s, cu::is_available() ? Device::gpu : Device::cpu); + auto stream = cu::is_available() ? to_stream(s) : to_stream(s, Device::cpu); uint64_t jsonHeaderLength = 0; // This is the same limit as in the original Rust Safetensors code.