fix malloc or wait deadlock (#1976)

This commit is contained in:
Awni Hannun 2025-03-20 16:48:43 -07:00 committed by GitHub
parent 1177d28395
commit 7b7e2352cd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
55 changed files with 201 additions and 217 deletions

View File

@ -247,9 +247,7 @@ point-wise. This is captured in the templated function :meth:`axpby_impl`.
float alpha_,
float beta_,
mx::Stream stream) {
// Allocate the output with `malloc_or_wait` which synchronously allocates
// memory, potentially waiting if the system is under memory pressure
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
out.set_data(mx::allocator::malloc(out.nbytes()));
// Get the CPU command encoder and register input and output arrays
auto& encoder = mx::cpu::get_command_encoder(stream);
@ -393,7 +391,7 @@ below.
auto& d = metal::device(s.device);
// Allocate output memory
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
// Resolve name of kernel
std::ostringstream kname;

View File

@ -72,9 +72,7 @@ void axpby_impl(
float alpha_,
float beta_,
mx::Stream stream) {
// Allocate the output with `malloc_or_wait` which synchronously allocates
// memory, potentially waiting if the system is under memory pressure
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
out.set_data(mx::allocator::malloc(out.nbytes()));
// Get the CPU command encoder and register input and output arrays
auto& encoder = mx::cpu::get_command_encoder(stream);
@ -160,12 +158,12 @@ void Axpby::eval_gpu(
// Allocate output memory with strides based on specialization
if (contiguous_kernel) {
out.set_data(
mx::allocator::malloc_or_wait(x.data_size() * out.itemsize()),
mx::allocator::malloc(x.data_size() * out.itemsize()),
x.data_size(),
x.strides(),
x.flags());
} else {
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
out.set_data(mx::allocator::malloc(out.nbytes()));
}
// Resolve name of kernel (corresponds to axpby.metal)

View File

@ -9,7 +9,7 @@
namespace mlx::core::allocator {
Buffer malloc(size_t size) {
auto buffer = allocator().malloc(size, /* allow_swap */ true);
auto buffer = allocator().malloc(size);
if (size && !buffer.ptr()) {
std::ostringstream msg;
msg << "[malloc] Unable to allocate " << size << " bytes.";
@ -22,7 +22,7 @@ void free(Buffer buffer) {
allocator().free(buffer);
}
Buffer CommonAllocator::malloc(size_t size, bool) {
Buffer CommonAllocator::malloc(size_t size) {
void* ptr = std::malloc(size + sizeof(size_t));
if (ptr != nullptr) {
*static_cast<size_t*>(ptr) = size;
@ -41,26 +41,4 @@ size_t CommonAllocator::size(Buffer buffer) const {
return *static_cast<size_t*>(buffer.ptr());
}
Buffer malloc_or_wait(size_t size) {
auto buffer = allocator().malloc(size);
while (size && !buffer.ptr() && scheduler::n_active_tasks() > 0) {
scheduler::wait_for_one();
buffer = allocator().malloc(size);
}
// Try swapping if needed
if (size && !buffer.ptr()) {
buffer = allocator().malloc(size, /* allow_swap = */ true);
}
if (size && !buffer.ptr()) {
std::ostringstream msg;
msg << "[malloc_or_wait] Unable to allocate " << size << " bytes.";
throw std::runtime_error(msg.str());
}
return buffer;
}
} // namespace mlx::core::allocator

View File

@ -32,14 +32,10 @@ Buffer malloc(size_t size);
void free(Buffer buffer);
// Wait for running tasks to finish and free up memory
// if allocation fails
Buffer malloc_or_wait(size_t size);
class Allocator {
/** Abstract base class for a memory allocator. */
public:
virtual Buffer malloc(size_t size, bool allow_swap = false) = 0;
virtual Buffer malloc(size_t size) = 0;
virtual void free(Buffer buffer) = 0;
virtual size_t size(Buffer buffer) const = 0;
@ -56,7 +52,7 @@ Allocator& allocator();
class CommonAllocator : public Allocator {
/** A general CPU allocator. */
public:
virtual Buffer malloc(size_t size, bool allow_swap = false) override;
virtual Buffer malloc(size_t size) override;
virtual void free(Buffer buffer) override;
virtual size_t size(Buffer buffer) const override;

View File

@ -44,14 +44,14 @@ inline void set_binary_op_output_data(
switch (bopt) {
case BinaryOpType::ScalarScalar:
out.set_data(
allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags());
allocator::malloc(out.itemsize()), 1, a.strides(), a.flags());
break;
case BinaryOpType::ScalarVector:
if (b_donatable) {
out.copy_shared_buffer(b);
} else {
out.set_data(
allocator::malloc_or_wait(b.data_size() * out.itemsize()),
allocator::malloc(b.data_size() * out.itemsize()),
b.data_size(),
b.strides(),
b.flags());
@ -62,7 +62,7 @@ inline void set_binary_op_output_data(
out.copy_shared_buffer(a);
} else {
out.set_data(
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
allocator::malloc(a.data_size() * out.itemsize()),
a.data_size(),
a.strides(),
a.flags());
@ -75,7 +75,7 @@ inline void set_binary_op_output_data(
out.copy_shared_buffer(b);
} else {
out.set_data(
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
allocator::malloc(a.data_size() * out.itemsize()),
a.data_size(),
a.strides(),
a.flags());
@ -88,7 +88,7 @@ inline void set_binary_op_output_data(
b_donatable && b.flags().row_contiguous && b.size() == out.size()) {
out.copy_shared_buffer(b);
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
}
break;
}

View File

@ -103,7 +103,7 @@ void ExpandDims::eval(const std::vector<array>& inputs, array& out) {
void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
double numel = 1;
for (auto ax : axes_) {

View File

@ -188,7 +188,7 @@ void compiled_allocate_outputs(
}
for (; o < outputs.size(); ++o) {
outputs[o].set_data(
allocator::malloc_or_wait(data_size * outputs[o].itemsize()),
allocator::malloc(data_size * outputs[o].itemsize()),
data_size,
strides,
flags);
@ -211,7 +211,7 @@ void compiled_allocate_outputs(
}
}
for (; o < outputs.size(); ++o) {
outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes()));
outputs[o].set_data(allocator::malloc(outputs[o].nbytes()));
}
}
}

View File

@ -31,14 +31,14 @@ inline bool set_copy_output_data(const array& in, array& out, CopyType ctype) {
return true;
} else {
out.set_data(
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
allocator::malloc(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
return false;
}
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
return false;
}
}

View File

@ -28,7 +28,7 @@ void swap_endianness(uint8_t* data_bytes, size_t N) {
namespace mlx::core {
void Load::eval_cpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
auto read_task = [out_ptr = out.data<char>(),
size = out.size(),
itemsize = out.itemsize(),

View File

@ -48,12 +48,12 @@ inline void set_ternary_op_output_data(
switch (topt) {
case TernaryOpType::ScalarScalarScalar:
out.set_data(
allocator::malloc_or_wait(out.itemsize()), 1, b.strides(), b.flags());
allocator::malloc(out.itemsize()), 1, b.strides(), b.flags());
break;
case TernaryOpType::VectorVectorVector:
if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) {
out.set_data(
allocator::malloc_or_wait(out.itemsize() * b.data_size()),
allocator::malloc(out.itemsize() * b.data_size()),
b.data_size(),
b.strides(),
b.flags());
@ -64,7 +64,7 @@ inline void set_ternary_op_output_data(
if (!((a.flags().row_contiguous && maybe_donate(a)) ||
(b.flags().row_contiguous && maybe_donate(b)) ||
(c.flags().row_contiguous && maybe_donate(c)))) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
}
break;
}

View File

@ -68,7 +68,7 @@ void arg_reduce_dispatch(
void ArgReduce::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(in);
encoder.set_output_array(out);

View File

@ -921,7 +921,7 @@ void explicit_gemm_conv_1D_cpu(
if (out.dtype() != float32) {
gemm_out = array(out.shape(), float32, nullptr, {});
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes()));
gemm_out.set_data(allocator::malloc(gemm_out.nbytes()));
temps.push_back(gemm_out);
}
@ -1048,7 +1048,7 @@ void explicit_gemm_conv_2D_cpu(
if (out.dtype() != float32) {
gemm_out = array(out.shape(), float32, nullptr, {});
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes()));
gemm_out.set_data(allocator::malloc(gemm_out.nbytes()));
temps.push_back(gemm_out);
}
@ -1214,7 +1214,7 @@ void explicit_gemm_conv_ND_cpu(
if (out.dtype() != float32) {
gemm_out = array(out.shape(), float32, nullptr, {});
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes()));
gemm_out.set_data(allocator::malloc(gemm_out.nbytes()));
temps.push_back(gemm_out);
}
@ -1327,7 +1327,7 @@ void conv_3D_cpu(
} // namespace
void Convolution::eval_cpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
auto& in = inputs[0];
auto& wt = inputs[1];

View File

@ -30,7 +30,7 @@ void AllReduce::eval_cpu(
if (in.is_donatable()) {
out.copy_shared_buffer(in);
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
}
return in;
} else {
@ -58,7 +58,7 @@ void AllGather::eval_cpu(
assert(outputs.size() == 1);
auto [in, copied] = ensure_row_contiguous(inputs[0], stream());
outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes()));
outputs[0].set_data(allocator::malloc(outputs[0].nbytes()));
distributed::detail::all_gather(group(), in, outputs[0], stream());
if (copied) {
auto& enc = cpu::get_command_encoder(stream());
@ -87,7 +87,7 @@ void Recv::eval_cpu(
assert(inputs.size() == 0);
assert(outputs.size() == 1);
outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes()));
outputs[0].set_data(allocator::malloc(outputs[0].nbytes()));
distributed::detail::recv(group(), outputs[0], src_, stream());
}

View File

@ -55,9 +55,8 @@ void eigh_impl(
liwork = iwork;
}
auto work_buf = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)};
auto iwork_buf =
array::Data{allocator::malloc_or_wait(sizeof(int) * liwork)};
auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)};
auto iwork_buf = array::Data{allocator::malloc(sizeof(int) * liwork)};
for (size_t i = 0; i < size / (N * N); ++i) {
syevd<T>(
&jobz,
@ -98,7 +97,7 @@ void Eigh::eval_cpu(
? outputs[1]
: array(a.shape(), a.dtype(), nullptr, {});
values.set_data(allocator::malloc_or_wait(values.nbytes()));
values.set_data(allocator::malloc(values.nbytes()));
copy(
a,

View File

@ -22,7 +22,7 @@ void FFT::eval_cpu(const std::vector<array>& inputs, array& out) {
s *= out.itemsize();
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
std::vector<size_t> shape;
if (out.dtype() == float32) {

View File

@ -197,7 +197,7 @@ void dispatch_gather(
}
void Gather::eval_cpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
auto& src = inputs[0];
std::vector<array> inds;
@ -354,7 +354,7 @@ void dispatch_gather_axis(
}
void GatherAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
auto& src = inputs[0];
auto& inds = inputs[1];

View File

@ -11,7 +11,7 @@ namespace mlx::core {
template <typename T>
void general_inv(T* inv, int N) {
int info;
auto ipiv = array::Data{allocator::malloc_or_wait(sizeof(int) * N)};
auto ipiv = array::Data{allocator::malloc(sizeof(int) * N)};
// Compute LU factorization.
getrf<T>(
/* m = */ &N,
@ -49,7 +49,7 @@ void general_inv(T* inv, int N) {
}
const int lwork = workspace_size;
auto scratch = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)};
auto scratch = array::Data{allocator::malloc(sizeof(T) * lwork)};
// Compute inverse.
getri<T>(

View File

@ -30,8 +30,7 @@ void luf_impl(
auto strides = lu.strides();
strides[ndim - 1] = M;
strides[ndim - 2] = 1;
lu.set_data(
allocator::malloc_or_wait(lu.nbytes()), lu.nbytes(), strides, flags);
lu.set_data(allocator::malloc(lu.nbytes()), lu.nbytes(), strides, flags);
copy_inplace(
a,
lu,
@ -44,8 +43,8 @@ void luf_impl(
stream);
auto a_ptr = lu.data<T>();
pivots.set_data(allocator::malloc_or_wait(pivots.nbytes()));
row_indices.set_data(allocator::malloc_or_wait(row_indices.nbytes()));
pivots.set_data(allocator::malloc(pivots.nbytes()));
row_indices.set_data(allocator::malloc(row_indices.nbytes()));
auto pivots_ptr = pivots.data<uint32_t>();
auto row_indices_ptr = row_indices.data<uint32_t>();
size_t num_matrices = a.size() / (M * N);

View File

@ -59,7 +59,7 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
throw std::runtime_error(
"[BlockMaskedMM::eval] Currently only supports float32.");
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
auto& a_pre = inputs[0];
auto& b_pre = inputs[1];
@ -318,7 +318,7 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
throw std::runtime_error(
"[GatherMM::eval] Currently only supports float32.");
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
auto& a_pre = inputs[0];
auto& b_pre = inputs[1];

View File

@ -115,7 +115,7 @@ void matmul_general(
}
void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
if (inputs[0].shape(-1) == 0) {
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_output_array(out);

View File

@ -21,7 +21,7 @@ namespace mlx::core {
void reshape(const array& in, array& out) {
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
copy_inplace(in, out, CopyType::General, out.primitive().stream());
} else {
shared_buffer_reshape(in, out_strides, out);
@ -39,7 +39,7 @@ static std::pair<array, bool> compute_dynamic_offset(
if (donate) {
offset.copy_shared_buffer(indices);
} else {
offset.set_data(allocator::malloc_or_wait(offset.itemsize()));
offset.set_data(allocator::malloc(offset.itemsize()));
}
auto& encoder = cpu::get_command_encoder(stream);
@ -124,7 +124,7 @@ void Transpose::eval_cpu(const std::vector<array>& inputs, array& out) {
void Arange::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 0);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
switch (out.dtype()) {
case bool_:
throw std::runtime_error("Bool type unsupported for arange.");
@ -186,7 +186,7 @@ void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
}
std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin());
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
auto strides = out.strides();
auto flags = out.flags();
@ -276,7 +276,7 @@ void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
size_t elems_per_key = out.size() / num_keys;
size_t bytes_per_key = out.itemsize() * elems_per_key;
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
auto kptr = inputs[0].data<uint32_t>();
auto cptr = out.data<char>();
@ -335,7 +335,7 @@ void DynamicSlice::eval_cpu(const std::vector<array>& inputs, array& out) {
return;
}
auto& in = inputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
auto [in_offset, donated] =
compute_dynamic_offset(inputs[1], in.strides(), axes_, stream());
copy_inplace(
@ -450,7 +450,7 @@ void View::eval_cpu(const std::vector<array>& inputs, array& out) {
} else {
auto tmp = array(
in.shape(), in.dtype() == bool_ ? uint8 : in.dtype(), nullptr, {});
tmp.set_data(allocator::malloc_or_wait(tmp.nbytes()));
tmp.set_data(allocator::malloc(tmp.nbytes()));
if (in.dtype() == bool_) {
auto in_tmp = array(in.shape(), uint8, nullptr, {});
in_tmp.copy_shared_buffer(in);

View File

@ -25,12 +25,11 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
auto strides = in.strides();
strides[in.ndim() - 2] = 1;
strides[in.ndim() - 1] = M;
in.set_data(
allocator::malloc_or_wait(in.nbytes()), in.nbytes(), strides, flags);
in.set_data(allocator::malloc(in.nbytes()), in.nbytes(), strides, flags);
copy_inplace(a, in, CopyType::GeneralGeneral, stream);
auto& encoder = cpu::get_command_encoder(stream);
q.set_data(allocator::malloc_or_wait(q.nbytes()));
r.set_data(allocator::malloc_or_wait(r.nbytes()));
q.set_data(allocator::malloc(q.nbytes()));
r.set_data(allocator::malloc(r.nbytes()));
auto in_ptr = in.data<T>();
auto r_ptr = r.data<T>();
@ -41,8 +40,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
encoder.set_output_array(r);
encoder.dispatch([in_ptr, q_ptr, r_ptr, M, N, lda, num_matrices]() {
int num_reflectors = std::min(M, N);
auto tau =
allocator::malloc_or_wait(sizeof(T) * num_matrices * num_reflectors);
auto tau = allocator::malloc(sizeof(T) * num_matrices * num_reflectors);
T optimal_work;
int lwork = -1;
@ -53,7 +51,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
// Update workspace size
lwork = optimal_work;
auto work = allocator::malloc_or_wait(sizeof(T) * lwork);
auto work = allocator::malloc(sizeof(T) * lwork);
// Loop over matrices
for (int i = 0; i < num_matrices; ++i) {
@ -96,7 +94,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
&lwork,
&info);
lwork = optimal_work;
work = allocator::malloc_or_wait(sizeof(T) * lwork);
work = allocator::malloc(sizeof(T) * lwork);
// Loop over matrices
for (int i = 0; i < num_matrices; ++i) {

View File

@ -515,7 +515,7 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
auto scales = ensure_row_contiguous(scales_pre);
auto biases = ensure_row_contiguous(biases_pre);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream());
encoder.add_temporaries(std::move(temps));
@ -565,7 +565,7 @@ void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
auto scales = ensure_row_contiguous_last_dims(scales_pre);
auto biases = ensure_row_contiguous_last_dims(biases_pre);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream());
encoder.add_temporaries(std::move(temps));
@ -691,12 +691,12 @@ void fast::AffineQuantize::eval_cpu(
auto [w, copied] = ensure_row_contiguous(inputs[0]);
auto& out = outputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
auto& scales = outputs[1];
auto& biases = outputs[2];
scales.set_data(allocator::malloc_or_wait(scales.nbytes()));
biases.set_data(allocator::malloc_or_wait(biases.nbytes()));
scales.set_data(allocator::malloc(scales.nbytes()));
biases.set_data(allocator::malloc(biases.nbytes()));
auto& encoder = cpu::get_command_encoder(stream());
if (copied) {
encoder.add_temporary(w);

View File

@ -433,7 +433,7 @@ void reduce_dispatch_min_max(
void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(in);
encoder.set_output_array(out);

View File

@ -244,7 +244,7 @@ void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
in = arr_copy;
encoder.add_temporary(arr_copy);
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
encoder.set_input_array(in);
encoder.set_output_array(out);

View File

@ -129,7 +129,7 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
out.copy_shared_buffer(x);
} else {
out.set_data(
allocator::malloc_or_wait(x.data_size() * x.itemsize()),
allocator::malloc(x.data_size() * x.itemsize()),
x.data_size(),
x.strides(),
x.flags());

View File

@ -288,7 +288,7 @@ void ArgSort::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0];
// Allocate output
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(in);
@ -379,7 +379,7 @@ void ArgPartition::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0];
// Allocate output
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(in);

View File

@ -50,9 +50,9 @@ void svd_impl(
array& s = outputs[1];
array& vt = outputs[2];
u.set_data(allocator::malloc_or_wait(u.nbytes()));
s.set_data(allocator::malloc_or_wait(s.nbytes()));
vt.set_data(allocator::malloc_or_wait(vt.nbytes()));
u.set_data(allocator::malloc(u.nbytes()));
s.set_data(allocator::malloc(s.nbytes()));
vt.set_data(allocator::malloc(vt.nbytes()));
encoder.set_output_array(u);
encoder.set_output_array(s);
@ -64,7 +64,7 @@ void svd_impl(
} else {
array& s = outputs[0];
s.set_data(allocator::malloc_or_wait(s.nbytes()));
s.set_data(allocator::malloc(s.nbytes()));
encoder.set_output_array(s);
@ -91,7 +91,7 @@ void svd_impl(
// Will contain the indices of eigenvectors that failed to converge (not
// used here but required by lapack).
auto iwork = array::Data{allocator::malloc_or_wait(sizeof(int) * 12 * K)};
auto iwork = array::Data{allocator::malloc(sizeof(int) * 12 * K)};
static const int lwork_query = -1;
@ -132,7 +132,7 @@ void svd_impl(
}
const int lwork = workspace_dimension;
auto scratch = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)};
auto scratch = array::Data{allocator::malloc(sizeof(T) * lwork)};
// Loop over matrices.
for (int i = 0; i < num_matrices; i++) {

View File

@ -18,13 +18,13 @@ void set_unary_output_data(const array& in, array& out) {
} else {
auto size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
allocator::malloc(size * out.itemsize()),
size,
in.strides(),
in.flags());
}
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
}
}

View File

@ -192,16 +192,19 @@ size_t MetalAllocator::set_cache_limit(size_t limit) {
return limit;
};
size_t MetalAllocator::set_memory_limit(size_t limit, bool relaxed) {
size_t MetalAllocator::set_memory_limit(size_t limit) {
std::unique_lock lk(mutex_);
std::swap(limit, block_limit_);
relaxed_ = relaxed;
gc_limit_ = std::min(
block_limit_,
static_cast<size_t>(0.95 * device_->recommendedMaxWorkingSetSize()));
return limit;
};
size_t MetalAllocator::get_memory_limit() {
return block_limit_;
}
size_t MetalAllocator::set_wired_limit(size_t limit) {
std::unique_lock lk(mutex_);
std::swap(limit, wired_limit_);
@ -209,7 +212,7 @@ size_t MetalAllocator::set_wired_limit(size_t limit) {
return limit;
};
Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
Buffer MetalAllocator::malloc(size_t size) {
// Metal doesn't like empty buffers
if (size == 0) {
return Buffer{nullptr};
@ -236,11 +239,6 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
if (!buf) {
size_t mem_required = get_active_memory() + get_cache_memory() + size;
// If there is too much memory pressure, fail (likely causes a wait).
if (!(allow_swap && relaxed_) && mem_required >= block_limit_) {
return Buffer{nullptr};
}
auto pool = metal::new_scoped_memory_pool();
// If we have a lot of memory pressure or are over the maximum cache size,
@ -328,8 +326,11 @@ MetalAllocator& allocator() {
size_t set_cache_limit(size_t limit) {
return allocator().set_cache_limit(limit);
}
size_t set_memory_limit(size_t limit, bool relaxed /* = true */) {
return allocator().set_memory_limit(limit, relaxed);
size_t set_memory_limit(size_t limit) {
return allocator().set_memory_limit(limit);
}
size_t get_memory_limit() {
return allocator().get_memory_limit();
}
size_t set_wired_limit(size_t limit) {
if (limit >

View File

@ -56,7 +56,7 @@ class BufferCache {
class MetalAllocator : public allocator::Allocator {
/** Allocator for Metal GPUs. */
public:
virtual Buffer malloc(size_t size, bool allow_swap = false) override;
virtual Buffer malloc(size_t size) override;
virtual void free(Buffer buffer) override;
virtual size_t size(Buffer buffer) const override;
size_t get_active_memory() {
@ -73,7 +73,8 @@ class MetalAllocator : public allocator::Allocator {
return buffer_cache_.cache_size();
};
size_t set_cache_limit(size_t limit);
size_t set_memory_limit(size_t limit, bool relaxed);
size_t set_memory_limit(size_t limit);
size_t get_memory_limit();
size_t set_wired_limit(size_t limit);
void clear_cache();

View File

@ -37,7 +37,7 @@ void explicit_gemm_conv_ND_gpu(
Shape unfolded_shape{implicit_M, implicit_K};
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes()));
in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes()));
// Prepare unfolding kernel
std::ostringstream kname;
@ -115,7 +115,7 @@ void explicit_gemm_conv_group_ND_gpu(
// Prepare unfolding array
Shape unfolded_shape{implicit_M, implicit_K * groups};
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes()));
in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes()));
// Prepare unfolding kernel
std::ostringstream kname;
@ -613,7 +613,7 @@ void winograd_conv_2D_gpu(
// Do filter transform
Shape filt_wg_shape = {8 * 8, conv_params.C, conv_params.O};
array filt_wg(std::move(filt_wg_shape), wt.dtype(), nullptr, {});
filt_wg.set_data(allocator::malloc_or_wait(filt_wg.nbytes()));
filt_wg.set_data(allocator::malloc(filt_wg.nbytes()));
copies_w.push_back(filt_wg);
{
int bc = 32;
@ -640,7 +640,7 @@ void winograd_conv_2D_gpu(
// Do input transform
Shape inp_wg_shape = {8 * 8, N_tiles, conv_params.C};
array inp_wg(std::move(inp_wg_shape), in.dtype(), nullptr, {});
inp_wg.set_data(allocator::malloc_or_wait(inp_wg.nbytes()));
inp_wg.set_data(allocator::malloc(inp_wg.nbytes()));
copies_w.push_back(inp_wg);
{
int bc = 32;
@ -667,7 +667,7 @@ void winograd_conv_2D_gpu(
// Do batched gemm
Shape out_wg_shape = {8 * 8, N_tiles, conv_params.O};
array out_wg(std::move(out_wg_shape), in.dtype(), nullptr, {});
out_wg.set_data(allocator::malloc_or_wait(out_wg.nbytes()));
out_wg.set_data(allocator::malloc(out_wg.nbytes()));
copies_w.push_back(out_wg);
{
std::vector<array> empty_copies;
@ -855,7 +855,7 @@ void conv_3D_gpu(
} // namespace
void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);

View File

@ -202,7 +202,7 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
if (out.size() == 0) {
return;
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
bool large = out.data_size() > UINT32_MAX;
auto& d = metal::device(s.device);
std::string kernel_name = std::string(large ? "s2" : "s") + "_copy" +

View File

@ -19,7 +19,7 @@ void CustomKernel::eval_gpu(
copies.emplace_back(init_value_.value(), out.dtype());
fill_gpu(copies.back(), out, s);
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
}
}

View File

@ -20,7 +20,7 @@ struct FenceImpl {
auto p = metal::new_scoped_memory_pool();
fence = static_cast<void*>(d->newSharedEvent());
} else {
auto buf = allocator::malloc_or_wait(sizeof(uint32_t)).ptr();
auto buf = allocator::malloc(sizeof(uint32_t)).ptr();
fence = static_cast<void*>(buf);
cpu_value()[0] = 0;
}

View File

@ -281,7 +281,7 @@ std::tuple<array, array, array> compute_raders_constants(
}
array b_q_fft({rader_n - 1}, complex64, nullptr, {});
b_q_fft.set_data(allocator::malloc_or_wait(b_q_fft.nbytes()));
b_q_fft.set_data(allocator::malloc(b_q_fft.nbytes()));
auto b_q_fft_ptr =
reinterpret_cast<std::complex<float>*>(b_q_fft.data<complex64_t>());
std::ptrdiff_t item_size = b_q_fft.itemsize();
@ -327,11 +327,11 @@ std::pair<array, array> compute_bluestein_constants(int n, int bluestein_n) {
}
array w_k({n}, complex64, nullptr, {});
w_k.set_data(allocator::malloc_or_wait(w_k.nbytes()));
w_k.set_data(allocator::malloc(w_k.nbytes()));
std::copy(w_k_vec.begin(), w_k_vec.end(), w_k.data<complex64_t>());
array w_q({bluestein_n}, complex64, nullptr, {});
w_q.set_data(allocator::malloc_or_wait(w_q.nbytes()));
w_q.set_data(allocator::malloc(w_q.nbytes()));
auto w_q_ptr =
reinterpret_cast<std::complex<float>*>(w_q.data<complex64_t>());
@ -551,8 +551,7 @@ void fft_op(
flags.row_contiguous = is_row_contiguous;
flags.contiguous = data_size == x_copy.size();
x_copy.set_data(
allocator::malloc_or_wait(x.nbytes()), data_size, strides, flags);
x_copy.set_data(allocator::malloc(x.nbytes()), data_size, strides, flags);
copy_gpu_inplace(x, x_copy, CopyType::GeneralGeneral, s);
copies.push_back(x_copy);
return x_copy;
@ -583,7 +582,7 @@ void fft_op(
// TODO: allow donation here
if (!inplace) {
out.set_data(
allocator::malloc_or_wait(out.nbytes()),
allocator::malloc(out.nbytes()),
out_data_size,
out_strides,
in_contiguous.flags());

View File

@ -84,7 +84,7 @@ void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
if (in_contiguous.is_donatable()) {
out.copy_shared_buffer(in_contiguous);
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
}
int n, m;
@ -161,7 +161,7 @@ void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
// Upload 2:
// y = h12 @ tmp
array temp(in.shape(), in.dtype(), nullptr, {});
temp.set_data(allocator::malloc_or_wait(temp.nbytes()));
temp.set_data(allocator::malloc(temp.nbytes()));
copies.push_back(temp);
launch_hadamard(in_contiguous, temp, "n" + kernel_name, 1.0);

View File

@ -43,7 +43,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
throw std::runtime_error(msg.str());
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
if (out.size() == 0) {
return;
}
@ -393,7 +393,7 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& src = inputs[0];
auto& idx = inputs[1];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
if (out.size() == 0) {
return;
}

View File

@ -382,7 +382,7 @@ void steel_matmul(
int split_k_partition_size = gemm_k_iterations * bk;
array C_split({split_k_partitions, M, N}, float32, nullptr, {});
C_split.set_data(allocator::malloc_or_wait(C_split.nbytes()));
C_split.set_data(allocator::malloc(C_split.nbytes()));
copies.push_back(C_split);
bool mn_aligned = M % bm == 0 && N % bn == 0;
@ -513,7 +513,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
return;
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
/////////////////////////////////////////////////////////////////////////////
// Init checks and prep
@ -677,7 +677,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
throw std::runtime_error(
"[matmul] Does not yet support non-floating point types.");
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);
@ -860,7 +860,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
int split_k_partition_size = gemm_k_iterations * bk;
array C_split({split_k_partitions, M, N}, float32, nullptr, {});
C_split.set_data(allocator::malloc_or_wait(C_split.nbytes()));
C_split.set_data(allocator::malloc(C_split.nbytes()));
copies.push_back(C_split);
bool mn_aligned = M % bm == 0 && N % bn == 0;
@ -1096,7 +1096,7 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
return;
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
/////////////////////////////////////////////////////////////////////////////
// Init checks and prep
@ -1484,7 +1484,7 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
return;
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
/////////////////////////////////////////////////////////////////////////////
// Init checks and prep

View File

@ -38,17 +38,20 @@ void reset_peak_memory();
size_t get_cache_memory();
/* Set the memory limit.
* Calls to malloc will wait on scheduled tasks if the limit is exceeded. If
* there are no more scheduled tasks an error will be raised if relaxed
* is false or memory will be allocated (including the potential for
* swap) if relaxed is true.
* The memory limit is a guideline for the maximum amount of memory to use
* during graph evaluation. If the memory limit is exceeded and there is no
* more RAM (including swap when available) allocations will result in an
* exception.
*
* The memory limit defaults to 1.5 times the maximum recommended working set
* size reported by the device.
* When metal is available the memory limit defaults to 1.5 times the maximum
* recommended working set size reported by the device.
*
* Returns the previous memory limit.
* */
size_t set_memory_limit(size_t limit, bool relaxed = true);
size_t set_memory_limit(size_t limit);
/* Get the current memory limit. */
size_t get_memory_limit();
/* Set the free cache limit.
* If using more than the given limit, free memory will be reclaimed

View File

@ -29,7 +29,7 @@ void RMSNorm::eval_gpu(
out.copy_shared_buffer(x);
} else {
out.set_data(
allocator::malloc_or_wait(x.data_size() * x.itemsize()),
allocator::malloc(x.data_size() * x.itemsize()),
x.data_size(),
x.strides(),
x.flags());
@ -129,7 +129,7 @@ void RMSNormVJP::eval_gpu(
gx.copy_shared_buffer(g);
g_in_gx = true;
} else {
gx.set_data(allocator::malloc_or_wait(gx.nbytes()));
gx.set_data(allocator::malloc(gx.nbytes()));
}
if (g_copied && !g_in_gx) {
d.add_temporary(g, s.index);
@ -146,11 +146,11 @@ void RMSNormVJP::eval_gpu(
if (!g_in_gx && donate_g) {
gw_temp.copy_shared_buffer(g);
} else {
gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes()));
gw_temp.set_data(allocator::malloc(gw_temp.nbytes()));
d.add_temporary(gw_temp, s.index);
}
}
gw.set_data(allocator::malloc_or_wait(gw.nbytes()));
gw.set_data(allocator::malloc(gw.nbytes()));
const int simd_size = 32;
const int n_reads = RMS_N_READS;
@ -226,7 +226,7 @@ void LayerNorm::eval_gpu(
out.copy_shared_buffer(x);
} else {
out.set_data(
allocator::malloc_or_wait(x.data_size() * x.itemsize()),
allocator::malloc(x.data_size() * x.itemsize()),
x.data_size(),
x.strides(),
x.flags());
@ -331,7 +331,7 @@ void LayerNormVJP::eval_gpu(
gx.copy_shared_buffer(g);
g_in_gx = true;
} else {
gx.set_data(allocator::malloc_or_wait(gx.nbytes()));
gx.set_data(allocator::malloc(gx.nbytes()));
}
if (g_copied && !g_in_gx) {
d.add_temporary(g, s.index);
@ -348,12 +348,12 @@ void LayerNormVJP::eval_gpu(
if (!g_in_gx && donate_g) {
gw_temp.copy_shared_buffer(g);
} else {
gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes()));
gw_temp.set_data(allocator::malloc(gw_temp.nbytes()));
d.add_temporary(gw_temp, s.index);
}
}
gw.set_data(allocator::malloc_or_wait(gw.nbytes()));
gb.set_data(allocator::malloc_or_wait(gb.nbytes()));
gw.set_data(allocator::malloc(gw.nbytes()));
gb.set_data(allocator::malloc(gb.nbytes()));
// Finish with the gradient for b in case we had a b
auto& compute_encoder = d.get_command_encoder(s.index);

View File

@ -28,7 +28,7 @@ void arange_set_scalars(T start, T next, metal::CommandEncoder& enc) {
void reshape(const array& in, array& out, Stream s) {
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
copy_gpu_inplace(
in,
out,
@ -58,7 +58,7 @@ static array compute_dynamic_offset(
if (donate) {
offset.copy_shared_buffer(indices);
} else {
offset.set_data(allocator::malloc_or_wait(offset.itemsize()));
offset.set_data(allocator::malloc(offset.itemsize()));
}
d.add_temporary(offset, s.index);
@ -100,7 +100,7 @@ static array compute_dynamic_offset(
void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 0);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
if (out.size() == 0) {
return;
}
@ -161,7 +161,7 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);
std::string op_name;
@ -333,7 +333,7 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
size_t elems_per_key = out.size() / num_keys;
size_t bytes_per_key = out.itemsize() * elems_per_key;
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
if (out.size() == 0) {
return;
}
@ -397,7 +397,7 @@ void DynamicSlice::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0];
auto& start = inputs[1];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
auto s = stream();
auto in_offset = compute_dynamic_offset(start, in.strides(), axes_, s);
copy_gpu_inplace(
@ -554,7 +554,7 @@ void View::eval_gpu(const std::vector<array>& inputs, array& out) {
in, strides, in.flags(), in.data_size() * ibytes / obytes);
} else {
auto tmp = array(in.shape(), in.dtype(), nullptr, {});
tmp.set_data(allocator::malloc_or_wait(tmp.nbytes()));
tmp.set_data(allocator::malloc(tmp.nbytes()));
copy_gpu_inplace(in, tmp, CopyType::General, stream());
auto flags = out.flags();

View File

@ -224,7 +224,7 @@ void qvm_split_k(
auto temp_shape = out.shape();
temp_shape.insert(temp_shape.end() - 2, split_k);
array intermediate(temp_shape, x.dtype(), nullptr, {});
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
d.add_temporary(intermediate, s.index);
std::ostringstream kname;
@ -277,7 +277,7 @@ void qmm_op(
int bits,
bool gather,
const Stream& s) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
MTL::Size group_dims;
MTL::Size grid_dims;
@ -394,7 +394,7 @@ void fast::AffineQuantize::eval_gpu(
std::vector<array>& outputs) {
auto& w_pre = inputs[0];
auto& out = outputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);
@ -425,8 +425,8 @@ void fast::AffineQuantize::eval_gpu(
} else {
auto& scales = outputs[1];
auto& biases = outputs[2];
scales.set_data(allocator::malloc_or_wait(scales.nbytes()));
biases.set_data(allocator::malloc_or_wait(biases.nbytes()));
scales.set_data(allocator::malloc(scales.nbytes()));
biases.set_data(allocator::malloc(biases.nbytes()));
compute_encoder.set_output_array(out, 1);
compute_encoder.set_output_array(scales, 2);
compute_encoder.set_output_array(biases, 3);

View File

@ -347,7 +347,7 @@ void all_reduce_dispatch(
// Allocate an intermediate tensor to hold results if needed
array intermediate({n_rows}, out_type, nullptr, {});
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
d.add_temporary(intermediate, s.index);
// 1st pass
@ -641,7 +641,7 @@ void strided_reduce_longcolumn(
intermediate_shape.insert(
intermediate_shape.end(), out.shape().begin(), out.shape().end());
array intermediate(std::move(intermediate_shape), out_type, nullptr, {});
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
d.add_temporary(intermediate, s.index);
// Prepare the arguments for the kernel
@ -812,7 +812,7 @@ void strided_reduce_2pass(
intermediate_shape.insert(
intermediate_shape.end(), out.shape().begin(), out.shape().end());
array intermediate(std::move(intermediate_shape), out_type, nullptr, {});
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
d.add_temporary(intermediate, s.index);
// Prepare the arguments for the kernel
@ -950,7 +950,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
// Minimum of 4 bytes since we use size 4 structs for all reduce
// and metal will complain o/w
size_t min_bytes = std::max(out.nbytes(), 4ul);
out.set_data(allocator::malloc_or_wait(min_bytes));
out.set_data(allocator::malloc(min_bytes));
std::string op_name;
switch (reduce_type_) {
case Reduce::And:

View File

@ -43,14 +43,14 @@ void RoPE::eval_gpu(
donated = true;
out.copy_shared_buffer(in);
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
}
strides[0] = mat_size;
strides[1] = in.strides()[ndim - 2];
strides[2] = in.strides()[ndim - 1];
} else if (dispatch_ndim == 3) {
// Handle non-contiguous 3D inputs
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
strides[0] = in.strides()[ndim - 3];
strides[1] = in.strides()[ndim - 2];
strides[2] = in.strides()[ndim - 1];

View File

@ -248,9 +248,9 @@ void sdpa_vector_2pass(
intermediate_shape.pop_back();
array sums(intermediate_shape, float32, nullptr, {});
array maxs(std::move(intermediate_shape), float32, nullptr, {});
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
sums.set_data(allocator::malloc_or_wait(sums.nbytes()));
maxs.set_data(allocator::malloc_or_wait(maxs.nbytes()));
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
sums.set_data(allocator::malloc(sums.nbytes()));
maxs.set_data(allocator::malloc(maxs.nbytes()));
d.add_temporary(intermediate, s.index);
d.add_temporary(sums, s.index);
d.add_temporary(maxs, s.index);
@ -383,7 +383,7 @@ void ScaledDotProductAttention::eval_gpu(
o.copy_shared_buffer(q);
} else {
if (o.shape(2) == 1) {
o.set_data(allocator::malloc_or_wait(o.nbytes()));
o.set_data(allocator::malloc(o.nbytes()));
} else {
auto strides = o.strides();
strides[2] = o.shape(1) * o.shape(3);
@ -391,10 +391,7 @@ void ScaledDotProductAttention::eval_gpu(
auto flags = q.flags();
flags.row_contiguous = q.shape(1) == 1;
o.set_data(
allocator::malloc_or_wait(o.nbytes()),
o.size(),
std::move(strides),
flags);
allocator::malloc(o.nbytes()), o.size(), std::move(strides), flags);
}
}
@ -432,7 +429,7 @@ void ScaledDotProductAttention::eval_gpu(
};
o.set_data(
allocator::malloc_or_wait(o.nbytes()),
allocator::malloc(o.nbytes()),
data_size,
{str_oB, str_oH, str_oL, str_oD},
flags);

View File

@ -24,7 +24,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
out.copy_shared_buffer(in);
} else {
out.set_data(
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
allocator::malloc(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());

View File

@ -29,7 +29,7 @@ void concatenate_gpu(
}
std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin());
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
auto strides = out.strides();
auto flags = out.flags();

View File

@ -33,7 +33,7 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
out.copy_shared_buffer(x);
} else {
out.set_data(
allocator::malloc_or_wait(x.data_size() * x.itemsize()),
allocator::malloc(x.data_size() * x.itemsize()),
x.data_size(),
x.strides(),
x.flags());

View File

@ -150,12 +150,11 @@ void multi_block_sort(
array block_partitions({n_rows, n_blocks + 1}, uint32, nullptr, {});
// Do allocations
dev_vals_0.set_data(allocator::malloc_or_wait(dev_vals_0.nbytes()));
dev_vals_1.set_data(allocator::malloc_or_wait(dev_vals_1.nbytes()));
dev_idxs_0.set_data(allocator::malloc_or_wait(dev_idxs_0.nbytes()));
dev_idxs_1.set_data(allocator::malloc_or_wait(dev_idxs_1.nbytes()));
block_partitions.set_data(
allocator::malloc_or_wait(block_partitions.nbytes()));
dev_vals_0.set_data(allocator::malloc(dev_vals_0.nbytes()));
dev_vals_1.set_data(allocator::malloc(dev_vals_1.nbytes()));
dev_idxs_0.set_data(allocator::malloc(dev_idxs_0.nbytes()));
dev_idxs_1.set_data(allocator::malloc(dev_idxs_1.nbytes()));
block_partitions.set_data(allocator::malloc(block_partitions.nbytes()));
std::vector<array> copies = {
dev_vals_0, dev_vals_1, dev_idxs_0, dev_idxs_1, block_partitions};
@ -319,7 +318,7 @@ void gpu_merge_sort(
void ArgSort::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);
@ -331,7 +330,7 @@ void ArgSort::eval_gpu(const std::vector<array>& inputs, array& out) {
void Sort::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);
@ -344,7 +343,7 @@ void ArgPartition::eval_gpu(const std::vector<array>& inputs, array& out) {
// We direct arg partition to sort for now
assert(inputs.size() == 1);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);
@ -357,7 +356,7 @@ void Partition::eval_gpu(const std::vector<array>& inputs, array& out) {
// We direct partition to sort for now
assert(inputs.size() == 1);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);

View File

@ -97,13 +97,13 @@ void unary_op_gpu(
out.copy_shared_buffer(in);
} else {
out.set_data(
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
allocator::malloc(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
}
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
}
unary_op_gpu_inplace(inputs, out, op, s);
}

View File

@ -42,7 +42,10 @@ void reset_peak_memory() {}
size_t get_cache_memory() {
return 0;
}
size_t set_memory_limit(size_t, bool) {
size_t set_memory_limit(size_t) {
return 0;
}
size_t get_memory_limit() {
return 0;
}
size_t set_cache_limit(size_t) {

View File

@ -218,7 +218,9 @@ array eval_impl(std::vector<array> outputs, bool async) {
cpu::eval(arr);
}
if (scheduler::n_active_tasks() > MAX_ACTIVE_TASKS) {
if (scheduler::n_active_tasks() > MAX_ACTIVE_TASKS ||
(metal::get_active_memory() > metal::get_memory_limit() &&
scheduler::n_active_tasks() > 0)) {
// Commit any open streams
for (auto& [_, e] : events) {
if (e.stream().device == Device::gpu) {
@ -226,6 +228,11 @@ array eval_impl(std::vector<array> outputs, bool async) {
}
}
scheduler::wait_for_one();
// TODO memory api should be moved out of metal
while (metal::get_active_memory() > metal::get_memory_limit() &&
scheduler::n_active_tasks() > 0) {
scheduler::wait_for_one();
}
}
auto maybe_update_fence = [&fences, &needs_fence, stream](const array& a) {

View File

@ -57,23 +57,19 @@ void init_metal(nb::module_& m) {
"set_memory_limit",
&mx::metal::set_memory_limit,
"limit"_a,
nb::kw_only(),
"relaxed"_a = true,
R"pbdoc(
Set the memory limit.
Memory allocations will wait on scheduled tasks to complete if the limit
is exceeded. If there are no more scheduled tasks an error will be raised
if ``relaxed`` is ``False``. Otherwise memory will be allocated
(including the potential for swap) if ``relaxed`` is ``True``.
The memory limit is a guideline for the maximum amount of memory to use
during graph evaluation. If the memory limit is exceeded and there is no
more RAM (including swap when available) allocations will result in an
exception.
The memory limit defaults to 1.5 times the maximum recommended working set
size reported by the device.
When metal is available the memory limit defaults to 1.5 times the
maximum recommended working set size reported by the device.
Args:
limit (int): Memory limit in bytes.
relaxed (bool, optional): If `False`` an error is raised if the limit
is exceeded. Default: ``True``
Returns:
int: The previous memory limit in bytes.

View File

@ -185,6 +185,18 @@ class TestEval(mlx_tests.MLXTestCase):
x = mx.abs(x, stream=s2)
mx.eval(x)
s1 = mx.default_stream(mx.gpu)
s2 = mx.new_stream(mx.gpu)
old_limit = mx.metal.set_memory_limit(1000)
x = mx.ones((512, 512), stream=s2)
for _ in range(80):
x = mx.abs(x, stream=s1)
y = mx.abs(x, stream=s2)
z = mx.abs(y, stream=s2)
mx.eval(z)
mx.metal.set_memory_limit(old_limit)
if __name__ == "__main__":
unittest.main()