mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
rebase on main
This commit is contained in:
parent
d7acf59fd0
commit
e3d275bc49
@ -232,7 +232,7 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
auto& s = stream();
|
auto& s = stream();
|
||||||
arg_reduce_dispatch(in, out, axis_, op_name, s);
|
arg_reduce_dispatch(in, out, axis_, op_name, s);
|
||||||
}
|
}
|
||||||
|
@ -636,7 +636,7 @@ void fast::TrellisQuantize::eval_gpu(
|
|||||||
std::vector<array>& outputs) {
|
std::vector<array>& outputs) {
|
||||||
auto& w_pre = inputs[0];
|
auto& w_pre = inputs[0];
|
||||||
auto& out = outputs[0];
|
auto& out = outputs[0];
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
auto& s = stream();
|
auto& s = stream();
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
@ -660,19 +660,19 @@ void fast::TrellisQuantize::eval_gpu(
|
|||||||
constexpr int num_states = 1 << 14;
|
constexpr int num_states = 1 << 14;
|
||||||
|
|
||||||
array scores({B, num_states}, float16, nullptr, {});
|
array scores({B, num_states}, float16, nullptr, {});
|
||||||
scores.set_data(allocator::malloc_or_wait(scores.nbytes()));
|
scores.set_data(allocator::malloc(scores.nbytes()));
|
||||||
copies.push_back(scores);
|
copies.push_back(scores);
|
||||||
|
|
||||||
array pointers({B, T, num_states}, uint8, nullptr, {});
|
array pointers({B, T, num_states}, uint8, nullptr, {});
|
||||||
pointers.set_data(allocator::malloc_or_wait(pointers.nbytes()));
|
pointers.set_data(allocator::malloc(pointers.nbytes()));
|
||||||
copies.push_back(pointers);
|
copies.push_back(pointers);
|
||||||
|
|
||||||
array start({B}, uint32, nullptr, {});
|
array start({B}, uint32, nullptr, {});
|
||||||
start.set_data(allocator::malloc_or_wait(start.nbytes()));
|
start.set_data(allocator::malloc(start.nbytes()));
|
||||||
copies.push_back(start);
|
copies.push_back(start);
|
||||||
|
|
||||||
array rolled({B, T}, uint16, nullptr, {});
|
array rolled({B, T}, uint16, nullptr, {});
|
||||||
rolled.set_data(allocator::malloc_or_wait(rolled.nbytes()));
|
rolled.set_data(allocator::malloc(rolled.nbytes()));
|
||||||
copies.push_back(rolled);
|
copies.push_back(rolled);
|
||||||
|
|
||||||
viterbi(w, scores, pointers, start, out, false, s);
|
viterbi(w, scores, pointers, start, out, false, s);
|
||||||
|
Loading…
Reference in New Issue
Block a user