From 998404ada4a92f33839fa4ef9ecabf055cdcbd02 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sat, 26 Apr 2025 07:02:20 -0700 Subject: [PATCH] Get trellis to run --- mlx/backend/metal/kernels/quantized.metal | 30 +++++++++++++++++++++++ mlx/fast.cpp | 5 ++-- python/mlx/nn/layers/linear.py | 1 - python/mlx/nn/layers/quantized.py | 1 - 4 files changed, 33 insertions(+), 4 deletions(-) diff --git a/mlx/backend/metal/kernels/quantized.metal b/mlx/backend/metal/kernels/quantized.metal index 7af554437..e3efc3bef 100644 --- a/mlx/backend/metal/kernels/quantized.metal +++ b/mlx/backend/metal/kernels/quantized.metal @@ -120,4 +120,34 @@ instantiate_quantized_groups(6) \ instantiate_quantized_groups(8) +instantiate_kernel( + "trellis_viterbi_float16_t_overlap_0", + trellis_viterbi, + float16_t, + false) +instantiate_kernel( + "trellis_viterbi_float16_t_overlap_1", + trellis_viterbi, + float16_t, + true) + +instantiate_kernel( + "trellis_backtrack_overlap_0", + trellis_backtrack, + false) +instantiate_kernel( + "trellis_backtrack_overlap_1", + trellis_backtrack, + true) + +instantiate_kernel( + "qmv_fast_float16_t_gs_64_b_2_batch_0_mode_1", + qmv_fast, + float16_t, + 64, + 2, + false, + true) + + instantiate_quantized_all() // clang-format on diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 402c4999f..b59313a5d 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -1048,7 +1048,8 @@ trellis_quantize(const array& w_, int bits, StreamOrDevice s_) { int T = Tx * Ty; auto scale = std(astype(w_, float32, s), s); auto w = divide(w_, scale, s); - w = astype(w, float16, s); + w = astype(w, w_.dtype(), s); + scale = astype(scale, w_.dtype(), s); w = reshape(w, {M / Tx, Tx, -1, Ty}, s); w = transpose(w, {0, 2, 1, 3}, s); @@ -1067,7 +1068,7 @@ trellis_quantize(const array& w_, int bits, StreamOrDevice s_) { {w_batch}); q_batch = slice(q_batch, {0, 0}, q_batch.shape(), {1, L / bits}, s); q = slice_update(q, q_batch, {i, 0}, {i + batch_size, q.shape(-1)}, s); - eval(q); + // eval(q); } q = reshape(q, {M / Tx, -1, Tx, Ty * bits / L}, s); diff --git a/python/mlx/nn/layers/linear.py b/python/mlx/nn/layers/linear.py index 038dc3c58..877aab6c2 100644 --- a/python/mlx/nn/layers/linear.py +++ b/python/mlx/nn/layers/linear.py @@ -6,7 +6,6 @@ from typing import Any, Literal import mlx.core as mx from mlx.nn.layers.base import Module from mlx.nn.layers.quantized import QuantizedLinear -from mlx.nn.layers.viterbi import quantize as trellis_quantize class Identity(Module): diff --git a/python/mlx/nn/layers/quantized.py b/python/mlx/nn/layers/quantized.py index 2dca6cc8b..7d867b187 100644 --- a/python/mlx/nn/layers/quantized.py +++ b/python/mlx/nn/layers/quantized.py @@ -5,7 +5,6 @@ from typing import Callable, Literal, Optional, Union import mlx.core as mx from mlx.nn.layers.base import Module -from mlx.nn.layers.viterbi import quantize as trellis_quantize from mlx.utils import tree_map_with_path