mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
An initial quantized matmul implementation (#205)
* Add quantized matvec * Add quantized matrix matrix with 2nd matrix transposed * Add quantized matmul tests * Add a slow cpu quantized matmul * Add a slightly faster vectorized cpu version
This commit is contained in:
parent
e6872a4149
commit
dfa9f4bc58
@ -23,6 +23,16 @@ def none_or_list(x):
|
|||||||
return [int(xi) for xi in x.split(",")]
|
return [int(xi) for xi in x.split(",")]
|
||||||
|
|
||||||
|
|
||||||
|
def dtype_from_str(x):
|
||||||
|
if x == "":
|
||||||
|
return mx.float32
|
||||||
|
else:
|
||||||
|
dt = getattr(mx, x)
|
||||||
|
if not isinstance(dt, mx.Dtype):
|
||||||
|
raise ValueError(f"{x} is not an mlx dtype")
|
||||||
|
return dt
|
||||||
|
|
||||||
|
|
||||||
def bench(f, *args):
|
def bench(f, *args):
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
f(*args)
|
f(*args)
|
||||||
@ -49,6 +59,15 @@ def matmul(x, y):
|
|||||||
mx.eval(ys)
|
mx.eval(ys)
|
||||||
|
|
||||||
|
|
||||||
|
def quant_matmul(x, w, s, b):
|
||||||
|
groups = x.shape[-1] // s.shape[-1]
|
||||||
|
width = 32 // (x.shape[-1] // w.shape[0])
|
||||||
|
ys = []
|
||||||
|
for i in range(10):
|
||||||
|
ys.append(mx.quantized_matmul(x, w, s, b, groups=groups, width=width))
|
||||||
|
mx.eval(ys)
|
||||||
|
|
||||||
|
|
||||||
def conv1d(x, y):
|
def conv1d(x, y):
|
||||||
ys = []
|
ys = []
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
@ -296,9 +315,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--fused", action="store_true", help="Use fused functions where possible"
|
"--fused", action="store_true", help="Use fused functions where possible"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument("--dtype", type=dtype_from_str, default=[], action="append")
|
||||||
"--dtype", choices=["float32", "float16", "bfloat16"], default="float32"
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@ -315,11 +332,15 @@ if __name__ == "__main__":
|
|||||||
mx.set_default_device(mx.cpu)
|
mx.set_default_device(mx.cpu)
|
||||||
else:
|
else:
|
||||||
mx.set_default_device(mx.gpu)
|
mx.set_default_device(mx.gpu)
|
||||||
dtype = dict(float32=mx.float32, float16=mx.float16, bfloat16=mx.bfloat16)[
|
|
||||||
args.dtype
|
types = args.dtype
|
||||||
]
|
if not types:
|
||||||
|
types = [mx.float32]
|
||||||
|
if len(types) < len(args.size):
|
||||||
|
types = types + [types[0]] * (len(args.size) - len(types))
|
||||||
|
|
||||||
xs = []
|
xs = []
|
||||||
for size in args.size:
|
for size, dtype in zip(args.size, types):
|
||||||
xs.append(mx.random.normal(size).astype(dtype))
|
xs.append(mx.random.normal(size).astype(dtype))
|
||||||
for i, t in enumerate(args.transpose):
|
for i, t in enumerate(args.transpose):
|
||||||
if t is None:
|
if t is None:
|
||||||
@ -335,6 +356,9 @@ if __name__ == "__main__":
|
|||||||
elif args.benchmark == "matmul":
|
elif args.benchmark == "matmul":
|
||||||
print(bench(matmul, *xs))
|
print(bench(matmul, *xs))
|
||||||
|
|
||||||
|
elif args.benchmark == "quant_matmul":
|
||||||
|
print(bench(quant_matmul, *xs))
|
||||||
|
|
||||||
elif args.benchmark == "linear":
|
elif args.benchmark == "linear":
|
||||||
print(bench(linear, *xs))
|
print(bench(linear, *xs))
|
||||||
|
|
||||||
|
@ -22,6 +22,16 @@ def none_or_list(x):
|
|||||||
return [int(xi) for xi in x.split(",")]
|
return [int(xi) for xi in x.split(",")]
|
||||||
|
|
||||||
|
|
||||||
|
def dtype_from_str(x):
|
||||||
|
if x == "":
|
||||||
|
return torch.float32
|
||||||
|
else:
|
||||||
|
dt = getattr(torch, x)
|
||||||
|
if not isinstance(dt, torch.dtype):
|
||||||
|
raise ValueError(f"{x} is not a torch dtype")
|
||||||
|
return dt
|
||||||
|
|
||||||
|
|
||||||
def bench(f, *args):
|
def bench(f, *args):
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
f(*args)
|
f(*args)
|
||||||
@ -312,7 +322,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--fused", action="store_true", help="Use fused functions where possible"
|
"--fused", action="store_true", help="Use fused functions where possible"
|
||||||
)
|
)
|
||||||
parser.add_argument("--dtype", choices=["float32", "float16"], default="float32")
|
parser.add_argument("--dtype", type=dtype_from_str, default=[], action="append")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@ -327,9 +337,15 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
device = "cpu" if args.cpu else "mps"
|
device = "cpu" if args.cpu else "mps"
|
||||||
dtype = dict(float32=torch.float32, float16=torch.float16)[args.dtype]
|
|
||||||
|
types = args.dtype
|
||||||
|
if not types:
|
||||||
|
types = [torch.float32]
|
||||||
|
if len(types) < len(args.size):
|
||||||
|
types = types + [types[0]] * (len(args.size) - len(types))
|
||||||
|
|
||||||
xs = []
|
xs = []
|
||||||
for size in args.size:
|
for size, dtype in zip(args.size, types):
|
||||||
xs.append(torch.randn(*size).to(device).to(dtype))
|
xs.append(torch.randn(*size).to(device).to(dtype))
|
||||||
for i, t in enumerate(args.transpose):
|
for i, t in enumerate(args.transpose):
|
||||||
if t is None:
|
if t is None:
|
||||||
|
@ -4,6 +4,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||||
)
|
)
|
||||||
|
107
mlx/backend/accelerate/quantized.cpp
Normal file
107
mlx/backend/accelerate/quantized.cpp
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
#include <simd/vector.h>
|
||||||
|
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
void _qmm_t_4_64(
|
||||||
|
float* result,
|
||||||
|
const float* x,
|
||||||
|
const uint32_t* w,
|
||||||
|
const float* scales,
|
||||||
|
const float* biases,
|
||||||
|
int M,
|
||||||
|
int N,
|
||||||
|
int K) {
|
||||||
|
constexpr int width = 4;
|
||||||
|
constexpr int groups = 64;
|
||||||
|
constexpr int bitmask = (1 << width) - 1;
|
||||||
|
constexpr int pack_factor = 32 / width;
|
||||||
|
constexpr int packs_in_group = groups / pack_factor;
|
||||||
|
const int Kg = K / groups;
|
||||||
|
const int Kw = K / pack_factor;
|
||||||
|
|
||||||
|
for (int m = 0; m < M; m++) {
|
||||||
|
const uint32_t* w_local = w;
|
||||||
|
const float* scales_local = scales;
|
||||||
|
const float* biases_local = biases;
|
||||||
|
|
||||||
|
for (int n = 0; n < N; n++) {
|
||||||
|
const simd_float16* x_local = (simd_float16*)x;
|
||||||
|
simd_float16 sum = 0;
|
||||||
|
for (int k = 0; k < K; k += groups) {
|
||||||
|
float scale = *scales_local++;
|
||||||
|
float bias = *biases_local++;
|
||||||
|
|
||||||
|
for (int kw = 0; kw < packs_in_group; kw += 2) {
|
||||||
|
// TODO: vectorize this properly
|
||||||
|
simd_uint16 wi;
|
||||||
|
for (int e = 0; e < 2; e++) {
|
||||||
|
uint32_t wii = *w_local++;
|
||||||
|
for (int p = 0; p < 8; p++) {
|
||||||
|
wi[e * 8 + p] = wii & bitmask;
|
||||||
|
wii >>= width;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
simd_float16 wf = simd_float(wi);
|
||||||
|
wf *= scale;
|
||||||
|
wf += bias;
|
||||||
|
|
||||||
|
sum += (*x_local) * wf;
|
||||||
|
x_local++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
*result = simd_reduce_add(sum);
|
||||||
|
result++;
|
||||||
|
}
|
||||||
|
|
||||||
|
x += K;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 4);
|
||||||
|
|
||||||
|
auto& x = inputs[0];
|
||||||
|
auto& w = inputs[1];
|
||||||
|
auto& scales = inputs[2];
|
||||||
|
auto& biases = inputs[3];
|
||||||
|
|
||||||
|
if (w.strides()[0] != 1) {
|
||||||
|
throw std::runtime_error("The quantized weight should be transposed");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!x.flags().row_contiguous || !scales.flags().row_contiguous ||
|
||||||
|
!biases.flags().row_contiguous) {
|
||||||
|
throw std::runtime_error("x, scales and biases should be row contiguous.");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (x.dtype() == float32 && width_ == 4 && groups_ == 64) {
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
int K = x.shape(-1);
|
||||||
|
int M = x.size() / K;
|
||||||
|
int N = w.shape(1);
|
||||||
|
_qmm_t_4_64(
|
||||||
|
out.data<float>(),
|
||||||
|
x.data<float>(),
|
||||||
|
w.data<uint32_t>(),
|
||||||
|
scales.data<float>(),
|
||||||
|
biases.data<float>(),
|
||||||
|
M,
|
||||||
|
N,
|
||||||
|
K);
|
||||||
|
} else {
|
||||||
|
eval(inputs, out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
@ -8,6 +8,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||||
|
@ -62,6 +62,7 @@ DEFAULT(NotEqual)
|
|||||||
DEFAULT(Pad)
|
DEFAULT(Pad)
|
||||||
DEFAULT(Partition)
|
DEFAULT(Partition)
|
||||||
DEFAULT(Power)
|
DEFAULT(Power)
|
||||||
|
DEFAULT(QuantizedMatmul)
|
||||||
DEFAULT(RandomBits)
|
DEFAULT(RandomBits)
|
||||||
DEFAULT(Reduce)
|
DEFAULT(Reduce)
|
||||||
DEFAULT(Reshape)
|
DEFAULT(Reshape)
|
||||||
|
183
mlx/backend/common/quantized.cpp
Normal file
183
mlx/backend/common/quantized.cpp
Normal file
@ -0,0 +1,183 @@
|
|||||||
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
template <typename T, int width, int groups>
|
||||||
|
void _qmm_t(
|
||||||
|
T* result,
|
||||||
|
const T* x,
|
||||||
|
const uint32_t* w,
|
||||||
|
const T* scales,
|
||||||
|
const T* biases,
|
||||||
|
int M,
|
||||||
|
int N,
|
||||||
|
int K) {
|
||||||
|
constexpr int bitmask = (1 << width) - 1;
|
||||||
|
constexpr int pack_factor = 32 / width;
|
||||||
|
constexpr int packs_in_group = groups / pack_factor;
|
||||||
|
const int Kg = K / groups;
|
||||||
|
const int Kw = K / pack_factor;
|
||||||
|
|
||||||
|
for (int m = 0; m < M; m++) {
|
||||||
|
const uint32_t* w_local = w;
|
||||||
|
const T* scales_local = scales;
|
||||||
|
const T* biases_local = biases;
|
||||||
|
|
||||||
|
for (int n = 0; n < N; n++) {
|
||||||
|
const T* x_local = x;
|
||||||
|
T sum = 0;
|
||||||
|
for (int k = 0; k < K; k += groups) {
|
||||||
|
T scale = *scales_local++;
|
||||||
|
T bias = *biases_local++;
|
||||||
|
|
||||||
|
for (int kw = 0; kw < packs_in_group; kw++) {
|
||||||
|
uint32_t wi = *w_local++;
|
||||||
|
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for (int p = 0; p < pack_factor; p++) {
|
||||||
|
sum += (*x_local++) * (scale * static_cast<T>(wi & bitmask) + bias);
|
||||||
|
wi >>= width;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*result = sum;
|
||||||
|
result++;
|
||||||
|
}
|
||||||
|
|
||||||
|
x += K;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void _qmm_t_dispatch_typed(
|
||||||
|
T* result,
|
||||||
|
const T* x,
|
||||||
|
const uint32_t* w,
|
||||||
|
const T* scales,
|
||||||
|
const T* biases,
|
||||||
|
int M,
|
||||||
|
int N,
|
||||||
|
int K,
|
||||||
|
int width,
|
||||||
|
int groups) {
|
||||||
|
switch (width) {
|
||||||
|
case 2: {
|
||||||
|
switch (groups) {
|
||||||
|
case 64:
|
||||||
|
return _qmm_t<T, 2, 64>(result, x, w, scales, biases, M, N, K);
|
||||||
|
case 128:
|
||||||
|
return _qmm_t<T, 2, 128>(result, x, w, scales, biases, M, N, K);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case 4: {
|
||||||
|
switch (groups) {
|
||||||
|
case 64:
|
||||||
|
return _qmm_t<T, 4, 64>(result, x, w, scales, biases, M, N, K);
|
||||||
|
case 128:
|
||||||
|
return _qmm_t<T, 4, 128>(result, x, w, scales, biases, M, N, K);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case 8: {
|
||||||
|
switch (groups) {
|
||||||
|
case 64:
|
||||||
|
return _qmm_t<T, 8, 64>(result, x, w, scales, biases, M, N, K);
|
||||||
|
case 128:
|
||||||
|
return _qmm_t<T, 8, 128>(result, x, w, scales, biases, M, N, K);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "Quantization type not supported. Provided bit width=" << width
|
||||||
|
<< " and groups=" << groups << ". The supported options are width in "
|
||||||
|
<< "{2, 4, 8} and groups in {64, 128}.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
void _qmm_t_dispatch(
|
||||||
|
array out,
|
||||||
|
const array& x,
|
||||||
|
const array& w,
|
||||||
|
const array& scales,
|
||||||
|
const array& biases,
|
||||||
|
int width,
|
||||||
|
int groups) {
|
||||||
|
int K = x.shape(-1);
|
||||||
|
int M = x.size() / K;
|
||||||
|
int N = w.shape(1);
|
||||||
|
|
||||||
|
switch (x.dtype()) {
|
||||||
|
case float32:
|
||||||
|
_qmm_t_dispatch_typed<float>(
|
||||||
|
out.data<float>(),
|
||||||
|
x.data<float>(),
|
||||||
|
w.data<uint32_t>(),
|
||||||
|
scales.data<float>(),
|
||||||
|
biases.data<float>(),
|
||||||
|
M,
|
||||||
|
N,
|
||||||
|
K,
|
||||||
|
width,
|
||||||
|
groups);
|
||||||
|
break;
|
||||||
|
case float16:
|
||||||
|
_qmm_t_dispatch_typed<float16_t>(
|
||||||
|
out.data<float16_t>(),
|
||||||
|
x.data<float16_t>(),
|
||||||
|
w.data<uint32_t>(),
|
||||||
|
scales.data<float16_t>(),
|
||||||
|
biases.data<float16_t>(),
|
||||||
|
M,
|
||||||
|
N,
|
||||||
|
K,
|
||||||
|
width,
|
||||||
|
groups);
|
||||||
|
break;
|
||||||
|
case bfloat16:
|
||||||
|
_qmm_t_dispatch_typed<bfloat16_t>(
|
||||||
|
out.data<bfloat16_t>(),
|
||||||
|
x.data<bfloat16_t>(),
|
||||||
|
w.data<uint32_t>(),
|
||||||
|
scales.data<bfloat16_t>(),
|
||||||
|
biases.data<bfloat16_t>(),
|
||||||
|
M,
|
||||||
|
N,
|
||||||
|
K,
|
||||||
|
width,
|
||||||
|
groups);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[quantized_matmul] only floating types are supported");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void QuantizedMatmul::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 4);
|
||||||
|
|
||||||
|
auto& x = inputs[0];
|
||||||
|
auto& w = inputs[1];
|
||||||
|
auto& scales = inputs[2];
|
||||||
|
auto& biases = inputs[3];
|
||||||
|
|
||||||
|
if (w.strides()[0] != 1) {
|
||||||
|
throw std::runtime_error("The quantized weight should be transposed");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!x.flags().row_contiguous || !scales.flags().row_contiguous ||
|
||||||
|
!biases.flags().row_contiguous) {
|
||||||
|
throw std::runtime_error("x, scales and biases should be row contiguous.");
|
||||||
|
}
|
||||||
|
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
_qmm_t_dispatch(out, x, w, scales, biases, width_, groups_);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
@ -10,6 +10,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
||||||
|
@ -18,6 +18,7 @@ set(
|
|||||||
"copy"
|
"copy"
|
||||||
"gemm"
|
"gemm"
|
||||||
"gemv"
|
"gemv"
|
||||||
|
"quantized"
|
||||||
"random"
|
"random"
|
||||||
"reduce"
|
"reduce"
|
||||||
"scan"
|
"scan"
|
||||||
|
287
mlx/backend/metal/kernels/quantized.metal
Normal file
287
mlx/backend/metal/kernels/quantized.metal
Normal file
@ -0,0 +1,287 @@
|
|||||||
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
#include <metal_stdlib>
|
||||||
|
#include <metal_simdgroup>
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/kernels/bf16.h"
|
||||||
|
#include "mlx/backend/metal/kernels/defines.h"
|
||||||
|
#include "mlx/backend/metal/kernels/gemm/gemm.h"
|
||||||
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
|
|
||||||
|
using namespace metal;
|
||||||
|
|
||||||
|
#define MLX_MTL_CONST static constant constexpr const
|
||||||
|
|
||||||
|
MLX_MTL_CONST int SIMD_SIZE = 32;
|
||||||
|
|
||||||
|
template <typename T, const int BM, const int BN, const int groups, const int width>
|
||||||
|
[[kernel]] void qmv(
|
||||||
|
const device uint32_t* w [[buffer(0)]],
|
||||||
|
const device T* scales [[buffer(1)]],
|
||||||
|
const device T* biases [[buffer(2)]],
|
||||||
|
const device T* x [[buffer(3)]],
|
||||||
|
device T* y [[buffer(4)]],
|
||||||
|
const constant int& in_vec_size [[buffer(5)]],
|
||||||
|
const constant int& out_vec_size [[buffer(6)]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint lid [[thread_index_in_threadgroup]],
|
||||||
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
|
|
||||||
|
static_assert(BN == SIMD_SIZE, "qmv expects BN to be equal to SIMD_SIZE");
|
||||||
|
|
||||||
|
constexpr int bitmask = (1 << width) - 1;
|
||||||
|
constexpr int el_per_thread = 32 / width;
|
||||||
|
constexpr int colgroup = BN * el_per_thread;
|
||||||
|
constexpr int groups_per_block = colgroup / groups;
|
||||||
|
constexpr int simdgroups_fetching_vec = colgroup / SIMD_SIZE;
|
||||||
|
|
||||||
|
threadgroup T scales_block[BM * groups_per_block];
|
||||||
|
threadgroup T biases_block[BM * groups_per_block];
|
||||||
|
threadgroup T x_block[colgroup];
|
||||||
|
|
||||||
|
thread uint32_t w_local;
|
||||||
|
thread T result = 0;
|
||||||
|
thread T scale = 1;
|
||||||
|
thread T bias = 0;
|
||||||
|
thread T x_thread[el_per_thread];
|
||||||
|
|
||||||
|
// Adjust positions
|
||||||
|
const int in_vec_size_w = in_vec_size / el_per_thread;
|
||||||
|
const int in_vec_size_g = in_vec_size / groups;
|
||||||
|
int out_row = tid.y * BM + simd_gid;
|
||||||
|
w += out_row * in_vec_size_w;
|
||||||
|
scales += out_row * in_vec_size_g;
|
||||||
|
biases += out_row * in_vec_size_g;
|
||||||
|
x += tid.z * in_vec_size;
|
||||||
|
y += tid.z * out_vec_size;
|
||||||
|
|
||||||
|
// Loop over in_vec in blocks of colgroup
|
||||||
|
for (int i=0; i<in_vec_size; i+=colgroup) {
|
||||||
|
// Load the vec to shared memory
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
if (simd_gid < simdgroups_fetching_vec) {
|
||||||
|
x_block[lid] = x[lid + i];
|
||||||
|
}
|
||||||
|
if (simd_lid == 0) {
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for (int j=0; j<groups_per_block; j++) {
|
||||||
|
scales_block[simd_gid * groups_per_block + j] = scales[i / groups + j];
|
||||||
|
}
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for (int j=0; j<groups_per_block; j++) {
|
||||||
|
biases_block[simd_gid * groups_per_block + j] = biases[i / groups + j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Load in_vec, scale, bias to registers
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for (int j=0; j<el_per_thread; j++) {
|
||||||
|
x_thread[j] = x_block[simd_lid*el_per_thread + j];
|
||||||
|
}
|
||||||
|
scale = scales_block[simd_gid * groups_per_block + simd_lid * el_per_thread / groups];
|
||||||
|
bias = biases_block[simd_gid * groups_per_block + simd_lid * el_per_thread / groups];
|
||||||
|
|
||||||
|
// Load the matrix elements
|
||||||
|
w_local = w[i / el_per_thread + simd_lid];
|
||||||
|
|
||||||
|
// Do all the work.
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for (int k=0; k<el_per_thread; k++) {
|
||||||
|
result += (scale * static_cast<T>(w_local & bitmask) + bias) * x_thread[k];
|
||||||
|
w_local >>= width;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Accumulate in the simdgroup
|
||||||
|
result = simd_sum(result);
|
||||||
|
|
||||||
|
// Store the result
|
||||||
|
if (simd_lid == 0) {
|
||||||
|
y[out_row] = result;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T, const int BM, const int BK, const int BN, const int groups, const int width>
|
||||||
|
[[kernel]] void qmm_t(
|
||||||
|
const device T* x [[buffer(0)]],
|
||||||
|
const device uint32_t* w [[buffer(1)]],
|
||||||
|
const device T* scales [[buffer(2)]],
|
||||||
|
const device T* biases [[buffer(3)]],
|
||||||
|
device T* y [[buffer(4)]],
|
||||||
|
const constant int& M [[buffer(5)]],
|
||||||
|
const constant int& N [[buffer(6)]],
|
||||||
|
const constant int& K [[buffer(7)]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint lid [[thread_index_in_threadgroup]],
|
||||||
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
|
|
||||||
|
static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
|
||||||
|
static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
|
||||||
|
|
||||||
|
const uint lidy = lid / SIMD_SIZE;
|
||||||
|
|
||||||
|
constexpr int WM = 2;
|
||||||
|
constexpr int WN = 2;
|
||||||
|
constexpr int bitmask = (1 << width) - 1;
|
||||||
|
constexpr int el_per_int = 32 / width;
|
||||||
|
constexpr int ints_per_block = BK / el_per_int;
|
||||||
|
constexpr int groups_per_block = (BK / groups > 0) ? (BK / groups) : 1;
|
||||||
|
constexpr int groups_per_simd = BN / (WM * WN);
|
||||||
|
constexpr int w_els_per_thread = (BN * BK / el_per_int) / (SIMD_SIZE * WM * WN);
|
||||||
|
|
||||||
|
// Using the kernel just as a type to instantiate the appropriate BlockMMA
|
||||||
|
// and constexpr size calculations
|
||||||
|
using mma_t = BlockMMA<T, BM, BN, BK, WM, WN, false, true>;
|
||||||
|
using loader_x_t = BlockLoader<T, BM, BK, BK, 4, WM * WN * SIMD_SIZE, false, true, 0>;
|
||||||
|
|
||||||
|
threadgroup T scales_block[BN * groups_per_block];
|
||||||
|
threadgroup T biases_block[BN * groups_per_block];
|
||||||
|
threadgroup T Xs[BM * BK];
|
||||||
|
threadgroup T Ws[BN * BK];
|
||||||
|
|
||||||
|
// Set the block
|
||||||
|
const int K_w = K / el_per_int;
|
||||||
|
const int K_g = K / groups;
|
||||||
|
const int y_row = tid.y * BM;
|
||||||
|
const int y_col = tid.x * BN;
|
||||||
|
x += y_row * K;
|
||||||
|
w += y_col * K_w;
|
||||||
|
scales += y_col * K_g;
|
||||||
|
biases += y_col * K_g;
|
||||||
|
y += y_row * N + y_col;
|
||||||
|
|
||||||
|
// Make the x loader and mma operation
|
||||||
|
const short num_els = min(BM, M - y_row);
|
||||||
|
loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
|
||||||
|
mma_t mma_op(simd_gid, simd_lid);
|
||||||
|
|
||||||
|
for (int k=0; k<K; k += BK) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
// Load the x tile
|
||||||
|
if (num_els < BM) {
|
||||||
|
loader_x.load_safe(short2(BK, num_els));
|
||||||
|
} else {
|
||||||
|
loader_x.load_unsafe();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load the scale and bias
|
||||||
|
if (simd_lid == 0) {
|
||||||
|
threadgroup T *scales_block_local = scales_block + lidy * groups_per_block * groups_per_simd;
|
||||||
|
threadgroup T *biases_block_local = biases_block + lidy * groups_per_block * groups_per_simd;
|
||||||
|
const device T *scales_local = scales + lidy * groups_per_simd * K_g + k / groups;
|
||||||
|
const device T *biases_local = biases + lidy * groups_per_simd * K_g + k / groups;
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for (int gs=0; gs<groups_per_simd; gs++) {
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for (int gc=0; gc<groups_per_block; gc++) {
|
||||||
|
scales_block_local[gc] = scales_local[gc];
|
||||||
|
biases_block_local[gc] = biases_local[gc];
|
||||||
|
}
|
||||||
|
scales_block_local += groups_per_block;
|
||||||
|
scales_local += K_g;
|
||||||
|
biases_block_local += groups_per_block;
|
||||||
|
biases_local += K_g;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Load the w tile
|
||||||
|
{
|
||||||
|
for (int wo=0; wo<w_els_per_thread; wo++) {
|
||||||
|
int offset = lid * w_els_per_thread + wo;
|
||||||
|
int offset_row = offset / (BK / el_per_int);
|
||||||
|
int offset_col = offset % (BK / el_per_int);
|
||||||
|
const device uint32_t * w_local = w + offset_row * K_w + offset_col;
|
||||||
|
threadgroup T * Ws_local = Ws + offset_row * BK + offset_col * el_per_int;
|
||||||
|
|
||||||
|
uint32_t wi = *w_local;
|
||||||
|
T scale = scales_block[offset_row * groups_per_block + offset_col / (groups / el_per_int)];
|
||||||
|
T bias = biases_block[offset_row * groups_per_block + offset_col / (groups / el_per_int)];
|
||||||
|
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for (int t=0; t<el_per_int; t++) {
|
||||||
|
Ws_local[t] = scale * static_cast<T>(wi & bitmask) + bias;
|
||||||
|
wi >>= width;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Multiply and accumulate threadgroup elements
|
||||||
|
mma_op.mma(Xs, Ws);
|
||||||
|
|
||||||
|
// Prepare for next iteration
|
||||||
|
loader_x.next();
|
||||||
|
w += ints_per_block;
|
||||||
|
// scales and biases cannot be advanced because they would have to be
|
||||||
|
// advanced every other iteration or sth.
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store results to device memory
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
if (num_els < BM) {
|
||||||
|
mma_op.store_result_safe(y, N, short2(BN, num_els));
|
||||||
|
} else {
|
||||||
|
mma_op.store_result(y, N);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#define instantiate_qmv(name, itype, groups, width) \
|
||||||
|
template [[host_name("qmv_n_" #name "_groups_" #groups "_width_" #width)]] \
|
||||||
|
[[kernel]] void qmv<itype, 32, 32, groups, width>( \
|
||||||
|
const device uint32_t* w [[buffer(0)]], \
|
||||||
|
const device itype* scales [[buffer(1)]], \
|
||||||
|
const device itype* biases [[buffer(2)]], \
|
||||||
|
const device itype* x [[buffer(3)]], \
|
||||||
|
device itype* y [[buffer(4)]], \
|
||||||
|
const constant int& in_vec_size [[buffer(5)]], \
|
||||||
|
const constant int& out_vec_size [[buffer(6)]], \
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]], \
|
||||||
|
uint lid [[thread_index_in_threadgroup]], \
|
||||||
|
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||||
|
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||||
|
|
||||||
|
#define instantiate_qmv_types(groups, width) \
|
||||||
|
instantiate_qmv(float32, float, groups, width) \
|
||||||
|
instantiate_qmv(float16, half, groups, width) \
|
||||||
|
instantiate_qmv(bfloat16, bfloat16_t, groups, width)
|
||||||
|
|
||||||
|
instantiate_qmv_types(128, 2)
|
||||||
|
instantiate_qmv_types(128, 4)
|
||||||
|
instantiate_qmv_types(128, 8)
|
||||||
|
instantiate_qmv_types( 64, 2)
|
||||||
|
instantiate_qmv_types( 64, 4)
|
||||||
|
instantiate_qmv_types( 64, 8)
|
||||||
|
|
||||||
|
#define instantiate_qmm_t(name, itype, groups, width) \
|
||||||
|
template [[host_name("qmm_t_" #name "_groups_" #groups "_width_" #width)]] \
|
||||||
|
[[kernel]] void qmm_t<itype, 32, 64, 32, groups, width>( \
|
||||||
|
const device itype* x [[buffer(0)]], \
|
||||||
|
const device uint32_t* w [[buffer(1)]], \
|
||||||
|
const device itype* scales [[buffer(2)]], \
|
||||||
|
const device itype* biases [[buffer(3)]], \
|
||||||
|
device itype* y [[buffer(4)]], \
|
||||||
|
const constant int& M [[buffer(5)]], \
|
||||||
|
const constant int& N [[buffer(6)]], \
|
||||||
|
const constant int& K [[buffer(7)]], \
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]], \
|
||||||
|
uint lid [[thread_index_in_threadgroup]], \
|
||||||
|
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||||
|
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||||
|
|
||||||
|
#define instantiate_qmm_t_types(groups, width) \
|
||||||
|
instantiate_qmm_t(float32, float, groups, width) \
|
||||||
|
instantiate_qmm_t(float16, half, groups, width) \
|
||||||
|
instantiate_qmm_t(bfloat16, bfloat16_t, groups, width)
|
||||||
|
|
||||||
|
instantiate_qmm_t_types(128, 2)
|
||||||
|
instantiate_qmm_t_types(128, 4)
|
||||||
|
instantiate_qmm_t_types(128, 8)
|
||||||
|
instantiate_qmm_t_types( 64, 2)
|
||||||
|
instantiate_qmm_t_types( 64, 4)
|
||||||
|
instantiate_qmm_t_types( 64, 8)
|
123
mlx/backend/metal/quantized.cpp
Normal file
123
mlx/backend/metal/quantized.cpp
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/copy.h"
|
||||||
|
#include "mlx/backend/metal/device.h"
|
||||||
|
#include "mlx/backend/metal/utils.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 4);
|
||||||
|
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
auto& s = stream();
|
||||||
|
auto& d = metal::device(s.device);
|
||||||
|
|
||||||
|
auto& x_pre = inputs[0];
|
||||||
|
auto& w_pre = inputs[1];
|
||||||
|
auto& scales_pre = inputs[2];
|
||||||
|
auto& biases_pre = inputs[3];
|
||||||
|
|
||||||
|
std::vector<array> copies;
|
||||||
|
auto check_transpose = [&copies, &s](const array& arr) {
|
||||||
|
auto stx = arr.strides()[arr.ndim() - 2];
|
||||||
|
auto sty = arr.strides()[arr.ndim() - 1];
|
||||||
|
if (stx == arr.shape(-1) && sty == 1) {
|
||||||
|
return std::make_tuple(false, stx, arr);
|
||||||
|
} else if (stx == 1 && sty == arr.shape(-2)) {
|
||||||
|
return std::make_tuple(true, sty, arr);
|
||||||
|
} else {
|
||||||
|
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||||
|
copy_gpu(arr, arr_copy, CopyType::General, s);
|
||||||
|
copies.push_back(arr_copy);
|
||||||
|
size_t stx = arr.shape(-1);
|
||||||
|
return std::make_tuple(false, stx, arr_copy);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
auto [x_transposed, x_cols, x] = check_transpose(x_pre);
|
||||||
|
auto [w_transposed, w_cols, w] = check_transpose(w_pre);
|
||||||
|
auto [scales_transposed, scales_cols, scales] = check_transpose(scales_pre);
|
||||||
|
auto [biases_transposed, biases_cols, biases] = check_transpose(biases_pre);
|
||||||
|
|
||||||
|
if (!w_transposed) {
|
||||||
|
throw std::runtime_error("The quantized weight should be transposed.");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (x_transposed || scales_transposed || biases_transposed) {
|
||||||
|
throw std::runtime_error("x, scales and biases should be row contiguous.");
|
||||||
|
}
|
||||||
|
|
||||||
|
int D = x.shape(-1);
|
||||||
|
int B = x.size() / D;
|
||||||
|
|
||||||
|
// Route to the qmv kernel
|
||||||
|
if (B == 1) {
|
||||||
|
std::ostringstream kname;
|
||||||
|
kname << "qmv_" << (w_transposed ? "n_" : "t_") << type_to_name(out)
|
||||||
|
<< "_groups_" << groups_ << "_width_" << width_;
|
||||||
|
|
||||||
|
// Encode and dispatch kernel
|
||||||
|
auto compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
auto kernel = d.get_kernel(kname.str());
|
||||||
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
|
int O = w.size() / w_cols;
|
||||||
|
|
||||||
|
int bo = 32;
|
||||||
|
int bd = 32;
|
||||||
|
MTL::Size group_dims = MTL::Size(bd, bo, 1);
|
||||||
|
MTL::Size grid_dims = MTL::Size(1, O / bo, B);
|
||||||
|
|
||||||
|
set_array_buffer(compute_encoder, w, 0);
|
||||||
|
set_array_buffer(compute_encoder, scales, 1);
|
||||||
|
set_array_buffer(compute_encoder, biases, 2);
|
||||||
|
set_array_buffer(compute_encoder, x, 3);
|
||||||
|
set_array_buffer(compute_encoder, out, 4);
|
||||||
|
compute_encoder->setBytes(&D, sizeof(int), 5);
|
||||||
|
compute_encoder->setBytes(&O, sizeof(int), 6);
|
||||||
|
|
||||||
|
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Route to the qmm kernel
|
||||||
|
else {
|
||||||
|
std::ostringstream kname;
|
||||||
|
kname << "qmm_" << (w_transposed ? "t_" : "n_") << type_to_name(out)
|
||||||
|
<< "_groups_" << groups_ << "_width_" << width_;
|
||||||
|
|
||||||
|
// Encode and dispatch kernel
|
||||||
|
auto compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
auto kernel = d.get_kernel(kname.str());
|
||||||
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
|
int O = w.size() / w_cols;
|
||||||
|
|
||||||
|
int wn = 2;
|
||||||
|
int wm = 2;
|
||||||
|
int bm = 32;
|
||||||
|
int bn = 32;
|
||||||
|
int bk = 64;
|
||||||
|
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||||
|
MTL::Size grid_dims = MTL::Size(O / bn, (B + bm - 1) / bm, 1);
|
||||||
|
|
||||||
|
set_array_buffer(compute_encoder, x, 0);
|
||||||
|
set_array_buffer(compute_encoder, w, 1);
|
||||||
|
set_array_buffer(compute_encoder, scales, 2);
|
||||||
|
set_array_buffer(compute_encoder, biases, 3);
|
||||||
|
set_array_buffer(compute_encoder, out, 4);
|
||||||
|
compute_encoder->setBytes(&B, sizeof(int), 5);
|
||||||
|
compute_encoder->setBytes(&O, sizeof(int), 6);
|
||||||
|
compute_encoder->setBytes(&D, sizeof(int), 7);
|
||||||
|
|
||||||
|
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||||
|
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
@ -58,6 +58,7 @@ NO_GPU(NotEqual)
|
|||||||
NO_GPU(Pad)
|
NO_GPU(Pad)
|
||||||
NO_GPU(Partition)
|
NO_GPU(Partition)
|
||||||
NO_GPU(Power)
|
NO_GPU(Power)
|
||||||
|
NO_GPU(QuantizedMatmul)
|
||||||
NO_GPU(RandomBits)
|
NO_GPU(RandomBits)
|
||||||
NO_GPU(Reduce)
|
NO_GPU(Reduce)
|
||||||
NO_GPU(Reshape)
|
NO_GPU(Reshape)
|
||||||
|
71
mlx/ops.cpp
71
mlx/ops.cpp
@ -2564,4 +2564,75 @@ array conv2d(
|
|||||||
{in, wt});
|
{in, wt});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array quantized_matmul(
|
||||||
|
const array& in_x,
|
||||||
|
const array& w,
|
||||||
|
const array& scales,
|
||||||
|
const array& biases,
|
||||||
|
int groups /* = 128 */,
|
||||||
|
int width /* = 4 */,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
auto x = in_x;
|
||||||
|
|
||||||
|
if (w.dtype() != uint32) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[quantized_matmul] The weight matrix should be uint32 "
|
||||||
|
<< "but received" << w.dtype();
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
if (w.ndim() != 2) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[quantized_matmul] Batched quantized matmul is not supported for now "
|
||||||
|
<< "received w with shape " << w.shape();
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Keep x's batch dimensions to reshape it back after the matmul
|
||||||
|
auto original_shape = x.shape();
|
||||||
|
int x_inner_dims = original_shape.back();
|
||||||
|
original_shape.pop_back();
|
||||||
|
|
||||||
|
// Reshape x into a matrix if it isn't already one
|
||||||
|
if (x.ndim() != 2) {
|
||||||
|
x = reshape(x, {-1, x_inner_dims}, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
int w_inner_dims = w.shape(0) * (32 / width);
|
||||||
|
if (w_inner_dims != x_inner_dims) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[quantized_matmul] Last dimension of first input with "
|
||||||
|
<< "shape (..., " << x_inner_dims
|
||||||
|
<< ") does not match the expanded first "
|
||||||
|
<< "dimension of the quantized matrix " << w_inner_dims
|
||||||
|
<< ", computed from shape " << w.shape() << " with groups=" << groups
|
||||||
|
<< " and width=" << width;
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
int n_groups = x_inner_dims / groups;
|
||||||
|
if (scales.shape(-1) != n_groups || biases.shape(-1) != n_groups) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[quantized_matmul] Scales and biases provided do not match the "
|
||||||
|
<< "quantization arguments (groups=" << groups << ", width=" << width
|
||||||
|
<< "). Expected shapes (" << w.shape(1) << ", " << x_inner_dims / groups
|
||||||
|
<< "), but got scales.shape=" << scales.shape()
|
||||||
|
<< " and biases.shape=" << biases.shape();
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto out = array(
|
||||||
|
{x.shape(0), w.shape(1)},
|
||||||
|
x.dtype(),
|
||||||
|
std::make_unique<QuantizedMatmul>(to_stream(s), groups, width),
|
||||||
|
{x, w, scales, biases});
|
||||||
|
|
||||||
|
// If needed reshape x to the original batch shape
|
||||||
|
if (original_shape.size() != 1) {
|
||||||
|
original_shape.push_back(w.shape(1));
|
||||||
|
out = reshape(out, original_shape, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
10
mlx/ops.h
10
mlx/ops.h
@ -1028,4 +1028,14 @@ array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s = {});
|
|||||||
/** Load array from file in .npy format */
|
/** Load array from file in .npy format */
|
||||||
array load(const std::string& file, StreamOrDevice s = {});
|
array load(const std::string& file, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Quantized matmul multiplies x with a quantized matrix w*/
|
||||||
|
array quantized_matmul(
|
||||||
|
const array& x,
|
||||||
|
const array& w,
|
||||||
|
const array& scales,
|
||||||
|
const array& biases,
|
||||||
|
int groups = 128,
|
||||||
|
int width = 4,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -1696,6 +1696,31 @@ std::pair<array, int> Power::vmap(
|
|||||||
return {power(a, b, stream()), to_ax};
|
return {power(a, b, stream()), to_ax};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::pair<array, int> QuantizedMatmul::vmap(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<int>& axes) {
|
||||||
|
throw std::runtime_error("QuantizedMatmul::vmap NYI");
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<array> QuantizedMatmul::vjp(
|
||||||
|
const std::vector<array>& primals,
|
||||||
|
const array& cotan,
|
||||||
|
const std::vector<int>& argnums) {
|
||||||
|
throw std::runtime_error("QuantizedMatmul::vjp NYI");
|
||||||
|
}
|
||||||
|
|
||||||
|
array QuantizedMatmul::jvp(
|
||||||
|
const std::vector<array>& primals,
|
||||||
|
const std::vector<array>& tangents,
|
||||||
|
const std::vector<int>& argnums) {
|
||||||
|
throw std::runtime_error("QuantizedMatmul::vjp NYI");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool QuantizedMatmul::is_equivalent(const Primitive& other) const {
|
||||||
|
const QuantizedMatmul& qm_other = static_cast<const QuantizedMatmul&>(other);
|
||||||
|
return groups_ == qm_other.groups_ && width_ == qm_other.width_;
|
||||||
|
}
|
||||||
|
|
||||||
std::pair<array, int> RandomBits::vmap(
|
std::pair<array, int> RandomBits::vmap(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<int>& axes) {
|
const std::vector<int>& axes) {
|
||||||
|
@ -1110,6 +1110,29 @@ class Power : public Primitive {
|
|||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class QuantizedMatmul : public Primitive {
|
||||||
|
public:
|
||||||
|
explicit QuantizedMatmul(Stream stream, int groups, int width)
|
||||||
|
: Primitive(stream), groups_(groups), width_(width){};
|
||||||
|
|
||||||
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
|
|
||||||
|
std::pair<array, int> vmap(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<int>& axes) override;
|
||||||
|
|
||||||
|
DEFINE_GRADS()
|
||||||
|
DEFINE_PRINT(QuantizedMatmul)
|
||||||
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
int groups_;
|
||||||
|
int width_;
|
||||||
|
|
||||||
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
|
};
|
||||||
|
|
||||||
class RandomBits : public Primitive {
|
class RandomBits : public Primitive {
|
||||||
public:
|
public:
|
||||||
explicit RandomBits(Stream stream, const std::vector<int>& shape, int width)
|
explicit RandomBits(Stream stream, const std::vector<int>& shape, int width)
|
||||||
|
@ -2977,4 +2977,36 @@ void init_ops(py::module_& m) {
|
|||||||
Returns:
|
Returns:
|
||||||
result (array): An array of the same type as ``a`` rounded to the given number of decimals.
|
result (array): An array of the same type as ``a`` rounded to the given number of decimals.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"quantized_matmul",
|
||||||
|
&quantized_matmul,
|
||||||
|
"x"_a,
|
||||||
|
"w"_a,
|
||||||
|
py::pos_only(),
|
||||||
|
"scales"_a,
|
||||||
|
"biases"_a,
|
||||||
|
"groups"_a = 128,
|
||||||
|
"width"_a = 4,
|
||||||
|
py::kw_only(),
|
||||||
|
"stream"_a = none,
|
||||||
|
R"pbdoc(
|
||||||
|
quantized_matmul(x: array, w: array, scales: array, biases: array, /, groups: int = 128, width: int = 4, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
|
Perform the matrix multiplication with the quantized matrix ``w``. The
|
||||||
|
quantization uses one floating point scale and bias per ``groups`` of
|
||||||
|
elements. Each element in ``w`` takes ``width`` bits and is packed in an
|
||||||
|
unsigned 32 bit integer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (array): Input array
|
||||||
|
w (array): Quantized matrix packed in unsigned integers
|
||||||
|
scales (array): The scales to use per ``groups`` elements of ``w``
|
||||||
|
biases (array): The biases to use per ``groups`` elements of ``w``
|
||||||
|
groups (int): The size of the group in ``w`` that shares a scale and
|
||||||
|
bias. (default: 128)
|
||||||
|
width (int): The bitwidth of the elements in ``w``. (default: 4)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
result (array): The result of the multiplication of ``x`` with ``w``.
|
||||||
|
)pbdoc");
|
||||||
}
|
}
|
||||||
|
112
python/tests/test_quantized.py
Normal file
112
python/tests/test_quantized.py
Normal file
@ -0,0 +1,112 @@
|
|||||||
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx_tests
|
||||||
|
|
||||||
|
|
||||||
|
def select_bits(w, width, start):
|
||||||
|
shift_left = 32 - (start + width)
|
||||||
|
shift_right = shift_left + start
|
||||||
|
return (w * (2**shift_left)) // (2**shift_right)
|
||||||
|
|
||||||
|
|
||||||
|
def dequantize(w, scales, biases, width):
|
||||||
|
w_full = mx.concatenate(
|
||||||
|
[select_bits(w, width, i)[..., None] for i in range(0, 32, width)], axis=-1
|
||||||
|
)
|
||||||
|
w_full = w_full.reshape(len(w), scales.shape[-1], -1)
|
||||||
|
w_full = scales[..., None] * w_full + biases[..., None]
|
||||||
|
w_full = w_full.reshape(len(w), -1)
|
||||||
|
|
||||||
|
return w_full
|
||||||
|
|
||||||
|
|
||||||
|
def quantize(w, width, groups):
|
||||||
|
w = w.reshape(len(w), -1, groups)
|
||||||
|
w_max = w.max(-1, keepdims=True)
|
||||||
|
w_min = w.min(-1, keepdims=True)
|
||||||
|
delta = (w_max - w_min) / (2**width - 1)
|
||||||
|
|
||||||
|
w_int = mx.round((w - w_min) / delta).astype(mx.uint32)
|
||||||
|
scales = delta.squeeze(-1)
|
||||||
|
biases = w_min.squeeze(-1)
|
||||||
|
|
||||||
|
shifts = mx.array([2**i for i in range(0, 32, width)], dtype=mx.uint32)
|
||||||
|
w_int = w_int.reshape(len(w), -1, 32 // width)
|
||||||
|
w_int = w_int * shifts[None, None]
|
||||||
|
packed_w = w_int.sum(-1)
|
||||||
|
|
||||||
|
return packed_w, scales, biases
|
||||||
|
|
||||||
|
|
||||||
|
class TestQuantized(mlx_tests.MLXTestCase):
|
||||||
|
def test_quantize_dequantize(self):
|
||||||
|
w = mx.random.normal(shape=(128, 128))
|
||||||
|
w_q, scales, biases = quantize(w, 4, 64)
|
||||||
|
w_hat = dequantize(w_q, scales, biases, 4)
|
||||||
|
w_hat2 = dequantize(*quantize(w_hat, 4, 64), 4)
|
||||||
|
self.assertLess((w_hat - w_hat2).abs().max(), 1e-6)
|
||||||
|
|
||||||
|
def test_qmm(self):
|
||||||
|
key = mx.random.key(0)
|
||||||
|
k1, k2 = mx.random.split(key)
|
||||||
|
for groups in [128, 64]:
|
||||||
|
for width in [2, 4, 8]:
|
||||||
|
for M in [8, 32, 33, 64]:
|
||||||
|
for N in [512, 1024]:
|
||||||
|
for K in [512, 1024]:
|
||||||
|
with self.subTest(
|
||||||
|
shape=(M, N, K), groups=groups, width=width
|
||||||
|
):
|
||||||
|
x = mx.random.normal(shape=(M, K), key=k1)
|
||||||
|
w = mx.random.normal(shape=(N, K), key=k2)
|
||||||
|
w_q, scales, biases = quantize(w, width, groups)
|
||||||
|
w_hat = dequantize(w_q, scales, biases, width)
|
||||||
|
y_q = mx.quantized_matmul(
|
||||||
|
x, w_q.T, scales, biases, width=width, groups=groups
|
||||||
|
)
|
||||||
|
y_hat = x @ w_hat.T
|
||||||
|
self.assertEqual(y_q.shape, y_hat.shape)
|
||||||
|
self.assertLess((y_q - y_hat).abs().max(), 0.1)
|
||||||
|
|
||||||
|
def test_qmm_shapes(self):
|
||||||
|
key = mx.random.key(0)
|
||||||
|
k1, k2 = mx.random.split(key)
|
||||||
|
groups = 64
|
||||||
|
width = 4
|
||||||
|
w = mx.random.normal(shape=(32, 128), key=k2)
|
||||||
|
w_q, scales, biases = quantize(w, width, groups)
|
||||||
|
w_hat = dequantize(w_q, scales, biases, width)
|
||||||
|
for s in [(3, 128), (2, 1, 7, 128)]:
|
||||||
|
x = mx.random.normal(shape=(3, 128), key=k1)
|
||||||
|
y_q = mx.quantized_matmul(
|
||||||
|
x, w_q.T, scales, biases, width=width, groups=groups
|
||||||
|
)
|
||||||
|
y_hat = x @ w_hat.T
|
||||||
|
self.assertEqual(y_q.shape, y_hat.shape)
|
||||||
|
self.assertLess((y_q - y_hat).abs().max(), 0.1)
|
||||||
|
|
||||||
|
def test_qmv(self):
|
||||||
|
key = mx.random.key(0)
|
||||||
|
k1, k2 = mx.random.split(key)
|
||||||
|
for groups in [128, 64]:
|
||||||
|
for width in [2, 4, 8]:
|
||||||
|
for M in [512, 1024]:
|
||||||
|
for N in [512, 1024]:
|
||||||
|
# with self.subTest(shape=(M, N), groups=groups, width=width):
|
||||||
|
x = mx.random.normal(shape=(1, N), key=k1)
|
||||||
|
w = mx.random.normal(shape=(M, N), key=k2)
|
||||||
|
w_q, scales, biases = quantize(w, width, groups)
|
||||||
|
w_hat = dequantize(w_q, scales, biases, width)
|
||||||
|
y_q = mx.quantized_matmul(
|
||||||
|
x, w_q.T, scales, biases, width=width, groups=groups
|
||||||
|
)
|
||||||
|
y_hat = x @ w_hat.T
|
||||||
|
self.assertEqual(y_q.shape, y_hat.shape)
|
||||||
|
self.assertLess((y_q - y_hat).abs().max(), 0.1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
Loading…
Reference in New Issue
Block a user