rebase on main

This commit is contained in:
Alex Barron 2025-04-14 16:37:23 -07:00
parent d7acf59fd0
commit e3d275bc49
2 changed files with 6 additions and 6 deletions

View File

@ -232,7 +232,7 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
break;
}
auto& in = inputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
arg_reduce_dispatch(in, out, axis_, op_name, s);
}

View File

@ -636,7 +636,7 @@ void fast::TrellisQuantize::eval_gpu(
std::vector<array>& outputs) {
auto& w_pre = inputs[0];
auto& out = outputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);
@ -660,19 +660,19 @@ void fast::TrellisQuantize::eval_gpu(
constexpr int num_states = 1 << 14;
array scores({B, num_states}, float16, nullptr, {});
scores.set_data(allocator::malloc_or_wait(scores.nbytes()));
scores.set_data(allocator::malloc(scores.nbytes()));
copies.push_back(scores);
array pointers({B, T, num_states}, uint8, nullptr, {});
pointers.set_data(allocator::malloc_or_wait(pointers.nbytes()));
pointers.set_data(allocator::malloc(pointers.nbytes()));
copies.push_back(pointers);
array start({B}, uint32, nullptr, {});
start.set_data(allocator::malloc_or_wait(start.nbytes()));
start.set_data(allocator::malloc(start.nbytes()));
copies.push_back(start);
array rolled({B, T}, uint16, nullptr, {});
rolled.set_data(allocator::malloc_or_wait(rolled.nbytes()));
rolled.set_data(allocator::malloc(rolled.nbytes()));
copies.push_back(rolled);
viterbi(w, scores, pointers, start, out, false, s);