Get trellis to run

This commit is contained in:
Awni Hannun 2025-04-26 07:02:20 -07:00
parent e3d275bc49
commit 998404ada4
4 changed files with 33 additions and 4 deletions

View File

@ -120,4 +120,34 @@
instantiate_quantized_groups(6) \ instantiate_quantized_groups(6) \
instantiate_quantized_groups(8) instantiate_quantized_groups(8)
instantiate_kernel(
"trellis_viterbi_float16_t_overlap_0",
trellis_viterbi,
float16_t,
false)
instantiate_kernel(
"trellis_viterbi_float16_t_overlap_1",
trellis_viterbi,
float16_t,
true)
instantiate_kernel(
"trellis_backtrack_overlap_0",
trellis_backtrack,
false)
instantiate_kernel(
"trellis_backtrack_overlap_1",
trellis_backtrack,
true)
instantiate_kernel(
"qmv_fast_float16_t_gs_64_b_2_batch_0_mode_1",
qmv_fast,
float16_t,
64,
2,
false,
true)
instantiate_quantized_all() // clang-format on instantiate_quantized_all() // clang-format on

View File

@ -1048,7 +1048,8 @@ trellis_quantize(const array& w_, int bits, StreamOrDevice s_) {
int T = Tx * Ty; int T = Tx * Ty;
auto scale = std(astype(w_, float32, s), s); auto scale = std(astype(w_, float32, s), s);
auto w = divide(w_, scale, s); auto w = divide(w_, scale, s);
w = astype(w, float16, s); w = astype(w, w_.dtype(), s);
scale = astype(scale, w_.dtype(), s);
w = reshape(w, {M / Tx, Tx, -1, Ty}, s); w = reshape(w, {M / Tx, Tx, -1, Ty}, s);
w = transpose(w, {0, 2, 1, 3}, s); w = transpose(w, {0, 2, 1, 3}, s);
@ -1067,7 +1068,7 @@ trellis_quantize(const array& w_, int bits, StreamOrDevice s_) {
{w_batch}); {w_batch});
q_batch = slice(q_batch, {0, 0}, q_batch.shape(), {1, L / bits}, s); q_batch = slice(q_batch, {0, 0}, q_batch.shape(), {1, L / bits}, s);
q = slice_update(q, q_batch, {i, 0}, {i + batch_size, q.shape(-1)}, s); q = slice_update(q, q_batch, {i, 0}, {i + batch_size, q.shape(-1)}, s);
eval(q); // eval(q);
} }
q = reshape(q, {M / Tx, -1, Tx, Ty * bits / L}, s); q = reshape(q, {M / Tx, -1, Tx, Ty * bits / L}, s);

View File

@ -6,7 +6,6 @@ from typing import Any, Literal
import mlx.core as mx import mlx.core as mx
from mlx.nn.layers.base import Module from mlx.nn.layers.base import Module
from mlx.nn.layers.quantized import QuantizedLinear from mlx.nn.layers.quantized import QuantizedLinear
from mlx.nn.layers.viterbi import quantize as trellis_quantize
class Identity(Module): class Identity(Module):

View File

@ -5,7 +5,6 @@ from typing import Callable, Literal, Optional, Union
import mlx.core as mx import mlx.core as mx
from mlx.nn.layers.base import Module from mlx.nn.layers.base import Module
from mlx.nn.layers.viterbi import quantize as trellis_quantize
from mlx.utils import tree_map_with_path from mlx.utils import tree_map_with_path