refactor for regular cuda malloc

This commit is contained in:
Awni Hannun
2025-10-31 14:12:15 -07:00
parent b84fc978d3
commit d378567cc6
55 changed files with 370 additions and 294 deletions

View File

@@ -64,7 +64,7 @@ array array::unsafe_weak_copy(const array& other) {
other.strides(),
other.flags(),
[](auto) {});
cpy.array_desc_->data_ptr = other.array_desc_->data_ptr;
cpy.array_desc_->offset = other.array_desc_->offset;
return cpy;
}
@@ -141,7 +141,7 @@ bool array::is_tracer() const {
void array::set_data(allocator::Buffer buffer, Deleter d) {
array_desc_->data = std::make_shared<Data>(buffer, d);
array_desc_->data_ptr = buffer.raw_ptr();
array_desc_->offset = 0;
array_desc_->data_size = size();
array_desc_->flags.contiguous = true;
array_desc_->flags.row_contiguous = true;
@@ -156,7 +156,7 @@ void array::set_data(
Flags flags,
Deleter d) {
array_desc_->data = std::make_shared<Data>(buffer, d);
array_desc_->data_ptr = buffer.raw_ptr();
array_desc_->offset = 0;
array_desc_->data_size = data_size;
array_desc_->strides = std::move(strides);
array_desc_->flags = flags;
@@ -172,9 +172,8 @@ void array::copy_shared_buffer(
array_desc_->strides = strides;
array_desc_->flags = flags;
array_desc_->data_size = data_size;
auto char_offset = sizeof(char) * itemsize() * offset;
array_desc_->data_ptr = static_cast<void*>(
static_cast<char*>(other.array_desc_->data_ptr) + char_offset);
array_desc_->offset =
sizeof(char) * itemsize() * offset + other.array_desc_->offset;
}
void array::copy_shared_buffer(const array& other) {

View File

@@ -352,12 +352,13 @@ class array {
// Return a raw pointer to the arrays data
template <typename T>
T* data() {
return static_cast<T*>(array_desc_->data_ptr);
return reinterpret_cast<T*>(
(static_cast<char*>(buffer().raw_ptr()) + array_desc_->offset));
}
template <typename T>
const T* data() const {
return static_cast<T*>(array_desc_->data_ptr);
return const_cast<array&>(*this).data<T>();
}
enum Status {
@@ -461,8 +462,8 @@ class array {
// can share the underlying data buffer.
std::shared_ptr<Data> data;
// Properly offset data pointer
void* data_ptr{nullptr};
// Offset from beginning of data pointer
int64_t offset{0};
// The size in elements of the data buffer the array accesses
size_t data_size;

View File

@@ -38,20 +38,20 @@ inline void set_binary_op_output_data(
const array& a,
const array& b,
array& out,
BinaryOpType bopt) {
BinaryOpType bopt,
std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {
bool b_donatable = is_donatable(b, out);
bool a_donatable = is_donatable(a, out);
switch (bopt) {
case BinaryOpType::ScalarScalar:
out.set_data(
allocator::malloc(out.itemsize()), 1, a.strides(), a.flags());
out.set_data(mallocfn(out.itemsize()), 1, a.strides(), a.flags());
break;
case BinaryOpType::ScalarVector:
if (b_donatable) {
out.copy_shared_buffer(b);
} else {
out.set_data(
allocator::malloc(b.data_size() * out.itemsize()),
mallocfn(b.data_size() * out.itemsize()),
b.data_size(),
b.strides(),
b.flags());
@@ -62,7 +62,7 @@ inline void set_binary_op_output_data(
out.copy_shared_buffer(a);
} else {
out.set_data(
allocator::malloc(a.data_size() * out.itemsize()),
mallocfn(a.data_size() * out.itemsize()),
a.data_size(),
a.strides(),
a.flags());
@@ -75,7 +75,7 @@ inline void set_binary_op_output_data(
out.copy_shared_buffer(b);
} else {
out.set_data(
allocator::malloc(a.data_size() * out.itemsize()),
mallocfn(a.data_size() * out.itemsize()),
a.data_size(),
a.strides(),
a.flags());
@@ -88,7 +88,7 @@ inline void set_binary_op_output_data(
b_donatable && b.flags().row_contiguous && b.size() == out.size()) {
out.copy_shared_buffer(b);
} else {
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(mallocfn(out.nbytes()));
}
break;
}

View File

@@ -114,7 +114,9 @@ void compiled_allocate_outputs(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::function<bool(size_t)>& is_constant,
bool contiguous) {
bool contiguous,
const std::function<allocator::Buffer(size_t)>&
mallocfn /* = allocator::malloc */) {
if (contiguous) {
int o = 0;
Strides strides;
@@ -140,7 +142,7 @@ void compiled_allocate_outputs(
}
for (; o < outputs.size(); ++o) {
outputs[o].set_data(
allocator::malloc(data_size * outputs[o].itemsize()),
mallocfn(data_size * outputs[o].itemsize()),
data_size,
strides,
flags);
@@ -163,7 +165,7 @@ void compiled_allocate_outputs(
}
}
for (; o < outputs.size(); ++o) {
outputs[o].set_data(allocator::malloc(outputs[o].nbytes()));
outputs[o].set_data(mallocfn(outputs[o].nbytes()));
}
}
}

View File

@@ -58,7 +58,9 @@ void compiled_allocate_outputs(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::function<bool(size_t)>& is_constant,
bool contiguous);
bool contiguous,
const std::function<allocator::Buffer(size_t)>& mallocfn =
allocator::malloc);
// Collapse contiguous dims ignoring scalars and constants.
std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(

View File

@@ -22,7 +22,11 @@ enum class CopyType {
GeneralGeneral
};
inline bool set_copy_output_data(const array& in, array& out, CopyType ctype) {
inline bool set_copy_output_data(
const array& in,
array& out,
CopyType ctype,
std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {
if (ctype == CopyType::Vector) {
// If the input is donateable, we are doing a vector copy and the types
// have the same size, then the input buffer can hold the output.
@@ -31,14 +35,14 @@ inline bool set_copy_output_data(const array& in, array& out, CopyType ctype) {
return true;
} else {
out.set_data(
allocator::malloc(in.data_size() * out.itemsize()),
mallocfn(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
return false;
}
} else {
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(mallocfn(out.nbytes()));
return false;
}
}

View File

@@ -46,7 +46,8 @@ inline void set_ternary_op_output_data(
const array& b,
const array& c,
array& out,
TernaryOpType topt) {
TernaryOpType topt,
std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {
auto maybe_donate = [&out](const array& x) {
if (is_donatable(x, out)) {
out.copy_shared_buffer(x);
@@ -57,13 +58,12 @@ inline void set_ternary_op_output_data(
switch (topt) {
case TernaryOpType::ScalarScalarScalar:
out.set_data(
allocator::malloc(out.itemsize()), 1, b.strides(), b.flags());
out.set_data(mallocfn(out.itemsize()), 1, b.strides(), b.flags());
break;
case TernaryOpType::VectorVectorVector:
if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) {
out.set_data(
allocator::malloc(out.itemsize() * b.data_size()),
mallocfn(out.itemsize() * b.data_size()),
b.data_size(),
b.strides(),
b.flags());
@@ -76,7 +76,7 @@ inline void set_ternary_op_output_data(
if (!((a.flags().row_contiguous && maybe_donate(a)) ||
(b.flags().row_contiguous && maybe_donate(b)) ||
(c.flags().row_contiguous && maybe_donate(c)))) {
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(mallocfn(out.nbytes()));
}
break;
}

View File

@@ -7,19 +7,22 @@
namespace mlx::core {
inline void set_unary_output_data(const array& in, array& out) {
inline void set_unary_output_data(
const array& in,
array& out,
std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {
if (in.flags().contiguous) {
if (is_donatable(in, out)) {
out.copy_shared_buffer(in);
} else {
out.set_data(
allocator::malloc(in.data_size() * out.itemsize()),
mallocfn(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
}
} else {
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(mallocfn(out.nbytes()));
}
}

View File

@@ -68,6 +68,7 @@ CudaBuffer* SmallSizePool::malloc() {
next_free_ = next_free_->next;
b->buf.data = static_cast<char*>(data_) + i * small_block_size;
b->buf.size = small_block_size;
b->buf.managed = true;
return &b->buf;
}
@@ -94,16 +95,13 @@ CudaAllocator::CudaAllocator()
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
memory_limit_ = total * 0.95;
max_pool_size_ = memory_limit_;
#if CUDART_VERSION >= 13000
cudaMemLocation loc;
loc.id = 0;
loc.type = cudaMemLocationTypeNone;
cudaMemGetDefaultMemPool(&cuda_pool_, &loc, cudaMemAllocationTypeManaged);
int loc = 0;
cudaDeviceGetDefaultMemPool(&cuda_pool_, loc);
// TODO need a strategy for that
uint64_t threshold = UINT64_MAX;
cudaMemPoolSetAttribute(
cuda_pool_, cudaMemPoolAttrReleaseThreshold, &threshold);
#endif
}
Buffer CudaAllocator::malloc_impl(size_t size, cudaStream_t stream) {
@@ -133,12 +131,13 @@ Buffer CudaAllocator::malloc_impl(size_t size, cudaStream_t stream) {
}
lock.unlock();
if (!buf) {
buf = new CudaBuffer{nullptr, size};
bool managed = stream == nullptr;
buf = new CudaBuffer{nullptr, size, managed};
cudaError_t err;
if (stream != nullptr && cuda_pool_ != nullptr) {
err = cudaMallocFromPoolAsync(&buf->data, size, cuda_pool_, stream);
} else {
if (managed) {
err = cudaMallocManaged(&buf->data, size);
} else {
err = cudaMallocFromPoolAsync(&buf->data, size, cuda_pool_, stream);
}
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
throw std::runtime_error(fmt::format(
@@ -266,7 +265,17 @@ void* Buffer::raw_ptr() {
if (!ptr_) {
return nullptr;
}
return static_cast<cu::CudaBuffer*>(ptr_)->data;
auto& cbuf = *static_cast<cu::CudaBuffer*>(ptr_);
if (!cbuf.managed) {
void* new_data;
CHECK_CUDA_ERROR(cudaMallocManaged(&new_data, cbuf.size));
cbuf.managed = true;
CHECK_CUDA_ERROR(
cudaMemcpy(new_data, cbuf.data, cbuf.size, cudaMemcpyDefault));
CHECK_CUDA_ERROR(cudaFree(cbuf.data));
cbuf.data = new_data;
}
return cbuf.data;
}
} // namespace allocator

View File

@@ -18,8 +18,14 @@ using allocator::Buffer;
struct CudaBuffer {
void* data;
size_t size;
bool managed;
};
template <typename T>
T* gpu_ptr(Buffer buf) {
return static_cast<T*>(static_cast<cu::CudaBuffer*>(buf.ptr())->data);
}
class SmallSizePool {
private:
union Block {

View File

@@ -57,7 +57,7 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
num_blocks,
block_dims,
0,
out.data<OutType>(),
gpu_ptr<OutType>(out),
out.data_size(),
static_cast<CTYPE>(start_),
static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_));

View File

@@ -173,8 +173,8 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
num_blocks,
block_dim(),
0,
in.data<T>(),
out.data<uint32_t>(),
gpu_ptr<T>(in),
gpu_ptr<uint32_t>(out),
out.size(),
const_param(shape),
const_param(in_strides),

View File

@@ -292,9 +292,9 @@ void binary_op_gpu_inplace(
{num_blocks_x, num_blocks_y},
block_dims,
0,
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
gpu_ptr<InType>(a),
gpu_ptr<InType>(b),
gpu_ptr<OutType>(out),
rest,
const_param<dims_constant()>(shape),
const_param<dims_constant()>(a_strides),
@@ -310,9 +310,9 @@ void binary_op_gpu_inplace(
{num_blocks_x, num_blocks_y},
block_dims,
0,
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
gpu_ptr<InType>(a),
gpu_ptr<InType>(b),
gpu_ptr<OutType>(out),
rest,
const_param(shape),
const_param(a_strides),
@@ -339,9 +339,9 @@ void binary_op_gpu_inplace(
num_blocks,
block_dims,
0,
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
gpu_ptr<InType>(a),
gpu_ptr<InType>(b),
gpu_ptr<OutType>(out),
out.data_size());
});
}
@@ -365,7 +365,11 @@ void binary_op_gpu(
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cu::get_command_encoder(s);
set_binary_op_output_data(a, b, out, bopt, [&](auto n) {
return cu::malloc_async(n, encoder.stream());
});
binary_op_gpu_inplace<Op>(inputs, out, op, s);
}

View File

@@ -245,14 +245,18 @@ void binary_two_op_gpu_inplace(
auto& out_a = outputs[0];
auto& out_b = outputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out_a, bopt);
set_binary_op_output_data(a, b, out_b, bopt);
auto& encoder = cu::get_command_encoder(s);
set_binary_op_output_data(a, b, out_a, bopt, [&](auto n) {
return cu::malloc_async(n, encoder.stream());
});
set_binary_op_output_data(a, b, out_b, bopt, [&](auto n) {
return cu::malloc_async(n, encoder.stream());
});
if (out_a.size() == 0) {
return;
}
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out_a);
@@ -313,10 +317,10 @@ void binary_two_op_gpu_inplace(
{num_blocks_x, num_blocks_y},
block_dims,
0,
a.data<InType>(),
b.data<InType>(),
out_a.data<OutType>(),
out_b.data<OutType>(),
gpu_ptr<InType>(a),
gpu_ptr<InType>(b),
gpu_ptr<OutType>(out_a),
gpu_ptr<OutType>(out_b),
rest,
const_param<dims_constant()>(shape),
const_param<dims_constant()>(a_strides),
@@ -332,10 +336,10 @@ void binary_two_op_gpu_inplace(
{num_blocks_x, num_blocks_y},
block_dims,
0,
a.data<InType>(),
b.data<InType>(),
out_a.data<OutType>(),
out_b.data<OutType>(),
gpu_ptr<InType>(a),
gpu_ptr<InType>(b),
gpu_ptr<OutType>(out_a),
gpu_ptr<OutType>(out_b),
rest,
const_param(shape),
const_param(a_strides),
@@ -366,10 +370,10 @@ void binary_two_op_gpu_inplace(
num_blocks,
block_dims,
0,
a.data<InType>(),
b.data<InType>(),
out_a.data<OutType>(),
out_b.data<OutType>(),
gpu_ptr<InType>(a),
gpu_ptr<InType>(b),
gpu_ptr<OutType>(out_a),
gpu_ptr<OutType>(out_b),
out_a.data_size());
});
}

View File

@@ -270,17 +270,16 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
if (out_.size() == 0) {
return;
}
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
assert(inputs.size() == 2);
array in = inputs[0];
array wt = inputs[1];
array out = out_;
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
Dtype dtype = out.dtype();
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
// Search cache.
ConvCacheKey cache_key{
encoder.device().cuda_device(),

View File

@@ -86,7 +86,7 @@ array unfold_inputs_nd(
int mat_N,
ConvParams<NDIM>& params) {
array unfolded({mat_M, mat_K}, in.dtype(), nullptr, {});
unfolded.set_data(allocator::malloc(unfolded.nbytes()));
unfolded.set_data(cu::malloc_async(unfolded.nbytes(), encoder.stream()));
encoder.add_temporary(unfolded);
int filter_size = params.C;
@@ -118,8 +118,8 @@ array unfold_inputs_nd(
num_blocks,
block_dims,
0,
in.data<DataType>(),
unfolded.data<DataType>(),
gpu_ptr<DataType>(in),
gpu_ptr<DataType>(unfolded),
filter_size,
out_pixels,
params);

View File

@@ -89,7 +89,7 @@ array grouped_unfold_transpose_inputs_nd(
int mat_N,
ConvParams<NDIM>& params) {
array unfolded({mat_M, mat_K * params.groups}, in.dtype(), nullptr, {});
unfolded.set_data(allocator::malloc(unfolded.nbytes()));
unfolded.set_data(cu::malloc_async(unfolded.nbytes(), encoder.stream()));
encoder.add_temporary(unfolded);
int filter_size = params.C;
@@ -121,8 +121,8 @@ array grouped_unfold_transpose_inputs_nd(
num_blocks,
block_dims,
0,
in.data<DataType>(),
unfolded.data<DataType>(),
gpu_ptr<DataType>(in),
gpu_ptr<DataType>(unfolded),
filter_size,
out_pixels,
params);

View File

@@ -5,6 +5,22 @@
namespace mlx::core {
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
auto& encoder = cu::get_command_encoder(s);
bool donated = set_copy_output_data(in, out, ctype, [&](auto n) {
return cu::malloc_async(n, encoder.stream());
});
if (donated && in.dtype() == out.dtype()) {
// If the output has the same type as the input then there is nothing to
// copy, just use the buffer.
return;
}
if (ctype == CopyType::GeneralGeneral) {
ctype = CopyType::General;
}
copy_gpu_inplace(in, out, ctype, s);
}
void copy_gpu_inplace(
const array& in,
array& out,

View File

@@ -77,8 +77,8 @@ void copy_contiguous(
num_blocks,
block_dims,
0,
in.data<InType>() + in_offset,
out.data<OutType>() + out_offset,
gpu_ptr<InType>(in) + in_offset,
gpu_ptr<OutType>(out) + out_offset,
out.data_size());
});
});

View File

@@ -106,8 +106,8 @@ void copy_general(
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
const InType* in_ptr = in.data<InType>() + offset_in;
OutType* out_ptr = out.data<OutType>() + offset_out;
const InType* in_ptr = gpu_ptr<InType>(in) + offset_in;
OutType* out_ptr = gpu_ptr<OutType>(out) + offset_out;
int ndim = shape.size();
size_t data_size = 1;
for (auto& s : shape)

View File

@@ -69,8 +69,8 @@ void copy_general_dynamic(
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
const InType* in_ptr = in.data<InType>() + offset_in;
OutType* out_ptr = out.data<OutType>() + offset_out;
const InType* in_ptr = gpu_ptr<InType>(in) + offset_in;
OutType* out_ptr = gpu_ptr<OutType>(out) + offset_out;
int ndim = shape.size();
if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto dims_constant) {
@@ -90,8 +90,8 @@ void copy_general_dynamic(
const_param<dims_constant()>(shape),
const_param<dims_constant()>(strides_in),
const_param<dims_constant()>(strides_out),
dynamic_offset_in.data<int64_t>(),
dynamic_offset_out.data<int64_t>());
gpu_ptr<int64_t>(dynamic_offset_in),
gpu_ptr<int64_t>(dynamic_offset_out));
});
} else { // ndim >= 4
auto [num_blocks, block_dims] = get_launch_args(out, large());
@@ -107,8 +107,8 @@ void copy_general_dynamic(
const_param(strides_in),
const_param(strides_out),
ndim,
dynamic_offset_in.data<int64_t>(),
dynamic_offset_out.data<int64_t>());
gpu_ptr<int64_t>(dynamic_offset_in),
gpu_ptr<int64_t>(dynamic_offset_out));
}
});
});

View File

@@ -92,8 +92,8 @@ void copy_general_input(
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
const InType* in_ptr = in.data<InType>() + offset_in;
OutType* out_ptr = out.data<OutType>() + offset_out;
const InType* in_ptr = gpu_ptr<InType>(in) + offset_in;
OutType* out_ptr = gpu_ptr<OutType>(out) + offset_out;
int ndim = shape.size();
int work_per_thread = 1;
auto dim0 = ndim > 0 ? shape.back() : 1;

View File

@@ -133,13 +133,13 @@ bool prepare_cudnn_plan(
F&& execute) {
int workspace_size = plan.getWorkspaceSize();
array workspace(
workspace_size > 0 ? allocator::malloc(workspace_size)
workspace_size > 0 ? cu::malloc_async(workspace_size, encoder.stream())
: allocator::Buffer(nullptr),
{workspace_size},
uint8);
auto args = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace.data<void>())
.setWorkspacePointer(gpu_ptr<void>(workspace))
.setDataPointers(num_args, data_ptrs)
.setUids(num_args, uids)
.build();

View File

@@ -3,6 +3,7 @@
#pragma once
#include "mlx/array.h"
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/device/config.h"
#include "mlx/backend/cuda/utils.h"
#include "mlx/dtype_utils.h"
@@ -23,7 +24,7 @@ class CommandEncoder;
// Return pointer alignment of |x|'s data.
inline uint8_t get_alignment(const array& x) {
uint8_t alignment = 1;
uintptr_t address = reinterpret_cast<uintptr_t>(x.data<void>());
uintptr_t address = reinterpret_cast<uintptr_t>(gpu_ptr<void>(x));
for (; alignment < 32; alignment *= 2) {
if (address % (alignment * 2)) {
return alignment;
@@ -56,7 +57,7 @@ inline std::array<T, MAX_NDIM> vector_key(const Vec<T>& vec) {
// Helpers used by get_data_ptrs to get pointers.
inline void* get_data_ptr(const array& arr) {
return const_cast<void*>(arr.data<void>());
return const_cast<void*>(gpu_ptr<void>(arr));
}
template <typename T, typename = std::enable_if_t<std::is_scalar_v<T>>>

View File

@@ -279,6 +279,7 @@ void CustomKernel::eval_gpu(
std::vector<array>& outputs) {
nvtx3::scoped_range r("CustomKernel::eval_gpu");
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
std::vector<array> copies;
@@ -288,7 +289,7 @@ void CustomKernel::eval_gpu(
copies.emplace_back(init_value_.value(), out.dtype());
fill_gpu(copies.back(), out, s);
} else {
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
}
}
@@ -356,7 +357,6 @@ void CustomKernel::eval_gpu(
dim3 grid((gx + tx - 1) / tx, (gy + ty - 1) / ty, (gz + tz - 1) / tz);
// Call the kernel
auto& encoder = cu::get_command_encoder(s);
for (const auto& in : checked_inputs) {
encoder.set_input_array(in);
}

View File

@@ -15,8 +15,10 @@ void AllReduce::eval_gpu(
assert(inputs.size() == 1);
assert(outputs.size() == 1);
auto set_input_output =
[s = stream()](const array& in, array& out) -> std::pair<array, array> {
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
auto set_input_output = [&](const array& in,
array& out) -> std::pair<array, array> {
if (!in.flags().row_contiguous) {
copy_gpu(in, out, CopyType::General, s);
return {out, out};
@@ -24,19 +26,17 @@ void AllReduce::eval_gpu(
out.copy_shared_buffer(in);
return {in, out};
} else {
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
return {in, out};
}
};
auto [input, output] = set_input_output(inputs[0], outputs[0]);
auto& encoder = cu::get_command_encoder(stream());
encoder.set_input_array(input);
encoder.set_output_array(output);
auto capture = encoder.capture_context();
auto& s = stream();
switch (reduce_type_) {
case Sum:

View File

@@ -241,7 +241,7 @@ void CublasGemm::set_bias(cu::CommandEncoder& encoder, const array& bias) {
CUBLASLT_MATMUL_DESC_EPILOGUE,
&epilogue,
sizeof(epilogue)));
auto* bias_ptr = bias.data<void>();
auto* bias_ptr = gpu_ptr<void>(bias);
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
CUBLASLT_MATMUL_DESC_BIAS_POINTER,
@@ -278,9 +278,9 @@ void CublasGemm::run(
execute(
encoder,
out.data<void>(),
a.data<void>(),
b.data<void>(),
gpu_ptr<void>(out),
gpu_ptr<void>(a),
gpu_ptr<void>(b),
nullptr,
alpha);
}
@@ -321,10 +321,10 @@ void CublasGemm::run(
execute(
encoder,
out.data<void>(),
a.data<void>(),
b.data<void>(),
c.data<void>(),
gpu_ptr<void>(out),
gpu_ptr<void>(a),
gpu_ptr<void>(b),
gpu_ptr<void>(c),
alpha,
beta);
}
@@ -374,7 +374,7 @@ void CublasGemm::execute(
{static_cast<int>(heuristic_.workspaceSize)},
int8);
encoder.add_temporary(workspace);
workspace_ptr = workspace.data<void>();
workspace_ptr = gpu_ptr<void>(workspace);
}
auto capture = encoder.capture_context();

View File

@@ -25,9 +25,10 @@ void CublasGemm::run_batched(
for (size_t i = 0; i < nbatch; ++i) {
execute(
encoder,
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
a.data<int8_t>() + a.itemsize() * a_it.loc,
b.data<int8_t>() + b.itemsize() * b_it.loc,
gpu_ptr<int8_t>(out) +
out.itemsize() * i * batch_shape.back() * M_ * N_,
gpu_ptr<int8_t>(a) + a.itemsize() * a_it.loc,
gpu_ptr<int8_t>(b) + b.itemsize() * b_it.loc,
nullptr,
alpha);
a_it.step();
@@ -60,10 +61,11 @@ void CublasGemm::run_batched(
for (size_t i = 0; i < nbatch; ++i) {
execute(
encoder,
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
a.data<int8_t>() + a.itemsize() * a_it.loc,
b.data<int8_t>() + b.itemsize() * b_it.loc,
c.data<int8_t>() + c.itemsize() * c_it.loc,
gpu_ptr<int8_t>(out) +
out.itemsize() * i * batch_shape.back() * M_ * N_,
gpu_ptr<int8_t>(a) + a.itemsize() * a_it.loc,
gpu_ptr<int8_t>(b) + b.itemsize() * b_it.loc,
gpu_ptr<int8_t>(c) + c.itemsize() * c_it.loc,
alpha,
beta);
a_it.step();

View File

@@ -183,10 +183,10 @@ void CublasGemm::run_batched(
num_blocks,
block_dims,
0,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
out.data<int8_t>(),
gpu_ptr<int8_t*>(pointers),
gpu_ptr<int8_t>(a),
gpu_ptr<int8_t>(b),
gpu_ptr<int8_t>(out),
item_size,
const_param<ndim_constant()>(batch_shape),
const_param<ndim_constant()>(a_batch_strides),
@@ -200,10 +200,10 @@ void CublasGemm::run_batched(
num_blocks,
block_dims,
0,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
out.data<int8_t>(),
gpu_ptr<int8_t*>(pointers),
gpu_ptr<int8_t>(a),
gpu_ptr<int8_t>(b),
gpu_ptr<int8_t>(out),
item_size,
const_param(batch_shape),
const_param(a_batch_strides),
@@ -219,7 +219,7 @@ void CublasGemm::run_batched(
encoder.set_input_array(b);
encoder.set_output_array(out);
auto a_pointers = pointers.data<int8_t*>();
auto a_pointers = gpu_ptr<int8_t*>(pointers);
auto b_pointers = a_pointers + batch_count;
auto out_pointers = b_pointers + batch_count;
execute(
@@ -271,11 +271,11 @@ void CublasGemm::run_batched(
num_blocks,
block_dims,
0,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
c.data<int8_t>(),
out.data<int8_t>(),
gpu_ptr<int8_t*>(pointers),
gpu_ptr<int8_t>(a),
gpu_ptr<int8_t>(b),
gpu_ptr<int8_t>(c),
gpu_ptr<int8_t>(out),
item_size,
const_param<ndim_constant()>(batch_shape),
const_param<ndim_constant()>(a_batch_strides),
@@ -290,11 +290,11 @@ void CublasGemm::run_batched(
num_blocks,
block_dims,
0,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
c.data<int8_t>(),
out.data<int8_t>(),
gpu_ptr<int8_t*>(pointers),
gpu_ptr<int8_t>(a),
gpu_ptr<int8_t>(b),
gpu_ptr<int8_t>(c),
gpu_ptr<int8_t>(out),
item_size,
const_param(batch_shape),
const_param(a_batch_strides),
@@ -312,7 +312,7 @@ void CublasGemm::run_batched(
encoder.set_input_array(c);
encoder.set_output_array(out);
auto a_pointers = pointers.data<int8_t*>();
auto a_pointers = gpu_ptr<int8_t*>(pointers);
auto b_pointers = a_pointers + batch_count;
auto c_pointers = b_pointers + batch_count;
auto out_pointers = c_pointers + batch_count;

View File

@@ -149,13 +149,13 @@ void gemv(
auto vec_strides = const_param(b_batch_strides);
if (M == 1) {
mat = b.data<DataType>();
vec = a.data<DataType>();
mat = gpu_ptr<DataType>(b);
vec = gpu_ptr<DataType>(a);
rows = N;
std::swap(mat_strides, vec_strides);
} else {
mat = a.data<DataType>();
vec = b.data<DataType>();
mat = gpu_ptr<DataType>(a);
vec = gpu_ptr<DataType>(b);
rows = M;
}
uint32_t num_blocks_x = (rows + rows_per_block - 1) / rows_per_block;
@@ -177,7 +177,7 @@ void gemv(
0,
mat,
vec,
out.data<DataType>(),
gpu_ptr<DataType>(out),
rows,
cols);
} else {
@@ -189,7 +189,7 @@ void gemv(
0,
mat,
vec,
out.data<DataType>(),
gpu_ptr<DataType>(out),
rows,
cols,
const_param(batch_shape),

View File

@@ -31,7 +31,7 @@ void append_indices_arg(
int idx_ndim) {
SmallVector<const void*> indices(nidx);
for (int i = 0; i < nidx; ++i) {
indices[i] = inputs[i + 1].data<void>();
indices[i] = gpu_ptr<void>(inputs[i + 1]);
}
args.append(std::move(indices));
SmallVector<int32_t> indices_shape(nidx * idx_ndim);

View File

@@ -31,7 +31,7 @@ struct KernelArgs {
}
void append(const array& a) {
append(reinterpret_cast<CUdeviceptr>(a.data<void>()));
append(reinterpret_cast<CUdeviceptr>(gpu_ptr<void>(a.buffer())));
}
template <typename T>

View File

@@ -9,6 +9,7 @@
#include <type_traits>
#include "mlx/array.h"
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/device/utils.cuh"
#include <cuda.h>

View File

@@ -280,10 +280,10 @@ void LayerNorm::eval_gpu(
n_rows,
block_dim(),
0,
x.data<DataType>(),
w.data<DataType>(),
b.data<DataType>(),
out.data<DataType>(),
gpu_ptr<DataType>(x),
gpu_ptr<DataType>(w),
gpu_ptr<DataType>(b),
gpu_ptr<DataType>(out),
eps_,
axis_size,
w_stride,
@@ -393,11 +393,11 @@ void LayerNormVJP::eval_gpu(
n_rows,
block_dim(),
0,
x.data<DataType>(),
w.data<DataType>(),
g.data<DataType>(),
gx.data<DataType>(),
gw_temp.data<DataType>(),
gpu_ptr<DataType>(x),
gpu_ptr<DataType>(w),
gpu_ptr<DataType>(g),
gpu_ptr<DataType>(gx),
gpu_ptr<DataType>(gw_temp),
eps_,
axis_size,
w_stride);

View File

@@ -151,8 +151,8 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
n_rows,
block_dim(),
0,
in.data<DataType>(),
out.data<DataType>(),
gpu_ptr<DataType>(in),
gpu_ptr<DataType>(out),
axis_size);
});
});

View File

@@ -262,10 +262,10 @@ void affine_quantize(
num_blocks,
block_dims,
0,
w.data<T>(),
wq.data<uint8_t>(),
scales.data<T>(),
biases.data<T>(),
gpu_ptr<T>(w),
gpu_ptr<uint8_t>(wq),
gpu_ptr<T>(scales),
gpu_ptr<T>(biases),
w.size());
});
});
@@ -318,10 +318,10 @@ void affine_dequantize(
num_blocks,
block_dims,
0,
wq.data<uint8_t>(),
scales.data<T>(),
biases.data<T>(),
w.data<T>(),
gpu_ptr<uint8_t>(wq),
gpu_ptr<T>(scales),
gpu_ptr<T>(biases),
gpu_ptr<T>(w),
w.size());
});
});

View File

@@ -156,9 +156,9 @@ void fp_quantize(
num_blocks,
block_dims,
0,
w.data<T>(),
wq.data<uint8_t>(),
scales.data<uint8_t>(),
gpu_ptr<T>(w),
gpu_ptr<uint8_t>(wq),
gpu_ptr<uint8_t>(scales),
w.size());
} else {
throw std::runtime_error(
@@ -202,9 +202,9 @@ void fp_dequantize(
num_blocks,
block_dims,
0,
wq.data<uint8_t>(),
scales.data<T>(),
w.data<T>(),
gpu_ptr<uint8_t>(wq),
gpu_ptr<uint8_t>(scales),
gpu_ptr<T>(w),
w.size());
} else {
throw std::runtime_error(

View File

@@ -59,7 +59,7 @@ void fast::Quantize::eval_gpu(
auto scales = ensure_row_contiguous(inputs[1], enc, s);
auto& w = outputs[0];
w.set_data(allocator::malloc(w.nbytes()));
w.set_data(cu::malloc_async(w.nbytes(), enc.stream()));
if (mode_ == QuantizationMode::Affine) {
auto biases = ensure_row_contiguous(inputs[2], enc, s);
@@ -72,8 +72,8 @@ void fast::Quantize::eval_gpu(
auto& wq = outputs[0];
auto& scales = outputs[1];
wq.set_data(allocator::malloc(wq.nbytes()));
scales.set_data(allocator::malloc(scales.nbytes()));
wq.set_data(cu::malloc_async(wq.nbytes(), enc.stream()));
scales.set_data(cu::malloc_async(scales.nbytes(), enc.stream()));
if (mode_ == QuantizationMode::Affine) {
auto& biases = outputs[2];
biases.set_data(allocator::malloc(biases.nbytes()));

View File

@@ -143,7 +143,9 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
uint32_t elems_per_key = out.size() / num_keys;
uint32_t bytes_per_key = out.itemsize() * elems_per_key;
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
if (out.size() == 0) {
return;
}
@@ -152,8 +154,6 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
uint32_t half_size = out_per_key / 2;
bool odd = out_per_key % 2;
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(keys);
encoder.set_output_array(out);
dim3 grid_dims{num_keys, half_size + odd};
@@ -171,8 +171,8 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
grid,
block,
0,
keys.data<uint32_t>(),
out.data<uint8_t>(),
gpu_ptr<uint32_t>(keys),
gpu_ptr<uint8_t>(out),
grid_dims,
odd,
bytes_per_key);
@@ -182,8 +182,8 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
grid,
block,
0,
keys.data<uint32_t>(),
out.data<uint8_t>(),
gpu_ptr<uint32_t>(keys),
gpu_ptr<uint8_t>(out),
grid_dims,
odd,
bytes_per_key,

View File

@@ -66,7 +66,7 @@ void all_reduce(
Reduce::ReduceType reduce_type) {
constexpr int N_READS = 8;
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
auto get_args = [](size_t size, int N) {
int threads = std::min(512UL, (size + N - 1) / N);
@@ -100,14 +100,15 @@ void all_reduce(
Dtype dt = in.dtype();
// Cub doesn't like const pointers for load (sigh).
void* indata = const_cast<void*>(in.data<void>());
void* indata = const_cast<void*>(gpu_ptr<void>(in));
// Large array so allocate an intermediate and accumulate there
std::tie(blocks, threads, block_step) = get_args(insize, N_READS);
encoder.set_input_array(in);
if (blocks > 1) {
array intermediate({blocks}, out.dtype(), nullptr, {});
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
intermediate.set_data(
cu::malloc_async(intermediate.nbytes(), encoder.stream()));
encoder.add_temporary(intermediate);
encoder.set_output_array(intermediate);
dispatch_all_types(dt, [&](auto type_tag) {
@@ -122,14 +123,14 @@ void all_reduce(
threads,
0,
static_cast<T*>(indata),
intermediate.data<U>(),
gpu_ptr<U>(intermediate),
block_step,
insize);
});
});
// Set the input for the next step and recalculate the blocks
indata = intermediate.data<void>();
indata = gpu_ptr<void>(intermediate);
dt = intermediate.dtype();
insize = intermediate.size();
std::tie(blocks, threads, block_step) = get_args(insize, N_READS);
@@ -149,7 +150,7 @@ void all_reduce(
threads,
0,
static_cast<T*>(indata),
out.data<U>(),
gpu_ptr<U>(out),
block_step,
insize);
});

View File

@@ -250,7 +250,7 @@ void col_reduce_looped(
const cu::ColReduceArgs& args) {
// Allocate data for the output using in's layout to access them as
// contiguously as possible.
allocate_same_layout(out, in, axes);
allocate_same_layout(out, in, axes, encoder);
encoder.set_input_array(in);
encoder.set_output_array(out);
@@ -261,7 +261,7 @@ void col_reduce_looped(
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using U = typename cu::ReduceResult<OP, T>::type;
// Cub doesn't like const pointers for vectorized loads. (sigh)
T* indata = const_cast<T*>(in.data<T>());
T* indata = const_cast<T*>(gpu_ptr<T>(in));
constexpr int N_READS = 4;
constexpr int BM = 32;
@@ -276,7 +276,7 @@ void col_reduce_looped(
blocks,
0,
indata,
out.data<U>(),
gpu_ptr<U>(out),
static_cast<cu::ColReduceArgs>(args));
});
});
@@ -293,7 +293,7 @@ void col_reduce_small(
const cu::ColReduceArgs& args) {
// Allocate data for the output using in's layout to access them as
// contiguously as possible.
allocate_same_layout(out, in, axes);
allocate_same_layout(out, in, axes, encoder);
encoder.set_input_array(in);
encoder.set_output_array(out);
@@ -312,8 +312,8 @@ void col_reduce_small(
grid,
block,
0,
in.data<T>(),
out.data<U>(),
gpu_ptr<T>(in),
gpu_ptr<U>(out),
static_cast<cu::ColReduceArgs>(args),
out.size());
});

View File

@@ -28,7 +28,7 @@ void init_reduce(
Reduce::ReduceType reduce_type) {
// Allocate if needed
if (out.data_shared_ptr() == nullptr) {
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
}
encoder.set_output_array(out);
@@ -42,7 +42,7 @@ void init_reduce(
dim3 block(grid.x < 1024 ? grid.x : 1024, 1, 1);
grid.x = (grid.x + 1023) / 1024;
encoder.add_kernel_node(
kernel, grid, block, 0, out.data<U>(), out.size());
kernel, grid, block, 0, gpu_ptr<U>(out), out.size());
});
});
}

View File

@@ -5,6 +5,7 @@
#include <numeric>
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/utils.cuh"
#include <cooperative_groups.h>
@@ -92,9 +93,10 @@ block_reduce(Block block, Warp warp, T (&vals)[N], T* smem, Op op, T init) {
inline void allocate_same_layout(
array& out,
const array& in,
const std::vector<int>& axes) {
const std::vector<int>& axes,
cu::CommandEncoder& encoder) {
if (in.flags().row_contiguous) {
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
return;
}
@@ -133,7 +135,7 @@ inline void allocate_same_layout(
fl.col_contiguous = cc;
fl.contiguous = true;
out.set_data(
allocator::malloc(out.nbytes()),
cu::malloc_async(out.nbytes(), encoder.stream()),
data_size,
final_strides,
fl,

View File

@@ -238,7 +238,7 @@ void row_reduce_simple(
const ReductionPlan& plan) {
// Allocate data for the output using in's layout to avoid elem_to_loc in the
// kernel.
allocate_same_layout(out, in, axes);
allocate_same_layout(out, in, axes, encoder);
// TODO: If out.size() < 1024 which will be a common case then write this in
// 2 passes. Something like 32 * out.size() and then do a warp reduce.
@@ -268,10 +268,10 @@ void row_reduce_simple(
kernel = cu::row_reduce_simple<T, U, OP, N_READS, 2>;
}
T* indata = const_cast<T*>(in.data<T>());
T* indata = const_cast<T*>(gpu_ptr<T>(in));
int size = plan.shape.back();
encoder.add_kernel_node(
kernel, grid, block, 0, indata, out.data<U>(), out.size(), size);
kernel, grid, block, 0, indata, gpu_ptr<U>(out), out.size(), size);
});
});
}
@@ -286,7 +286,7 @@ void row_reduce_looped(
cu::RowReduceArgs args) {
// Allocate data for the output using in's layout to access them as
// contiguously as possible.
allocate_same_layout(out, in, axes);
allocate_same_layout(out, in, axes, encoder);
encoder.set_input_array(in);
encoder.set_output_array(out);
@@ -315,7 +315,7 @@ void row_reduce_looped(
});
encoder.add_kernel_node(
kernel, grid, block, 0, in.data<T>(), out.data<U>(), args);
kernel, grid, block, 0, gpu_ptr<T>(in), gpu_ptr<U>(out), args);
});
});
}

View File

@@ -223,9 +223,9 @@ void RMSNorm::eval_gpu(
n_rows,
block_dim(),
0,
x.data<DataType>(),
w.data<DataType>(),
out.data<DataType>(),
gpu_ptr<DataType>(x),
gpu_ptr<DataType>(w),
gpu_ptr<DataType>(out),
eps_,
axis_size,
w_stride);
@@ -318,11 +318,11 @@ void RMSNormVJP::eval_gpu(
n_rows,
block_dim(),
0,
x.data<DataType>(),
w.data<DataType>(),
g.data<DataType>(),
gx.data<DataType>(),
gw_temp.data<DataType>(),
gpu_ptr<DataType>(x),
gpu_ptr<DataType>(w),
gpu_ptr<DataType>(g),
gpu_ptr<DataType>(gx),
gpu_ptr<DataType>(gw_temp),
eps_,
axis_size,
w_stride);

View File

@@ -340,9 +340,9 @@ void RoPE::eval_gpu(
grid,
block,
0,
(donated ? out : in).data<DataType>(),
out.data<DataType>(),
offset.data<int32_t>(),
gpu_ptr<DataType>(donated ? out : in),
gpu_ptr<DataType>(out),
gpu_ptr<int32_t>(offset),
scale_,
std::log2(base_),
mat_size,
@@ -357,10 +357,10 @@ void RoPE::eval_gpu(
grid,
block,
0,
(donated ? out : in).data<DataType>(),
out.data<DataType>(),
offset.data<int32_t>(),
inputs[2].data<float>(),
gpu_ptr<DataType>(donated ? out : in),
gpu_ptr<DataType>(out),
gpu_ptr<int32_t>(offset),
gpu_ptr<float>(inputs[2]),
scale_,
mat_size,
dims,
@@ -381,10 +381,10 @@ void RoPE::eval_gpu(
grid,
block,
0,
(donated ? out : in).data<DataType>(),
out.data<DataType>(),
offset.data<int32_t>(),
inputs[2].data<float>(),
gpu_ptr<DataType>(donated ? out : in),
gpu_ptr<DataType>(out),
gpu_ptr<int32_t>(offset),
gpu_ptr<float>(inputs[2]),
scale_,
std::log2(base_),
strides,
@@ -408,9 +408,9 @@ void RoPE::eval_gpu(
grid,
block,
0,
(donated ? out : in).data<DataType>(),
out.data<DataType>(),
offset.data<int32_t>(),
gpu_ptr<DataType>(donated ? out : in),
gpu_ptr<DataType>(out),
gpu_ptr<int32_t>(offset),
scale_,
std::log2(base_),
strides,

View File

@@ -513,11 +513,11 @@ void sdpa_vector_1pass_fallback(
grid_dim,
block_dim,
0,
q.data<DataType>(),
k.data<DataType>(),
v.data<DataType>(),
o.data<DataType>(),
sinks ? (*sinks).data<DataType>() : nullptr,
gpu_ptr<DataType>(q),
gpu_ptr<DataType>(k),
gpu_ptr<DataType>(v),
gpu_ptr<DataType>(o),
sinks ? gpu_ptr<DataType>(*sinks) : nullptr,
params);
});
});
@@ -602,13 +602,13 @@ void sdpa_vector_2pass_fallback(
grid_dim,
block_dim,
0,
q.data<DataType>(),
k.data<DataType>(),
v.data<DataType>(),
sinks ? (*sinks).data<DataType>() : nullptr,
intermediate.data<float>(),
sums.data<float>(),
maxs.data<float>(),
gpu_ptr<DataType>(q),
gpu_ptr<DataType>(k),
gpu_ptr<DataType>(v),
sinks ? gpu_ptr<DataType>(*sinks) : nullptr,
gpu_ptr<float>(intermediate),
gpu_ptr<float>(sums),
gpu_ptr<float>(maxs),
params);
}
@@ -629,10 +629,10 @@ void sdpa_vector_2pass_fallback(
grid_dim,
block_dim,
0,
intermediate.data<float>(),
sums.data<float>(),
maxs.data<float>(),
o.data<DataType>(),
gpu_ptr<float>(intermediate),
gpu_ptr<float>(sums),
gpu_ptr<float>(maxs),
gpu_ptr<DataType>(o),
params);
}
});

View File

@@ -415,8 +415,8 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
in.data_size() / axis_size,
block_dim,
0,
in.data<T>(),
out.data<U>(),
gpu_ptr<T>(in),
gpu_ptr<U>(out),
axis_size);
} else {
constexpr int BM = WARP_SIZE;
@@ -445,8 +445,8 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
num_blocks,
block_dim,
0,
in.data<T>(),
out.data<U>(),
gpu_ptr<T>(in),
gpu_ptr<U>(out),
axis_size,
stride,
stride_blocks);

View File

@@ -152,8 +152,8 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
n_rows,
block_dim(),
0,
in.data<DataType>(),
out.data<DataType>(),
gpu_ptr<DataType>(in),
gpu_ptr<DataType>(out),
axis_size);
});
});

View File

@@ -91,10 +91,10 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
CHECK_CUDA_ERROR(cub::DeviceSegmentedRadixSort::SortPairs(
nullptr,
size,
in.data<Type>(),
discard.data<Type>(),
indices.data<uint32_t>(),
out.data<uint32_t>(),
gpu_ptr<Type>(in),
gpu_ptr<Type>(discard),
gpu_ptr<uint32_t>(indices),
gpu_ptr<uint32_t>(out),
in.data_size(),
in.data_size() / nsort,
offsets,
@@ -115,16 +115,16 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
cu::thrust_policy(stream),
thrust::counting_iterator<uint32_t>(0),
thrust::counting_iterator<uint32_t>(indices.data_size()),
thrust::device_pointer_cast(indices.data<uint32_t>()),
thrust::device_pointer_cast(gpu_ptr<uint32_t>(indices)),
ModOp<uint32_t>{static_cast<uint32_t>(nsort)});
CHECK_CUDA_ERROR(cub::DeviceSegmentedRadixSort::SortPairs(
temp.data<void>(),
gpu_ptr<void>(temp),
size,
in.data<Type>(),
discard.data<Type>(),
indices.data<uint32_t>(),
out.data<uint32_t>(),
gpu_ptr<Type>(in),
gpu_ptr<Type>(discard),
gpu_ptr<uint32_t>(indices),
gpu_ptr<uint32_t>(out),
in.data_size(),
in.data_size() / nsort,
offsets,
@@ -137,8 +137,8 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
CHECK_CUDA_ERROR(cub::DeviceSegmentedRadixSort::SortKeys(
nullptr,
size,
in.data<Type>(),
out.data<Type>(),
gpu_ptr<Type>(in),
gpu_ptr<Type>(out),
in.data_size(),
in.data_size() / nsort,
offsets,
@@ -156,10 +156,10 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
// Start capturing after allocations
auto capture = encoder.capture_context();
CHECK_CUDA_ERROR(cub::DeviceSegmentedRadixSort::SortKeys(
temp.data<void>(),
gpu_ptr<void>(temp),
size,
in.data<Type>(),
out.data<Type>(),
gpu_ptr<Type>(in),
gpu_ptr<Type>(out),
in.data_size(),
in.data_size() / nsort,
offsets,

View File

@@ -168,10 +168,10 @@ void ternary_op_gpu_inplace(
num_blocks,
block_dims,
0,
a.data<bool>(),
b.data<DType>(),
c.data<DType>(),
out.data<DType>(),
gpu_ptr<bool>(a),
gpu_ptr<DType>(b),
gpu_ptr<DType>(c),
gpu_ptr<DType>(out),
out.data_size());
});
} else {
@@ -211,10 +211,10 @@ void ternary_op_gpu_inplace(
{num_blocks_x, num_blocks_y},
block_dims,
0,
a.data<bool>(),
b.data<DType>(),
c.data<DType>(),
out.data<DType>(),
gpu_ptr<bool>(a),
gpu_ptr<DType>(b),
gpu_ptr<DType>(c),
gpu_ptr<DType>(out),
rest,
const_param<dims_constant()>(shape),
const_param<dims_constant()>(a_strides),
@@ -231,10 +231,10 @@ void ternary_op_gpu_inplace(
{num_blocks_x, num_blocks_y},
block_dims,
0,
a.data<bool>(),
b.data<DType>(),
c.data<DType>(),
out.data<DType>(),
gpu_ptr<bool>(a),
gpu_ptr<DType>(b),
gpu_ptr<DType>(c),
gpu_ptr<DType>(out),
rest,
const_param(shape),
const_param(a_strides),
@@ -256,7 +256,10 @@ void ternary_op_gpu(
auto& b = inputs[1];
auto& c = inputs[2];
auto topt = get_ternary_op_type(a, b, c);
set_ternary_op_output_data(a, b, c, out, topt);
auto& encoder = cu::get_command_encoder(s);
set_ternary_op_output_data(a, b, c, out, topt, [&](auto n) {
return cu::malloc_async(n, encoder.stream());
});
ternary_op_gpu_inplace<Op>(inputs, out, s);
}

View File

@@ -158,8 +158,8 @@ void unary_op_gpu_inplace(
num_blocks,
block_dims,
0,
in.data<InType>(),
out.data<OutType>(),
gpu_ptr<InType>(in),
gpu_ptr<OutType>(out),
out.data_size());
} else {
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
@@ -182,8 +182,8 @@ void unary_op_gpu_inplace(
{num_blocks_x, num_blocks_y},
block_dims,
0,
in.data<InType>(),
out.data<OutType>(),
gpu_ptr<InType>(in),
gpu_ptr<OutType>(out),
rest,
const_param(shape),
const_param(strides),
@@ -207,7 +207,10 @@ void unary_op_gpu(
array& out,
const char* op,
const Stream& s) {
set_unary_output_data(inputs[0], out);
auto& encoder = cu::get_command_encoder(s);
set_unary_output_data(inputs[0], out, [&](auto n) {
return cu::malloc_async(n, encoder.stream());
});
unary_op_gpu_inplace<Op>(inputs, out, op, s);
}

View File

@@ -7,6 +7,8 @@
#include <cublasLt.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include "mlx/array.h"
#include "mlx/backend/cuda/allocator.h"
namespace mlx::core {
@@ -15,6 +17,16 @@ class Device;
}
template <typename T>
T* gpu_ptr(array& arr) {
return cu::gpu_ptr<T>(arr.buffer());
}
template <typename T>
const T* gpu_ptr(const array& arr) {
return cu::gpu_ptr<T>(arr.buffer());
}
struct Dtype;
// Throw exception if the cuda API does not succeed.

View File

@@ -7,18 +7,7 @@
namespace mlx::core {
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
bool donated = set_copy_output_data(in, out, ctype);
if (donated && in.dtype() == out.dtype()) {
// If the output has the same type as the input then there is nothing to
// copy, just use the buffer.
return;
}
if (ctype == CopyType::GeneralGeneral) {
ctype = CopyType::General;
}
copy_gpu_inplace(in, out, ctype, s);
}
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s);
void copy_gpu(const array& in, array& out, CopyType ctype) {
copy_gpu(in, out, ctype, out.primitive().stream());

View File

@@ -10,6 +10,19 @@ namespace mlx::core {
constexpr int MAX_COPY_SPECIALIZED_DIMS = 3;
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
bool donated = set_copy_output_data(in, out, ctype);
if (donated && in.dtype() == out.dtype()) {
// If the output has the same type as the input then there is nothing to
// copy, just use the buffer.
return;
}
if (ctype == CopyType::GeneralGeneral) {
ctype = CopyType::General;
}
copy_gpu_inplace(in, out, ctype, s);
}
void copy_gpu_inplace(
const array& in,
array& out,