This commit is contained in:
Awni Hannun 2025-08-21 08:58:36 -07:00
parent 8da1c64fe9
commit 3ce23755ea
4 changed files with 4 additions and 10 deletions

View File

@ -1,7 +1,5 @@
// Copyright © 2023 Apple Inc.
#include <cassert>
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h"
@ -838,8 +836,6 @@ void mxfp4_bs_qmm_dispatch(
} // namespace
void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 4);
auto& x_pre = inputs[0];
auto& w_pre = inputs[1];
auto& scales_pre = inputs[2];
@ -892,8 +888,6 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
}
void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 6);
auto& x_pre = inputs[0];
auto& w_pre = inputs[1];
auto& scales_pre = inputs[2];

View File

@ -1,7 +1,5 @@
// Copyright © 2023-2024 Apple Inc.
#include <cassert>
#include "mlx/backend/common/broadcasting.h"
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/gpu/copy.h"

View File

@ -154,7 +154,7 @@ class QuantizedEmbedding(Module):
):
"""Create a :obj:`QuantizedEmbedding` layer from an :obj:`Embedding` layer."""
embedding_dims, dims = embedding_layer.weight.shape
ql = cls(embedding_dims, dims, group_size, bits)
ql = cls(embedding_dims, dims, group_size, bits, mode=mode)
ql.weight, *scales_biases = mx.quantize(
embedding_layer.weight,
group_size,
@ -260,7 +260,7 @@ class QuantizedLinear(Module):
):
"""Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer."""
output_dims, input_dims = linear_layer.weight.shape
ql = cls(input_dims, output_dims, False, group_size, bits)
ql = cls(input_dims, output_dims, False, group_size, bits, mode=mode)
ql.weight, *scales_biases = mx.quantize(
linear_layer.weight,
group_size,

View File

@ -49,6 +49,8 @@ cuda_skip = {
"TestQuantized.test_qmm_shapes",
"TestQuantized.test_qmm_vjp",
"TestQuantized.test_qmv",
"TestQuantized.test_mxfp4_qmv",
"TestQuantized.test_mxfp4_qvm",
"TestQuantized.test_qvm",
"TestQuantized.test_qvm_splitk",
"TestQuantized.test_small_matrix",