This commit is contained in:
Awni Hannun
2025-11-03 12:48:22 -08:00
parent 742033fefe
commit cc6df9fc8a
15 changed files with 39 additions and 39 deletions

View File

@@ -14,7 +14,7 @@ class Buffer {
void* ptr_;
public:
Buffer(void* ptr) : ptr_(ptr) {};
explicit Buffer(void* ptr) : ptr_(ptr) {};
// Get the raw data pointer from the buffer
void* raw_ptr();

View File

@@ -364,6 +364,10 @@ class array {
return const_cast<array&>(*this).data<T>();
}
int64_t offset() const {
return array_desc_->offset;
}
enum Status {
// The output of a computation which has not been scheduled.
// For example, the status of `x` in `auto x = a + b`.

View File

@@ -6,7 +6,7 @@ namespace mlx::core {
void broadcast(const array& in, array& out) {
if (out.size() == 0) {
out.set_data(nullptr);
out.set_data(allocator::malloc(0));
return;
}
Strides strides(out.ndim(), 0);

View File

@@ -45,7 +45,7 @@ void slice(
const Shape& start_indices,
const Shape& strides) {
if (out.size() == 0) {
out.set_data(nullptr);
out.set_data(allocator::malloc(0));
return;
}

View File

@@ -333,7 +333,7 @@ void Reshape::eval_cpu(const std::vector<array>& inputs, array& out) {
void DynamicSlice::eval_cpu(const std::vector<array>& inputs, array& out) {
if (out.size() == 0) {
out.set_data(nullptr);
out.set_data(allocator::malloc(0));
return;
}
auto& in = inputs[0];
@@ -361,7 +361,7 @@ void DynamicSliceUpdate::eval_cpu(
const std::vector<array>& inputs,
array& out) {
if (out.size() == 0) {
out.set_data(nullptr);
out.set_data(allocator::malloc(0));
return;
}
@@ -396,7 +396,7 @@ void DynamicSliceUpdate::eval_cpu(
void SliceUpdate::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
if (out.size() == 0) {
out.set_data(nullptr);
out.set_data(allocator::malloc(0));
return;
}

View File

@@ -97,11 +97,11 @@ CudaAllocator::CudaAllocator()
int device_count = 0;
CHECK_CUDA_ERROR(cudaGetDeviceCount(&device_count));
int curr_device = 0;
CHECK_CUDA_ERROR(cudaGetDevice(&curr_device));
for (int i = 0; i < device_count; ++i) {
free_streams_.emplace_back(
cu::device(mlx::core::Device{mlx::core::Device::gpu, i}));
CHECK_CUDA_ERROR(cudaSetDevice(i));
cudaStream_t s;
CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&s, cudaStreamNonBlocking));
free_streams_.push_back(s);
}
}

View File

@@ -22,11 +22,6 @@ struct CudaBuffer {
int device; // -1 for managed
};
template <typename T>
inline T* gpu_ptr(Buffer buf) {
return static_cast<T*>(static_cast<cu::CudaBuffer*>(buf.ptr())->data);
}
class SmallSizePool {
private:
union Block {
@@ -79,7 +74,7 @@ class CudaAllocator : public allocator::Allocator {
BufferCache<CudaBuffer> buffer_cache_;
size_t active_memory_{0};
size_t peak_memory_{0};
std::vector<CudaStream> free_streams_;
std::vector<cudaStream_t> free_streams_;
SmallSizePool scalar_pool_;
};

View File

@@ -31,7 +31,7 @@ struct KernelArgs {
}
void append(const array& a) {
append(reinterpret_cast<CUdeviceptr>(gpu_ptr<void>(a.buffer())));
append(reinterpret_cast<CUdeviceptr>(gpu_ptr<void>(a)));
}
template <typename T>

View File

@@ -25,12 +25,15 @@ inline uint max_occupancy_block_dim(T kernel) {
template <typename T>
inline T* gpu_ptr(array& arr) {
return cu::gpu_ptr<T>(arr.buffer());
return reinterpret_cast<T*>(
static_cast<char*>(
static_cast<cu::CudaBuffer*>(arr.buffer().ptr())->data) +
arr.offset());
}
template <typename T>
inline const T* gpu_ptr(const array& arr) {
return cu::gpu_ptr<T>(arr.buffer());
return gpu_ptr<T>(const_cast<array&>(arr));
}
struct Dtype;

View File

@@ -83,7 +83,7 @@ void Depends::eval_gpu(
void DynamicSlice::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("DynamicSlice::eval_gpu");
if (out.size() == 0) {
out.set_data(nullptr);
out.set_data(allocator::malloc(0));
return;
}
@@ -112,7 +112,7 @@ void DynamicSliceUpdate::eval_gpu(
array& out) {
MLX_PROFILER_RANGE("DynamicSliceUpdate::eval_gpu");
if (out.size() == 0) {
out.set_data(nullptr);
out.set_data(allocator::malloc(0));
return;
}
@@ -209,7 +209,7 @@ void Slice::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Slice::eval_gpu");
assert(inputs.size() == 1);
if (out.size() == 0) {
out.set_data(nullptr);
out.set_data(allocator::malloc(0));
return;
}
@@ -220,7 +220,7 @@ void Slice::eval_gpu(const std::vector<array>& inputs, array& out) {
void SliceUpdate::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
if (out.size() == 0) {
out.set_data(nullptr);
out.set_data(allocator::malloc(0));
return;
}

View File

@@ -259,10 +259,7 @@ void CommandEncoder::set_input_array(
needs_barrier_ =
needs_barrier_ | (prev_outputs_.find(r_buf) != prev_outputs_.end());
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
auto base_offset = a.data<char>() -
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
base_offset += offset;
enc_->setBuffer(a_buf, base_offset, idx);
enc_->setBuffer(a_buf, a.offset() + offset, idx);
}
void CommandEncoder::set_output_array(
@@ -446,11 +443,9 @@ void Device::end_encoding(int index) {
auto& enc = *stream.encoder;
// Remove temporaries from inputs and outputs
for (auto& t : stream.temporaries) {
if (t.data<void>() != nullptr) {
enc.outputs().erase(t.buffer().ptr());
enc.inputs().erase(t.buffer().ptr());
}
}
// Keep references to the fences we waited on and put them
// in the completion handler so they are not prematurely released

View File

@@ -31,7 +31,7 @@ struct FenceImpl {
auto p = metal::new_scoped_memory_pool();
static_cast<MTL::SharedEvent*>(fence)->release();
} else {
allocator::free(static_cast<MTL::Buffer*>(fence));
allocator::free(allocator::Buffer{static_cast<MTL::Buffer*>(fence)});
}
}
bool use_fast{false};

View File

@@ -768,9 +768,12 @@ void nd_fft_op(
const Stream& s) {
// Perform ND FFT on GPU as a series of 1D FFTs
auto temp_shape = inverse ? in.shape() : out.shape();
array temp1(temp_shape, complex64, nullptr, {});
array temp2(temp_shape, complex64, nullptr, {});
std::vector<array> temp_arrs = {temp1, temp2};
std::vector<array> temp_arrs;
temp_arrs.emplace_back(temp_shape, complex64, nullptr, std::vector<array>{});
if (axes.size() > 2) {
temp_arrs.emplace_back(
temp_shape, complex64, nullptr, std::vector<array>{});
}
for (int i = axes.size() - 1; i >= 0; i--) {
int reverse_index = axes.size() - i - 1;
// For 5D and above, we don't want to reallocate our two temporary arrays
@@ -781,8 +784,8 @@ void nd_fft_op(
// Mirror np.fft.(i)rfftn and perform a real transform
// only on the final axis.
bool step_real = (real && index == axes.size() - 1);
const array& in_arr = i == axes.size() - 1 ? in : temp_arrs[1 - i % 2];
array& out_arr = i == 0 ? out : temp_arrs[i % 2];
const array& in_arr = i == axes.size() - 1 ? in : temp_arrs[i % 2];
array& out_arr = i == 0 ? out : temp_arrs[1 - i % 2];
fft_op(in_arr, out_arr, axis, inverse, step_real, inplace, s);
}

View File

@@ -227,7 +227,7 @@ array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s) {
throw std::runtime_error("[load] Failed to open " + in_stream->label());
}
auto stream = to_stream(s, cu::is_available() ? Device::gpu : Device::cpu);
auto stream = cu::is_available() ? to_stream(s) : to_stream(s, Device::cpu);
////////////////////////////////////////////////////////
// Read header and prepare array details

View File

@@ -114,7 +114,7 @@ SafetensorsLoad load_safetensors(
"[load_safetensors] Failed to open " + in_stream->label());
}
auto stream = to_stream(s, cu::is_available() ? Device::gpu : Device::cpu);
auto stream = cu::is_available() ? to_stream(s) : to_stream(s, Device::cpu);
uint64_t jsonHeaderLength = 0;
// This is the same limit as in the original Rust Safetensors code.