mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
rebase on main
This commit is contained in:
parent
d7acf59fd0
commit
e3d275bc49
@ -232,7 +232,7 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
break;
|
||||
}
|
||||
auto& in = inputs[0];
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
auto& s = stream();
|
||||
arg_reduce_dispatch(in, out, axis_, op_name, s);
|
||||
}
|
||||
|
@ -636,7 +636,7 @@ void fast::TrellisQuantize::eval_gpu(
|
||||
std::vector<array>& outputs) {
|
||||
auto& w_pre = inputs[0];
|
||||
auto& out = outputs[0];
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
@ -660,19 +660,19 @@ void fast::TrellisQuantize::eval_gpu(
|
||||
constexpr int num_states = 1 << 14;
|
||||
|
||||
array scores({B, num_states}, float16, nullptr, {});
|
||||
scores.set_data(allocator::malloc_or_wait(scores.nbytes()));
|
||||
scores.set_data(allocator::malloc(scores.nbytes()));
|
||||
copies.push_back(scores);
|
||||
|
||||
array pointers({B, T, num_states}, uint8, nullptr, {});
|
||||
pointers.set_data(allocator::malloc_or_wait(pointers.nbytes()));
|
||||
pointers.set_data(allocator::malloc(pointers.nbytes()));
|
||||
copies.push_back(pointers);
|
||||
|
||||
array start({B}, uint32, nullptr, {});
|
||||
start.set_data(allocator::malloc_or_wait(start.nbytes()));
|
||||
start.set_data(allocator::malloc(start.nbytes()));
|
||||
copies.push_back(start);
|
||||
|
||||
array rolled({B, T}, uint16, nullptr, {});
|
||||
rolled.set_data(allocator::malloc_or_wait(rolled.nbytes()));
|
||||
rolled.set_data(allocator::malloc(rolled.nbytes()));
|
||||
copies.push_back(rolled);
|
||||
|
||||
viterbi(w, scores, pointers, start, out, false, s);
|
||||
|
Loading…
Reference in New Issue
Block a user