mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-29 18:26:41 +08:00
fix
This commit is contained in:
parent
8da1c64fe9
commit
3ce23755ea
@ -1,7 +1,5 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
#include <cassert>
|
|
||||||
|
|
||||||
#include "mlx/backend/cpu/copy.h"
|
#include "mlx/backend/cpu/copy.h"
|
||||||
#include "mlx/backend/cpu/encoder.h"
|
#include "mlx/backend/cpu/encoder.h"
|
||||||
#include "mlx/backend/cpu/simd/simd.h"
|
#include "mlx/backend/cpu/simd/simd.h"
|
||||||
@ -838,8 +836,6 @@ void mxfp4_bs_qmm_dispatch(
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 4);
|
|
||||||
|
|
||||||
auto& x_pre = inputs[0];
|
auto& x_pre = inputs[0];
|
||||||
auto& w_pre = inputs[1];
|
auto& w_pre = inputs[1];
|
||||||
auto& scales_pre = inputs[2];
|
auto& scales_pre = inputs[2];
|
||||||
@ -892,8 +888,6 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 6);
|
|
||||||
|
|
||||||
auto& x_pre = inputs[0];
|
auto& x_pre = inputs[0];
|
||||||
auto& w_pre = inputs[1];
|
auto& w_pre = inputs[1];
|
||||||
auto& scales_pre = inputs[2];
|
auto& scales_pre = inputs[2];
|
||||||
|
@ -1,7 +1,5 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include <cassert>
|
|
||||||
|
|
||||||
#include "mlx/backend/common/broadcasting.h"
|
#include "mlx/backend/common/broadcasting.h"
|
||||||
#include "mlx/backend/common/compiled.h"
|
#include "mlx/backend/common/compiled.h"
|
||||||
#include "mlx/backend/gpu/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
|
@ -154,7 +154,7 @@ class QuantizedEmbedding(Module):
|
|||||||
):
|
):
|
||||||
"""Create a :obj:`QuantizedEmbedding` layer from an :obj:`Embedding` layer."""
|
"""Create a :obj:`QuantizedEmbedding` layer from an :obj:`Embedding` layer."""
|
||||||
embedding_dims, dims = embedding_layer.weight.shape
|
embedding_dims, dims = embedding_layer.weight.shape
|
||||||
ql = cls(embedding_dims, dims, group_size, bits)
|
ql = cls(embedding_dims, dims, group_size, bits, mode=mode)
|
||||||
ql.weight, *scales_biases = mx.quantize(
|
ql.weight, *scales_biases = mx.quantize(
|
||||||
embedding_layer.weight,
|
embedding_layer.weight,
|
||||||
group_size,
|
group_size,
|
||||||
@ -260,7 +260,7 @@ class QuantizedLinear(Module):
|
|||||||
):
|
):
|
||||||
"""Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer."""
|
"""Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer."""
|
||||||
output_dims, input_dims = linear_layer.weight.shape
|
output_dims, input_dims = linear_layer.weight.shape
|
||||||
ql = cls(input_dims, output_dims, False, group_size, bits)
|
ql = cls(input_dims, output_dims, False, group_size, bits, mode=mode)
|
||||||
ql.weight, *scales_biases = mx.quantize(
|
ql.weight, *scales_biases = mx.quantize(
|
||||||
linear_layer.weight,
|
linear_layer.weight,
|
||||||
group_size,
|
group_size,
|
||||||
|
@ -49,6 +49,8 @@ cuda_skip = {
|
|||||||
"TestQuantized.test_qmm_shapes",
|
"TestQuantized.test_qmm_shapes",
|
||||||
"TestQuantized.test_qmm_vjp",
|
"TestQuantized.test_qmm_vjp",
|
||||||
"TestQuantized.test_qmv",
|
"TestQuantized.test_qmv",
|
||||||
|
"TestQuantized.test_mxfp4_qmv",
|
||||||
|
"TestQuantized.test_mxfp4_qvm",
|
||||||
"TestQuantized.test_qvm",
|
"TestQuantized.test_qvm",
|
||||||
"TestQuantized.test_qvm_splitk",
|
"TestQuantized.test_qvm_splitk",
|
||||||
"TestQuantized.test_small_matrix",
|
"TestQuantized.test_small_matrix",
|
||||||
|
Loading…
Reference in New Issue
Block a user