Get trellis to run

This commit is contained in:
Awni Hannun 2025-04-26 07:02:20 -07:00
parent e3d275bc49
commit 998404ada4
4 changed files with 33 additions and 4 deletions

View File

@ -120,4 +120,34 @@
instantiate_quantized_groups(6) \
instantiate_quantized_groups(8)
instantiate_kernel(
"trellis_viterbi_float16_t_overlap_0",
trellis_viterbi,
float16_t,
false)
instantiate_kernel(
"trellis_viterbi_float16_t_overlap_1",
trellis_viterbi,
float16_t,
true)
instantiate_kernel(
"trellis_backtrack_overlap_0",
trellis_backtrack,
false)
instantiate_kernel(
"trellis_backtrack_overlap_1",
trellis_backtrack,
true)
instantiate_kernel(
"qmv_fast_float16_t_gs_64_b_2_batch_0_mode_1",
qmv_fast,
float16_t,
64,
2,
false,
true)
instantiate_quantized_all() // clang-format on

View File

@ -1048,7 +1048,8 @@ trellis_quantize(const array& w_, int bits, StreamOrDevice s_) {
int T = Tx * Ty;
auto scale = std(astype(w_, float32, s), s);
auto w = divide(w_, scale, s);
w = astype(w, float16, s);
w = astype(w, w_.dtype(), s);
scale = astype(scale, w_.dtype(), s);
w = reshape(w, {M / Tx, Tx, -1, Ty}, s);
w = transpose(w, {0, 2, 1, 3}, s);
@ -1067,7 +1068,7 @@ trellis_quantize(const array& w_, int bits, StreamOrDevice s_) {
{w_batch});
q_batch = slice(q_batch, {0, 0}, q_batch.shape(), {1, L / bits}, s);
q = slice_update(q, q_batch, {i, 0}, {i + batch_size, q.shape(-1)}, s);
eval(q);
// eval(q);
}
q = reshape(q, {M / Tx, -1, Tx, Ty * bits / L}, s);

View File

@ -6,7 +6,6 @@ from typing import Any, Literal
import mlx.core as mx
from mlx.nn.layers.base import Module
from mlx.nn.layers.quantized import QuantizedLinear
from mlx.nn.layers.viterbi import quantize as trellis_quantize
class Identity(Module):

View File

@ -5,7 +5,6 @@ from typing import Callable, Literal, Optional, Union
import mlx.core as mx
from mlx.nn.layers.base import Module
from mlx.nn.layers.viterbi import quantize as trellis_quantize
from mlx.utils import tree_map_with_path