mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Get trellis to run
This commit is contained in:
parent
e3d275bc49
commit
998404ada4
@ -120,4 +120,34 @@
|
||||
instantiate_quantized_groups(6) \
|
||||
instantiate_quantized_groups(8)
|
||||
|
||||
instantiate_kernel(
|
||||
"trellis_viterbi_float16_t_overlap_0",
|
||||
trellis_viterbi,
|
||||
float16_t,
|
||||
false)
|
||||
instantiate_kernel(
|
||||
"trellis_viterbi_float16_t_overlap_1",
|
||||
trellis_viterbi,
|
||||
float16_t,
|
||||
true)
|
||||
|
||||
instantiate_kernel(
|
||||
"trellis_backtrack_overlap_0",
|
||||
trellis_backtrack,
|
||||
false)
|
||||
instantiate_kernel(
|
||||
"trellis_backtrack_overlap_1",
|
||||
trellis_backtrack,
|
||||
true)
|
||||
|
||||
instantiate_kernel(
|
||||
"qmv_fast_float16_t_gs_64_b_2_batch_0_mode_1",
|
||||
qmv_fast,
|
||||
float16_t,
|
||||
64,
|
||||
2,
|
||||
false,
|
||||
true)
|
||||
|
||||
|
||||
instantiate_quantized_all() // clang-format on
|
||||
|
@ -1048,7 +1048,8 @@ trellis_quantize(const array& w_, int bits, StreamOrDevice s_) {
|
||||
int T = Tx * Ty;
|
||||
auto scale = std(astype(w_, float32, s), s);
|
||||
auto w = divide(w_, scale, s);
|
||||
w = astype(w, float16, s);
|
||||
w = astype(w, w_.dtype(), s);
|
||||
scale = astype(scale, w_.dtype(), s);
|
||||
|
||||
w = reshape(w, {M / Tx, Tx, -1, Ty}, s);
|
||||
w = transpose(w, {0, 2, 1, 3}, s);
|
||||
@ -1067,7 +1068,7 @@ trellis_quantize(const array& w_, int bits, StreamOrDevice s_) {
|
||||
{w_batch});
|
||||
q_batch = slice(q_batch, {0, 0}, q_batch.shape(), {1, L / bits}, s);
|
||||
q = slice_update(q, q_batch, {i, 0}, {i + batch_size, q.shape(-1)}, s);
|
||||
eval(q);
|
||||
// eval(q);
|
||||
}
|
||||
|
||||
q = reshape(q, {M / Tx, -1, Tx, Ty * bits / L}, s);
|
||||
|
@ -6,7 +6,6 @@ from typing import Any, Literal
|
||||
import mlx.core as mx
|
||||
from mlx.nn.layers.base import Module
|
||||
from mlx.nn.layers.quantized import QuantizedLinear
|
||||
from mlx.nn.layers.viterbi import quantize as trellis_quantize
|
||||
|
||||
|
||||
class Identity(Module):
|
||||
|
@ -5,7 +5,6 @@ from typing import Callable, Literal, Optional, Union
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.nn.layers.base import Module
|
||||
from mlx.nn.layers.viterbi import quantize as trellis_quantize
|
||||
from mlx.utils import tree_map_with_path
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user