mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
fix malloc or wait deadlock (#1976)
This commit is contained in:
parent
1177d28395
commit
7b7e2352cd
@ -247,9 +247,7 @@ point-wise. This is captured in the templated function :meth:`axpby_impl`.
|
||||
float alpha_,
|
||||
float beta_,
|
||||
mx::Stream stream) {
|
||||
// Allocate the output with `malloc_or_wait` which synchronously allocates
|
||||
// memory, potentially waiting if the system is under memory pressure
|
||||
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(mx::allocator::malloc(out.nbytes()));
|
||||
|
||||
// Get the CPU command encoder and register input and output arrays
|
||||
auto& encoder = mx::cpu::get_command_encoder(stream);
|
||||
@ -393,7 +391,7 @@ below.
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
// Allocate output memory
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
// Resolve name of kernel
|
||||
std::ostringstream kname;
|
||||
|
@ -72,9 +72,7 @@ void axpby_impl(
|
||||
float alpha_,
|
||||
float beta_,
|
||||
mx::Stream stream) {
|
||||
// Allocate the output with `malloc_or_wait` which synchronously allocates
|
||||
// memory, potentially waiting if the system is under memory pressure
|
||||
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(mx::allocator::malloc(out.nbytes()));
|
||||
|
||||
// Get the CPU command encoder and register input and output arrays
|
||||
auto& encoder = mx::cpu::get_command_encoder(stream);
|
||||
@ -160,12 +158,12 @@ void Axpby::eval_gpu(
|
||||
// Allocate output memory with strides based on specialization
|
||||
if (contiguous_kernel) {
|
||||
out.set_data(
|
||||
mx::allocator::malloc_or_wait(x.data_size() * out.itemsize()),
|
||||
mx::allocator::malloc(x.data_size() * out.itemsize()),
|
||||
x.data_size(),
|
||||
x.strides(),
|
||||
x.flags());
|
||||
} else {
|
||||
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(mx::allocator::malloc(out.nbytes()));
|
||||
}
|
||||
|
||||
// Resolve name of kernel (corresponds to axpby.metal)
|
||||
|
@ -9,7 +9,7 @@
|
||||
namespace mlx::core::allocator {
|
||||
|
||||
Buffer malloc(size_t size) {
|
||||
auto buffer = allocator().malloc(size, /* allow_swap */ true);
|
||||
auto buffer = allocator().malloc(size);
|
||||
if (size && !buffer.ptr()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[malloc] Unable to allocate " << size << " bytes.";
|
||||
@ -22,7 +22,7 @@ void free(Buffer buffer) {
|
||||
allocator().free(buffer);
|
||||
}
|
||||
|
||||
Buffer CommonAllocator::malloc(size_t size, bool) {
|
||||
Buffer CommonAllocator::malloc(size_t size) {
|
||||
void* ptr = std::malloc(size + sizeof(size_t));
|
||||
if (ptr != nullptr) {
|
||||
*static_cast<size_t*>(ptr) = size;
|
||||
@ -41,26 +41,4 @@ size_t CommonAllocator::size(Buffer buffer) const {
|
||||
return *static_cast<size_t*>(buffer.ptr());
|
||||
}
|
||||
|
||||
Buffer malloc_or_wait(size_t size) {
|
||||
auto buffer = allocator().malloc(size);
|
||||
|
||||
while (size && !buffer.ptr() && scheduler::n_active_tasks() > 0) {
|
||||
scheduler::wait_for_one();
|
||||
buffer = allocator().malloc(size);
|
||||
}
|
||||
|
||||
// Try swapping if needed
|
||||
if (size && !buffer.ptr()) {
|
||||
buffer = allocator().malloc(size, /* allow_swap = */ true);
|
||||
}
|
||||
|
||||
if (size && !buffer.ptr()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[malloc_or_wait] Unable to allocate " << size << " bytes.";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
return buffer;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::allocator
|
||||
|
@ -32,14 +32,10 @@ Buffer malloc(size_t size);
|
||||
|
||||
void free(Buffer buffer);
|
||||
|
||||
// Wait for running tasks to finish and free up memory
|
||||
// if allocation fails
|
||||
Buffer malloc_or_wait(size_t size);
|
||||
|
||||
class Allocator {
|
||||
/** Abstract base class for a memory allocator. */
|
||||
public:
|
||||
virtual Buffer malloc(size_t size, bool allow_swap = false) = 0;
|
||||
virtual Buffer malloc(size_t size) = 0;
|
||||
virtual void free(Buffer buffer) = 0;
|
||||
virtual size_t size(Buffer buffer) const = 0;
|
||||
|
||||
@ -56,7 +52,7 @@ Allocator& allocator();
|
||||
class CommonAllocator : public Allocator {
|
||||
/** A general CPU allocator. */
|
||||
public:
|
||||
virtual Buffer malloc(size_t size, bool allow_swap = false) override;
|
||||
virtual Buffer malloc(size_t size) override;
|
||||
virtual void free(Buffer buffer) override;
|
||||
virtual size_t size(Buffer buffer) const override;
|
||||
|
||||
|
@ -44,14 +44,14 @@ inline void set_binary_op_output_data(
|
||||
switch (bopt) {
|
||||
case BinaryOpType::ScalarScalar:
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags());
|
||||
allocator::malloc(out.itemsize()), 1, a.strides(), a.flags());
|
||||
break;
|
||||
case BinaryOpType::ScalarVector:
|
||||
if (b_donatable) {
|
||||
out.copy_shared_buffer(b);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(b.data_size() * out.itemsize()),
|
||||
allocator::malloc(b.data_size() * out.itemsize()),
|
||||
b.data_size(),
|
||||
b.strides(),
|
||||
b.flags());
|
||||
@ -62,7 +62,7 @@ inline void set_binary_op_output_data(
|
||||
out.copy_shared_buffer(a);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
|
||||
allocator::malloc(a.data_size() * out.itemsize()),
|
||||
a.data_size(),
|
||||
a.strides(),
|
||||
a.flags());
|
||||
@ -75,7 +75,7 @@ inline void set_binary_op_output_data(
|
||||
out.copy_shared_buffer(b);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
|
||||
allocator::malloc(a.data_size() * out.itemsize()),
|
||||
a.data_size(),
|
||||
a.strides(),
|
||||
a.flags());
|
||||
@ -88,7 +88,7 @@ inline void set_binary_op_output_data(
|
||||
b_donatable && b.flags().row_contiguous && b.size() == out.size()) {
|
||||
out.copy_shared_buffer(b);
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
@ -103,7 +103,7 @@ void ExpandDims::eval(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
double numel = 1;
|
||||
for (auto ax : axes_) {
|
||||
|
@ -188,7 +188,7 @@ void compiled_allocate_outputs(
|
||||
}
|
||||
for (; o < outputs.size(); ++o) {
|
||||
outputs[o].set_data(
|
||||
allocator::malloc_or_wait(data_size * outputs[o].itemsize()),
|
||||
allocator::malloc(data_size * outputs[o].itemsize()),
|
||||
data_size,
|
||||
strides,
|
||||
flags);
|
||||
@ -211,7 +211,7 @@ void compiled_allocate_outputs(
|
||||
}
|
||||
}
|
||||
for (; o < outputs.size(); ++o) {
|
||||
outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes()));
|
||||
outputs[o].set_data(allocator::malloc(outputs[o].nbytes()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -31,14 +31,14 @@ inline bool set_copy_output_data(const array& in, array& out, CopyType ctype) {
|
||||
return true;
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
|
||||
allocator::malloc(in.data_size() * out.itemsize()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@ -28,7 +28,7 @@ void swap_endianness(uint8_t* data_bytes, size_t N) {
|
||||
namespace mlx::core {
|
||||
|
||||
void Load::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
auto read_task = [out_ptr = out.data<char>(),
|
||||
size = out.size(),
|
||||
itemsize = out.itemsize(),
|
||||
|
@ -48,12 +48,12 @@ inline void set_ternary_op_output_data(
|
||||
switch (topt) {
|
||||
case TernaryOpType::ScalarScalarScalar:
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(out.itemsize()), 1, b.strides(), b.flags());
|
||||
allocator::malloc(out.itemsize()), 1, b.strides(), b.flags());
|
||||
break;
|
||||
case TernaryOpType::VectorVectorVector:
|
||||
if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(out.itemsize() * b.data_size()),
|
||||
allocator::malloc(out.itemsize() * b.data_size()),
|
||||
b.data_size(),
|
||||
b.strides(),
|
||||
b.flags());
|
||||
@ -64,7 +64,7 @@ inline void set_ternary_op_output_data(
|
||||
if (!((a.flags().row_contiguous && maybe_donate(a)) ||
|
||||
(b.flags().row_contiguous && maybe_donate(b)) ||
|
||||
(c.flags().row_contiguous && maybe_donate(c)))) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
@ -68,7 +68,7 @@ void arg_reduce_dispatch(
|
||||
void ArgReduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
|
@ -921,7 +921,7 @@ void explicit_gemm_conv_1D_cpu(
|
||||
|
||||
if (out.dtype() != float32) {
|
||||
gemm_out = array(out.shape(), float32, nullptr, {});
|
||||
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes()));
|
||||
gemm_out.set_data(allocator::malloc(gemm_out.nbytes()));
|
||||
temps.push_back(gemm_out);
|
||||
}
|
||||
|
||||
@ -1048,7 +1048,7 @@ void explicit_gemm_conv_2D_cpu(
|
||||
|
||||
if (out.dtype() != float32) {
|
||||
gemm_out = array(out.shape(), float32, nullptr, {});
|
||||
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes()));
|
||||
gemm_out.set_data(allocator::malloc(gemm_out.nbytes()));
|
||||
temps.push_back(gemm_out);
|
||||
}
|
||||
|
||||
@ -1214,7 +1214,7 @@ void explicit_gemm_conv_ND_cpu(
|
||||
|
||||
if (out.dtype() != float32) {
|
||||
gemm_out = array(out.shape(), float32, nullptr, {});
|
||||
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes()));
|
||||
gemm_out.set_data(allocator::malloc(gemm_out.nbytes()));
|
||||
temps.push_back(gemm_out);
|
||||
}
|
||||
|
||||
@ -1327,7 +1327,7 @@ void conv_3D_cpu(
|
||||
} // namespace
|
||||
|
||||
void Convolution::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& in = inputs[0];
|
||||
auto& wt = inputs[1];
|
||||
|
@ -30,7 +30,7 @@ void AllReduce::eval_cpu(
|
||||
if (in.is_donatable()) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
}
|
||||
return in;
|
||||
} else {
|
||||
@ -58,7 +58,7 @@ void AllGather::eval_cpu(
|
||||
assert(outputs.size() == 1);
|
||||
|
||||
auto [in, copied] = ensure_row_contiguous(inputs[0], stream());
|
||||
outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes()));
|
||||
outputs[0].set_data(allocator::malloc(outputs[0].nbytes()));
|
||||
distributed::detail::all_gather(group(), in, outputs[0], stream());
|
||||
if (copied) {
|
||||
auto& enc = cpu::get_command_encoder(stream());
|
||||
@ -87,7 +87,7 @@ void Recv::eval_cpu(
|
||||
assert(inputs.size() == 0);
|
||||
assert(outputs.size() == 1);
|
||||
|
||||
outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes()));
|
||||
outputs[0].set_data(allocator::malloc(outputs[0].nbytes()));
|
||||
distributed::detail::recv(group(), outputs[0], src_, stream());
|
||||
}
|
||||
|
||||
|
@ -55,9 +55,8 @@ void eigh_impl(
|
||||
liwork = iwork;
|
||||
}
|
||||
|
||||
auto work_buf = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)};
|
||||
auto iwork_buf =
|
||||
array::Data{allocator::malloc_or_wait(sizeof(int) * liwork)};
|
||||
auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)};
|
||||
auto iwork_buf = array::Data{allocator::malloc(sizeof(int) * liwork)};
|
||||
for (size_t i = 0; i < size / (N * N); ++i) {
|
||||
syevd<T>(
|
||||
&jobz,
|
||||
@ -98,7 +97,7 @@ void Eigh::eval_cpu(
|
||||
? outputs[1]
|
||||
: array(a.shape(), a.dtype(), nullptr, {});
|
||||
|
||||
values.set_data(allocator::malloc_or_wait(values.nbytes()));
|
||||
values.set_data(allocator::malloc(values.nbytes()));
|
||||
|
||||
copy(
|
||||
a,
|
||||
|
@ -22,7 +22,7 @@ void FFT::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
s *= out.itemsize();
|
||||
}
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
std::vector<size_t> shape;
|
||||
if (out.dtype() == float32) {
|
||||
|
@ -197,7 +197,7 @@ void dispatch_gather(
|
||||
}
|
||||
|
||||
void Gather::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& src = inputs[0];
|
||||
std::vector<array> inds;
|
||||
@ -354,7 +354,7 @@ void dispatch_gather_axis(
|
||||
}
|
||||
|
||||
void GatherAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& src = inputs[0];
|
||||
auto& inds = inputs[1];
|
||||
|
@ -11,7 +11,7 @@ namespace mlx::core {
|
||||
template <typename T>
|
||||
void general_inv(T* inv, int N) {
|
||||
int info;
|
||||
auto ipiv = array::Data{allocator::malloc_or_wait(sizeof(int) * N)};
|
||||
auto ipiv = array::Data{allocator::malloc(sizeof(int) * N)};
|
||||
// Compute LU factorization.
|
||||
getrf<T>(
|
||||
/* m = */ &N,
|
||||
@ -49,7 +49,7 @@ void general_inv(T* inv, int N) {
|
||||
}
|
||||
|
||||
const int lwork = workspace_size;
|
||||
auto scratch = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)};
|
||||
auto scratch = array::Data{allocator::malloc(sizeof(T) * lwork)};
|
||||
|
||||
// Compute inverse.
|
||||
getri<T>(
|
||||
|
@ -30,8 +30,7 @@ void luf_impl(
|
||||
auto strides = lu.strides();
|
||||
strides[ndim - 1] = M;
|
||||
strides[ndim - 2] = 1;
|
||||
lu.set_data(
|
||||
allocator::malloc_or_wait(lu.nbytes()), lu.nbytes(), strides, flags);
|
||||
lu.set_data(allocator::malloc(lu.nbytes()), lu.nbytes(), strides, flags);
|
||||
copy_inplace(
|
||||
a,
|
||||
lu,
|
||||
@ -44,8 +43,8 @@ void luf_impl(
|
||||
stream);
|
||||
|
||||
auto a_ptr = lu.data<T>();
|
||||
pivots.set_data(allocator::malloc_or_wait(pivots.nbytes()));
|
||||
row_indices.set_data(allocator::malloc_or_wait(row_indices.nbytes()));
|
||||
pivots.set_data(allocator::malloc(pivots.nbytes()));
|
||||
row_indices.set_data(allocator::malloc(row_indices.nbytes()));
|
||||
auto pivots_ptr = pivots.data<uint32_t>();
|
||||
auto row_indices_ptr = row_indices.data<uint32_t>();
|
||||
size_t num_matrices = a.size() / (M * N);
|
||||
|
@ -59,7 +59,7 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
throw std::runtime_error(
|
||||
"[BlockMaskedMM::eval] Currently only supports float32.");
|
||||
}
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& a_pre = inputs[0];
|
||||
auto& b_pre = inputs[1];
|
||||
@ -318,7 +318,7 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
throw std::runtime_error(
|
||||
"[GatherMM::eval] Currently only supports float32.");
|
||||
}
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& a_pre = inputs[0];
|
||||
auto& b_pre = inputs[1];
|
||||
|
@ -115,7 +115,7 @@ void matmul_general(
|
||||
}
|
||||
|
||||
void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
if (inputs[0].shape(-1) == 0) {
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.set_output_array(out);
|
||||
|
@ -21,7 +21,7 @@ namespace mlx::core {
|
||||
void reshape(const array& in, array& out) {
|
||||
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
||||
if (copy_necessary) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
copy_inplace(in, out, CopyType::General, out.primitive().stream());
|
||||
} else {
|
||||
shared_buffer_reshape(in, out_strides, out);
|
||||
@ -39,7 +39,7 @@ static std::pair<array, bool> compute_dynamic_offset(
|
||||
if (donate) {
|
||||
offset.copy_shared_buffer(indices);
|
||||
} else {
|
||||
offset.set_data(allocator::malloc_or_wait(offset.itemsize()));
|
||||
offset.set_data(allocator::malloc(offset.itemsize()));
|
||||
}
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
@ -124,7 +124,7 @@ void Transpose::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
void Arange::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 0);
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
throw std::runtime_error("Bool type unsupported for arange.");
|
||||
@ -186,7 +186,7 @@ void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin());
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto strides = out.strides();
|
||||
auto flags = out.flags();
|
||||
@ -276,7 +276,7 @@ void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
size_t elems_per_key = out.size() / num_keys;
|
||||
size_t bytes_per_key = out.itemsize() * elems_per_key;
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto kptr = inputs[0].data<uint32_t>();
|
||||
auto cptr = out.data<char>();
|
||||
@ -335,7 +335,7 @@ void DynamicSlice::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
return;
|
||||
}
|
||||
auto& in = inputs[0];
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
auto [in_offset, donated] =
|
||||
compute_dynamic_offset(inputs[1], in.strides(), axes_, stream());
|
||||
copy_inplace(
|
||||
@ -450,7 +450,7 @@ void View::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
} else {
|
||||
auto tmp = array(
|
||||
in.shape(), in.dtype() == bool_ ? uint8 : in.dtype(), nullptr, {});
|
||||
tmp.set_data(allocator::malloc_or_wait(tmp.nbytes()));
|
||||
tmp.set_data(allocator::malloc(tmp.nbytes()));
|
||||
if (in.dtype() == bool_) {
|
||||
auto in_tmp = array(in.shape(), uint8, nullptr, {});
|
||||
in_tmp.copy_shared_buffer(in);
|
||||
|
@ -25,12 +25,11 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
|
||||
auto strides = in.strides();
|
||||
strides[in.ndim() - 2] = 1;
|
||||
strides[in.ndim() - 1] = M;
|
||||
in.set_data(
|
||||
allocator::malloc_or_wait(in.nbytes()), in.nbytes(), strides, flags);
|
||||
in.set_data(allocator::malloc(in.nbytes()), in.nbytes(), strides, flags);
|
||||
copy_inplace(a, in, CopyType::GeneralGeneral, stream);
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
q.set_data(allocator::malloc_or_wait(q.nbytes()));
|
||||
r.set_data(allocator::malloc_or_wait(r.nbytes()));
|
||||
q.set_data(allocator::malloc(q.nbytes()));
|
||||
r.set_data(allocator::malloc(r.nbytes()));
|
||||
|
||||
auto in_ptr = in.data<T>();
|
||||
auto r_ptr = r.data<T>();
|
||||
@ -41,8 +40,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
|
||||
encoder.set_output_array(r);
|
||||
encoder.dispatch([in_ptr, q_ptr, r_ptr, M, N, lda, num_matrices]() {
|
||||
int num_reflectors = std::min(M, N);
|
||||
auto tau =
|
||||
allocator::malloc_or_wait(sizeof(T) * num_matrices * num_reflectors);
|
||||
auto tau = allocator::malloc(sizeof(T) * num_matrices * num_reflectors);
|
||||
|
||||
T optimal_work;
|
||||
int lwork = -1;
|
||||
@ -53,7 +51,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
|
||||
|
||||
// Update workspace size
|
||||
lwork = optimal_work;
|
||||
auto work = allocator::malloc_or_wait(sizeof(T) * lwork);
|
||||
auto work = allocator::malloc(sizeof(T) * lwork);
|
||||
|
||||
// Loop over matrices
|
||||
for (int i = 0; i < num_matrices; ++i) {
|
||||
@ -96,7 +94,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
|
||||
&lwork,
|
||||
&info);
|
||||
lwork = optimal_work;
|
||||
work = allocator::malloc_or_wait(sizeof(T) * lwork);
|
||||
work = allocator::malloc(sizeof(T) * lwork);
|
||||
|
||||
// Loop over matrices
|
||||
for (int i = 0; i < num_matrices; ++i) {
|
||||
|
@ -515,7 +515,7 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto scales = ensure_row_contiguous(scales_pre);
|
||||
auto biases = ensure_row_contiguous(biases_pre);
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.add_temporaries(std::move(temps));
|
||||
@ -565,7 +565,7 @@ void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto scales = ensure_row_contiguous_last_dims(scales_pre);
|
||||
auto biases = ensure_row_contiguous_last_dims(biases_pre);
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.add_temporaries(std::move(temps));
|
||||
@ -691,12 +691,12 @@ void fast::AffineQuantize::eval_cpu(
|
||||
|
||||
auto [w, copied] = ensure_row_contiguous(inputs[0]);
|
||||
auto& out = outputs[0];
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& scales = outputs[1];
|
||||
auto& biases = outputs[2];
|
||||
scales.set_data(allocator::malloc_or_wait(scales.nbytes()));
|
||||
biases.set_data(allocator::malloc_or_wait(biases.nbytes()));
|
||||
scales.set_data(allocator::malloc(scales.nbytes()));
|
||||
biases.set_data(allocator::malloc(biases.nbytes()));
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
if (copied) {
|
||||
encoder.add_temporary(w);
|
||||
|
@ -433,7 +433,7 @@ void reduce_dispatch_min_max(
|
||||
void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
|
@ -244,7 +244,7 @@ void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
in = arr_copy;
|
||||
encoder.add_temporary(arr_copy);
|
||||
}
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
|
@ -129,7 +129,7 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
out.copy_shared_buffer(x);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(x.data_size() * x.itemsize()),
|
||||
allocator::malloc(x.data_size() * x.itemsize()),
|
||||
x.data_size(),
|
||||
x.strides(),
|
||||
x.flags());
|
||||
|
@ -288,7 +288,7 @@ void ArgSort::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& in = inputs[0];
|
||||
|
||||
// Allocate output
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.set_input_array(in);
|
||||
@ -379,7 +379,7 @@ void ArgPartition::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& in = inputs[0];
|
||||
|
||||
// Allocate output
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.set_input_array(in);
|
||||
|
@ -50,9 +50,9 @@ void svd_impl(
|
||||
array& s = outputs[1];
|
||||
array& vt = outputs[2];
|
||||
|
||||
u.set_data(allocator::malloc_or_wait(u.nbytes()));
|
||||
s.set_data(allocator::malloc_or_wait(s.nbytes()));
|
||||
vt.set_data(allocator::malloc_or_wait(vt.nbytes()));
|
||||
u.set_data(allocator::malloc(u.nbytes()));
|
||||
s.set_data(allocator::malloc(s.nbytes()));
|
||||
vt.set_data(allocator::malloc(vt.nbytes()));
|
||||
|
||||
encoder.set_output_array(u);
|
||||
encoder.set_output_array(s);
|
||||
@ -64,7 +64,7 @@ void svd_impl(
|
||||
} else {
|
||||
array& s = outputs[0];
|
||||
|
||||
s.set_data(allocator::malloc_or_wait(s.nbytes()));
|
||||
s.set_data(allocator::malloc(s.nbytes()));
|
||||
|
||||
encoder.set_output_array(s);
|
||||
|
||||
@ -91,7 +91,7 @@ void svd_impl(
|
||||
|
||||
// Will contain the indices of eigenvectors that failed to converge (not
|
||||
// used here but required by lapack).
|
||||
auto iwork = array::Data{allocator::malloc_or_wait(sizeof(int) * 12 * K)};
|
||||
auto iwork = array::Data{allocator::malloc(sizeof(int) * 12 * K)};
|
||||
|
||||
static const int lwork_query = -1;
|
||||
|
||||
@ -132,7 +132,7 @@ void svd_impl(
|
||||
}
|
||||
|
||||
const int lwork = workspace_dimension;
|
||||
auto scratch = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)};
|
||||
auto scratch = array::Data{allocator::malloc(sizeof(T) * lwork)};
|
||||
|
||||
// Loop over matrices.
|
||||
for (int i = 0; i < num_matrices; i++) {
|
||||
|
@ -18,13 +18,13 @@ void set_unary_output_data(const array& in, array& out) {
|
||||
} else {
|
||||
auto size = in.data_size();
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(size * out.itemsize()),
|
||||
allocator::malloc(size * out.itemsize()),
|
||||
size,
|
||||
in.strides(),
|
||||
in.flags());
|
||||
}
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -192,16 +192,19 @@ size_t MetalAllocator::set_cache_limit(size_t limit) {
|
||||
return limit;
|
||||
};
|
||||
|
||||
size_t MetalAllocator::set_memory_limit(size_t limit, bool relaxed) {
|
||||
size_t MetalAllocator::set_memory_limit(size_t limit) {
|
||||
std::unique_lock lk(mutex_);
|
||||
std::swap(limit, block_limit_);
|
||||
relaxed_ = relaxed;
|
||||
gc_limit_ = std::min(
|
||||
block_limit_,
|
||||
static_cast<size_t>(0.95 * device_->recommendedMaxWorkingSetSize()));
|
||||
return limit;
|
||||
};
|
||||
|
||||
size_t MetalAllocator::get_memory_limit() {
|
||||
return block_limit_;
|
||||
}
|
||||
|
||||
size_t MetalAllocator::set_wired_limit(size_t limit) {
|
||||
std::unique_lock lk(mutex_);
|
||||
std::swap(limit, wired_limit_);
|
||||
@ -209,7 +212,7 @@ size_t MetalAllocator::set_wired_limit(size_t limit) {
|
||||
return limit;
|
||||
};
|
||||
|
||||
Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
|
||||
Buffer MetalAllocator::malloc(size_t size) {
|
||||
// Metal doesn't like empty buffers
|
||||
if (size == 0) {
|
||||
return Buffer{nullptr};
|
||||
@ -236,11 +239,6 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
|
||||
if (!buf) {
|
||||
size_t mem_required = get_active_memory() + get_cache_memory() + size;
|
||||
|
||||
// If there is too much memory pressure, fail (likely causes a wait).
|
||||
if (!(allow_swap && relaxed_) && mem_required >= block_limit_) {
|
||||
return Buffer{nullptr};
|
||||
}
|
||||
|
||||
auto pool = metal::new_scoped_memory_pool();
|
||||
|
||||
// If we have a lot of memory pressure or are over the maximum cache size,
|
||||
@ -328,8 +326,11 @@ MetalAllocator& allocator() {
|
||||
size_t set_cache_limit(size_t limit) {
|
||||
return allocator().set_cache_limit(limit);
|
||||
}
|
||||
size_t set_memory_limit(size_t limit, bool relaxed /* = true */) {
|
||||
return allocator().set_memory_limit(limit, relaxed);
|
||||
size_t set_memory_limit(size_t limit) {
|
||||
return allocator().set_memory_limit(limit);
|
||||
}
|
||||
size_t get_memory_limit() {
|
||||
return allocator().get_memory_limit();
|
||||
}
|
||||
size_t set_wired_limit(size_t limit) {
|
||||
if (limit >
|
||||
|
@ -56,7 +56,7 @@ class BufferCache {
|
||||
class MetalAllocator : public allocator::Allocator {
|
||||
/** Allocator for Metal GPUs. */
|
||||
public:
|
||||
virtual Buffer malloc(size_t size, bool allow_swap = false) override;
|
||||
virtual Buffer malloc(size_t size) override;
|
||||
virtual void free(Buffer buffer) override;
|
||||
virtual size_t size(Buffer buffer) const override;
|
||||
size_t get_active_memory() {
|
||||
@ -73,7 +73,8 @@ class MetalAllocator : public allocator::Allocator {
|
||||
return buffer_cache_.cache_size();
|
||||
};
|
||||
size_t set_cache_limit(size_t limit);
|
||||
size_t set_memory_limit(size_t limit, bool relaxed);
|
||||
size_t set_memory_limit(size_t limit);
|
||||
size_t get_memory_limit();
|
||||
size_t set_wired_limit(size_t limit);
|
||||
void clear_cache();
|
||||
|
||||
|
@ -37,7 +37,7 @@ void explicit_gemm_conv_ND_gpu(
|
||||
Shape unfolded_shape{implicit_M, implicit_K};
|
||||
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
|
||||
|
||||
in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes()));
|
||||
in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes()));
|
||||
|
||||
// Prepare unfolding kernel
|
||||
std::ostringstream kname;
|
||||
@ -115,7 +115,7 @@ void explicit_gemm_conv_group_ND_gpu(
|
||||
// Prepare unfolding array
|
||||
Shape unfolded_shape{implicit_M, implicit_K * groups};
|
||||
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
|
||||
in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes()));
|
||||
in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes()));
|
||||
|
||||
// Prepare unfolding kernel
|
||||
std::ostringstream kname;
|
||||
@ -613,7 +613,7 @@ void winograd_conv_2D_gpu(
|
||||
// Do filter transform
|
||||
Shape filt_wg_shape = {8 * 8, conv_params.C, conv_params.O};
|
||||
array filt_wg(std::move(filt_wg_shape), wt.dtype(), nullptr, {});
|
||||
filt_wg.set_data(allocator::malloc_or_wait(filt_wg.nbytes()));
|
||||
filt_wg.set_data(allocator::malloc(filt_wg.nbytes()));
|
||||
copies_w.push_back(filt_wg);
|
||||
{
|
||||
int bc = 32;
|
||||
@ -640,7 +640,7 @@ void winograd_conv_2D_gpu(
|
||||
// Do input transform
|
||||
Shape inp_wg_shape = {8 * 8, N_tiles, conv_params.C};
|
||||
array inp_wg(std::move(inp_wg_shape), in.dtype(), nullptr, {});
|
||||
inp_wg.set_data(allocator::malloc_or_wait(inp_wg.nbytes()));
|
||||
inp_wg.set_data(allocator::malloc(inp_wg.nbytes()));
|
||||
copies_w.push_back(inp_wg);
|
||||
{
|
||||
int bc = 32;
|
||||
@ -667,7 +667,7 @@ void winograd_conv_2D_gpu(
|
||||
// Do batched gemm
|
||||
Shape out_wg_shape = {8 * 8, N_tiles, conv_params.O};
|
||||
array out_wg(std::move(out_wg_shape), in.dtype(), nullptr, {});
|
||||
out_wg.set_data(allocator::malloc_or_wait(out_wg.nbytes()));
|
||||
out_wg.set_data(allocator::malloc(out_wg.nbytes()));
|
||||
copies_w.push_back(out_wg);
|
||||
{
|
||||
std::vector<array> empty_copies;
|
||||
@ -855,7 +855,7 @@ void conv_3D_gpu(
|
||||
} // namespace
|
||||
|
||||
void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
|
@ -202,7 +202,7 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
bool large = out.data_size() > UINT32_MAX;
|
||||
auto& d = metal::device(s.device);
|
||||
std::string kernel_name = std::string(large ? "s2" : "s") + "_copy" +
|
||||
|
@ -19,7 +19,7 @@ void CustomKernel::eval_gpu(
|
||||
copies.emplace_back(init_value_.value(), out.dtype());
|
||||
fill_gpu(copies.back(), out, s);
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -20,7 +20,7 @@ struct FenceImpl {
|
||||
auto p = metal::new_scoped_memory_pool();
|
||||
fence = static_cast<void*>(d->newSharedEvent());
|
||||
} else {
|
||||
auto buf = allocator::malloc_or_wait(sizeof(uint32_t)).ptr();
|
||||
auto buf = allocator::malloc(sizeof(uint32_t)).ptr();
|
||||
fence = static_cast<void*>(buf);
|
||||
cpu_value()[0] = 0;
|
||||
}
|
||||
|
@ -281,7 +281,7 @@ std::tuple<array, array, array> compute_raders_constants(
|
||||
}
|
||||
|
||||
array b_q_fft({rader_n - 1}, complex64, nullptr, {});
|
||||
b_q_fft.set_data(allocator::malloc_or_wait(b_q_fft.nbytes()));
|
||||
b_q_fft.set_data(allocator::malloc(b_q_fft.nbytes()));
|
||||
auto b_q_fft_ptr =
|
||||
reinterpret_cast<std::complex<float>*>(b_q_fft.data<complex64_t>());
|
||||
std::ptrdiff_t item_size = b_q_fft.itemsize();
|
||||
@ -327,11 +327,11 @@ std::pair<array, array> compute_bluestein_constants(int n, int bluestein_n) {
|
||||
}
|
||||
|
||||
array w_k({n}, complex64, nullptr, {});
|
||||
w_k.set_data(allocator::malloc_or_wait(w_k.nbytes()));
|
||||
w_k.set_data(allocator::malloc(w_k.nbytes()));
|
||||
std::copy(w_k_vec.begin(), w_k_vec.end(), w_k.data<complex64_t>());
|
||||
|
||||
array w_q({bluestein_n}, complex64, nullptr, {});
|
||||
w_q.set_data(allocator::malloc_or_wait(w_q.nbytes()));
|
||||
w_q.set_data(allocator::malloc(w_q.nbytes()));
|
||||
auto w_q_ptr =
|
||||
reinterpret_cast<std::complex<float>*>(w_q.data<complex64_t>());
|
||||
|
||||
@ -551,8 +551,7 @@ void fft_op(
|
||||
flags.row_contiguous = is_row_contiguous;
|
||||
flags.contiguous = data_size == x_copy.size();
|
||||
|
||||
x_copy.set_data(
|
||||
allocator::malloc_or_wait(x.nbytes()), data_size, strides, flags);
|
||||
x_copy.set_data(allocator::malloc(x.nbytes()), data_size, strides, flags);
|
||||
copy_gpu_inplace(x, x_copy, CopyType::GeneralGeneral, s);
|
||||
copies.push_back(x_copy);
|
||||
return x_copy;
|
||||
@ -583,7 +582,7 @@ void fft_op(
|
||||
// TODO: allow donation here
|
||||
if (!inplace) {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(out.nbytes()),
|
||||
allocator::malloc(out.nbytes()),
|
||||
out_data_size,
|
||||
out_strides,
|
||||
in_contiguous.flags());
|
||||
|
@ -84,7 +84,7 @@ void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
if (in_contiguous.is_donatable()) {
|
||||
out.copy_shared_buffer(in_contiguous);
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
}
|
||||
|
||||
int n, m;
|
||||
@ -161,7 +161,7 @@ void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Upload 2:
|
||||
// y = h12 @ tmp
|
||||
array temp(in.shape(), in.dtype(), nullptr, {});
|
||||
temp.set_data(allocator::malloc_or_wait(temp.nbytes()));
|
||||
temp.set_data(allocator::malloc(temp.nbytes()));
|
||||
copies.push_back(temp);
|
||||
|
||||
launch_hadamard(in_contiguous, temp, "n" + kernel_name, 1.0);
|
||||
|
@ -43,7 +43,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
@ -393,7 +393,7 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& src = inputs[0];
|
||||
auto& idx = inputs[1];
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
@ -382,7 +382,7 @@ void steel_matmul(
|
||||
int split_k_partition_size = gemm_k_iterations * bk;
|
||||
|
||||
array C_split({split_k_partitions, M, N}, float32, nullptr, {});
|
||||
C_split.set_data(allocator::malloc_or_wait(C_split.nbytes()));
|
||||
C_split.set_data(allocator::malloc(C_split.nbytes()));
|
||||
copies.push_back(C_split);
|
||||
|
||||
bool mn_aligned = M % bm == 0 && N % bn == 0;
|
||||
@ -513,7 +513,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
return;
|
||||
}
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Init checks and prep
|
||||
@ -677,7 +677,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
throw std::runtime_error(
|
||||
"[matmul] Does not yet support non-floating point types.");
|
||||
}
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
@ -860,7 +860,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
int split_k_partition_size = gemm_k_iterations * bk;
|
||||
|
||||
array C_split({split_k_partitions, M, N}, float32, nullptr, {});
|
||||
C_split.set_data(allocator::malloc_or_wait(C_split.nbytes()));
|
||||
C_split.set_data(allocator::malloc(C_split.nbytes()));
|
||||
copies.push_back(C_split);
|
||||
|
||||
bool mn_aligned = M % bm == 0 && N % bn == 0;
|
||||
@ -1096,7 +1096,7 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
return;
|
||||
}
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Init checks and prep
|
||||
@ -1484,7 +1484,7 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
return;
|
||||
}
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Init checks and prep
|
||||
|
@ -38,17 +38,20 @@ void reset_peak_memory();
|
||||
size_t get_cache_memory();
|
||||
|
||||
/* Set the memory limit.
|
||||
* Calls to malloc will wait on scheduled tasks if the limit is exceeded. If
|
||||
* there are no more scheduled tasks an error will be raised if relaxed
|
||||
* is false or memory will be allocated (including the potential for
|
||||
* swap) if relaxed is true.
|
||||
* The memory limit is a guideline for the maximum amount of memory to use
|
||||
* during graph evaluation. If the memory limit is exceeded and there is no
|
||||
* more RAM (including swap when available) allocations will result in an
|
||||
* exception.
|
||||
*
|
||||
* The memory limit defaults to 1.5 times the maximum recommended working set
|
||||
* size reported by the device.
|
||||
* When metal is available the memory limit defaults to 1.5 times the maximum
|
||||
* recommended working set size reported by the device.
|
||||
*
|
||||
* Returns the previous memory limit.
|
||||
* */
|
||||
size_t set_memory_limit(size_t limit, bool relaxed = true);
|
||||
size_t set_memory_limit(size_t limit);
|
||||
|
||||
/* Get the current memory limit. */
|
||||
size_t get_memory_limit();
|
||||
|
||||
/* Set the free cache limit.
|
||||
* If using more than the given limit, free memory will be reclaimed
|
||||
|
@ -29,7 +29,7 @@ void RMSNorm::eval_gpu(
|
||||
out.copy_shared_buffer(x);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(x.data_size() * x.itemsize()),
|
||||
allocator::malloc(x.data_size() * x.itemsize()),
|
||||
x.data_size(),
|
||||
x.strides(),
|
||||
x.flags());
|
||||
@ -129,7 +129,7 @@ void RMSNormVJP::eval_gpu(
|
||||
gx.copy_shared_buffer(g);
|
||||
g_in_gx = true;
|
||||
} else {
|
||||
gx.set_data(allocator::malloc_or_wait(gx.nbytes()));
|
||||
gx.set_data(allocator::malloc(gx.nbytes()));
|
||||
}
|
||||
if (g_copied && !g_in_gx) {
|
||||
d.add_temporary(g, s.index);
|
||||
@ -146,11 +146,11 @@ void RMSNormVJP::eval_gpu(
|
||||
if (!g_in_gx && donate_g) {
|
||||
gw_temp.copy_shared_buffer(g);
|
||||
} else {
|
||||
gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes()));
|
||||
gw_temp.set_data(allocator::malloc(gw_temp.nbytes()));
|
||||
d.add_temporary(gw_temp, s.index);
|
||||
}
|
||||
}
|
||||
gw.set_data(allocator::malloc_or_wait(gw.nbytes()));
|
||||
gw.set_data(allocator::malloc(gw.nbytes()));
|
||||
|
||||
const int simd_size = 32;
|
||||
const int n_reads = RMS_N_READS;
|
||||
@ -226,7 +226,7 @@ void LayerNorm::eval_gpu(
|
||||
out.copy_shared_buffer(x);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(x.data_size() * x.itemsize()),
|
||||
allocator::malloc(x.data_size() * x.itemsize()),
|
||||
x.data_size(),
|
||||
x.strides(),
|
||||
x.flags());
|
||||
@ -331,7 +331,7 @@ void LayerNormVJP::eval_gpu(
|
||||
gx.copy_shared_buffer(g);
|
||||
g_in_gx = true;
|
||||
} else {
|
||||
gx.set_data(allocator::malloc_or_wait(gx.nbytes()));
|
||||
gx.set_data(allocator::malloc(gx.nbytes()));
|
||||
}
|
||||
if (g_copied && !g_in_gx) {
|
||||
d.add_temporary(g, s.index);
|
||||
@ -348,12 +348,12 @@ void LayerNormVJP::eval_gpu(
|
||||
if (!g_in_gx && donate_g) {
|
||||
gw_temp.copy_shared_buffer(g);
|
||||
} else {
|
||||
gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes()));
|
||||
gw_temp.set_data(allocator::malloc(gw_temp.nbytes()));
|
||||
d.add_temporary(gw_temp, s.index);
|
||||
}
|
||||
}
|
||||
gw.set_data(allocator::malloc_or_wait(gw.nbytes()));
|
||||
gb.set_data(allocator::malloc_or_wait(gb.nbytes()));
|
||||
gw.set_data(allocator::malloc(gw.nbytes()));
|
||||
gb.set_data(allocator::malloc(gb.nbytes()));
|
||||
|
||||
// Finish with the gradient for b in case we had a b
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
|
@ -28,7 +28,7 @@ void arange_set_scalars(T start, T next, metal::CommandEncoder& enc) {
|
||||
void reshape(const array& in, array& out, Stream s) {
|
||||
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
||||
if (copy_necessary) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
copy_gpu_inplace(
|
||||
in,
|
||||
out,
|
||||
@ -58,7 +58,7 @@ static array compute_dynamic_offset(
|
||||
if (donate) {
|
||||
offset.copy_shared_buffer(indices);
|
||||
} else {
|
||||
offset.set_data(allocator::malloc_or_wait(offset.itemsize()));
|
||||
offset.set_data(allocator::malloc(offset.itemsize()));
|
||||
}
|
||||
d.add_temporary(offset, s.index);
|
||||
|
||||
@ -100,7 +100,7 @@ static array compute_dynamic_offset(
|
||||
|
||||
void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 0);
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
@ -161,7 +161,7 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
std::string op_name;
|
||||
@ -333,7 +333,7 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
size_t elems_per_key = out.size() / num_keys;
|
||||
size_t bytes_per_key = out.itemsize() * elems_per_key;
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
@ -397,7 +397,7 @@ void DynamicSlice::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
auto& in = inputs[0];
|
||||
auto& start = inputs[1];
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
auto s = stream();
|
||||
auto in_offset = compute_dynamic_offset(start, in.strides(), axes_, s);
|
||||
copy_gpu_inplace(
|
||||
@ -554,7 +554,7 @@ void View::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
in, strides, in.flags(), in.data_size() * ibytes / obytes);
|
||||
} else {
|
||||
auto tmp = array(in.shape(), in.dtype(), nullptr, {});
|
||||
tmp.set_data(allocator::malloc_or_wait(tmp.nbytes()));
|
||||
tmp.set_data(allocator::malloc(tmp.nbytes()));
|
||||
copy_gpu_inplace(in, tmp, CopyType::General, stream());
|
||||
|
||||
auto flags = out.flags();
|
||||
|
@ -224,7 +224,7 @@ void qvm_split_k(
|
||||
auto temp_shape = out.shape();
|
||||
temp_shape.insert(temp_shape.end() - 2, split_k);
|
||||
array intermediate(temp_shape, x.dtype(), nullptr, {});
|
||||
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
|
||||
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
|
||||
d.add_temporary(intermediate, s.index);
|
||||
|
||||
std::ostringstream kname;
|
||||
@ -277,7 +277,7 @@ void qmm_op(
|
||||
int bits,
|
||||
bool gather,
|
||||
const Stream& s) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
MTL::Size group_dims;
|
||||
MTL::Size grid_dims;
|
||||
@ -394,7 +394,7 @@ void fast::AffineQuantize::eval_gpu(
|
||||
std::vector<array>& outputs) {
|
||||
auto& w_pre = inputs[0];
|
||||
auto& out = outputs[0];
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
@ -425,8 +425,8 @@ void fast::AffineQuantize::eval_gpu(
|
||||
} else {
|
||||
auto& scales = outputs[1];
|
||||
auto& biases = outputs[2];
|
||||
scales.set_data(allocator::malloc_or_wait(scales.nbytes()));
|
||||
biases.set_data(allocator::malloc_or_wait(biases.nbytes()));
|
||||
scales.set_data(allocator::malloc(scales.nbytes()));
|
||||
biases.set_data(allocator::malloc(biases.nbytes()));
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder.set_output_array(scales, 2);
|
||||
compute_encoder.set_output_array(biases, 3);
|
||||
|
@ -347,7 +347,7 @@ void all_reduce_dispatch(
|
||||
|
||||
// Allocate an intermediate tensor to hold results if needed
|
||||
array intermediate({n_rows}, out_type, nullptr, {});
|
||||
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
|
||||
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
|
||||
d.add_temporary(intermediate, s.index);
|
||||
|
||||
// 1st pass
|
||||
@ -641,7 +641,7 @@ void strided_reduce_longcolumn(
|
||||
intermediate_shape.insert(
|
||||
intermediate_shape.end(), out.shape().begin(), out.shape().end());
|
||||
array intermediate(std::move(intermediate_shape), out_type, nullptr, {});
|
||||
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
|
||||
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
|
||||
d.add_temporary(intermediate, s.index);
|
||||
|
||||
// Prepare the arguments for the kernel
|
||||
@ -812,7 +812,7 @@ void strided_reduce_2pass(
|
||||
intermediate_shape.insert(
|
||||
intermediate_shape.end(), out.shape().begin(), out.shape().end());
|
||||
array intermediate(std::move(intermediate_shape), out_type, nullptr, {});
|
||||
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
|
||||
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
|
||||
d.add_temporary(intermediate, s.index);
|
||||
|
||||
// Prepare the arguments for the kernel
|
||||
@ -950,7 +950,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Minimum of 4 bytes since we use size 4 structs for all reduce
|
||||
// and metal will complain o/w
|
||||
size_t min_bytes = std::max(out.nbytes(), 4ul);
|
||||
out.set_data(allocator::malloc_or_wait(min_bytes));
|
||||
out.set_data(allocator::malloc(min_bytes));
|
||||
std::string op_name;
|
||||
switch (reduce_type_) {
|
||||
case Reduce::And:
|
||||
|
@ -43,14 +43,14 @@ void RoPE::eval_gpu(
|
||||
donated = true;
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
}
|
||||
strides[0] = mat_size;
|
||||
strides[1] = in.strides()[ndim - 2];
|
||||
strides[2] = in.strides()[ndim - 1];
|
||||
} else if (dispatch_ndim == 3) {
|
||||
// Handle non-contiguous 3D inputs
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
strides[0] = in.strides()[ndim - 3];
|
||||
strides[1] = in.strides()[ndim - 2];
|
||||
strides[2] = in.strides()[ndim - 1];
|
||||
|
@ -248,9 +248,9 @@ void sdpa_vector_2pass(
|
||||
intermediate_shape.pop_back();
|
||||
array sums(intermediate_shape, float32, nullptr, {});
|
||||
array maxs(std::move(intermediate_shape), float32, nullptr, {});
|
||||
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
|
||||
sums.set_data(allocator::malloc_or_wait(sums.nbytes()));
|
||||
maxs.set_data(allocator::malloc_or_wait(maxs.nbytes()));
|
||||
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
|
||||
sums.set_data(allocator::malloc(sums.nbytes()));
|
||||
maxs.set_data(allocator::malloc(maxs.nbytes()));
|
||||
d.add_temporary(intermediate, s.index);
|
||||
d.add_temporary(sums, s.index);
|
||||
d.add_temporary(maxs, s.index);
|
||||
@ -383,7 +383,7 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
o.copy_shared_buffer(q);
|
||||
} else {
|
||||
if (o.shape(2) == 1) {
|
||||
o.set_data(allocator::malloc_or_wait(o.nbytes()));
|
||||
o.set_data(allocator::malloc(o.nbytes()));
|
||||
} else {
|
||||
auto strides = o.strides();
|
||||
strides[2] = o.shape(1) * o.shape(3);
|
||||
@ -391,10 +391,7 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
auto flags = q.flags();
|
||||
flags.row_contiguous = q.shape(1) == 1;
|
||||
o.set_data(
|
||||
allocator::malloc_or_wait(o.nbytes()),
|
||||
o.size(),
|
||||
std::move(strides),
|
||||
flags);
|
||||
allocator::malloc(o.nbytes()), o.size(), std::move(strides), flags);
|
||||
}
|
||||
}
|
||||
|
||||
@ -432,7 +429,7 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
};
|
||||
|
||||
o.set_data(
|
||||
allocator::malloc_or_wait(o.nbytes()),
|
||||
allocator::malloc(o.nbytes()),
|
||||
data_size,
|
||||
{str_oB, str_oH, str_oL, str_oD},
|
||||
flags);
|
||||
|
@ -24,7 +24,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
|
||||
allocator::malloc(in.data_size() * out.itemsize()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
|
@ -29,7 +29,7 @@ void concatenate_gpu(
|
||||
}
|
||||
std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin());
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto strides = out.strides();
|
||||
auto flags = out.flags();
|
||||
|
@ -33,7 +33,7 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
out.copy_shared_buffer(x);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(x.data_size() * x.itemsize()),
|
||||
allocator::malloc(x.data_size() * x.itemsize()),
|
||||
x.data_size(),
|
||||
x.strides(),
|
||||
x.flags());
|
||||
|
@ -150,12 +150,11 @@ void multi_block_sort(
|
||||
array block_partitions({n_rows, n_blocks + 1}, uint32, nullptr, {});
|
||||
|
||||
// Do allocations
|
||||
dev_vals_0.set_data(allocator::malloc_or_wait(dev_vals_0.nbytes()));
|
||||
dev_vals_1.set_data(allocator::malloc_or_wait(dev_vals_1.nbytes()));
|
||||
dev_idxs_0.set_data(allocator::malloc_or_wait(dev_idxs_0.nbytes()));
|
||||
dev_idxs_1.set_data(allocator::malloc_or_wait(dev_idxs_1.nbytes()));
|
||||
block_partitions.set_data(
|
||||
allocator::malloc_or_wait(block_partitions.nbytes()));
|
||||
dev_vals_0.set_data(allocator::malloc(dev_vals_0.nbytes()));
|
||||
dev_vals_1.set_data(allocator::malloc(dev_vals_1.nbytes()));
|
||||
dev_idxs_0.set_data(allocator::malloc(dev_idxs_0.nbytes()));
|
||||
dev_idxs_1.set_data(allocator::malloc(dev_idxs_1.nbytes()));
|
||||
block_partitions.set_data(allocator::malloc(block_partitions.nbytes()));
|
||||
|
||||
std::vector<array> copies = {
|
||||
dev_vals_0, dev_vals_1, dev_idxs_0, dev_idxs_1, block_partitions};
|
||||
@ -319,7 +318,7 @@ void gpu_merge_sort(
|
||||
void ArgSort::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
@ -331,7 +330,7 @@ void ArgSort::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
void Sort::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
@ -344,7 +343,7 @@ void ArgPartition::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// We direct arg partition to sort for now
|
||||
assert(inputs.size() == 1);
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
@ -357,7 +356,7 @@ void Partition::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// We direct partition to sort for now
|
||||
assert(inputs.size() == 1);
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
@ -97,13 +97,13 @@ void unary_op_gpu(
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
|
||||
allocator::malloc(in.data_size() * out.itemsize()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
}
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
}
|
||||
unary_op_gpu_inplace(inputs, out, op, s);
|
||||
}
|
||||
|
@ -42,7 +42,10 @@ void reset_peak_memory() {}
|
||||
size_t get_cache_memory() {
|
||||
return 0;
|
||||
}
|
||||
size_t set_memory_limit(size_t, bool) {
|
||||
size_t set_memory_limit(size_t) {
|
||||
return 0;
|
||||
}
|
||||
size_t get_memory_limit() {
|
||||
return 0;
|
||||
}
|
||||
size_t set_cache_limit(size_t) {
|
||||
|
@ -218,7 +218,9 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
||||
cpu::eval(arr);
|
||||
}
|
||||
|
||||
if (scheduler::n_active_tasks() > MAX_ACTIVE_TASKS) {
|
||||
if (scheduler::n_active_tasks() > MAX_ACTIVE_TASKS ||
|
||||
(metal::get_active_memory() > metal::get_memory_limit() &&
|
||||
scheduler::n_active_tasks() > 0)) {
|
||||
// Commit any open streams
|
||||
for (auto& [_, e] : events) {
|
||||
if (e.stream().device == Device::gpu) {
|
||||
@ -226,6 +228,11 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
||||
}
|
||||
}
|
||||
scheduler::wait_for_one();
|
||||
// TODO memory api should be moved out of metal
|
||||
while (metal::get_active_memory() > metal::get_memory_limit() &&
|
||||
scheduler::n_active_tasks() > 0) {
|
||||
scheduler::wait_for_one();
|
||||
}
|
||||
}
|
||||
|
||||
auto maybe_update_fence = [&fences, &needs_fence, stream](const array& a) {
|
||||
|
@ -57,23 +57,19 @@ void init_metal(nb::module_& m) {
|
||||
"set_memory_limit",
|
||||
&mx::metal::set_memory_limit,
|
||||
"limit"_a,
|
||||
nb::kw_only(),
|
||||
"relaxed"_a = true,
|
||||
R"pbdoc(
|
||||
Set the memory limit.
|
||||
|
||||
Memory allocations will wait on scheduled tasks to complete if the limit
|
||||
is exceeded. If there are no more scheduled tasks an error will be raised
|
||||
if ``relaxed`` is ``False``. Otherwise memory will be allocated
|
||||
(including the potential for swap) if ``relaxed`` is ``True``.
|
||||
The memory limit is a guideline for the maximum amount of memory to use
|
||||
during graph evaluation. If the memory limit is exceeded and there is no
|
||||
more RAM (including swap when available) allocations will result in an
|
||||
exception.
|
||||
|
||||
The memory limit defaults to 1.5 times the maximum recommended working set
|
||||
size reported by the device.
|
||||
When metal is available the memory limit defaults to 1.5 times the
|
||||
maximum recommended working set size reported by the device.
|
||||
|
||||
Args:
|
||||
limit (int): Memory limit in bytes.
|
||||
relaxed (bool, optional): If `False`` an error is raised if the limit
|
||||
is exceeded. Default: ``True``
|
||||
|
||||
Returns:
|
||||
int: The previous memory limit in bytes.
|
||||
|
@ -185,6 +185,18 @@ class TestEval(mlx_tests.MLXTestCase):
|
||||
x = mx.abs(x, stream=s2)
|
||||
mx.eval(x)
|
||||
|
||||
s1 = mx.default_stream(mx.gpu)
|
||||
s2 = mx.new_stream(mx.gpu)
|
||||
old_limit = mx.metal.set_memory_limit(1000)
|
||||
|
||||
x = mx.ones((512, 512), stream=s2)
|
||||
for _ in range(80):
|
||||
x = mx.abs(x, stream=s1)
|
||||
y = mx.abs(x, stream=s2)
|
||||
z = mx.abs(y, stream=s2)
|
||||
mx.eval(z)
|
||||
mx.metal.set_memory_limit(old_limit)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user