diff --git a/mlx/backend/cpu/quantized.cpp b/mlx/backend/cpu/quantized.cpp index 75ff35a2e..ce51cac3a 100644 --- a/mlx/backend/cpu/quantized.cpp +++ b/mlx/backend/cpu/quantized.cpp @@ -1,7 +1,5 @@ // Copyright © 2023 Apple Inc. -#include - #include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/simd/simd.h" @@ -838,8 +836,6 @@ void mxfp4_bs_qmm_dispatch( } // namespace void QuantizedMatmul::eval_cpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 4); - auto& x_pre = inputs[0]; auto& w_pre = inputs[1]; auto& scales_pre = inputs[2]; @@ -892,8 +888,6 @@ void QuantizedMatmul::eval_cpu(const std::vector& inputs, array& out) { } void GatherQMM::eval_cpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 6); - auto& x_pre = inputs[0]; auto& w_pre = inputs[1]; auto& scales_pre = inputs[2]; diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index 903c650bc..3886f5b59 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -1,7 +1,5 @@ // Copyright © 2023-2024 Apple Inc. -#include - #include "mlx/backend/common/broadcasting.h" #include "mlx/backend/common/compiled.h" #include "mlx/backend/gpu/copy.h" diff --git a/python/mlx/nn/layers/quantized.py b/python/mlx/nn/layers/quantized.py index 5d9c7ae5c..669162e68 100644 --- a/python/mlx/nn/layers/quantized.py +++ b/python/mlx/nn/layers/quantized.py @@ -154,7 +154,7 @@ class QuantizedEmbedding(Module): ): """Create a :obj:`QuantizedEmbedding` layer from an :obj:`Embedding` layer.""" embedding_dims, dims = embedding_layer.weight.shape - ql = cls(embedding_dims, dims, group_size, bits) + ql = cls(embedding_dims, dims, group_size, bits, mode=mode) ql.weight, *scales_biases = mx.quantize( embedding_layer.weight, group_size, @@ -260,7 +260,7 @@ class QuantizedLinear(Module): ): """Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer.""" output_dims, input_dims = linear_layer.weight.shape - ql = cls(input_dims, output_dims, False, group_size, bits) + ql = cls(input_dims, output_dims, False, group_size, bits, mode=mode) ql.weight, *scales_biases = mx.quantize( linear_layer.weight, group_size, diff --git a/python/tests/cuda_skip.py b/python/tests/cuda_skip.py index 2cc4a6c17..a662ca1a0 100644 --- a/python/tests/cuda_skip.py +++ b/python/tests/cuda_skip.py @@ -49,6 +49,8 @@ cuda_skip = { "TestQuantized.test_qmm_shapes", "TestQuantized.test_qmm_vjp", "TestQuantized.test_qmv", + "TestQuantized.test_mxfp4_qmv", + "TestQuantized.test_mxfp4_qvm", "TestQuantized.test_qvm", "TestQuantized.test_qvm_splitk", "TestQuantized.test_small_matrix",