mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-28 06:07:46 +08:00
fix
This commit is contained in:
parent
8da1c64fe9
commit
3ce23755ea
@ -1,7 +1,5 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/simd/simd.h"
|
||||
@ -838,8 +836,6 @@ void mxfp4_bs_qmm_dispatch(
|
||||
} // namespace
|
||||
|
||||
void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 4);
|
||||
|
||||
auto& x_pre = inputs[0];
|
||||
auto& w_pre = inputs[1];
|
||||
auto& scales_pre = inputs[2];
|
||||
@ -892,8 +888,6 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 6);
|
||||
|
||||
auto& x_pre = inputs[0];
|
||||
auto& w_pre = inputs[1];
|
||||
auto& scales_pre = inputs[2];
|
||||
|
@ -1,7 +1,5 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/backend/common/broadcasting.h"
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
|
@ -154,7 +154,7 @@ class QuantizedEmbedding(Module):
|
||||
):
|
||||
"""Create a :obj:`QuantizedEmbedding` layer from an :obj:`Embedding` layer."""
|
||||
embedding_dims, dims = embedding_layer.weight.shape
|
||||
ql = cls(embedding_dims, dims, group_size, bits)
|
||||
ql = cls(embedding_dims, dims, group_size, bits, mode=mode)
|
||||
ql.weight, *scales_biases = mx.quantize(
|
||||
embedding_layer.weight,
|
||||
group_size,
|
||||
@ -260,7 +260,7 @@ class QuantizedLinear(Module):
|
||||
):
|
||||
"""Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer."""
|
||||
output_dims, input_dims = linear_layer.weight.shape
|
||||
ql = cls(input_dims, output_dims, False, group_size, bits)
|
||||
ql = cls(input_dims, output_dims, False, group_size, bits, mode=mode)
|
||||
ql.weight, *scales_biases = mx.quantize(
|
||||
linear_layer.weight,
|
||||
group_size,
|
||||
|
@ -49,6 +49,8 @@ cuda_skip = {
|
||||
"TestQuantized.test_qmm_shapes",
|
||||
"TestQuantized.test_qmm_vjp",
|
||||
"TestQuantized.test_qmv",
|
||||
"TestQuantized.test_mxfp4_qmv",
|
||||
"TestQuantized.test_mxfp4_qvm",
|
||||
"TestQuantized.test_qvm",
|
||||
"TestQuantized.test_qvm_splitk",
|
||||
"TestQuantized.test_small_matrix",
|
||||
|
Loading…
Reference in New Issue
Block a user