mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Get trellis to run
This commit is contained in:
parent
e3d275bc49
commit
998404ada4
@ -120,4 +120,34 @@
|
|||||||
instantiate_quantized_groups(6) \
|
instantiate_quantized_groups(6) \
|
||||||
instantiate_quantized_groups(8)
|
instantiate_quantized_groups(8)
|
||||||
|
|
||||||
|
instantiate_kernel(
|
||||||
|
"trellis_viterbi_float16_t_overlap_0",
|
||||||
|
trellis_viterbi,
|
||||||
|
float16_t,
|
||||||
|
false)
|
||||||
|
instantiate_kernel(
|
||||||
|
"trellis_viterbi_float16_t_overlap_1",
|
||||||
|
trellis_viterbi,
|
||||||
|
float16_t,
|
||||||
|
true)
|
||||||
|
|
||||||
|
instantiate_kernel(
|
||||||
|
"trellis_backtrack_overlap_0",
|
||||||
|
trellis_backtrack,
|
||||||
|
false)
|
||||||
|
instantiate_kernel(
|
||||||
|
"trellis_backtrack_overlap_1",
|
||||||
|
trellis_backtrack,
|
||||||
|
true)
|
||||||
|
|
||||||
|
instantiate_kernel(
|
||||||
|
"qmv_fast_float16_t_gs_64_b_2_batch_0_mode_1",
|
||||||
|
qmv_fast,
|
||||||
|
float16_t,
|
||||||
|
64,
|
||||||
|
2,
|
||||||
|
false,
|
||||||
|
true)
|
||||||
|
|
||||||
|
|
||||||
instantiate_quantized_all() // clang-format on
|
instantiate_quantized_all() // clang-format on
|
||||||
|
@ -1048,7 +1048,8 @@ trellis_quantize(const array& w_, int bits, StreamOrDevice s_) {
|
|||||||
int T = Tx * Ty;
|
int T = Tx * Ty;
|
||||||
auto scale = std(astype(w_, float32, s), s);
|
auto scale = std(astype(w_, float32, s), s);
|
||||||
auto w = divide(w_, scale, s);
|
auto w = divide(w_, scale, s);
|
||||||
w = astype(w, float16, s);
|
w = astype(w, w_.dtype(), s);
|
||||||
|
scale = astype(scale, w_.dtype(), s);
|
||||||
|
|
||||||
w = reshape(w, {M / Tx, Tx, -1, Ty}, s);
|
w = reshape(w, {M / Tx, Tx, -1, Ty}, s);
|
||||||
w = transpose(w, {0, 2, 1, 3}, s);
|
w = transpose(w, {0, 2, 1, 3}, s);
|
||||||
@ -1067,7 +1068,7 @@ trellis_quantize(const array& w_, int bits, StreamOrDevice s_) {
|
|||||||
{w_batch});
|
{w_batch});
|
||||||
q_batch = slice(q_batch, {0, 0}, q_batch.shape(), {1, L / bits}, s);
|
q_batch = slice(q_batch, {0, 0}, q_batch.shape(), {1, L / bits}, s);
|
||||||
q = slice_update(q, q_batch, {i, 0}, {i + batch_size, q.shape(-1)}, s);
|
q = slice_update(q, q_batch, {i, 0}, {i + batch_size, q.shape(-1)}, s);
|
||||||
eval(q);
|
// eval(q);
|
||||||
}
|
}
|
||||||
|
|
||||||
q = reshape(q, {M / Tx, -1, Tx, Ty * bits / L}, s);
|
q = reshape(q, {M / Tx, -1, Tx, Ty * bits / L}, s);
|
||||||
|
@ -6,7 +6,6 @@ from typing import Any, Literal
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from mlx.nn.layers.base import Module
|
from mlx.nn.layers.base import Module
|
||||||
from mlx.nn.layers.quantized import QuantizedLinear
|
from mlx.nn.layers.quantized import QuantizedLinear
|
||||||
from mlx.nn.layers.viterbi import quantize as trellis_quantize
|
|
||||||
|
|
||||||
|
|
||||||
class Identity(Module):
|
class Identity(Module):
|
||||||
|
@ -5,7 +5,6 @@ from typing import Callable, Literal, Optional, Union
|
|||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from mlx.nn.layers.base import Module
|
from mlx.nn.layers.base import Module
|
||||||
from mlx.nn.layers.viterbi import quantize as trellis_quantize
|
|
||||||
from mlx.utils import tree_map_with_path
|
from mlx.utils import tree_map_with_path
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user