mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
fix malloc or wait deadlock (#1976)
This commit is contained in:
parent
1177d28395
commit
7b7e2352cd
@ -247,9 +247,7 @@ point-wise. This is captured in the templated function :meth:`axpby_impl`.
|
|||||||
float alpha_,
|
float alpha_,
|
||||||
float beta_,
|
float beta_,
|
||||||
mx::Stream stream) {
|
mx::Stream stream) {
|
||||||
// Allocate the output with `malloc_or_wait` which synchronously allocates
|
out.set_data(mx::allocator::malloc(out.nbytes()));
|
||||||
// memory, potentially waiting if the system is under memory pressure
|
|
||||||
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
|
|
||||||
|
|
||||||
// Get the CPU command encoder and register input and output arrays
|
// Get the CPU command encoder and register input and output arrays
|
||||||
auto& encoder = mx::cpu::get_command_encoder(stream);
|
auto& encoder = mx::cpu::get_command_encoder(stream);
|
||||||
@ -393,7 +391,7 @@ below.
|
|||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
|
|
||||||
// Allocate output memory
|
// Allocate output memory
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
// Resolve name of kernel
|
// Resolve name of kernel
|
||||||
std::ostringstream kname;
|
std::ostringstream kname;
|
||||||
|
@ -72,9 +72,7 @@ void axpby_impl(
|
|||||||
float alpha_,
|
float alpha_,
|
||||||
float beta_,
|
float beta_,
|
||||||
mx::Stream stream) {
|
mx::Stream stream) {
|
||||||
// Allocate the output with `malloc_or_wait` which synchronously allocates
|
out.set_data(mx::allocator::malloc(out.nbytes()));
|
||||||
// memory, potentially waiting if the system is under memory pressure
|
|
||||||
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
|
|
||||||
|
|
||||||
// Get the CPU command encoder and register input and output arrays
|
// Get the CPU command encoder and register input and output arrays
|
||||||
auto& encoder = mx::cpu::get_command_encoder(stream);
|
auto& encoder = mx::cpu::get_command_encoder(stream);
|
||||||
@ -160,12 +158,12 @@ void Axpby::eval_gpu(
|
|||||||
// Allocate output memory with strides based on specialization
|
// Allocate output memory with strides based on specialization
|
||||||
if (contiguous_kernel) {
|
if (contiguous_kernel) {
|
||||||
out.set_data(
|
out.set_data(
|
||||||
mx::allocator::malloc_or_wait(x.data_size() * out.itemsize()),
|
mx::allocator::malloc(x.data_size() * out.itemsize()),
|
||||||
x.data_size(),
|
x.data_size(),
|
||||||
x.strides(),
|
x.strides(),
|
||||||
x.flags());
|
x.flags());
|
||||||
} else {
|
} else {
|
||||||
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(mx::allocator::malloc(out.nbytes()));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Resolve name of kernel (corresponds to axpby.metal)
|
// Resolve name of kernel (corresponds to axpby.metal)
|
||||||
|
@ -9,7 +9,7 @@
|
|||||||
namespace mlx::core::allocator {
|
namespace mlx::core::allocator {
|
||||||
|
|
||||||
Buffer malloc(size_t size) {
|
Buffer malloc(size_t size) {
|
||||||
auto buffer = allocator().malloc(size, /* allow_swap */ true);
|
auto buffer = allocator().malloc(size);
|
||||||
if (size && !buffer.ptr()) {
|
if (size && !buffer.ptr()) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[malloc] Unable to allocate " << size << " bytes.";
|
msg << "[malloc] Unable to allocate " << size << " bytes.";
|
||||||
@ -22,7 +22,7 @@ void free(Buffer buffer) {
|
|||||||
allocator().free(buffer);
|
allocator().free(buffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
Buffer CommonAllocator::malloc(size_t size, bool) {
|
Buffer CommonAllocator::malloc(size_t size) {
|
||||||
void* ptr = std::malloc(size + sizeof(size_t));
|
void* ptr = std::malloc(size + sizeof(size_t));
|
||||||
if (ptr != nullptr) {
|
if (ptr != nullptr) {
|
||||||
*static_cast<size_t*>(ptr) = size;
|
*static_cast<size_t*>(ptr) = size;
|
||||||
@ -41,26 +41,4 @@ size_t CommonAllocator::size(Buffer buffer) const {
|
|||||||
return *static_cast<size_t*>(buffer.ptr());
|
return *static_cast<size_t*>(buffer.ptr());
|
||||||
}
|
}
|
||||||
|
|
||||||
Buffer malloc_or_wait(size_t size) {
|
|
||||||
auto buffer = allocator().malloc(size);
|
|
||||||
|
|
||||||
while (size && !buffer.ptr() && scheduler::n_active_tasks() > 0) {
|
|
||||||
scheduler::wait_for_one();
|
|
||||||
buffer = allocator().malloc(size);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try swapping if needed
|
|
||||||
if (size && !buffer.ptr()) {
|
|
||||||
buffer = allocator().malloc(size, /* allow_swap = */ true);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (size && !buffer.ptr()) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[malloc_or_wait] Unable to allocate " << size << " bytes.";
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
return buffer;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core::allocator
|
} // namespace mlx::core::allocator
|
||||||
|
@ -32,14 +32,10 @@ Buffer malloc(size_t size);
|
|||||||
|
|
||||||
void free(Buffer buffer);
|
void free(Buffer buffer);
|
||||||
|
|
||||||
// Wait for running tasks to finish and free up memory
|
|
||||||
// if allocation fails
|
|
||||||
Buffer malloc_or_wait(size_t size);
|
|
||||||
|
|
||||||
class Allocator {
|
class Allocator {
|
||||||
/** Abstract base class for a memory allocator. */
|
/** Abstract base class for a memory allocator. */
|
||||||
public:
|
public:
|
||||||
virtual Buffer malloc(size_t size, bool allow_swap = false) = 0;
|
virtual Buffer malloc(size_t size) = 0;
|
||||||
virtual void free(Buffer buffer) = 0;
|
virtual void free(Buffer buffer) = 0;
|
||||||
virtual size_t size(Buffer buffer) const = 0;
|
virtual size_t size(Buffer buffer) const = 0;
|
||||||
|
|
||||||
@ -56,7 +52,7 @@ Allocator& allocator();
|
|||||||
class CommonAllocator : public Allocator {
|
class CommonAllocator : public Allocator {
|
||||||
/** A general CPU allocator. */
|
/** A general CPU allocator. */
|
||||||
public:
|
public:
|
||||||
virtual Buffer malloc(size_t size, bool allow_swap = false) override;
|
virtual Buffer malloc(size_t size) override;
|
||||||
virtual void free(Buffer buffer) override;
|
virtual void free(Buffer buffer) override;
|
||||||
virtual size_t size(Buffer buffer) const override;
|
virtual size_t size(Buffer buffer) const override;
|
||||||
|
|
||||||
|
@ -44,14 +44,14 @@ inline void set_binary_op_output_data(
|
|||||||
switch (bopt) {
|
switch (bopt) {
|
||||||
case BinaryOpType::ScalarScalar:
|
case BinaryOpType::ScalarScalar:
|
||||||
out.set_data(
|
out.set_data(
|
||||||
allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags());
|
allocator::malloc(out.itemsize()), 1, a.strides(), a.flags());
|
||||||
break;
|
break;
|
||||||
case BinaryOpType::ScalarVector:
|
case BinaryOpType::ScalarVector:
|
||||||
if (b_donatable) {
|
if (b_donatable) {
|
||||||
out.copy_shared_buffer(b);
|
out.copy_shared_buffer(b);
|
||||||
} else {
|
} else {
|
||||||
out.set_data(
|
out.set_data(
|
||||||
allocator::malloc_or_wait(b.data_size() * out.itemsize()),
|
allocator::malloc(b.data_size() * out.itemsize()),
|
||||||
b.data_size(),
|
b.data_size(),
|
||||||
b.strides(),
|
b.strides(),
|
||||||
b.flags());
|
b.flags());
|
||||||
@ -62,7 +62,7 @@ inline void set_binary_op_output_data(
|
|||||||
out.copy_shared_buffer(a);
|
out.copy_shared_buffer(a);
|
||||||
} else {
|
} else {
|
||||||
out.set_data(
|
out.set_data(
|
||||||
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
|
allocator::malloc(a.data_size() * out.itemsize()),
|
||||||
a.data_size(),
|
a.data_size(),
|
||||||
a.strides(),
|
a.strides(),
|
||||||
a.flags());
|
a.flags());
|
||||||
@ -75,7 +75,7 @@ inline void set_binary_op_output_data(
|
|||||||
out.copy_shared_buffer(b);
|
out.copy_shared_buffer(b);
|
||||||
} else {
|
} else {
|
||||||
out.set_data(
|
out.set_data(
|
||||||
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
|
allocator::malloc(a.data_size() * out.itemsize()),
|
||||||
a.data_size(),
|
a.data_size(),
|
||||||
a.strides(),
|
a.strides(),
|
||||||
a.flags());
|
a.flags());
|
||||||
@ -88,7 +88,7 @@ inline void set_binary_op_output_data(
|
|||||||
b_donatable && b.flags().row_contiguous && b.size() == out.size()) {
|
b_donatable && b.flags().row_contiguous && b.size() == out.size()) {
|
||||||
out.copy_shared_buffer(b);
|
out.copy_shared_buffer(b);
|
||||||
} else {
|
} else {
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -103,7 +103,7 @@ void ExpandDims::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
|
void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
double numel = 1;
|
double numel = 1;
|
||||||
for (auto ax : axes_) {
|
for (auto ax : axes_) {
|
||||||
|
@ -188,7 +188,7 @@ void compiled_allocate_outputs(
|
|||||||
}
|
}
|
||||||
for (; o < outputs.size(); ++o) {
|
for (; o < outputs.size(); ++o) {
|
||||||
outputs[o].set_data(
|
outputs[o].set_data(
|
||||||
allocator::malloc_or_wait(data_size * outputs[o].itemsize()),
|
allocator::malloc(data_size * outputs[o].itemsize()),
|
||||||
data_size,
|
data_size,
|
||||||
strides,
|
strides,
|
||||||
flags);
|
flags);
|
||||||
@ -211,7 +211,7 @@ void compiled_allocate_outputs(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (; o < outputs.size(); ++o) {
|
for (; o < outputs.size(); ++o) {
|
||||||
outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes()));
|
outputs[o].set_data(allocator::malloc(outputs[o].nbytes()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -31,14 +31,14 @@ inline bool set_copy_output_data(const array& in, array& out, CopyType ctype) {
|
|||||||
return true;
|
return true;
|
||||||
} else {
|
} else {
|
||||||
out.set_data(
|
out.set_data(
|
||||||
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
|
allocator::malloc(in.data_size() * out.itemsize()),
|
||||||
in.data_size(),
|
in.data_size(),
|
||||||
in.strides(),
|
in.strides(),
|
||||||
in.flags());
|
in.flags());
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -28,7 +28,7 @@ void swap_endianness(uint8_t* data_bytes, size_t N) {
|
|||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
void Load::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Load::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
auto read_task = [out_ptr = out.data<char>(),
|
auto read_task = [out_ptr = out.data<char>(),
|
||||||
size = out.size(),
|
size = out.size(),
|
||||||
itemsize = out.itemsize(),
|
itemsize = out.itemsize(),
|
||||||
|
@ -48,12 +48,12 @@ inline void set_ternary_op_output_data(
|
|||||||
switch (topt) {
|
switch (topt) {
|
||||||
case TernaryOpType::ScalarScalarScalar:
|
case TernaryOpType::ScalarScalarScalar:
|
||||||
out.set_data(
|
out.set_data(
|
||||||
allocator::malloc_or_wait(out.itemsize()), 1, b.strides(), b.flags());
|
allocator::malloc(out.itemsize()), 1, b.strides(), b.flags());
|
||||||
break;
|
break;
|
||||||
case TernaryOpType::VectorVectorVector:
|
case TernaryOpType::VectorVectorVector:
|
||||||
if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) {
|
if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) {
|
||||||
out.set_data(
|
out.set_data(
|
||||||
allocator::malloc_or_wait(out.itemsize() * b.data_size()),
|
allocator::malloc(out.itemsize() * b.data_size()),
|
||||||
b.data_size(),
|
b.data_size(),
|
||||||
b.strides(),
|
b.strides(),
|
||||||
b.flags());
|
b.flags());
|
||||||
@ -64,7 +64,7 @@ inline void set_ternary_op_output_data(
|
|||||||
if (!((a.flags().row_contiguous && maybe_donate(a)) ||
|
if (!((a.flags().row_contiguous && maybe_donate(a)) ||
|
||||||
(b.flags().row_contiguous && maybe_donate(b)) ||
|
(b.flags().row_contiguous && maybe_donate(b)) ||
|
||||||
(c.flags().row_contiguous && maybe_donate(c)))) {
|
(c.flags().row_contiguous && maybe_donate(c)))) {
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -68,7 +68,7 @@ void arg_reduce_dispatch(
|
|||||||
void ArgReduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void ArgReduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
auto& encoder = cpu::get_command_encoder(stream());
|
auto& encoder = cpu::get_command_encoder(stream());
|
||||||
encoder.set_input_array(in);
|
encoder.set_input_array(in);
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
|
@ -921,7 +921,7 @@ void explicit_gemm_conv_1D_cpu(
|
|||||||
|
|
||||||
if (out.dtype() != float32) {
|
if (out.dtype() != float32) {
|
||||||
gemm_out = array(out.shape(), float32, nullptr, {});
|
gemm_out = array(out.shape(), float32, nullptr, {});
|
||||||
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes()));
|
gemm_out.set_data(allocator::malloc(gemm_out.nbytes()));
|
||||||
temps.push_back(gemm_out);
|
temps.push_back(gemm_out);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1048,7 +1048,7 @@ void explicit_gemm_conv_2D_cpu(
|
|||||||
|
|
||||||
if (out.dtype() != float32) {
|
if (out.dtype() != float32) {
|
||||||
gemm_out = array(out.shape(), float32, nullptr, {});
|
gemm_out = array(out.shape(), float32, nullptr, {});
|
||||||
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes()));
|
gemm_out.set_data(allocator::malloc(gemm_out.nbytes()));
|
||||||
temps.push_back(gemm_out);
|
temps.push_back(gemm_out);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1214,7 +1214,7 @@ void explicit_gemm_conv_ND_cpu(
|
|||||||
|
|
||||||
if (out.dtype() != float32) {
|
if (out.dtype() != float32) {
|
||||||
gemm_out = array(out.shape(), float32, nullptr, {});
|
gemm_out = array(out.shape(), float32, nullptr, {});
|
||||||
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes()));
|
gemm_out.set_data(allocator::malloc(gemm_out.nbytes()));
|
||||||
temps.push_back(gemm_out);
|
temps.push_back(gemm_out);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1327,7 +1327,7 @@ void conv_3D_cpu(
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void Convolution::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Convolution::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
auto& wt = inputs[1];
|
auto& wt = inputs[1];
|
||||||
|
@ -30,7 +30,7 @@ void AllReduce::eval_cpu(
|
|||||||
if (in.is_donatable()) {
|
if (in.is_donatable()) {
|
||||||
out.copy_shared_buffer(in);
|
out.copy_shared_buffer(in);
|
||||||
} else {
|
} else {
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
}
|
}
|
||||||
return in;
|
return in;
|
||||||
} else {
|
} else {
|
||||||
@ -58,7 +58,7 @@ void AllGather::eval_cpu(
|
|||||||
assert(outputs.size() == 1);
|
assert(outputs.size() == 1);
|
||||||
|
|
||||||
auto [in, copied] = ensure_row_contiguous(inputs[0], stream());
|
auto [in, copied] = ensure_row_contiguous(inputs[0], stream());
|
||||||
outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes()));
|
outputs[0].set_data(allocator::malloc(outputs[0].nbytes()));
|
||||||
distributed::detail::all_gather(group(), in, outputs[0], stream());
|
distributed::detail::all_gather(group(), in, outputs[0], stream());
|
||||||
if (copied) {
|
if (copied) {
|
||||||
auto& enc = cpu::get_command_encoder(stream());
|
auto& enc = cpu::get_command_encoder(stream());
|
||||||
@ -87,7 +87,7 @@ void Recv::eval_cpu(
|
|||||||
assert(inputs.size() == 0);
|
assert(inputs.size() == 0);
|
||||||
assert(outputs.size() == 1);
|
assert(outputs.size() == 1);
|
||||||
|
|
||||||
outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes()));
|
outputs[0].set_data(allocator::malloc(outputs[0].nbytes()));
|
||||||
distributed::detail::recv(group(), outputs[0], src_, stream());
|
distributed::detail::recv(group(), outputs[0], src_, stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -55,9 +55,8 @@ void eigh_impl(
|
|||||||
liwork = iwork;
|
liwork = iwork;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto work_buf = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)};
|
auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)};
|
||||||
auto iwork_buf =
|
auto iwork_buf = array::Data{allocator::malloc(sizeof(int) * liwork)};
|
||||||
array::Data{allocator::malloc_or_wait(sizeof(int) * liwork)};
|
|
||||||
for (size_t i = 0; i < size / (N * N); ++i) {
|
for (size_t i = 0; i < size / (N * N); ++i) {
|
||||||
syevd<T>(
|
syevd<T>(
|
||||||
&jobz,
|
&jobz,
|
||||||
@ -98,7 +97,7 @@ void Eigh::eval_cpu(
|
|||||||
? outputs[1]
|
? outputs[1]
|
||||||
: array(a.shape(), a.dtype(), nullptr, {});
|
: array(a.shape(), a.dtype(), nullptr, {});
|
||||||
|
|
||||||
values.set_data(allocator::malloc_or_wait(values.nbytes()));
|
values.set_data(allocator::malloc(values.nbytes()));
|
||||||
|
|
||||||
copy(
|
copy(
|
||||||
a,
|
a,
|
||||||
|
@ -22,7 +22,7 @@ void FFT::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
s *= out.itemsize();
|
s *= out.itemsize();
|
||||||
}
|
}
|
||||||
|
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
std::vector<size_t> shape;
|
std::vector<size_t> shape;
|
||||||
if (out.dtype() == float32) {
|
if (out.dtype() == float32) {
|
||||||
|
@ -197,7 +197,7 @@ void dispatch_gather(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Gather::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Gather::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
auto& src = inputs[0];
|
auto& src = inputs[0];
|
||||||
std::vector<array> inds;
|
std::vector<array> inds;
|
||||||
@ -354,7 +354,7 @@ void dispatch_gather_axis(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void GatherAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void GatherAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
auto& src = inputs[0];
|
auto& src = inputs[0];
|
||||||
auto& inds = inputs[1];
|
auto& inds = inputs[1];
|
||||||
|
@ -11,7 +11,7 @@ namespace mlx::core {
|
|||||||
template <typename T>
|
template <typename T>
|
||||||
void general_inv(T* inv, int N) {
|
void general_inv(T* inv, int N) {
|
||||||
int info;
|
int info;
|
||||||
auto ipiv = array::Data{allocator::malloc_or_wait(sizeof(int) * N)};
|
auto ipiv = array::Data{allocator::malloc(sizeof(int) * N)};
|
||||||
// Compute LU factorization.
|
// Compute LU factorization.
|
||||||
getrf<T>(
|
getrf<T>(
|
||||||
/* m = */ &N,
|
/* m = */ &N,
|
||||||
@ -49,7 +49,7 @@ void general_inv(T* inv, int N) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const int lwork = workspace_size;
|
const int lwork = workspace_size;
|
||||||
auto scratch = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)};
|
auto scratch = array::Data{allocator::malloc(sizeof(T) * lwork)};
|
||||||
|
|
||||||
// Compute inverse.
|
// Compute inverse.
|
||||||
getri<T>(
|
getri<T>(
|
||||||
|
@ -30,8 +30,7 @@ void luf_impl(
|
|||||||
auto strides = lu.strides();
|
auto strides = lu.strides();
|
||||||
strides[ndim - 1] = M;
|
strides[ndim - 1] = M;
|
||||||
strides[ndim - 2] = 1;
|
strides[ndim - 2] = 1;
|
||||||
lu.set_data(
|
lu.set_data(allocator::malloc(lu.nbytes()), lu.nbytes(), strides, flags);
|
||||||
allocator::malloc_or_wait(lu.nbytes()), lu.nbytes(), strides, flags);
|
|
||||||
copy_inplace(
|
copy_inplace(
|
||||||
a,
|
a,
|
||||||
lu,
|
lu,
|
||||||
@ -44,8 +43,8 @@ void luf_impl(
|
|||||||
stream);
|
stream);
|
||||||
|
|
||||||
auto a_ptr = lu.data<T>();
|
auto a_ptr = lu.data<T>();
|
||||||
pivots.set_data(allocator::malloc_or_wait(pivots.nbytes()));
|
pivots.set_data(allocator::malloc(pivots.nbytes()));
|
||||||
row_indices.set_data(allocator::malloc_or_wait(row_indices.nbytes()));
|
row_indices.set_data(allocator::malloc(row_indices.nbytes()));
|
||||||
auto pivots_ptr = pivots.data<uint32_t>();
|
auto pivots_ptr = pivots.data<uint32_t>();
|
||||||
auto row_indices_ptr = row_indices.data<uint32_t>();
|
auto row_indices_ptr = row_indices.data<uint32_t>();
|
||||||
size_t num_matrices = a.size() / (M * N);
|
size_t num_matrices = a.size() / (M * N);
|
||||||
|
@ -59,7 +59,7 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"[BlockMaskedMM::eval] Currently only supports float32.");
|
"[BlockMaskedMM::eval] Currently only supports float32.");
|
||||||
}
|
}
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
auto& a_pre = inputs[0];
|
auto& a_pre = inputs[0];
|
||||||
auto& b_pre = inputs[1];
|
auto& b_pre = inputs[1];
|
||||||
@ -318,7 +318,7 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"[GatherMM::eval] Currently only supports float32.");
|
"[GatherMM::eval] Currently only supports float32.");
|
||||||
}
|
}
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
auto& a_pre = inputs[0];
|
auto& a_pre = inputs[0];
|
||||||
auto& b_pre = inputs[1];
|
auto& b_pre = inputs[1];
|
||||||
|
@ -115,7 +115,7 @@ void matmul_general(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
if (inputs[0].shape(-1) == 0) {
|
if (inputs[0].shape(-1) == 0) {
|
||||||
auto& encoder = cpu::get_command_encoder(stream());
|
auto& encoder = cpu::get_command_encoder(stream());
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
|
@ -21,7 +21,7 @@ namespace mlx::core {
|
|||||||
void reshape(const array& in, array& out) {
|
void reshape(const array& in, array& out) {
|
||||||
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
||||||
if (copy_necessary) {
|
if (copy_necessary) {
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
copy_inplace(in, out, CopyType::General, out.primitive().stream());
|
copy_inplace(in, out, CopyType::General, out.primitive().stream());
|
||||||
} else {
|
} else {
|
||||||
shared_buffer_reshape(in, out_strides, out);
|
shared_buffer_reshape(in, out_strides, out);
|
||||||
@ -39,7 +39,7 @@ static std::pair<array, bool> compute_dynamic_offset(
|
|||||||
if (donate) {
|
if (donate) {
|
||||||
offset.copy_shared_buffer(indices);
|
offset.copy_shared_buffer(indices);
|
||||||
} else {
|
} else {
|
||||||
offset.set_data(allocator::malloc_or_wait(offset.itemsize()));
|
offset.set_data(allocator::malloc(offset.itemsize()));
|
||||||
}
|
}
|
||||||
|
|
||||||
auto& encoder = cpu::get_command_encoder(stream);
|
auto& encoder = cpu::get_command_encoder(stream);
|
||||||
@ -124,7 +124,7 @@ void Transpose::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
void Arange::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Arange::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 0);
|
assert(inputs.size() == 0);
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
switch (out.dtype()) {
|
switch (out.dtype()) {
|
||||||
case bool_:
|
case bool_:
|
||||||
throw std::runtime_error("Bool type unsupported for arange.");
|
throw std::runtime_error("Bool type unsupported for arange.");
|
||||||
@ -186,7 +186,7 @@ void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin());
|
std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin());
|
||||||
|
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
auto strides = out.strides();
|
auto strides = out.strides();
|
||||||
auto flags = out.flags();
|
auto flags = out.flags();
|
||||||
@ -276,7 +276,7 @@ void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
size_t elems_per_key = out.size() / num_keys;
|
size_t elems_per_key = out.size() / num_keys;
|
||||||
size_t bytes_per_key = out.itemsize() * elems_per_key;
|
size_t bytes_per_key = out.itemsize() * elems_per_key;
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
auto kptr = inputs[0].data<uint32_t>();
|
auto kptr = inputs[0].data<uint32_t>();
|
||||||
auto cptr = out.data<char>();
|
auto cptr = out.data<char>();
|
||||||
@ -335,7 +335,7 @@ void DynamicSlice::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
auto [in_offset, donated] =
|
auto [in_offset, donated] =
|
||||||
compute_dynamic_offset(inputs[1], in.strides(), axes_, stream());
|
compute_dynamic_offset(inputs[1], in.strides(), axes_, stream());
|
||||||
copy_inplace(
|
copy_inplace(
|
||||||
@ -450,7 +450,7 @@ void View::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
} else {
|
} else {
|
||||||
auto tmp = array(
|
auto tmp = array(
|
||||||
in.shape(), in.dtype() == bool_ ? uint8 : in.dtype(), nullptr, {});
|
in.shape(), in.dtype() == bool_ ? uint8 : in.dtype(), nullptr, {});
|
||||||
tmp.set_data(allocator::malloc_or_wait(tmp.nbytes()));
|
tmp.set_data(allocator::malloc(tmp.nbytes()));
|
||||||
if (in.dtype() == bool_) {
|
if (in.dtype() == bool_) {
|
||||||
auto in_tmp = array(in.shape(), uint8, nullptr, {});
|
auto in_tmp = array(in.shape(), uint8, nullptr, {});
|
||||||
in_tmp.copy_shared_buffer(in);
|
in_tmp.copy_shared_buffer(in);
|
||||||
|
@ -25,12 +25,11 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
|
|||||||
auto strides = in.strides();
|
auto strides = in.strides();
|
||||||
strides[in.ndim() - 2] = 1;
|
strides[in.ndim() - 2] = 1;
|
||||||
strides[in.ndim() - 1] = M;
|
strides[in.ndim() - 1] = M;
|
||||||
in.set_data(
|
in.set_data(allocator::malloc(in.nbytes()), in.nbytes(), strides, flags);
|
||||||
allocator::malloc_or_wait(in.nbytes()), in.nbytes(), strides, flags);
|
|
||||||
copy_inplace(a, in, CopyType::GeneralGeneral, stream);
|
copy_inplace(a, in, CopyType::GeneralGeneral, stream);
|
||||||
auto& encoder = cpu::get_command_encoder(stream);
|
auto& encoder = cpu::get_command_encoder(stream);
|
||||||
q.set_data(allocator::malloc_or_wait(q.nbytes()));
|
q.set_data(allocator::malloc(q.nbytes()));
|
||||||
r.set_data(allocator::malloc_or_wait(r.nbytes()));
|
r.set_data(allocator::malloc(r.nbytes()));
|
||||||
|
|
||||||
auto in_ptr = in.data<T>();
|
auto in_ptr = in.data<T>();
|
||||||
auto r_ptr = r.data<T>();
|
auto r_ptr = r.data<T>();
|
||||||
@ -41,8 +40,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
|
|||||||
encoder.set_output_array(r);
|
encoder.set_output_array(r);
|
||||||
encoder.dispatch([in_ptr, q_ptr, r_ptr, M, N, lda, num_matrices]() {
|
encoder.dispatch([in_ptr, q_ptr, r_ptr, M, N, lda, num_matrices]() {
|
||||||
int num_reflectors = std::min(M, N);
|
int num_reflectors = std::min(M, N);
|
||||||
auto tau =
|
auto tau = allocator::malloc(sizeof(T) * num_matrices * num_reflectors);
|
||||||
allocator::malloc_or_wait(sizeof(T) * num_matrices * num_reflectors);
|
|
||||||
|
|
||||||
T optimal_work;
|
T optimal_work;
|
||||||
int lwork = -1;
|
int lwork = -1;
|
||||||
@ -53,7 +51,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
|
|||||||
|
|
||||||
// Update workspace size
|
// Update workspace size
|
||||||
lwork = optimal_work;
|
lwork = optimal_work;
|
||||||
auto work = allocator::malloc_or_wait(sizeof(T) * lwork);
|
auto work = allocator::malloc(sizeof(T) * lwork);
|
||||||
|
|
||||||
// Loop over matrices
|
// Loop over matrices
|
||||||
for (int i = 0; i < num_matrices; ++i) {
|
for (int i = 0; i < num_matrices; ++i) {
|
||||||
@ -96,7 +94,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
|
|||||||
&lwork,
|
&lwork,
|
||||||
&info);
|
&info);
|
||||||
lwork = optimal_work;
|
lwork = optimal_work;
|
||||||
work = allocator::malloc_or_wait(sizeof(T) * lwork);
|
work = allocator::malloc(sizeof(T) * lwork);
|
||||||
|
|
||||||
// Loop over matrices
|
// Loop over matrices
|
||||||
for (int i = 0; i < num_matrices; ++i) {
|
for (int i = 0; i < num_matrices; ++i) {
|
||||||
|
@ -515,7 +515,7 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto scales = ensure_row_contiguous(scales_pre);
|
auto scales = ensure_row_contiguous(scales_pre);
|
||||||
auto biases = ensure_row_contiguous(biases_pre);
|
auto biases = ensure_row_contiguous(biases_pre);
|
||||||
|
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
auto& encoder = cpu::get_command_encoder(stream());
|
auto& encoder = cpu::get_command_encoder(stream());
|
||||||
encoder.add_temporaries(std::move(temps));
|
encoder.add_temporaries(std::move(temps));
|
||||||
@ -565,7 +565,7 @@ void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto scales = ensure_row_contiguous_last_dims(scales_pre);
|
auto scales = ensure_row_contiguous_last_dims(scales_pre);
|
||||||
auto biases = ensure_row_contiguous_last_dims(biases_pre);
|
auto biases = ensure_row_contiguous_last_dims(biases_pre);
|
||||||
|
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
auto& encoder = cpu::get_command_encoder(stream());
|
auto& encoder = cpu::get_command_encoder(stream());
|
||||||
encoder.add_temporaries(std::move(temps));
|
encoder.add_temporaries(std::move(temps));
|
||||||
@ -691,12 +691,12 @@ void fast::AffineQuantize::eval_cpu(
|
|||||||
|
|
||||||
auto [w, copied] = ensure_row_contiguous(inputs[0]);
|
auto [w, copied] = ensure_row_contiguous(inputs[0]);
|
||||||
auto& out = outputs[0];
|
auto& out = outputs[0];
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
auto& scales = outputs[1];
|
auto& scales = outputs[1];
|
||||||
auto& biases = outputs[2];
|
auto& biases = outputs[2];
|
||||||
scales.set_data(allocator::malloc_or_wait(scales.nbytes()));
|
scales.set_data(allocator::malloc(scales.nbytes()));
|
||||||
biases.set_data(allocator::malloc_or_wait(biases.nbytes()));
|
biases.set_data(allocator::malloc(biases.nbytes()));
|
||||||
auto& encoder = cpu::get_command_encoder(stream());
|
auto& encoder = cpu::get_command_encoder(stream());
|
||||||
if (copied) {
|
if (copied) {
|
||||||
encoder.add_temporary(w);
|
encoder.add_temporary(w);
|
||||||
|
@ -433,7 +433,7 @@ void reduce_dispatch_min_max(
|
|||||||
void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
auto& encoder = cpu::get_command_encoder(stream());
|
auto& encoder = cpu::get_command_encoder(stream());
|
||||||
encoder.set_input_array(in);
|
encoder.set_input_array(in);
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
|
@ -244,7 +244,7 @@ void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
in = arr_copy;
|
in = arr_copy;
|
||||||
encoder.add_temporary(arr_copy);
|
encoder.add_temporary(arr_copy);
|
||||||
}
|
}
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
encoder.set_input_array(in);
|
encoder.set_input_array(in);
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
|
@ -129,7 +129,7 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
out.copy_shared_buffer(x);
|
out.copy_shared_buffer(x);
|
||||||
} else {
|
} else {
|
||||||
out.set_data(
|
out.set_data(
|
||||||
allocator::malloc_or_wait(x.data_size() * x.itemsize()),
|
allocator::malloc(x.data_size() * x.itemsize()),
|
||||||
x.data_size(),
|
x.data_size(),
|
||||||
x.strides(),
|
x.strides(),
|
||||||
x.flags());
|
x.flags());
|
||||||
|
@ -288,7 +288,7 @@ void ArgSort::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
|
|
||||||
// Allocate output
|
// Allocate output
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
auto& encoder = cpu::get_command_encoder(stream());
|
auto& encoder = cpu::get_command_encoder(stream());
|
||||||
encoder.set_input_array(in);
|
encoder.set_input_array(in);
|
||||||
@ -379,7 +379,7 @@ void ArgPartition::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
|
|
||||||
// Allocate output
|
// Allocate output
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
auto& encoder = cpu::get_command_encoder(stream());
|
auto& encoder = cpu::get_command_encoder(stream());
|
||||||
encoder.set_input_array(in);
|
encoder.set_input_array(in);
|
||||||
|
@ -50,9 +50,9 @@ void svd_impl(
|
|||||||
array& s = outputs[1];
|
array& s = outputs[1];
|
||||||
array& vt = outputs[2];
|
array& vt = outputs[2];
|
||||||
|
|
||||||
u.set_data(allocator::malloc_or_wait(u.nbytes()));
|
u.set_data(allocator::malloc(u.nbytes()));
|
||||||
s.set_data(allocator::malloc_or_wait(s.nbytes()));
|
s.set_data(allocator::malloc(s.nbytes()));
|
||||||
vt.set_data(allocator::malloc_or_wait(vt.nbytes()));
|
vt.set_data(allocator::malloc(vt.nbytes()));
|
||||||
|
|
||||||
encoder.set_output_array(u);
|
encoder.set_output_array(u);
|
||||||
encoder.set_output_array(s);
|
encoder.set_output_array(s);
|
||||||
@ -64,7 +64,7 @@ void svd_impl(
|
|||||||
} else {
|
} else {
|
||||||
array& s = outputs[0];
|
array& s = outputs[0];
|
||||||
|
|
||||||
s.set_data(allocator::malloc_or_wait(s.nbytes()));
|
s.set_data(allocator::malloc(s.nbytes()));
|
||||||
|
|
||||||
encoder.set_output_array(s);
|
encoder.set_output_array(s);
|
||||||
|
|
||||||
@ -91,7 +91,7 @@ void svd_impl(
|
|||||||
|
|
||||||
// Will contain the indices of eigenvectors that failed to converge (not
|
// Will contain the indices of eigenvectors that failed to converge (not
|
||||||
// used here but required by lapack).
|
// used here but required by lapack).
|
||||||
auto iwork = array::Data{allocator::malloc_or_wait(sizeof(int) * 12 * K)};
|
auto iwork = array::Data{allocator::malloc(sizeof(int) * 12 * K)};
|
||||||
|
|
||||||
static const int lwork_query = -1;
|
static const int lwork_query = -1;
|
||||||
|
|
||||||
@ -132,7 +132,7 @@ void svd_impl(
|
|||||||
}
|
}
|
||||||
|
|
||||||
const int lwork = workspace_dimension;
|
const int lwork = workspace_dimension;
|
||||||
auto scratch = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)};
|
auto scratch = array::Data{allocator::malloc(sizeof(T) * lwork)};
|
||||||
|
|
||||||
// Loop over matrices.
|
// Loop over matrices.
|
||||||
for (int i = 0; i < num_matrices; i++) {
|
for (int i = 0; i < num_matrices; i++) {
|
||||||
|
@ -18,13 +18,13 @@ void set_unary_output_data(const array& in, array& out) {
|
|||||||
} else {
|
} else {
|
||||||
auto size = in.data_size();
|
auto size = in.data_size();
|
||||||
out.set_data(
|
out.set_data(
|
||||||
allocator::malloc_or_wait(size * out.itemsize()),
|
allocator::malloc(size * out.itemsize()),
|
||||||
size,
|
size,
|
||||||
in.strides(),
|
in.strides(),
|
||||||
in.flags());
|
in.flags());
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -192,16 +192,19 @@ size_t MetalAllocator::set_cache_limit(size_t limit) {
|
|||||||
return limit;
|
return limit;
|
||||||
};
|
};
|
||||||
|
|
||||||
size_t MetalAllocator::set_memory_limit(size_t limit, bool relaxed) {
|
size_t MetalAllocator::set_memory_limit(size_t limit) {
|
||||||
std::unique_lock lk(mutex_);
|
std::unique_lock lk(mutex_);
|
||||||
std::swap(limit, block_limit_);
|
std::swap(limit, block_limit_);
|
||||||
relaxed_ = relaxed;
|
|
||||||
gc_limit_ = std::min(
|
gc_limit_ = std::min(
|
||||||
block_limit_,
|
block_limit_,
|
||||||
static_cast<size_t>(0.95 * device_->recommendedMaxWorkingSetSize()));
|
static_cast<size_t>(0.95 * device_->recommendedMaxWorkingSetSize()));
|
||||||
return limit;
|
return limit;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
size_t MetalAllocator::get_memory_limit() {
|
||||||
|
return block_limit_;
|
||||||
|
}
|
||||||
|
|
||||||
size_t MetalAllocator::set_wired_limit(size_t limit) {
|
size_t MetalAllocator::set_wired_limit(size_t limit) {
|
||||||
std::unique_lock lk(mutex_);
|
std::unique_lock lk(mutex_);
|
||||||
std::swap(limit, wired_limit_);
|
std::swap(limit, wired_limit_);
|
||||||
@ -209,7 +212,7 @@ size_t MetalAllocator::set_wired_limit(size_t limit) {
|
|||||||
return limit;
|
return limit;
|
||||||
};
|
};
|
||||||
|
|
||||||
Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
|
Buffer MetalAllocator::malloc(size_t size) {
|
||||||
// Metal doesn't like empty buffers
|
// Metal doesn't like empty buffers
|
||||||
if (size == 0) {
|
if (size == 0) {
|
||||||
return Buffer{nullptr};
|
return Buffer{nullptr};
|
||||||
@ -236,11 +239,6 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
|
|||||||
if (!buf) {
|
if (!buf) {
|
||||||
size_t mem_required = get_active_memory() + get_cache_memory() + size;
|
size_t mem_required = get_active_memory() + get_cache_memory() + size;
|
||||||
|
|
||||||
// If there is too much memory pressure, fail (likely causes a wait).
|
|
||||||
if (!(allow_swap && relaxed_) && mem_required >= block_limit_) {
|
|
||||||
return Buffer{nullptr};
|
|
||||||
}
|
|
||||||
|
|
||||||
auto pool = metal::new_scoped_memory_pool();
|
auto pool = metal::new_scoped_memory_pool();
|
||||||
|
|
||||||
// If we have a lot of memory pressure or are over the maximum cache size,
|
// If we have a lot of memory pressure or are over the maximum cache size,
|
||||||
@ -328,8 +326,11 @@ MetalAllocator& allocator() {
|
|||||||
size_t set_cache_limit(size_t limit) {
|
size_t set_cache_limit(size_t limit) {
|
||||||
return allocator().set_cache_limit(limit);
|
return allocator().set_cache_limit(limit);
|
||||||
}
|
}
|
||||||
size_t set_memory_limit(size_t limit, bool relaxed /* = true */) {
|
size_t set_memory_limit(size_t limit) {
|
||||||
return allocator().set_memory_limit(limit, relaxed);
|
return allocator().set_memory_limit(limit);
|
||||||
|
}
|
||||||
|
size_t get_memory_limit() {
|
||||||
|
return allocator().get_memory_limit();
|
||||||
}
|
}
|
||||||
size_t set_wired_limit(size_t limit) {
|
size_t set_wired_limit(size_t limit) {
|
||||||
if (limit >
|
if (limit >
|
||||||
|
@ -56,7 +56,7 @@ class BufferCache {
|
|||||||
class MetalAllocator : public allocator::Allocator {
|
class MetalAllocator : public allocator::Allocator {
|
||||||
/** Allocator for Metal GPUs. */
|
/** Allocator for Metal GPUs. */
|
||||||
public:
|
public:
|
||||||
virtual Buffer malloc(size_t size, bool allow_swap = false) override;
|
virtual Buffer malloc(size_t size) override;
|
||||||
virtual void free(Buffer buffer) override;
|
virtual void free(Buffer buffer) override;
|
||||||
virtual size_t size(Buffer buffer) const override;
|
virtual size_t size(Buffer buffer) const override;
|
||||||
size_t get_active_memory() {
|
size_t get_active_memory() {
|
||||||
@ -73,7 +73,8 @@ class MetalAllocator : public allocator::Allocator {
|
|||||||
return buffer_cache_.cache_size();
|
return buffer_cache_.cache_size();
|
||||||
};
|
};
|
||||||
size_t set_cache_limit(size_t limit);
|
size_t set_cache_limit(size_t limit);
|
||||||
size_t set_memory_limit(size_t limit, bool relaxed);
|
size_t set_memory_limit(size_t limit);
|
||||||
|
size_t get_memory_limit();
|
||||||
size_t set_wired_limit(size_t limit);
|
size_t set_wired_limit(size_t limit);
|
||||||
void clear_cache();
|
void clear_cache();
|
||||||
|
|
||||||
|
@ -37,7 +37,7 @@ void explicit_gemm_conv_ND_gpu(
|
|||||||
Shape unfolded_shape{implicit_M, implicit_K};
|
Shape unfolded_shape{implicit_M, implicit_K};
|
||||||
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
|
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
|
||||||
|
|
||||||
in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes()));
|
in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes()));
|
||||||
|
|
||||||
// Prepare unfolding kernel
|
// Prepare unfolding kernel
|
||||||
std::ostringstream kname;
|
std::ostringstream kname;
|
||||||
@ -115,7 +115,7 @@ void explicit_gemm_conv_group_ND_gpu(
|
|||||||
// Prepare unfolding array
|
// Prepare unfolding array
|
||||||
Shape unfolded_shape{implicit_M, implicit_K * groups};
|
Shape unfolded_shape{implicit_M, implicit_K * groups};
|
||||||
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
|
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
|
||||||
in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes()));
|
in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes()));
|
||||||
|
|
||||||
// Prepare unfolding kernel
|
// Prepare unfolding kernel
|
||||||
std::ostringstream kname;
|
std::ostringstream kname;
|
||||||
@ -613,7 +613,7 @@ void winograd_conv_2D_gpu(
|
|||||||
// Do filter transform
|
// Do filter transform
|
||||||
Shape filt_wg_shape = {8 * 8, conv_params.C, conv_params.O};
|
Shape filt_wg_shape = {8 * 8, conv_params.C, conv_params.O};
|
||||||
array filt_wg(std::move(filt_wg_shape), wt.dtype(), nullptr, {});
|
array filt_wg(std::move(filt_wg_shape), wt.dtype(), nullptr, {});
|
||||||
filt_wg.set_data(allocator::malloc_or_wait(filt_wg.nbytes()));
|
filt_wg.set_data(allocator::malloc(filt_wg.nbytes()));
|
||||||
copies_w.push_back(filt_wg);
|
copies_w.push_back(filt_wg);
|
||||||
{
|
{
|
||||||
int bc = 32;
|
int bc = 32;
|
||||||
@ -640,7 +640,7 @@ void winograd_conv_2D_gpu(
|
|||||||
// Do input transform
|
// Do input transform
|
||||||
Shape inp_wg_shape = {8 * 8, N_tiles, conv_params.C};
|
Shape inp_wg_shape = {8 * 8, N_tiles, conv_params.C};
|
||||||
array inp_wg(std::move(inp_wg_shape), in.dtype(), nullptr, {});
|
array inp_wg(std::move(inp_wg_shape), in.dtype(), nullptr, {});
|
||||||
inp_wg.set_data(allocator::malloc_or_wait(inp_wg.nbytes()));
|
inp_wg.set_data(allocator::malloc(inp_wg.nbytes()));
|
||||||
copies_w.push_back(inp_wg);
|
copies_w.push_back(inp_wg);
|
||||||
{
|
{
|
||||||
int bc = 32;
|
int bc = 32;
|
||||||
@ -667,7 +667,7 @@ void winograd_conv_2D_gpu(
|
|||||||
// Do batched gemm
|
// Do batched gemm
|
||||||
Shape out_wg_shape = {8 * 8, N_tiles, conv_params.O};
|
Shape out_wg_shape = {8 * 8, N_tiles, conv_params.O};
|
||||||
array out_wg(std::move(out_wg_shape), in.dtype(), nullptr, {});
|
array out_wg(std::move(out_wg_shape), in.dtype(), nullptr, {});
|
||||||
out_wg.set_data(allocator::malloc_or_wait(out_wg.nbytes()));
|
out_wg.set_data(allocator::malloc(out_wg.nbytes()));
|
||||||
copies_w.push_back(out_wg);
|
copies_w.push_back(out_wg);
|
||||||
{
|
{
|
||||||
std::vector<array> empty_copies;
|
std::vector<array> empty_copies;
|
||||||
@ -855,7 +855,7 @@ void conv_3D_gpu(
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
auto& s = stream();
|
auto& s = stream();
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
|
|
||||||
|
@ -202,7 +202,7 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
|
|||||||
if (out.size() == 0) {
|
if (out.size() == 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
bool large = out.data_size() > UINT32_MAX;
|
bool large = out.data_size() > UINT32_MAX;
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
std::string kernel_name = std::string(large ? "s2" : "s") + "_copy" +
|
std::string kernel_name = std::string(large ? "s2" : "s") + "_copy" +
|
||||||
|
@ -19,7 +19,7 @@ void CustomKernel::eval_gpu(
|
|||||||
copies.emplace_back(init_value_.value(), out.dtype());
|
copies.emplace_back(init_value_.value(), out.dtype());
|
||||||
fill_gpu(copies.back(), out, s);
|
fill_gpu(copies.back(), out, s);
|
||||||
} else {
|
} else {
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -20,7 +20,7 @@ struct FenceImpl {
|
|||||||
auto p = metal::new_scoped_memory_pool();
|
auto p = metal::new_scoped_memory_pool();
|
||||||
fence = static_cast<void*>(d->newSharedEvent());
|
fence = static_cast<void*>(d->newSharedEvent());
|
||||||
} else {
|
} else {
|
||||||
auto buf = allocator::malloc_or_wait(sizeof(uint32_t)).ptr();
|
auto buf = allocator::malloc(sizeof(uint32_t)).ptr();
|
||||||
fence = static_cast<void*>(buf);
|
fence = static_cast<void*>(buf);
|
||||||
cpu_value()[0] = 0;
|
cpu_value()[0] = 0;
|
||||||
}
|
}
|
||||||
|
@ -281,7 +281,7 @@ std::tuple<array, array, array> compute_raders_constants(
|
|||||||
}
|
}
|
||||||
|
|
||||||
array b_q_fft({rader_n - 1}, complex64, nullptr, {});
|
array b_q_fft({rader_n - 1}, complex64, nullptr, {});
|
||||||
b_q_fft.set_data(allocator::malloc_or_wait(b_q_fft.nbytes()));
|
b_q_fft.set_data(allocator::malloc(b_q_fft.nbytes()));
|
||||||
auto b_q_fft_ptr =
|
auto b_q_fft_ptr =
|
||||||
reinterpret_cast<std::complex<float>*>(b_q_fft.data<complex64_t>());
|
reinterpret_cast<std::complex<float>*>(b_q_fft.data<complex64_t>());
|
||||||
std::ptrdiff_t item_size = b_q_fft.itemsize();
|
std::ptrdiff_t item_size = b_q_fft.itemsize();
|
||||||
@ -327,11 +327,11 @@ std::pair<array, array> compute_bluestein_constants(int n, int bluestein_n) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
array w_k({n}, complex64, nullptr, {});
|
array w_k({n}, complex64, nullptr, {});
|
||||||
w_k.set_data(allocator::malloc_or_wait(w_k.nbytes()));
|
w_k.set_data(allocator::malloc(w_k.nbytes()));
|
||||||
std::copy(w_k_vec.begin(), w_k_vec.end(), w_k.data<complex64_t>());
|
std::copy(w_k_vec.begin(), w_k_vec.end(), w_k.data<complex64_t>());
|
||||||
|
|
||||||
array w_q({bluestein_n}, complex64, nullptr, {});
|
array w_q({bluestein_n}, complex64, nullptr, {});
|
||||||
w_q.set_data(allocator::malloc_or_wait(w_q.nbytes()));
|
w_q.set_data(allocator::malloc(w_q.nbytes()));
|
||||||
auto w_q_ptr =
|
auto w_q_ptr =
|
||||||
reinterpret_cast<std::complex<float>*>(w_q.data<complex64_t>());
|
reinterpret_cast<std::complex<float>*>(w_q.data<complex64_t>());
|
||||||
|
|
||||||
@ -551,8 +551,7 @@ void fft_op(
|
|||||||
flags.row_contiguous = is_row_contiguous;
|
flags.row_contiguous = is_row_contiguous;
|
||||||
flags.contiguous = data_size == x_copy.size();
|
flags.contiguous = data_size == x_copy.size();
|
||||||
|
|
||||||
x_copy.set_data(
|
x_copy.set_data(allocator::malloc(x.nbytes()), data_size, strides, flags);
|
||||||
allocator::malloc_or_wait(x.nbytes()), data_size, strides, flags);
|
|
||||||
copy_gpu_inplace(x, x_copy, CopyType::GeneralGeneral, s);
|
copy_gpu_inplace(x, x_copy, CopyType::GeneralGeneral, s);
|
||||||
copies.push_back(x_copy);
|
copies.push_back(x_copy);
|
||||||
return x_copy;
|
return x_copy;
|
||||||
@ -583,7 +582,7 @@ void fft_op(
|
|||||||
// TODO: allow donation here
|
// TODO: allow donation here
|
||||||
if (!inplace) {
|
if (!inplace) {
|
||||||
out.set_data(
|
out.set_data(
|
||||||
allocator::malloc_or_wait(out.nbytes()),
|
allocator::malloc(out.nbytes()),
|
||||||
out_data_size,
|
out_data_size,
|
||||||
out_strides,
|
out_strides,
|
||||||
in_contiguous.flags());
|
in_contiguous.flags());
|
||||||
|
@ -84,7 +84,7 @@ void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
if (in_contiguous.is_donatable()) {
|
if (in_contiguous.is_donatable()) {
|
||||||
out.copy_shared_buffer(in_contiguous);
|
out.copy_shared_buffer(in_contiguous);
|
||||||
} else {
|
} else {
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
}
|
}
|
||||||
|
|
||||||
int n, m;
|
int n, m;
|
||||||
@ -161,7 +161,7 @@ void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
// Upload 2:
|
// Upload 2:
|
||||||
// y = h12 @ tmp
|
// y = h12 @ tmp
|
||||||
array temp(in.shape(), in.dtype(), nullptr, {});
|
array temp(in.shape(), in.dtype(), nullptr, {});
|
||||||
temp.set_data(allocator::malloc_or_wait(temp.nbytes()));
|
temp.set_data(allocator::malloc(temp.nbytes()));
|
||||||
copies.push_back(temp);
|
copies.push_back(temp);
|
||||||
|
|
||||||
launch_hadamard(in_contiguous, temp, "n" + kernel_name, 1.0);
|
launch_hadamard(in_contiguous, temp, "n" + kernel_name, 1.0);
|
||||||
|
@ -43,7 +43,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
throw std::runtime_error(msg.str());
|
throw std::runtime_error(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
if (out.size() == 0) {
|
if (out.size() == 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -393,7 +393,7 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto& src = inputs[0];
|
auto& src = inputs[0];
|
||||||
auto& idx = inputs[1];
|
auto& idx = inputs[1];
|
||||||
|
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
if (out.size() == 0) {
|
if (out.size() == 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -382,7 +382,7 @@ void steel_matmul(
|
|||||||
int split_k_partition_size = gemm_k_iterations * bk;
|
int split_k_partition_size = gemm_k_iterations * bk;
|
||||||
|
|
||||||
array C_split({split_k_partitions, M, N}, float32, nullptr, {});
|
array C_split({split_k_partitions, M, N}, float32, nullptr, {});
|
||||||
C_split.set_data(allocator::malloc_or_wait(C_split.nbytes()));
|
C_split.set_data(allocator::malloc(C_split.nbytes()));
|
||||||
copies.push_back(C_split);
|
copies.push_back(C_split);
|
||||||
|
|
||||||
bool mn_aligned = M % bm == 0 && N % bn == 0;
|
bool mn_aligned = M % bm == 0 && N % bn == 0;
|
||||||
@ -513,7 +513,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
// Init checks and prep
|
// Init checks and prep
|
||||||
@ -677,7 +677,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"[matmul] Does not yet support non-floating point types.");
|
"[matmul] Does not yet support non-floating point types.");
|
||||||
}
|
}
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
auto& s = stream();
|
auto& s = stream();
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
|
|
||||||
@ -860,7 +860,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
int split_k_partition_size = gemm_k_iterations * bk;
|
int split_k_partition_size = gemm_k_iterations * bk;
|
||||||
|
|
||||||
array C_split({split_k_partitions, M, N}, float32, nullptr, {});
|
array C_split({split_k_partitions, M, N}, float32, nullptr, {});
|
||||||
C_split.set_data(allocator::malloc_or_wait(C_split.nbytes()));
|
C_split.set_data(allocator::malloc(C_split.nbytes()));
|
||||||
copies.push_back(C_split);
|
copies.push_back(C_split);
|
||||||
|
|
||||||
bool mn_aligned = M % bm == 0 && N % bn == 0;
|
bool mn_aligned = M % bm == 0 && N % bn == 0;
|
||||||
@ -1096,7 +1096,7 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
// Init checks and prep
|
// Init checks and prep
|
||||||
@ -1484,7 +1484,7 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
// Init checks and prep
|
// Init checks and prep
|
||||||
|
@ -38,17 +38,20 @@ void reset_peak_memory();
|
|||||||
size_t get_cache_memory();
|
size_t get_cache_memory();
|
||||||
|
|
||||||
/* Set the memory limit.
|
/* Set the memory limit.
|
||||||
* Calls to malloc will wait on scheduled tasks if the limit is exceeded. If
|
* The memory limit is a guideline for the maximum amount of memory to use
|
||||||
* there are no more scheduled tasks an error will be raised if relaxed
|
* during graph evaluation. If the memory limit is exceeded and there is no
|
||||||
* is false or memory will be allocated (including the potential for
|
* more RAM (including swap when available) allocations will result in an
|
||||||
* swap) if relaxed is true.
|
* exception.
|
||||||
*
|
*
|
||||||
* The memory limit defaults to 1.5 times the maximum recommended working set
|
* When metal is available the memory limit defaults to 1.5 times the maximum
|
||||||
* size reported by the device.
|
* recommended working set size reported by the device.
|
||||||
*
|
*
|
||||||
* Returns the previous memory limit.
|
* Returns the previous memory limit.
|
||||||
* */
|
* */
|
||||||
size_t set_memory_limit(size_t limit, bool relaxed = true);
|
size_t set_memory_limit(size_t limit);
|
||||||
|
|
||||||
|
/* Get the current memory limit. */
|
||||||
|
size_t get_memory_limit();
|
||||||
|
|
||||||
/* Set the free cache limit.
|
/* Set the free cache limit.
|
||||||
* If using more than the given limit, free memory will be reclaimed
|
* If using more than the given limit, free memory will be reclaimed
|
||||||
|
@ -29,7 +29,7 @@ void RMSNorm::eval_gpu(
|
|||||||
out.copy_shared_buffer(x);
|
out.copy_shared_buffer(x);
|
||||||
} else {
|
} else {
|
||||||
out.set_data(
|
out.set_data(
|
||||||
allocator::malloc_or_wait(x.data_size() * x.itemsize()),
|
allocator::malloc(x.data_size() * x.itemsize()),
|
||||||
x.data_size(),
|
x.data_size(),
|
||||||
x.strides(),
|
x.strides(),
|
||||||
x.flags());
|
x.flags());
|
||||||
@ -129,7 +129,7 @@ void RMSNormVJP::eval_gpu(
|
|||||||
gx.copy_shared_buffer(g);
|
gx.copy_shared_buffer(g);
|
||||||
g_in_gx = true;
|
g_in_gx = true;
|
||||||
} else {
|
} else {
|
||||||
gx.set_data(allocator::malloc_or_wait(gx.nbytes()));
|
gx.set_data(allocator::malloc(gx.nbytes()));
|
||||||
}
|
}
|
||||||
if (g_copied && !g_in_gx) {
|
if (g_copied && !g_in_gx) {
|
||||||
d.add_temporary(g, s.index);
|
d.add_temporary(g, s.index);
|
||||||
@ -146,11 +146,11 @@ void RMSNormVJP::eval_gpu(
|
|||||||
if (!g_in_gx && donate_g) {
|
if (!g_in_gx && donate_g) {
|
||||||
gw_temp.copy_shared_buffer(g);
|
gw_temp.copy_shared_buffer(g);
|
||||||
} else {
|
} else {
|
||||||
gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes()));
|
gw_temp.set_data(allocator::malloc(gw_temp.nbytes()));
|
||||||
d.add_temporary(gw_temp, s.index);
|
d.add_temporary(gw_temp, s.index);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
gw.set_data(allocator::malloc_or_wait(gw.nbytes()));
|
gw.set_data(allocator::malloc(gw.nbytes()));
|
||||||
|
|
||||||
const int simd_size = 32;
|
const int simd_size = 32;
|
||||||
const int n_reads = RMS_N_READS;
|
const int n_reads = RMS_N_READS;
|
||||||
@ -226,7 +226,7 @@ void LayerNorm::eval_gpu(
|
|||||||
out.copy_shared_buffer(x);
|
out.copy_shared_buffer(x);
|
||||||
} else {
|
} else {
|
||||||
out.set_data(
|
out.set_data(
|
||||||
allocator::malloc_or_wait(x.data_size() * x.itemsize()),
|
allocator::malloc(x.data_size() * x.itemsize()),
|
||||||
x.data_size(),
|
x.data_size(),
|
||||||
x.strides(),
|
x.strides(),
|
||||||
x.flags());
|
x.flags());
|
||||||
@ -331,7 +331,7 @@ void LayerNormVJP::eval_gpu(
|
|||||||
gx.copy_shared_buffer(g);
|
gx.copy_shared_buffer(g);
|
||||||
g_in_gx = true;
|
g_in_gx = true;
|
||||||
} else {
|
} else {
|
||||||
gx.set_data(allocator::malloc_or_wait(gx.nbytes()));
|
gx.set_data(allocator::malloc(gx.nbytes()));
|
||||||
}
|
}
|
||||||
if (g_copied && !g_in_gx) {
|
if (g_copied && !g_in_gx) {
|
||||||
d.add_temporary(g, s.index);
|
d.add_temporary(g, s.index);
|
||||||
@ -348,12 +348,12 @@ void LayerNormVJP::eval_gpu(
|
|||||||
if (!g_in_gx && donate_g) {
|
if (!g_in_gx && donate_g) {
|
||||||
gw_temp.copy_shared_buffer(g);
|
gw_temp.copy_shared_buffer(g);
|
||||||
} else {
|
} else {
|
||||||
gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes()));
|
gw_temp.set_data(allocator::malloc(gw_temp.nbytes()));
|
||||||
d.add_temporary(gw_temp, s.index);
|
d.add_temporary(gw_temp, s.index);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
gw.set_data(allocator::malloc_or_wait(gw.nbytes()));
|
gw.set_data(allocator::malloc(gw.nbytes()));
|
||||||
gb.set_data(allocator::malloc_or_wait(gb.nbytes()));
|
gb.set_data(allocator::malloc(gb.nbytes()));
|
||||||
|
|
||||||
// Finish with the gradient for b in case we had a b
|
// Finish with the gradient for b in case we had a b
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
@ -28,7 +28,7 @@ void arange_set_scalars(T start, T next, metal::CommandEncoder& enc) {
|
|||||||
void reshape(const array& in, array& out, Stream s) {
|
void reshape(const array& in, array& out, Stream s) {
|
||||||
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
||||||
if (copy_necessary) {
|
if (copy_necessary) {
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
copy_gpu_inplace(
|
copy_gpu_inplace(
|
||||||
in,
|
in,
|
||||||
out,
|
out,
|
||||||
@ -58,7 +58,7 @@ static array compute_dynamic_offset(
|
|||||||
if (donate) {
|
if (donate) {
|
||||||
offset.copy_shared_buffer(indices);
|
offset.copy_shared_buffer(indices);
|
||||||
} else {
|
} else {
|
||||||
offset.set_data(allocator::malloc_or_wait(offset.itemsize()));
|
offset.set_data(allocator::malloc(offset.itemsize()));
|
||||||
}
|
}
|
||||||
d.add_temporary(offset, s.index);
|
d.add_temporary(offset, s.index);
|
||||||
|
|
||||||
@ -100,7 +100,7 @@ static array compute_dynamic_offset(
|
|||||||
|
|
||||||
void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 0);
|
assert(inputs.size() == 0);
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
if (out.size() == 0) {
|
if (out.size() == 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -161,7 +161,7 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
auto& s = stream();
|
auto& s = stream();
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
std::string op_name;
|
std::string op_name;
|
||||||
@ -333,7 +333,7 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
size_t elems_per_key = out.size() / num_keys;
|
size_t elems_per_key = out.size() / num_keys;
|
||||||
size_t bytes_per_key = out.itemsize() * elems_per_key;
|
size_t bytes_per_key = out.itemsize() * elems_per_key;
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
if (out.size() == 0) {
|
if (out.size() == 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -397,7 +397,7 @@ void DynamicSlice::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
auto& start = inputs[1];
|
auto& start = inputs[1];
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
auto s = stream();
|
auto s = stream();
|
||||||
auto in_offset = compute_dynamic_offset(start, in.strides(), axes_, s);
|
auto in_offset = compute_dynamic_offset(start, in.strides(), axes_, s);
|
||||||
copy_gpu_inplace(
|
copy_gpu_inplace(
|
||||||
@ -554,7 +554,7 @@ void View::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
in, strides, in.flags(), in.data_size() * ibytes / obytes);
|
in, strides, in.flags(), in.data_size() * ibytes / obytes);
|
||||||
} else {
|
} else {
|
||||||
auto tmp = array(in.shape(), in.dtype(), nullptr, {});
|
auto tmp = array(in.shape(), in.dtype(), nullptr, {});
|
||||||
tmp.set_data(allocator::malloc_or_wait(tmp.nbytes()));
|
tmp.set_data(allocator::malloc(tmp.nbytes()));
|
||||||
copy_gpu_inplace(in, tmp, CopyType::General, stream());
|
copy_gpu_inplace(in, tmp, CopyType::General, stream());
|
||||||
|
|
||||||
auto flags = out.flags();
|
auto flags = out.flags();
|
||||||
|
@ -224,7 +224,7 @@ void qvm_split_k(
|
|||||||
auto temp_shape = out.shape();
|
auto temp_shape = out.shape();
|
||||||
temp_shape.insert(temp_shape.end() - 2, split_k);
|
temp_shape.insert(temp_shape.end() - 2, split_k);
|
||||||
array intermediate(temp_shape, x.dtype(), nullptr, {});
|
array intermediate(temp_shape, x.dtype(), nullptr, {});
|
||||||
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
|
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
|
||||||
d.add_temporary(intermediate, s.index);
|
d.add_temporary(intermediate, s.index);
|
||||||
|
|
||||||
std::ostringstream kname;
|
std::ostringstream kname;
|
||||||
@ -277,7 +277,7 @@ void qmm_op(
|
|||||||
int bits,
|
int bits,
|
||||||
bool gather,
|
bool gather,
|
||||||
const Stream& s) {
|
const Stream& s) {
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
MTL::Size group_dims;
|
MTL::Size group_dims;
|
||||||
MTL::Size grid_dims;
|
MTL::Size grid_dims;
|
||||||
@ -394,7 +394,7 @@ void fast::AffineQuantize::eval_gpu(
|
|||||||
std::vector<array>& outputs) {
|
std::vector<array>& outputs) {
|
||||||
auto& w_pre = inputs[0];
|
auto& w_pre = inputs[0];
|
||||||
auto& out = outputs[0];
|
auto& out = outputs[0];
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
auto& s = stream();
|
auto& s = stream();
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
@ -425,8 +425,8 @@ void fast::AffineQuantize::eval_gpu(
|
|||||||
} else {
|
} else {
|
||||||
auto& scales = outputs[1];
|
auto& scales = outputs[1];
|
||||||
auto& biases = outputs[2];
|
auto& biases = outputs[2];
|
||||||
scales.set_data(allocator::malloc_or_wait(scales.nbytes()));
|
scales.set_data(allocator::malloc(scales.nbytes()));
|
||||||
biases.set_data(allocator::malloc_or_wait(biases.nbytes()));
|
biases.set_data(allocator::malloc(biases.nbytes()));
|
||||||
compute_encoder.set_output_array(out, 1);
|
compute_encoder.set_output_array(out, 1);
|
||||||
compute_encoder.set_output_array(scales, 2);
|
compute_encoder.set_output_array(scales, 2);
|
||||||
compute_encoder.set_output_array(biases, 3);
|
compute_encoder.set_output_array(biases, 3);
|
||||||
|
@ -347,7 +347,7 @@ void all_reduce_dispatch(
|
|||||||
|
|
||||||
// Allocate an intermediate tensor to hold results if needed
|
// Allocate an intermediate tensor to hold results if needed
|
||||||
array intermediate({n_rows}, out_type, nullptr, {});
|
array intermediate({n_rows}, out_type, nullptr, {});
|
||||||
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
|
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
|
||||||
d.add_temporary(intermediate, s.index);
|
d.add_temporary(intermediate, s.index);
|
||||||
|
|
||||||
// 1st pass
|
// 1st pass
|
||||||
@ -641,7 +641,7 @@ void strided_reduce_longcolumn(
|
|||||||
intermediate_shape.insert(
|
intermediate_shape.insert(
|
||||||
intermediate_shape.end(), out.shape().begin(), out.shape().end());
|
intermediate_shape.end(), out.shape().begin(), out.shape().end());
|
||||||
array intermediate(std::move(intermediate_shape), out_type, nullptr, {});
|
array intermediate(std::move(intermediate_shape), out_type, nullptr, {});
|
||||||
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
|
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
|
||||||
d.add_temporary(intermediate, s.index);
|
d.add_temporary(intermediate, s.index);
|
||||||
|
|
||||||
// Prepare the arguments for the kernel
|
// Prepare the arguments for the kernel
|
||||||
@ -812,7 +812,7 @@ void strided_reduce_2pass(
|
|||||||
intermediate_shape.insert(
|
intermediate_shape.insert(
|
||||||
intermediate_shape.end(), out.shape().begin(), out.shape().end());
|
intermediate_shape.end(), out.shape().begin(), out.shape().end());
|
||||||
array intermediate(std::move(intermediate_shape), out_type, nullptr, {});
|
array intermediate(std::move(intermediate_shape), out_type, nullptr, {});
|
||||||
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
|
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
|
||||||
d.add_temporary(intermediate, s.index);
|
d.add_temporary(intermediate, s.index);
|
||||||
|
|
||||||
// Prepare the arguments for the kernel
|
// Prepare the arguments for the kernel
|
||||||
@ -950,7 +950,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
// Minimum of 4 bytes since we use size 4 structs for all reduce
|
// Minimum of 4 bytes since we use size 4 structs for all reduce
|
||||||
// and metal will complain o/w
|
// and metal will complain o/w
|
||||||
size_t min_bytes = std::max(out.nbytes(), 4ul);
|
size_t min_bytes = std::max(out.nbytes(), 4ul);
|
||||||
out.set_data(allocator::malloc_or_wait(min_bytes));
|
out.set_data(allocator::malloc(min_bytes));
|
||||||
std::string op_name;
|
std::string op_name;
|
||||||
switch (reduce_type_) {
|
switch (reduce_type_) {
|
||||||
case Reduce::And:
|
case Reduce::And:
|
||||||
|
@ -43,14 +43,14 @@ void RoPE::eval_gpu(
|
|||||||
donated = true;
|
donated = true;
|
||||||
out.copy_shared_buffer(in);
|
out.copy_shared_buffer(in);
|
||||||
} else {
|
} else {
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
}
|
}
|
||||||
strides[0] = mat_size;
|
strides[0] = mat_size;
|
||||||
strides[1] = in.strides()[ndim - 2];
|
strides[1] = in.strides()[ndim - 2];
|
||||||
strides[2] = in.strides()[ndim - 1];
|
strides[2] = in.strides()[ndim - 1];
|
||||||
} else if (dispatch_ndim == 3) {
|
} else if (dispatch_ndim == 3) {
|
||||||
// Handle non-contiguous 3D inputs
|
// Handle non-contiguous 3D inputs
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
strides[0] = in.strides()[ndim - 3];
|
strides[0] = in.strides()[ndim - 3];
|
||||||
strides[1] = in.strides()[ndim - 2];
|
strides[1] = in.strides()[ndim - 2];
|
||||||
strides[2] = in.strides()[ndim - 1];
|
strides[2] = in.strides()[ndim - 1];
|
||||||
|
@ -248,9 +248,9 @@ void sdpa_vector_2pass(
|
|||||||
intermediate_shape.pop_back();
|
intermediate_shape.pop_back();
|
||||||
array sums(intermediate_shape, float32, nullptr, {});
|
array sums(intermediate_shape, float32, nullptr, {});
|
||||||
array maxs(std::move(intermediate_shape), float32, nullptr, {});
|
array maxs(std::move(intermediate_shape), float32, nullptr, {});
|
||||||
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
|
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
|
||||||
sums.set_data(allocator::malloc_or_wait(sums.nbytes()));
|
sums.set_data(allocator::malloc(sums.nbytes()));
|
||||||
maxs.set_data(allocator::malloc_or_wait(maxs.nbytes()));
|
maxs.set_data(allocator::malloc(maxs.nbytes()));
|
||||||
d.add_temporary(intermediate, s.index);
|
d.add_temporary(intermediate, s.index);
|
||||||
d.add_temporary(sums, s.index);
|
d.add_temporary(sums, s.index);
|
||||||
d.add_temporary(maxs, s.index);
|
d.add_temporary(maxs, s.index);
|
||||||
@ -383,7 +383,7 @@ void ScaledDotProductAttention::eval_gpu(
|
|||||||
o.copy_shared_buffer(q);
|
o.copy_shared_buffer(q);
|
||||||
} else {
|
} else {
|
||||||
if (o.shape(2) == 1) {
|
if (o.shape(2) == 1) {
|
||||||
o.set_data(allocator::malloc_or_wait(o.nbytes()));
|
o.set_data(allocator::malloc(o.nbytes()));
|
||||||
} else {
|
} else {
|
||||||
auto strides = o.strides();
|
auto strides = o.strides();
|
||||||
strides[2] = o.shape(1) * o.shape(3);
|
strides[2] = o.shape(1) * o.shape(3);
|
||||||
@ -391,10 +391,7 @@ void ScaledDotProductAttention::eval_gpu(
|
|||||||
auto flags = q.flags();
|
auto flags = q.flags();
|
||||||
flags.row_contiguous = q.shape(1) == 1;
|
flags.row_contiguous = q.shape(1) == 1;
|
||||||
o.set_data(
|
o.set_data(
|
||||||
allocator::malloc_or_wait(o.nbytes()),
|
allocator::malloc(o.nbytes()), o.size(), std::move(strides), flags);
|
||||||
o.size(),
|
|
||||||
std::move(strides),
|
|
||||||
flags);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -432,7 +429,7 @@ void ScaledDotProductAttention::eval_gpu(
|
|||||||
};
|
};
|
||||||
|
|
||||||
o.set_data(
|
o.set_data(
|
||||||
allocator::malloc_or_wait(o.nbytes()),
|
allocator::malloc(o.nbytes()),
|
||||||
data_size,
|
data_size,
|
||||||
{str_oB, str_oH, str_oL, str_oD},
|
{str_oB, str_oH, str_oL, str_oD},
|
||||||
flags);
|
flags);
|
||||||
|
@ -24,7 +24,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
out.copy_shared_buffer(in);
|
out.copy_shared_buffer(in);
|
||||||
} else {
|
} else {
|
||||||
out.set_data(
|
out.set_data(
|
||||||
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
|
allocator::malloc(in.data_size() * out.itemsize()),
|
||||||
in.data_size(),
|
in.data_size(),
|
||||||
in.strides(),
|
in.strides(),
|
||||||
in.flags());
|
in.flags());
|
||||||
|
@ -29,7 +29,7 @@ void concatenate_gpu(
|
|||||||
}
|
}
|
||||||
std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin());
|
std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin());
|
||||||
|
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
auto strides = out.strides();
|
auto strides = out.strides();
|
||||||
auto flags = out.flags();
|
auto flags = out.flags();
|
||||||
|
@ -33,7 +33,7 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
out.copy_shared_buffer(x);
|
out.copy_shared_buffer(x);
|
||||||
} else {
|
} else {
|
||||||
out.set_data(
|
out.set_data(
|
||||||
allocator::malloc_or_wait(x.data_size() * x.itemsize()),
|
allocator::malloc(x.data_size() * x.itemsize()),
|
||||||
x.data_size(),
|
x.data_size(),
|
||||||
x.strides(),
|
x.strides(),
|
||||||
x.flags());
|
x.flags());
|
||||||
|
@ -150,12 +150,11 @@ void multi_block_sort(
|
|||||||
array block_partitions({n_rows, n_blocks + 1}, uint32, nullptr, {});
|
array block_partitions({n_rows, n_blocks + 1}, uint32, nullptr, {});
|
||||||
|
|
||||||
// Do allocations
|
// Do allocations
|
||||||
dev_vals_0.set_data(allocator::malloc_or_wait(dev_vals_0.nbytes()));
|
dev_vals_0.set_data(allocator::malloc(dev_vals_0.nbytes()));
|
||||||
dev_vals_1.set_data(allocator::malloc_or_wait(dev_vals_1.nbytes()));
|
dev_vals_1.set_data(allocator::malloc(dev_vals_1.nbytes()));
|
||||||
dev_idxs_0.set_data(allocator::malloc_or_wait(dev_idxs_0.nbytes()));
|
dev_idxs_0.set_data(allocator::malloc(dev_idxs_0.nbytes()));
|
||||||
dev_idxs_1.set_data(allocator::malloc_or_wait(dev_idxs_1.nbytes()));
|
dev_idxs_1.set_data(allocator::malloc(dev_idxs_1.nbytes()));
|
||||||
block_partitions.set_data(
|
block_partitions.set_data(allocator::malloc(block_partitions.nbytes()));
|
||||||
allocator::malloc_or_wait(block_partitions.nbytes()));
|
|
||||||
|
|
||||||
std::vector<array> copies = {
|
std::vector<array> copies = {
|
||||||
dev_vals_0, dev_vals_1, dev_idxs_0, dev_idxs_1, block_partitions};
|
dev_vals_0, dev_vals_1, dev_idxs_0, dev_idxs_1, block_partitions};
|
||||||
@ -319,7 +318,7 @@ void gpu_merge_sort(
|
|||||||
void ArgSort::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void ArgSort::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
|
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
auto& s = stream();
|
auto& s = stream();
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
@ -331,7 +330,7 @@ void ArgSort::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
void Sort::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void Sort::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
|
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
auto& s = stream();
|
auto& s = stream();
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
@ -344,7 +343,7 @@ void ArgPartition::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
// We direct arg partition to sort for now
|
// We direct arg partition to sort for now
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
|
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
auto& s = stream();
|
auto& s = stream();
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
@ -357,7 +356,7 @@ void Partition::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
// We direct partition to sort for now
|
// We direct partition to sort for now
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
|
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
auto& s = stream();
|
auto& s = stream();
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
|
@ -97,13 +97,13 @@ void unary_op_gpu(
|
|||||||
out.copy_shared_buffer(in);
|
out.copy_shared_buffer(in);
|
||||||
} else {
|
} else {
|
||||||
out.set_data(
|
out.set_data(
|
||||||
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
|
allocator::malloc(in.data_size() * out.itemsize()),
|
||||||
in.data_size(),
|
in.data_size(),
|
||||||
in.strides(),
|
in.strides(),
|
||||||
in.flags());
|
in.flags());
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
}
|
}
|
||||||
unary_op_gpu_inplace(inputs, out, op, s);
|
unary_op_gpu_inplace(inputs, out, op, s);
|
||||||
}
|
}
|
||||||
|
@ -42,7 +42,10 @@ void reset_peak_memory() {}
|
|||||||
size_t get_cache_memory() {
|
size_t get_cache_memory() {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
size_t set_memory_limit(size_t, bool) {
|
size_t set_memory_limit(size_t) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
size_t get_memory_limit() {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
size_t set_cache_limit(size_t) {
|
size_t set_cache_limit(size_t) {
|
||||||
|
@ -218,7 +218,9 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
|||||||
cpu::eval(arr);
|
cpu::eval(arr);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (scheduler::n_active_tasks() > MAX_ACTIVE_TASKS) {
|
if (scheduler::n_active_tasks() > MAX_ACTIVE_TASKS ||
|
||||||
|
(metal::get_active_memory() > metal::get_memory_limit() &&
|
||||||
|
scheduler::n_active_tasks() > 0)) {
|
||||||
// Commit any open streams
|
// Commit any open streams
|
||||||
for (auto& [_, e] : events) {
|
for (auto& [_, e] : events) {
|
||||||
if (e.stream().device == Device::gpu) {
|
if (e.stream().device == Device::gpu) {
|
||||||
@ -226,6 +228,11 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
scheduler::wait_for_one();
|
scheduler::wait_for_one();
|
||||||
|
// TODO memory api should be moved out of metal
|
||||||
|
while (metal::get_active_memory() > metal::get_memory_limit() &&
|
||||||
|
scheduler::n_active_tasks() > 0) {
|
||||||
|
scheduler::wait_for_one();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto maybe_update_fence = [&fences, &needs_fence, stream](const array& a) {
|
auto maybe_update_fence = [&fences, &needs_fence, stream](const array& a) {
|
||||||
|
@ -57,23 +57,19 @@ void init_metal(nb::module_& m) {
|
|||||||
"set_memory_limit",
|
"set_memory_limit",
|
||||||
&mx::metal::set_memory_limit,
|
&mx::metal::set_memory_limit,
|
||||||
"limit"_a,
|
"limit"_a,
|
||||||
nb::kw_only(),
|
|
||||||
"relaxed"_a = true,
|
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Set the memory limit.
|
Set the memory limit.
|
||||||
|
|
||||||
Memory allocations will wait on scheduled tasks to complete if the limit
|
The memory limit is a guideline for the maximum amount of memory to use
|
||||||
is exceeded. If there are no more scheduled tasks an error will be raised
|
during graph evaluation. If the memory limit is exceeded and there is no
|
||||||
if ``relaxed`` is ``False``. Otherwise memory will be allocated
|
more RAM (including swap when available) allocations will result in an
|
||||||
(including the potential for swap) if ``relaxed`` is ``True``.
|
exception.
|
||||||
|
|
||||||
The memory limit defaults to 1.5 times the maximum recommended working set
|
When metal is available the memory limit defaults to 1.5 times the
|
||||||
size reported by the device.
|
maximum recommended working set size reported by the device.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
limit (int): Memory limit in bytes.
|
limit (int): Memory limit in bytes.
|
||||||
relaxed (bool, optional): If `False`` an error is raised if the limit
|
|
||||||
is exceeded. Default: ``True``
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
int: The previous memory limit in bytes.
|
int: The previous memory limit in bytes.
|
||||||
|
@ -185,6 +185,18 @@ class TestEval(mlx_tests.MLXTestCase):
|
|||||||
x = mx.abs(x, stream=s2)
|
x = mx.abs(x, stream=s2)
|
||||||
mx.eval(x)
|
mx.eval(x)
|
||||||
|
|
||||||
|
s1 = mx.default_stream(mx.gpu)
|
||||||
|
s2 = mx.new_stream(mx.gpu)
|
||||||
|
old_limit = mx.metal.set_memory_limit(1000)
|
||||||
|
|
||||||
|
x = mx.ones((512, 512), stream=s2)
|
||||||
|
for _ in range(80):
|
||||||
|
x = mx.abs(x, stream=s1)
|
||||||
|
y = mx.abs(x, stream=s2)
|
||||||
|
z = mx.abs(y, stream=s2)
|
||||||
|
mx.eval(z)
|
||||||
|
mx.metal.set_memory_limit(old_limit)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user