mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-22 02:58:16 +08:00
fix malloc or wait deadlock (#1976)
This commit is contained in:
@@ -192,16 +192,19 @@ size_t MetalAllocator::set_cache_limit(size_t limit) {
|
||||
return limit;
|
||||
};
|
||||
|
||||
size_t MetalAllocator::set_memory_limit(size_t limit, bool relaxed) {
|
||||
size_t MetalAllocator::set_memory_limit(size_t limit) {
|
||||
std::unique_lock lk(mutex_);
|
||||
std::swap(limit, block_limit_);
|
||||
relaxed_ = relaxed;
|
||||
gc_limit_ = std::min(
|
||||
block_limit_,
|
||||
static_cast<size_t>(0.95 * device_->recommendedMaxWorkingSetSize()));
|
||||
return limit;
|
||||
};
|
||||
|
||||
size_t MetalAllocator::get_memory_limit() {
|
||||
return block_limit_;
|
||||
}
|
||||
|
||||
size_t MetalAllocator::set_wired_limit(size_t limit) {
|
||||
std::unique_lock lk(mutex_);
|
||||
std::swap(limit, wired_limit_);
|
||||
@@ -209,7 +212,7 @@ size_t MetalAllocator::set_wired_limit(size_t limit) {
|
||||
return limit;
|
||||
};
|
||||
|
||||
Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
|
||||
Buffer MetalAllocator::malloc(size_t size) {
|
||||
// Metal doesn't like empty buffers
|
||||
if (size == 0) {
|
||||
return Buffer{nullptr};
|
||||
@@ -236,11 +239,6 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
|
||||
if (!buf) {
|
||||
size_t mem_required = get_active_memory() + get_cache_memory() + size;
|
||||
|
||||
// If there is too much memory pressure, fail (likely causes a wait).
|
||||
if (!(allow_swap && relaxed_) && mem_required >= block_limit_) {
|
||||
return Buffer{nullptr};
|
||||
}
|
||||
|
||||
auto pool = metal::new_scoped_memory_pool();
|
||||
|
||||
// If we have a lot of memory pressure or are over the maximum cache size,
|
||||
@@ -328,8 +326,11 @@ MetalAllocator& allocator() {
|
||||
size_t set_cache_limit(size_t limit) {
|
||||
return allocator().set_cache_limit(limit);
|
||||
}
|
||||
size_t set_memory_limit(size_t limit, bool relaxed /* = true */) {
|
||||
return allocator().set_memory_limit(limit, relaxed);
|
||||
size_t set_memory_limit(size_t limit) {
|
||||
return allocator().set_memory_limit(limit);
|
||||
}
|
||||
size_t get_memory_limit() {
|
||||
return allocator().get_memory_limit();
|
||||
}
|
||||
size_t set_wired_limit(size_t limit) {
|
||||
if (limit >
|
||||
|
@@ -56,7 +56,7 @@ class BufferCache {
|
||||
class MetalAllocator : public allocator::Allocator {
|
||||
/** Allocator for Metal GPUs. */
|
||||
public:
|
||||
virtual Buffer malloc(size_t size, bool allow_swap = false) override;
|
||||
virtual Buffer malloc(size_t size) override;
|
||||
virtual void free(Buffer buffer) override;
|
||||
virtual size_t size(Buffer buffer) const override;
|
||||
size_t get_active_memory() {
|
||||
@@ -73,7 +73,8 @@ class MetalAllocator : public allocator::Allocator {
|
||||
return buffer_cache_.cache_size();
|
||||
};
|
||||
size_t set_cache_limit(size_t limit);
|
||||
size_t set_memory_limit(size_t limit, bool relaxed);
|
||||
size_t set_memory_limit(size_t limit);
|
||||
size_t get_memory_limit();
|
||||
size_t set_wired_limit(size_t limit);
|
||||
void clear_cache();
|
||||
|
||||
|
@@ -37,7 +37,7 @@ void explicit_gemm_conv_ND_gpu(
|
||||
Shape unfolded_shape{implicit_M, implicit_K};
|
||||
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
|
||||
|
||||
in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes()));
|
||||
in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes()));
|
||||
|
||||
// Prepare unfolding kernel
|
||||
std::ostringstream kname;
|
||||
@@ -115,7 +115,7 @@ void explicit_gemm_conv_group_ND_gpu(
|
||||
// Prepare unfolding array
|
||||
Shape unfolded_shape{implicit_M, implicit_K * groups};
|
||||
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
|
||||
in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes()));
|
||||
in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes()));
|
||||
|
||||
// Prepare unfolding kernel
|
||||
std::ostringstream kname;
|
||||
@@ -613,7 +613,7 @@ void winograd_conv_2D_gpu(
|
||||
// Do filter transform
|
||||
Shape filt_wg_shape = {8 * 8, conv_params.C, conv_params.O};
|
||||
array filt_wg(std::move(filt_wg_shape), wt.dtype(), nullptr, {});
|
||||
filt_wg.set_data(allocator::malloc_or_wait(filt_wg.nbytes()));
|
||||
filt_wg.set_data(allocator::malloc(filt_wg.nbytes()));
|
||||
copies_w.push_back(filt_wg);
|
||||
{
|
||||
int bc = 32;
|
||||
@@ -640,7 +640,7 @@ void winograd_conv_2D_gpu(
|
||||
// Do input transform
|
||||
Shape inp_wg_shape = {8 * 8, N_tiles, conv_params.C};
|
||||
array inp_wg(std::move(inp_wg_shape), in.dtype(), nullptr, {});
|
||||
inp_wg.set_data(allocator::malloc_or_wait(inp_wg.nbytes()));
|
||||
inp_wg.set_data(allocator::malloc(inp_wg.nbytes()));
|
||||
copies_w.push_back(inp_wg);
|
||||
{
|
||||
int bc = 32;
|
||||
@@ -667,7 +667,7 @@ void winograd_conv_2D_gpu(
|
||||
// Do batched gemm
|
||||
Shape out_wg_shape = {8 * 8, N_tiles, conv_params.O};
|
||||
array out_wg(std::move(out_wg_shape), in.dtype(), nullptr, {});
|
||||
out_wg.set_data(allocator::malloc_or_wait(out_wg.nbytes()));
|
||||
out_wg.set_data(allocator::malloc(out_wg.nbytes()));
|
||||
copies_w.push_back(out_wg);
|
||||
{
|
||||
std::vector<array> empty_copies;
|
||||
@@ -855,7 +855,7 @@ void conv_3D_gpu(
|
||||
} // namespace
|
||||
|
||||
void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
|
@@ -202,7 +202,7 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
bool large = out.data_size() > UINT32_MAX;
|
||||
auto& d = metal::device(s.device);
|
||||
std::string kernel_name = std::string(large ? "s2" : "s") + "_copy" +
|
||||
|
@@ -19,7 +19,7 @@ void CustomKernel::eval_gpu(
|
||||
copies.emplace_back(init_value_.value(), out.dtype());
|
||||
fill_gpu(copies.back(), out, s);
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -20,7 +20,7 @@ struct FenceImpl {
|
||||
auto p = metal::new_scoped_memory_pool();
|
||||
fence = static_cast<void*>(d->newSharedEvent());
|
||||
} else {
|
||||
auto buf = allocator::malloc_or_wait(sizeof(uint32_t)).ptr();
|
||||
auto buf = allocator::malloc(sizeof(uint32_t)).ptr();
|
||||
fence = static_cast<void*>(buf);
|
||||
cpu_value()[0] = 0;
|
||||
}
|
||||
|
@@ -281,7 +281,7 @@ std::tuple<array, array, array> compute_raders_constants(
|
||||
}
|
||||
|
||||
array b_q_fft({rader_n - 1}, complex64, nullptr, {});
|
||||
b_q_fft.set_data(allocator::malloc_or_wait(b_q_fft.nbytes()));
|
||||
b_q_fft.set_data(allocator::malloc(b_q_fft.nbytes()));
|
||||
auto b_q_fft_ptr =
|
||||
reinterpret_cast<std::complex<float>*>(b_q_fft.data<complex64_t>());
|
||||
std::ptrdiff_t item_size = b_q_fft.itemsize();
|
||||
@@ -327,11 +327,11 @@ std::pair<array, array> compute_bluestein_constants(int n, int bluestein_n) {
|
||||
}
|
||||
|
||||
array w_k({n}, complex64, nullptr, {});
|
||||
w_k.set_data(allocator::malloc_or_wait(w_k.nbytes()));
|
||||
w_k.set_data(allocator::malloc(w_k.nbytes()));
|
||||
std::copy(w_k_vec.begin(), w_k_vec.end(), w_k.data<complex64_t>());
|
||||
|
||||
array w_q({bluestein_n}, complex64, nullptr, {});
|
||||
w_q.set_data(allocator::malloc_or_wait(w_q.nbytes()));
|
||||
w_q.set_data(allocator::malloc(w_q.nbytes()));
|
||||
auto w_q_ptr =
|
||||
reinterpret_cast<std::complex<float>*>(w_q.data<complex64_t>());
|
||||
|
||||
@@ -551,8 +551,7 @@ void fft_op(
|
||||
flags.row_contiguous = is_row_contiguous;
|
||||
flags.contiguous = data_size == x_copy.size();
|
||||
|
||||
x_copy.set_data(
|
||||
allocator::malloc_or_wait(x.nbytes()), data_size, strides, flags);
|
||||
x_copy.set_data(allocator::malloc(x.nbytes()), data_size, strides, flags);
|
||||
copy_gpu_inplace(x, x_copy, CopyType::GeneralGeneral, s);
|
||||
copies.push_back(x_copy);
|
||||
return x_copy;
|
||||
@@ -583,7 +582,7 @@ void fft_op(
|
||||
// TODO: allow donation here
|
||||
if (!inplace) {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(out.nbytes()),
|
||||
allocator::malloc(out.nbytes()),
|
||||
out_data_size,
|
||||
out_strides,
|
||||
in_contiguous.flags());
|
||||
|
@@ -84,7 +84,7 @@ void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
if (in_contiguous.is_donatable()) {
|
||||
out.copy_shared_buffer(in_contiguous);
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
}
|
||||
|
||||
int n, m;
|
||||
@@ -161,7 +161,7 @@ void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Upload 2:
|
||||
// y = h12 @ tmp
|
||||
array temp(in.shape(), in.dtype(), nullptr, {});
|
||||
temp.set_data(allocator::malloc_or_wait(temp.nbytes()));
|
||||
temp.set_data(allocator::malloc(temp.nbytes()));
|
||||
copies.push_back(temp);
|
||||
|
||||
launch_hadamard(in_contiguous, temp, "n" + kernel_name, 1.0);
|
||||
|
@@ -43,7 +43,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
@@ -393,7 +393,7 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& src = inputs[0];
|
||||
auto& idx = inputs[1];
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
@@ -382,7 +382,7 @@ void steel_matmul(
|
||||
int split_k_partition_size = gemm_k_iterations * bk;
|
||||
|
||||
array C_split({split_k_partitions, M, N}, float32, nullptr, {});
|
||||
C_split.set_data(allocator::malloc_or_wait(C_split.nbytes()));
|
||||
C_split.set_data(allocator::malloc(C_split.nbytes()));
|
||||
copies.push_back(C_split);
|
||||
|
||||
bool mn_aligned = M % bm == 0 && N % bn == 0;
|
||||
@@ -513,7 +513,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
return;
|
||||
}
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Init checks and prep
|
||||
@@ -677,7 +677,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
throw std::runtime_error(
|
||||
"[matmul] Does not yet support non-floating point types.");
|
||||
}
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
@@ -860,7 +860,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
int split_k_partition_size = gemm_k_iterations * bk;
|
||||
|
||||
array C_split({split_k_partitions, M, N}, float32, nullptr, {});
|
||||
C_split.set_data(allocator::malloc_or_wait(C_split.nbytes()));
|
||||
C_split.set_data(allocator::malloc(C_split.nbytes()));
|
||||
copies.push_back(C_split);
|
||||
|
||||
bool mn_aligned = M % bm == 0 && N % bn == 0;
|
||||
@@ -1096,7 +1096,7 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
return;
|
||||
}
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Init checks and prep
|
||||
@@ -1484,7 +1484,7 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
return;
|
||||
}
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Init checks and prep
|
||||
|
@@ -38,17 +38,20 @@ void reset_peak_memory();
|
||||
size_t get_cache_memory();
|
||||
|
||||
/* Set the memory limit.
|
||||
* Calls to malloc will wait on scheduled tasks if the limit is exceeded. If
|
||||
* there are no more scheduled tasks an error will be raised if relaxed
|
||||
* is false or memory will be allocated (including the potential for
|
||||
* swap) if relaxed is true.
|
||||
* The memory limit is a guideline for the maximum amount of memory to use
|
||||
* during graph evaluation. If the memory limit is exceeded and there is no
|
||||
* more RAM (including swap when available) allocations will result in an
|
||||
* exception.
|
||||
*
|
||||
* The memory limit defaults to 1.5 times the maximum recommended working set
|
||||
* size reported by the device.
|
||||
* When metal is available the memory limit defaults to 1.5 times the maximum
|
||||
* recommended working set size reported by the device.
|
||||
*
|
||||
* Returns the previous memory limit.
|
||||
* */
|
||||
size_t set_memory_limit(size_t limit, bool relaxed = true);
|
||||
size_t set_memory_limit(size_t limit);
|
||||
|
||||
/* Get the current memory limit. */
|
||||
size_t get_memory_limit();
|
||||
|
||||
/* Set the free cache limit.
|
||||
* If using more than the given limit, free memory will be reclaimed
|
||||
|
@@ -29,7 +29,7 @@ void RMSNorm::eval_gpu(
|
||||
out.copy_shared_buffer(x);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(x.data_size() * x.itemsize()),
|
||||
allocator::malloc(x.data_size() * x.itemsize()),
|
||||
x.data_size(),
|
||||
x.strides(),
|
||||
x.flags());
|
||||
@@ -129,7 +129,7 @@ void RMSNormVJP::eval_gpu(
|
||||
gx.copy_shared_buffer(g);
|
||||
g_in_gx = true;
|
||||
} else {
|
||||
gx.set_data(allocator::malloc_or_wait(gx.nbytes()));
|
||||
gx.set_data(allocator::malloc(gx.nbytes()));
|
||||
}
|
||||
if (g_copied && !g_in_gx) {
|
||||
d.add_temporary(g, s.index);
|
||||
@@ -146,11 +146,11 @@ void RMSNormVJP::eval_gpu(
|
||||
if (!g_in_gx && donate_g) {
|
||||
gw_temp.copy_shared_buffer(g);
|
||||
} else {
|
||||
gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes()));
|
||||
gw_temp.set_data(allocator::malloc(gw_temp.nbytes()));
|
||||
d.add_temporary(gw_temp, s.index);
|
||||
}
|
||||
}
|
||||
gw.set_data(allocator::malloc_or_wait(gw.nbytes()));
|
||||
gw.set_data(allocator::malloc(gw.nbytes()));
|
||||
|
||||
const int simd_size = 32;
|
||||
const int n_reads = RMS_N_READS;
|
||||
@@ -226,7 +226,7 @@ void LayerNorm::eval_gpu(
|
||||
out.copy_shared_buffer(x);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(x.data_size() * x.itemsize()),
|
||||
allocator::malloc(x.data_size() * x.itemsize()),
|
||||
x.data_size(),
|
||||
x.strides(),
|
||||
x.flags());
|
||||
@@ -331,7 +331,7 @@ void LayerNormVJP::eval_gpu(
|
||||
gx.copy_shared_buffer(g);
|
||||
g_in_gx = true;
|
||||
} else {
|
||||
gx.set_data(allocator::malloc_or_wait(gx.nbytes()));
|
||||
gx.set_data(allocator::malloc(gx.nbytes()));
|
||||
}
|
||||
if (g_copied && !g_in_gx) {
|
||||
d.add_temporary(g, s.index);
|
||||
@@ -348,12 +348,12 @@ void LayerNormVJP::eval_gpu(
|
||||
if (!g_in_gx && donate_g) {
|
||||
gw_temp.copy_shared_buffer(g);
|
||||
} else {
|
||||
gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes()));
|
||||
gw_temp.set_data(allocator::malloc(gw_temp.nbytes()));
|
||||
d.add_temporary(gw_temp, s.index);
|
||||
}
|
||||
}
|
||||
gw.set_data(allocator::malloc_or_wait(gw.nbytes()));
|
||||
gb.set_data(allocator::malloc_or_wait(gb.nbytes()));
|
||||
gw.set_data(allocator::malloc(gw.nbytes()));
|
||||
gb.set_data(allocator::malloc(gb.nbytes()));
|
||||
|
||||
// Finish with the gradient for b in case we had a b
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
|
@@ -28,7 +28,7 @@ void arange_set_scalars(T start, T next, metal::CommandEncoder& enc) {
|
||||
void reshape(const array& in, array& out, Stream s) {
|
||||
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
||||
if (copy_necessary) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
copy_gpu_inplace(
|
||||
in,
|
||||
out,
|
||||
@@ -58,7 +58,7 @@ static array compute_dynamic_offset(
|
||||
if (donate) {
|
||||
offset.copy_shared_buffer(indices);
|
||||
} else {
|
||||
offset.set_data(allocator::malloc_or_wait(offset.itemsize()));
|
||||
offset.set_data(allocator::malloc(offset.itemsize()));
|
||||
}
|
||||
d.add_temporary(offset, s.index);
|
||||
|
||||
@@ -100,7 +100,7 @@ static array compute_dynamic_offset(
|
||||
|
||||
void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 0);
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
@@ -161,7 +161,7 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
std::string op_name;
|
||||
@@ -333,7 +333,7 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
size_t elems_per_key = out.size() / num_keys;
|
||||
size_t bytes_per_key = out.itemsize() * elems_per_key;
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
@@ -397,7 +397,7 @@ void DynamicSlice::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
auto& in = inputs[0];
|
||||
auto& start = inputs[1];
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
auto s = stream();
|
||||
auto in_offset = compute_dynamic_offset(start, in.strides(), axes_, s);
|
||||
copy_gpu_inplace(
|
||||
@@ -554,7 +554,7 @@ void View::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
in, strides, in.flags(), in.data_size() * ibytes / obytes);
|
||||
} else {
|
||||
auto tmp = array(in.shape(), in.dtype(), nullptr, {});
|
||||
tmp.set_data(allocator::malloc_or_wait(tmp.nbytes()));
|
||||
tmp.set_data(allocator::malloc(tmp.nbytes()));
|
||||
copy_gpu_inplace(in, tmp, CopyType::General, stream());
|
||||
|
||||
auto flags = out.flags();
|
||||
|
@@ -224,7 +224,7 @@ void qvm_split_k(
|
||||
auto temp_shape = out.shape();
|
||||
temp_shape.insert(temp_shape.end() - 2, split_k);
|
||||
array intermediate(temp_shape, x.dtype(), nullptr, {});
|
||||
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
|
||||
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
|
||||
d.add_temporary(intermediate, s.index);
|
||||
|
||||
std::ostringstream kname;
|
||||
@@ -277,7 +277,7 @@ void qmm_op(
|
||||
int bits,
|
||||
bool gather,
|
||||
const Stream& s) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
MTL::Size group_dims;
|
||||
MTL::Size grid_dims;
|
||||
@@ -394,7 +394,7 @@ void fast::AffineQuantize::eval_gpu(
|
||||
std::vector<array>& outputs) {
|
||||
auto& w_pre = inputs[0];
|
||||
auto& out = outputs[0];
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
@@ -425,8 +425,8 @@ void fast::AffineQuantize::eval_gpu(
|
||||
} else {
|
||||
auto& scales = outputs[1];
|
||||
auto& biases = outputs[2];
|
||||
scales.set_data(allocator::malloc_or_wait(scales.nbytes()));
|
||||
biases.set_data(allocator::malloc_or_wait(biases.nbytes()));
|
||||
scales.set_data(allocator::malloc(scales.nbytes()));
|
||||
biases.set_data(allocator::malloc(biases.nbytes()));
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder.set_output_array(scales, 2);
|
||||
compute_encoder.set_output_array(biases, 3);
|
||||
|
@@ -347,7 +347,7 @@ void all_reduce_dispatch(
|
||||
|
||||
// Allocate an intermediate tensor to hold results if needed
|
||||
array intermediate({n_rows}, out_type, nullptr, {});
|
||||
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
|
||||
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
|
||||
d.add_temporary(intermediate, s.index);
|
||||
|
||||
// 1st pass
|
||||
@@ -641,7 +641,7 @@ void strided_reduce_longcolumn(
|
||||
intermediate_shape.insert(
|
||||
intermediate_shape.end(), out.shape().begin(), out.shape().end());
|
||||
array intermediate(std::move(intermediate_shape), out_type, nullptr, {});
|
||||
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
|
||||
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
|
||||
d.add_temporary(intermediate, s.index);
|
||||
|
||||
// Prepare the arguments for the kernel
|
||||
@@ -812,7 +812,7 @@ void strided_reduce_2pass(
|
||||
intermediate_shape.insert(
|
||||
intermediate_shape.end(), out.shape().begin(), out.shape().end());
|
||||
array intermediate(std::move(intermediate_shape), out_type, nullptr, {});
|
||||
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
|
||||
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
|
||||
d.add_temporary(intermediate, s.index);
|
||||
|
||||
// Prepare the arguments for the kernel
|
||||
@@ -950,7 +950,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Minimum of 4 bytes since we use size 4 structs for all reduce
|
||||
// and metal will complain o/w
|
||||
size_t min_bytes = std::max(out.nbytes(), 4ul);
|
||||
out.set_data(allocator::malloc_or_wait(min_bytes));
|
||||
out.set_data(allocator::malloc(min_bytes));
|
||||
std::string op_name;
|
||||
switch (reduce_type_) {
|
||||
case Reduce::And:
|
||||
|
@@ -43,14 +43,14 @@ void RoPE::eval_gpu(
|
||||
donated = true;
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
}
|
||||
strides[0] = mat_size;
|
||||
strides[1] = in.strides()[ndim - 2];
|
||||
strides[2] = in.strides()[ndim - 1];
|
||||
} else if (dispatch_ndim == 3) {
|
||||
// Handle non-contiguous 3D inputs
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
strides[0] = in.strides()[ndim - 3];
|
||||
strides[1] = in.strides()[ndim - 2];
|
||||
strides[2] = in.strides()[ndim - 1];
|
||||
|
@@ -248,9 +248,9 @@ void sdpa_vector_2pass(
|
||||
intermediate_shape.pop_back();
|
||||
array sums(intermediate_shape, float32, nullptr, {});
|
||||
array maxs(std::move(intermediate_shape), float32, nullptr, {});
|
||||
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
|
||||
sums.set_data(allocator::malloc_or_wait(sums.nbytes()));
|
||||
maxs.set_data(allocator::malloc_or_wait(maxs.nbytes()));
|
||||
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
|
||||
sums.set_data(allocator::malloc(sums.nbytes()));
|
||||
maxs.set_data(allocator::malloc(maxs.nbytes()));
|
||||
d.add_temporary(intermediate, s.index);
|
||||
d.add_temporary(sums, s.index);
|
||||
d.add_temporary(maxs, s.index);
|
||||
@@ -383,7 +383,7 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
o.copy_shared_buffer(q);
|
||||
} else {
|
||||
if (o.shape(2) == 1) {
|
||||
o.set_data(allocator::malloc_or_wait(o.nbytes()));
|
||||
o.set_data(allocator::malloc(o.nbytes()));
|
||||
} else {
|
||||
auto strides = o.strides();
|
||||
strides[2] = o.shape(1) * o.shape(3);
|
||||
@@ -391,10 +391,7 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
auto flags = q.flags();
|
||||
flags.row_contiguous = q.shape(1) == 1;
|
||||
o.set_data(
|
||||
allocator::malloc_or_wait(o.nbytes()),
|
||||
o.size(),
|
||||
std::move(strides),
|
||||
flags);
|
||||
allocator::malloc(o.nbytes()), o.size(), std::move(strides), flags);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -432,7 +429,7 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
};
|
||||
|
||||
o.set_data(
|
||||
allocator::malloc_or_wait(o.nbytes()),
|
||||
allocator::malloc(o.nbytes()),
|
||||
data_size,
|
||||
{str_oB, str_oH, str_oL, str_oD},
|
||||
flags);
|
||||
|
@@ -24,7 +24,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
|
||||
allocator::malloc(in.data_size() * out.itemsize()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
|
@@ -29,7 +29,7 @@ void concatenate_gpu(
|
||||
}
|
||||
std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin());
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto strides = out.strides();
|
||||
auto flags = out.flags();
|
||||
|
@@ -33,7 +33,7 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
out.copy_shared_buffer(x);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(x.data_size() * x.itemsize()),
|
||||
allocator::malloc(x.data_size() * x.itemsize()),
|
||||
x.data_size(),
|
||||
x.strides(),
|
||||
x.flags());
|
||||
|
@@ -150,12 +150,11 @@ void multi_block_sort(
|
||||
array block_partitions({n_rows, n_blocks + 1}, uint32, nullptr, {});
|
||||
|
||||
// Do allocations
|
||||
dev_vals_0.set_data(allocator::malloc_or_wait(dev_vals_0.nbytes()));
|
||||
dev_vals_1.set_data(allocator::malloc_or_wait(dev_vals_1.nbytes()));
|
||||
dev_idxs_0.set_data(allocator::malloc_or_wait(dev_idxs_0.nbytes()));
|
||||
dev_idxs_1.set_data(allocator::malloc_or_wait(dev_idxs_1.nbytes()));
|
||||
block_partitions.set_data(
|
||||
allocator::malloc_or_wait(block_partitions.nbytes()));
|
||||
dev_vals_0.set_data(allocator::malloc(dev_vals_0.nbytes()));
|
||||
dev_vals_1.set_data(allocator::malloc(dev_vals_1.nbytes()));
|
||||
dev_idxs_0.set_data(allocator::malloc(dev_idxs_0.nbytes()));
|
||||
dev_idxs_1.set_data(allocator::malloc(dev_idxs_1.nbytes()));
|
||||
block_partitions.set_data(allocator::malloc(block_partitions.nbytes()));
|
||||
|
||||
std::vector<array> copies = {
|
||||
dev_vals_0, dev_vals_1, dev_idxs_0, dev_idxs_1, block_partitions};
|
||||
@@ -319,7 +318,7 @@ void gpu_merge_sort(
|
||||
void ArgSort::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
@@ -331,7 +330,7 @@ void ArgSort::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
void Sort::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
@@ -344,7 +343,7 @@ void ArgPartition::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// We direct arg partition to sort for now
|
||||
assert(inputs.size() == 1);
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
@@ -357,7 +356,7 @@ void Partition::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// We direct partition to sort for now
|
||||
assert(inputs.size() == 1);
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
@@ -97,13 +97,13 @@ void unary_op_gpu(
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
|
||||
allocator::malloc(in.data_size() * out.itemsize()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
}
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
}
|
||||
unary_op_gpu_inplace(inputs, out, op, s);
|
||||
}
|
||||
|
Reference in New Issue
Block a user