mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-03 18:18:15 +08:00 
			
		
		
		
	Get trellis to run
This commit is contained in:
		@@ -120,4 +120,34 @@
 | 
			
		||||
  instantiate_quantized_groups(6) \
 | 
			
		||||
  instantiate_quantized_groups(8)
 | 
			
		||||
 | 
			
		||||
instantiate_kernel(
 | 
			
		||||
  "trellis_viterbi_float16_t_overlap_0",
 | 
			
		||||
  trellis_viterbi,
 | 
			
		||||
  float16_t,
 | 
			
		||||
  false)
 | 
			
		||||
instantiate_kernel(
 | 
			
		||||
  "trellis_viterbi_float16_t_overlap_1",
 | 
			
		||||
  trellis_viterbi,
 | 
			
		||||
  float16_t,
 | 
			
		||||
  true)
 | 
			
		||||
 | 
			
		||||
instantiate_kernel(
 | 
			
		||||
  "trellis_backtrack_overlap_0",
 | 
			
		||||
  trellis_backtrack,
 | 
			
		||||
  false)
 | 
			
		||||
instantiate_kernel(
 | 
			
		||||
  "trellis_backtrack_overlap_1",
 | 
			
		||||
  trellis_backtrack,
 | 
			
		||||
  true)
 | 
			
		||||
 | 
			
		||||
instantiate_kernel(
 | 
			
		||||
  "qmv_fast_float16_t_gs_64_b_2_batch_0_mode_1",
 | 
			
		||||
  qmv_fast,
 | 
			
		||||
  float16_t,
 | 
			
		||||
  64,
 | 
			
		||||
  2,
 | 
			
		||||
  false,
 | 
			
		||||
  true)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
instantiate_quantized_all() // clang-format on
 | 
			
		||||
 
 | 
			
		||||
@@ -1048,7 +1048,8 @@ trellis_quantize(const array& w_, int bits, StreamOrDevice s_) {
 | 
			
		||||
  int T = Tx * Ty;
 | 
			
		||||
  auto scale = std(astype(w_, float32, s), s);
 | 
			
		||||
  auto w = divide(w_, scale, s);
 | 
			
		||||
  w = astype(w, float16, s);
 | 
			
		||||
  w = astype(w, w_.dtype(), s);
 | 
			
		||||
  scale = astype(scale, w_.dtype(), s);
 | 
			
		||||
 | 
			
		||||
  w = reshape(w, {M / Tx, Tx, -1, Ty}, s);
 | 
			
		||||
  w = transpose(w, {0, 2, 1, 3}, s);
 | 
			
		||||
@@ -1067,7 +1068,7 @@ trellis_quantize(const array& w_, int bits, StreamOrDevice s_) {
 | 
			
		||||
        {w_batch});
 | 
			
		||||
    q_batch = slice(q_batch, {0, 0}, q_batch.shape(), {1, L / bits}, s);
 | 
			
		||||
    q = slice_update(q, q_batch, {i, 0}, {i + batch_size, q.shape(-1)}, s);
 | 
			
		||||
    eval(q);
 | 
			
		||||
    //    eval(q);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  q = reshape(q, {M / Tx, -1, Tx, Ty * bits / L}, s);
 | 
			
		||||
 
 | 
			
		||||
@@ -6,7 +6,6 @@ from typing import Any, Literal
 | 
			
		||||
import mlx.core as mx
 | 
			
		||||
from mlx.nn.layers.base import Module
 | 
			
		||||
from mlx.nn.layers.quantized import QuantizedLinear
 | 
			
		||||
from mlx.nn.layers.viterbi import quantize as trellis_quantize
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Identity(Module):
 | 
			
		||||
 
 | 
			
		||||
@@ -5,7 +5,6 @@ from typing import Callable, Literal, Optional, Union
 | 
			
		||||
 | 
			
		||||
import mlx.core as mx
 | 
			
		||||
from mlx.nn.layers.base import Module
 | 
			
		||||
from mlx.nn.layers.viterbi import quantize as trellis_quantize
 | 
			
		||||
from mlx.utils import tree_map_with_path
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user