rebase on main

This commit is contained in:
Alex Barron 2025-04-14 16:37:23 -07:00
parent d7acf59fd0
commit e3d275bc49
2 changed files with 6 additions and 6 deletions

View File

@ -232,7 +232,7 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
break; break;
} }
auto& in = inputs[0]; auto& in = inputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream(); auto& s = stream();
arg_reduce_dispatch(in, out, axis_, op_name, s); arg_reduce_dispatch(in, out, axis_, op_name, s);
} }

View File

@ -636,7 +636,7 @@ void fast::TrellisQuantize::eval_gpu(
std::vector<array>& outputs) { std::vector<array>& outputs) {
auto& w_pre = inputs[0]; auto& w_pre = inputs[0];
auto& out = outputs[0]; auto& out = outputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream(); auto& s = stream();
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
@ -660,19 +660,19 @@ void fast::TrellisQuantize::eval_gpu(
constexpr int num_states = 1 << 14; constexpr int num_states = 1 << 14;
array scores({B, num_states}, float16, nullptr, {}); array scores({B, num_states}, float16, nullptr, {});
scores.set_data(allocator::malloc_or_wait(scores.nbytes())); scores.set_data(allocator::malloc(scores.nbytes()));
copies.push_back(scores); copies.push_back(scores);
array pointers({B, T, num_states}, uint8, nullptr, {}); array pointers({B, T, num_states}, uint8, nullptr, {});
pointers.set_data(allocator::malloc_or_wait(pointers.nbytes())); pointers.set_data(allocator::malloc(pointers.nbytes()));
copies.push_back(pointers); copies.push_back(pointers);
array start({B}, uint32, nullptr, {}); array start({B}, uint32, nullptr, {});
start.set_data(allocator::malloc_or_wait(start.nbytes())); start.set_data(allocator::malloc(start.nbytes()));
copies.push_back(start); copies.push_back(start);
array rolled({B, T}, uint16, nullptr, {}); array rolled({B, T}, uint16, nullptr, {});
rolled.set_data(allocator::malloc_or_wait(rolled.nbytes())); rolled.set_data(allocator::malloc(rolled.nbytes()));
copies.push_back(rolled); copies.push_back(rolled);
viterbi(w, scores, pointers, start, out, false, s); viterbi(w, scores, pointers, start, out, false, s);