mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-06 16:51:24 +08:00
Compare commits
7 Commits
v0.19.3
...
quantized-
Author | SHA1 | Date | |
---|---|---|---|
![]() |
f5b0f11968 | ||
![]() |
b509c2ad76 | ||
![]() |
852336b8a2 | ||
![]() |
6649244686 | ||
![]() |
047a584e3d | ||
![]() |
ef14b1e9c3 | ||
![]() |
5824626c0b |
@@ -1,10 +1,9 @@
|
||||
import argparse
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
from mlx.utils import tree_map
|
||||
from time_utils import time_fn
|
||||
|
||||
L = 1024
|
||||
L = 65536
|
||||
H = 32
|
||||
H_k = 32 // 4
|
||||
D = 128
|
||||
@@ -23,27 +22,60 @@ def attention(q, k, v):
|
||||
|
||||
|
||||
def sdpa(q, k, v):
|
||||
return mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0)
|
||||
return mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=None)
|
||||
|
||||
|
||||
def time_self_attention_primitives():
|
||||
mx.random.seed(3)
|
||||
q = mx.random.uniform(shape=(1, H, 1, D))
|
||||
k = mx.random.uniform(shape=(1, H_k, L, D))
|
||||
v = mx.random.uniform(shape=(1, H_k, L, D))
|
||||
mx.eval(q, k, v)
|
||||
def quant_sdpa(q, k, v, bits=4):
|
||||
return mx.fast.quantized_scaled_dot_product_attention(
|
||||
q, *k, *v, scale=1.0, mask=None, bits=bits
|
||||
)
|
||||
|
||||
|
||||
def quant_attention(q, k, v, bits=4):
|
||||
B, Hq, L, D = q.shape
|
||||
Hk = k[0].shape[1]
|
||||
|
||||
q = q.reshape((B, Hk, Hq // Hk, L, D))
|
||||
k = tree_map(lambda x: mx.expand_dims(x, axis=2), k)
|
||||
v = tree_map(lambda x: mx.expand_dims(x, axis=2), v)
|
||||
|
||||
scores = mx.quantized_matmul(q, *k, transpose=True, bits=bits)
|
||||
scores = mx.softmax(scores, axis=-1)
|
||||
|
||||
out = mx.quantized_matmul(scores, *v, transpose=False, bits=bits)
|
||||
out = out.reshape((B, Hq, L, D))
|
||||
return out
|
||||
|
||||
|
||||
def time_self_attention_primitives(q, k, v):
|
||||
time_fn(attention, q, k, v)
|
||||
|
||||
|
||||
def time_self_attention_sdpa():
|
||||
def time_self_attention_sdpa(q, k, v):
|
||||
time_fn(sdpa, q, k, v)
|
||||
|
||||
|
||||
def time_self_attention_quant_sdpa(q, k, v, bits=4):
|
||||
time_fn(quant_sdpa, q, k, v, bits)
|
||||
|
||||
|
||||
def time_self_attention_quant_primitives(q, k, v, bits=4):
|
||||
time_fn(quant_attention, q, k, v)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mx.random.seed(3)
|
||||
q = mx.random.uniform(shape=(1, H, 1, D))
|
||||
k = mx.random.uniform(shape=(1, H_k, L, D))
|
||||
v = mx.random.uniform(shape=(1, H_k, L, D))
|
||||
mx.eval(q, k, v)
|
||||
time_fn(sdpa, q, k, v)
|
||||
|
||||
bits = 4
|
||||
k_quant = mx.quantize(k, bits=bits)
|
||||
v_quant = mx.quantize(v, bits=bits)
|
||||
mx.eval(k_quant, v_quant)
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_self_attention_sdpa()
|
||||
time_self_attention_primitives()
|
||||
time_self_attention_sdpa(q, k, v)
|
||||
time_self_attention_quant_sdpa(q, k_quant, v_quant, bits)
|
||||
time_self_attention_primitives(q, k, v)
|
||||
time_self_attention_quant_primitives(q, k_quant, v_quant, bits)
|
||||
|
@@ -1,5 +1,6 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
|
@@ -15,6 +15,7 @@ void CustomKernel::eval_gpu(
|
||||
std::vector<array> copies;
|
||||
|
||||
for (auto& out : outputs) {
|
||||
// Copy from previous kernel
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
if (init_value_) {
|
||||
copies.emplace_back(init_value_.value(), out.dtype());
|
||||
|
@@ -1737,13 +1737,13 @@ template <
|
||||
}
|
||||
|
||||
template <typename T, const int group_size, const int bits>
|
||||
[[kernel]] void affine_quantize(
|
||||
const device T* w [[buffer(0)]],
|
||||
device uint8_t* out [[buffer(1)]],
|
||||
device T* scales [[buffer(2)]],
|
||||
device T* biases [[buffer(3)]],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
METAL_FUNC void affine_quantize_impl(
|
||||
const device T* w,
|
||||
device uint8_t* out,
|
||||
device T* scales,
|
||||
device T* biases,
|
||||
uint2 index,
|
||||
uint2 grid_dim) {
|
||||
constexpr T eps = T(1e-7);
|
||||
constexpr int simd_size = 32;
|
||||
constexpr int uint8_bits = 8;
|
||||
@@ -1820,6 +1820,18 @@ template <typename T, const int group_size, const int bits>
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, const int group_size, const int bits>
|
||||
[[kernel]] void affine_quantize(
|
||||
const device T* w [[buffer(0)]],
|
||||
device uint8_t* out [[buffer(1)]],
|
||||
device T* scales [[buffer(2)]],
|
||||
device T* biases [[buffer(3)]],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
affine_quantize_impl<T, group_size, bits>(
|
||||
w, out, scales, biases, index, grid_dim);
|
||||
}
|
||||
|
||||
template <typename T, const int group_size, const int bits>
|
||||
[[kernel]] void affine_quantize_scales_biases(
|
||||
const device T* w [[buffer(0)]],
|
||||
@@ -1883,3 +1895,41 @@ template <typename T, const int group_size, const int bits>
|
||||
out[oindex + i] = scale * d + bias;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, const int group_size, const int bits>
|
||||
[[kernel]] void kv_update(
|
||||
const device T* new_keys [[buffer(0)]],
|
||||
const device T* new_values [[buffer(1)]],
|
||||
device uint8_t* keys [[buffer(2)]],
|
||||
device T* key_scales [[buffer(3)]],
|
||||
device T* key_biases [[buffer(4)]],
|
||||
device uint8_t* values [[buffer(5)]],
|
||||
device T* value_scales [[buffer(6)]],
|
||||
device T* value_biases [[buffer(7)]],
|
||||
const constant int& offset,
|
||||
const constant int& batch_stride,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
// Get the right offset in the thing
|
||||
// Need to use the head dim too
|
||||
constexpr int pack_factor = 8 / bits;
|
||||
uint batch_idx = index.y * batch_stride * 4 + offset;
|
||||
new_keys += index.y * 128;
|
||||
new_values += index.y * 128;
|
||||
// uint batch_idx = offset;
|
||||
// // Index to correct slice
|
||||
uint group_idx = batch_idx * pack_factor / group_size;
|
||||
keys += batch_idx;
|
||||
key_scales += group_idx;
|
||||
key_biases += group_idx;
|
||||
values += batch_idx;
|
||||
value_scales += group_idx;
|
||||
value_biases += group_idx;
|
||||
|
||||
uint2 new_index = {index.x, 0};
|
||||
|
||||
affine_quantize_impl<T, group_size, bits>(
|
||||
new_keys, keys, key_scales, key_biases, new_index, grid_dim);
|
||||
affine_quantize_impl<T, group_size, bits>(
|
||||
new_values, values, value_scales, value_biases, new_index, grid_dim);
|
||||
}
|
||||
|
@@ -927,19 +927,7 @@ instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 128, 2, 2);
|
||||
|
||||
// SDPA vector instantiations
|
||||
#define instantiate_sdpa_vector(type, head_dim) \
|
||||
template [[host_name("sdpa_vector_" #type "_" #head_dim)]] \
|
||||
[[kernel]] void sdpa_vector<type, head_dim>( \
|
||||
const device type* queries [[buffer(0)]], \
|
||||
const device type* keys [[buffer(1)]], \
|
||||
const device type* values [[buffer(2)]], \
|
||||
device type* out [[buffer(3)]], \
|
||||
const constant int& gqa_factor, \
|
||||
const constant int& N, \
|
||||
const constant size_t& k_stride, \
|
||||
const constant float& scale, \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
instantiate_kernel("sdpa_vector_" #type "_" #head_dim, sdpa_vector, type, head_dim)
|
||||
|
||||
#define instantiate_sdpa_vector_heads(type) \
|
||||
instantiate_sdpa_vector(type, 64) \
|
||||
@@ -949,4 +937,30 @@ instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 128, 2, 2);
|
||||
instantiate_sdpa_vector_heads(float)
|
||||
instantiate_sdpa_vector_heads(bfloat16_t)
|
||||
instantiate_sdpa_vector_heads(float16_t)
|
||||
|
||||
// Quantized SDPA vector instantiations
|
||||
#define instantiate_quant_sdpa_vector(type, head_dim, group_size, bits) \
|
||||
instantiate_kernel( \
|
||||
"quant_sdpa_vector_" #type "_" #head_dim "_" #group_size "_" #bits, \
|
||||
quant_sdpa_vector, type, head_dim, group_size, bits)
|
||||
|
||||
#define instantiate_quant_sdpa_vector_bits(type, heads, group_size) \
|
||||
instantiate_quant_sdpa_vector(type, heads, group_size, 4) \
|
||||
instantiate_quant_sdpa_vector(type, heads, group_size, 8)
|
||||
|
||||
#define instantiate_quant_sdpa_vector_group_size(type, heads) \
|
||||
instantiate_quant_sdpa_vector_bits(type, heads, 32) \
|
||||
instantiate_quant_sdpa_vector_bits(type, heads, 64) \
|
||||
instantiate_quant_sdpa_vector_bits(type, heads, 128)
|
||||
|
||||
#define instantiate_quant_sdpa_vector_heads(type) \
|
||||
instantiate_quant_sdpa_vector_group_size(type, 64) \
|
||||
instantiate_quant_sdpa_vector_group_size(type, 96) \
|
||||
instantiate_quant_sdpa_vector_group_size(type, 128)
|
||||
|
||||
|
||||
instantiate_quant_sdpa_vector_heads(float)
|
||||
instantiate_quant_sdpa_vector_heads(bfloat16_t)
|
||||
instantiate_quant_sdpa_vector_heads(float16_t)
|
||||
|
||||
// clang-format on
|
||||
|
@@ -113,3 +113,205 @@ template <typename T, int D>
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, int elem_per_thread, int bits>
|
||||
METAL_FUNC U load_queries(const device T* queries, thread U* q, U scale) {
|
||||
U query_sum = 0;
|
||||
if (bits == 4) {
|
||||
for (int i = 0; i < elem_per_thread; i += 4) {
|
||||
q[i] = scale * queries[i];
|
||||
q[i + 1] = scale * queries[i + 1];
|
||||
q[i + 2] = scale * queries[i + 2];
|
||||
q[i + 3] = scale * queries[i + 3];
|
||||
query_sum += q[i] + q[i + 1] + q[i + 2] + q[i + 3];
|
||||
q[i + 1] /= 16.0f;
|
||||
q[i + 2] /= 256.0f;
|
||||
q[i + 3] /= 4096.0f;
|
||||
}
|
||||
} else if (bits == 8) {
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
q[i] = scale * queries[i];
|
||||
query_sum += q[i];
|
||||
}
|
||||
}
|
||||
return query_sum;
|
||||
}
|
||||
|
||||
template <typename U, int elem_per_thread, int bits>
|
||||
METAL_FUNC void load_keys(const device uint32_t* keys, thread U* k) {
|
||||
if (bits == 4) {
|
||||
auto ks = (const device uint16_t*)keys;
|
||||
for (int i = 0; i < elem_per_thread / 4; i++) {
|
||||
k[4 * i] = ks[i] & 0x000f;
|
||||
k[4 * i + 1] = ks[i] & 0x00f0;
|
||||
k[4 * i + 2] = ks[i] & 0x0f00;
|
||||
k[4 * i + 3] = ks[i] & 0xf000;
|
||||
}
|
||||
} else if (bits == 8) {
|
||||
auto ks = (const device uint8_t*)keys;
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
k[i] = ks[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U, int elem_per_thread, int bits>
|
||||
METAL_FUNC void load_values(
|
||||
const device uint32_t* values,
|
||||
thread U* v,
|
||||
U value_scale,
|
||||
U value_bias) {
|
||||
auto vs = (const device uint8_t*)values;
|
||||
if (bits == 4) {
|
||||
U s[2] = {value_scale, value_scale / 16.0f};
|
||||
for (int i = 0; i < elem_per_thread / 2; i++) {
|
||||
v[2 * i] = s[0] * (vs[i] & 0x0f) + value_bias;
|
||||
v[2 * i + 1] = s[1] * (vs[i] & 0xf0) + value_bias;
|
||||
}
|
||||
} else if (bits == 8) {
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
v[i] = value_scale * vs[i] + value_bias;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int D, int group_size, int bits>
|
||||
[[kernel]] void quant_sdpa_vector(
|
||||
const device T* queries [[buffer(0)]],
|
||||
const device uint32_t* keys [[buffer(1)]],
|
||||
const device T* key_scales [[buffer(2)]],
|
||||
const device T* key_biases [[buffer(3)]],
|
||||
const device uint32_t* values [[buffer(4)]],
|
||||
const device T* value_scales [[buffer(5)]],
|
||||
const device T* value_biases [[buffer(6)]],
|
||||
device T* out [[buffer(7)]],
|
||||
const constant int& gqa_factor,
|
||||
const constant int& N,
|
||||
const constant size_t& k_stride,
|
||||
const constant size_t& group_stride,
|
||||
const constant float& scale,
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]],
|
||||
uint quad_gid [[quadgroup_index_in_threadgroup]],
|
||||
uint quad_lid [[thread_index_in_quadgroup]]) {
|
||||
constexpr int BN = 32;
|
||||
constexpr int BD = 4;
|
||||
constexpr int elem_per_thread = D / BD;
|
||||
constexpr int pack_factor = 32 / bits;
|
||||
|
||||
const int stride = BN * D;
|
||||
|
||||
typedef float U;
|
||||
|
||||
thread U q[elem_per_thread];
|
||||
thread U k[elem_per_thread];
|
||||
thread U v[elem_per_thread];
|
||||
thread U o[elem_per_thread];
|
||||
|
||||
threadgroup U outputs[BN * BD];
|
||||
threadgroup U max_scores[BN];
|
||||
threadgroup U sum_exp_scores[BN];
|
||||
|
||||
// Adjust positions
|
||||
const int head_idx = tid.y;
|
||||
const int kv_head_idx = head_idx / gqa_factor;
|
||||
queries += head_idx * D + quad_lid * elem_per_thread;
|
||||
|
||||
const int kv_idx = quad_gid * D + quad_lid * elem_per_thread;
|
||||
const int packed_idx = kv_head_idx * k_stride + kv_idx / pack_factor;
|
||||
const int group_idx = kv_head_idx * group_stride + kv_idx / group_size;
|
||||
keys += packed_idx;
|
||||
key_scales += group_idx;
|
||||
key_biases += group_idx;
|
||||
values += packed_idx;
|
||||
value_scales += group_idx;
|
||||
value_biases += group_idx;
|
||||
|
||||
out += head_idx * D + simd_gid * elem_per_thread;
|
||||
|
||||
// Read the query and 0 the output accumulator
|
||||
U query_sum = load_queries<T, U, elem_per_thread, bits>(
|
||||
queries, q, static_cast<U>(scale));
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
o[i] = 0;
|
||||
}
|
||||
|
||||
U max_score = -INFINITY;
|
||||
U sum_exp_score = 0;
|
||||
|
||||
// For each key
|
||||
for (int i = quad_gid; i < N; i += BN) {
|
||||
load_keys<U, elem_per_thread, bits>(keys, k);
|
||||
|
||||
// Assume D % group_size == 0 so all the keys are in the same group
|
||||
U key_scale = key_scales[0];
|
||||
U key_bias = key_biases[0];
|
||||
|
||||
// Compute the i-th score
|
||||
U score = 0;
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
score += q[i] * k[i];
|
||||
}
|
||||
score = score * key_scale + query_sum * key_bias;
|
||||
score = quad_sum(score);
|
||||
|
||||
// Update the accumulators
|
||||
U new_max = max(max_score, score);
|
||||
U factor = fast::exp(max_score - new_max);
|
||||
U exp_score = fast::exp(score - new_max);
|
||||
|
||||
max_score = new_max;
|
||||
sum_exp_score = sum_exp_score * factor + exp_score;
|
||||
|
||||
U value_scale = value_scales[0];
|
||||
U value_bias = value_biases[0];
|
||||
|
||||
// Load the values
|
||||
load_values<U, elem_per_thread, bits>(values, v, value_scale, value_bias);
|
||||
|
||||
// Update the output accumulator
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
o[i] = o[i] * factor + exp_score * v[i];
|
||||
}
|
||||
|
||||
// Move the pointers to the next kv
|
||||
keys += stride / pack_factor;
|
||||
key_scales += stride / group_size;
|
||||
key_biases += stride / group_size;
|
||||
values += stride / pack_factor;
|
||||
value_scales += stride / group_size;
|
||||
value_biases += stride / group_size;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Each thread has a partial part of the output so we need to combine them.
|
||||
|
||||
// First let's communicate the max and sum_exp
|
||||
// Each quadgroup communicates it's max score
|
||||
if (quad_lid == 0) {
|
||||
max_scores[quad_gid] = max_score;
|
||||
sum_exp_scores[quad_gid] = sum_exp_score;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
max_score = max_scores[simd_lid];
|
||||
U new_max = simd_max(max_score);
|
||||
U factor = fast::exp(max_score - new_max);
|
||||
sum_exp_score = simd_sum(sum_exp_scores[simd_lid] * factor);
|
||||
|
||||
// Now we need to aggregate all the outputs
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
// 128 threads with 32 values per thread
|
||||
outputs[simd_gid * BN + simd_lid] = o[i];
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
o[i] = simd_sum(outputs[simd_lid * BD + simd_gid] * factor) / sum_exp_score;
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
// And write the output
|
||||
if (simd_lid == 0) {
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
out[i] = static_cast<T>(o[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -14,6 +14,8 @@
|
||||
#include "mlx/scheduler.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
#include <iostream>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <typename T>
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
@@ -354,4 +355,80 @@ void fast::AffineQuantize::eval_gpu(
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
}
|
||||
|
||||
void fast::KVUpdate::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
std::vector<array> copies;
|
||||
auto ensure_row_contiguous = [&copies, &s](const array& arr) {
|
||||
if (arr.flags().row_contiguous) {
|
||||
return arr;
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy_gpu(arr, arr_copy, CopyType::General, s);
|
||||
copies.push_back(arr_copy);
|
||||
return arr_copy;
|
||||
}
|
||||
};
|
||||
|
||||
// Copy from the inputs into the outputs
|
||||
const auto& new_keys = ensure_row_contiguous(inputs[0]);
|
||||
const auto& new_values = ensure_row_contiguous(inputs[1]);
|
||||
|
||||
// Copy the input KV cache to the output.
|
||||
// If the inputs are contiguous, this will be zero-copy.
|
||||
for (int i = 0; i < 6; i++) {
|
||||
auto in = ensure_row_contiguous(inputs[i + 2]);
|
||||
auto out = outputs[i];
|
||||
auto ctype = in.flags().contiguous && in.size() == in.data_size()
|
||||
? CopyType::Vector
|
||||
: CopyType::General;
|
||||
copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, s);
|
||||
}
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_input_array(new_keys, 0);
|
||||
compute_encoder.set_input_array(new_values, 1);
|
||||
int enc_offset = 2;
|
||||
for (auto& out : outputs) {
|
||||
compute_encoder.set_output_array(out, enc_offset);
|
||||
enc_offset++;
|
||||
}
|
||||
int offset = offset_ * inputs[2].strides(-2) * 4;
|
||||
// std::cout << "offset " << offset << std::endl;
|
||||
int batch_stride = inputs[2].shape(-1) * inputs[2].shape(-2);
|
||||
// std::cout << "batch stride " << batch_stride << std::endl;
|
||||
compute_encoder->setBytes(&offset, sizeof(int), enc_offset);
|
||||
compute_encoder->setBytes(&batch_stride, sizeof(int), enc_offset + 1);
|
||||
|
||||
auto type_string = get_type_string(new_keys.dtype());
|
||||
// Now launch the kernel
|
||||
std::ostringstream kname;
|
||||
kname << "kv_update" << "_" << type_string << "_gs_" << group_size_ << "_b_"
|
||||
<< bits_;
|
||||
auto template_def = get_template_definition(
|
||||
kname.str(), "kv_update", type_string, group_size_, bits_);
|
||||
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int per_thread = 8 / bits_;
|
||||
size_t nrows = new_keys.size() / new_keys.shape(-1);
|
||||
size_t ncols = new_keys.shape(-1) / per_thread;
|
||||
size_t nthreads = nrows * ncols;
|
||||
// std::cout << "nthreads " << nthreads << std::endl;
|
||||
// std::cout << "nrows " << nrows << std::endl;
|
||||
// std::cout << "ncols " << ncols << std::endl;
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size > nthreads) {
|
||||
thread_group_size = ncols;
|
||||
}
|
||||
auto group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
MTL::Size grid_dims = MTL::Size(ncols, nrows, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -185,19 +185,73 @@ void sdpa_vector(
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void quant_sdpa_vector(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& q,
|
||||
const array& k,
|
||||
const array& k_scales,
|
||||
const array& k_biases,
|
||||
const array& v,
|
||||
const array& v_scales,
|
||||
const array& v_biases,
|
||||
array& out,
|
||||
float scale,
|
||||
int group_size,
|
||||
int bits) {
|
||||
// Set the kernel name
|
||||
std::string kname;
|
||||
kname.reserve(96);
|
||||
kname += "quant_sdpa_vector_";
|
||||
kname += get_type_string(q.dtype());
|
||||
kname += "_";
|
||||
kname += std::to_string(q.shape(-1));
|
||||
kname += "_";
|
||||
kname += std::to_string(group_size);
|
||||
kname += "_";
|
||||
kname += std::to_string(bits);
|
||||
|
||||
// Compute the necessary sizes
|
||||
int gqa_factor = q.shape(1) / k.shape(1);
|
||||
int N = k.shape(2);
|
||||
int B = q.shape(0) * q.shape(1);
|
||||
size_t stride = k.strides()[1];
|
||||
size_t group_stride = k_scales.strides()[1];
|
||||
MTL::Size group_dims(128, 1, 1);
|
||||
MTL::Size grid_dims(1, B, 1);
|
||||
|
||||
// Get the kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Set its arguments
|
||||
compute_encoder.set_input_array(q.data_shared_ptr() == nullptr ? out : q, 0);
|
||||
compute_encoder.set_input_array(k, 1);
|
||||
compute_encoder.set_input_array(k_scales, 2);
|
||||
compute_encoder.set_input_array(k_biases, 3);
|
||||
compute_encoder.set_input_array(v, 4);
|
||||
compute_encoder.set_input_array(v_scales, 5);
|
||||
compute_encoder.set_input_array(v_biases, 6);
|
||||
compute_encoder.set_output_array(out, 7);
|
||||
compute_encoder->setBytes(&gqa_factor, sizeof(int), 8);
|
||||
compute_encoder->setBytes(&N, sizeof(int), 9);
|
||||
compute_encoder->setBytes(&stride, sizeof(size_t), 10);
|
||||
compute_encoder->setBytes(&group_stride, sizeof(size_t), 11);
|
||||
compute_encoder->setBytes(&scale, sizeof(float), 12);
|
||||
|
||||
// Launch
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void ScaledDotProductAttention::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out) {
|
||||
assert(inputs.size() == 3);
|
||||
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
auto& q_pre = inputs[0];
|
||||
auto& k_pre = inputs[1];
|
||||
auto& v_pre = inputs[2];
|
||||
auto& o = out;
|
||||
|
||||
std::vector<array> copies;
|
||||
@@ -236,11 +290,25 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
return strides[3] == 1 && strides[2] == shape[3];
|
||||
};
|
||||
|
||||
// We are in vector mode ie single query
|
||||
if (q_pre.shape(2) == 1) {
|
||||
if (quantized_) {
|
||||
auto& q_pre = inputs[0];
|
||||
auto& k_pre = inputs[1];
|
||||
auto& k_scales_pre = inputs[2];
|
||||
auto& k_biases_pre = inputs[3];
|
||||
auto& v_pre = inputs[4];
|
||||
auto& v_scales_pre = inputs[5];
|
||||
auto& v_biases_pre = inputs[6];
|
||||
|
||||
// Quantized should only be routed here for single queries
|
||||
assert(q_pre.shape(2) == 1);
|
||||
|
||||
auto q = copy_unless(is_contiguous, q_pre);
|
||||
auto k = copy_unless(is_contiguous_except_seq_len, k_pre);
|
||||
auto k_scales = copy_unless(is_contiguous_except_seq_len, k_scales_pre);
|
||||
auto k_biases = copy_unless(is_contiguous_except_seq_len, k_biases_pre);
|
||||
auto v = copy_unless(is_contiguous_except_seq_len, v_pre);
|
||||
auto v_scales = copy_unless(is_contiguous_except_seq_len, v_scales_pre);
|
||||
auto v_biases = copy_unless(is_contiguous_except_seq_len, v_biases_pre);
|
||||
|
||||
// Donate the query if possible
|
||||
if (q.is_donatable()) {
|
||||
@@ -249,17 +317,54 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
o.set_data(allocator::malloc_or_wait(o.nbytes()));
|
||||
}
|
||||
|
||||
sdpa_vector(s, d, q, k, v, o, scale_);
|
||||
quant_sdpa_vector(
|
||||
s,
|
||||
d,
|
||||
q,
|
||||
k,
|
||||
k_scales,
|
||||
k_biases,
|
||||
v,
|
||||
v_scales,
|
||||
v_biases,
|
||||
o,
|
||||
scale_,
|
||||
group_size_,
|
||||
bits_);
|
||||
|
||||
}
|
||||
|
||||
// Full attention mode
|
||||
// Non-quantized
|
||||
else {
|
||||
auto q = copy_unless(is_matrix_contiguous, q_pre);
|
||||
auto k = copy_unless(is_matrix_contiguous, k_pre);
|
||||
auto v = copy_unless(is_matrix_contiguous, v_pre);
|
||||
o.set_data(allocator::malloc_or_wait(o.nbytes()));
|
||||
assert(inputs.size() == 3);
|
||||
auto& q_pre = inputs[0];
|
||||
auto& k_pre = inputs[1];
|
||||
auto& v_pre = inputs[2];
|
||||
|
||||
sdpa_full_self_attention_metal(s, d, q, k, v, scale_, o);
|
||||
// We are in vector mode ie single query
|
||||
if (q_pre.shape(2) == 1) {
|
||||
auto q = copy_unless(is_contiguous, q_pre);
|
||||
auto k = copy_unless(is_contiguous_except_seq_len, k_pre);
|
||||
auto v = copy_unless(is_contiguous_except_seq_len, v_pre);
|
||||
|
||||
// Donate the query if possible
|
||||
if (q.is_donatable()) {
|
||||
o.move_shared_buffer(q);
|
||||
} else {
|
||||
o.set_data(allocator::malloc_or_wait(o.nbytes()));
|
||||
}
|
||||
|
||||
sdpa_vector(s, d, q, k, v, o, scale_);
|
||||
}
|
||||
// Full attention mode
|
||||
else {
|
||||
auto q = copy_unless(is_matrix_contiguous, q_pre);
|
||||
auto k = copy_unless(is_matrix_contiguous, k_pre);
|
||||
auto v = copy_unless(is_matrix_contiguous, v_pre);
|
||||
o.set_data(allocator::malloc_or_wait(o.nbytes()));
|
||||
|
||||
sdpa_full_self_attention_metal(s, d, q, k, v, scale_, o);
|
||||
}
|
||||
}
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
|
@@ -123,6 +123,7 @@ NO_GPU_MULTI(RMSNormVJP)
|
||||
NO_GPU_MULTI(RoPE)
|
||||
NO_GPU(ScaledDotProductAttention)
|
||||
NO_GPU_MULTI(AffineQuantize)
|
||||
NO_GPU_MULTI(KVUpdate)
|
||||
NO_GPU_MULTI(CustomKernel)
|
||||
} // namespace fast
|
||||
|
||||
|
172
mlx/fast.cpp
172
mlx/fast.cpp
@@ -648,7 +648,7 @@ array scaled_dot_product_attention(
|
||||
std::move(out_shape),
|
||||
final_type,
|
||||
std::make_shared<ScaledDotProductAttention>(
|
||||
stream, fallback, scale, false),
|
||||
stream, fallback, scale, /*needs_mask=*/false, /*quantized=*/false),
|
||||
{q, k, v});
|
||||
}
|
||||
|
||||
@@ -662,7 +662,130 @@ array scaled_dot_product_attention(
|
||||
bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const {
|
||||
const ScaledDotProductAttention& a_other =
|
||||
static_cast<const ScaledDotProductAttention&>(other);
|
||||
return needs_mask_ == a_other.needs_mask_ && scale_ == a_other.scale_;
|
||||
return needs_mask_ == a_other.needs_mask_ && scale_ == a_other.scale_ &&
|
||||
quantized_ == a_other.quantized_;
|
||||
}
|
||||
|
||||
array quantized_scaled_dot_product_attention(
|
||||
const array& queries,
|
||||
const array& keys,
|
||||
const array& key_scales,
|
||||
const array& key_biases,
|
||||
const array& values,
|
||||
const array& value_scales,
|
||||
const array& value_biases,
|
||||
const float scale,
|
||||
const std::optional<array>& mask,
|
||||
const int group_size,
|
||||
const int bits,
|
||||
StreamOrDevice s) {
|
||||
int el_per_int = 32 / bits;
|
||||
int out_dim = values.shape(-1) * el_per_int;
|
||||
|
||||
auto n_q_heads = queries.shape(-3);
|
||||
auto n_kv_heads = keys.shape(-3);
|
||||
|
||||
auto out_shape = std::vector<int>(
|
||||
{queries.shape(0), queries.shape(1), queries.shape(2), out_dim});
|
||||
auto stream = to_stream(s);
|
||||
bool needs_mask = mask.has_value();
|
||||
auto fallback =
|
||||
[scale, needs_mask, n_q_heads, n_kv_heads, group_size, bits, &s](
|
||||
const std::vector<array>& inputs) -> std::vector<array> {
|
||||
int n_repeats = n_q_heads / n_kv_heads;
|
||||
|
||||
auto q = multiply(array(scale, inputs[0].dtype()), inputs[0], s);
|
||||
|
||||
auto k = inputs[1];
|
||||
auto k_scales = inputs[2];
|
||||
auto k_biases = inputs[3];
|
||||
auto v = inputs[4];
|
||||
auto v_scales = inputs[5];
|
||||
auto v_biases = inputs[6];
|
||||
|
||||
int B = q.shape(0);
|
||||
int L = q.shape(2);
|
||||
|
||||
if (n_repeats > 1) {
|
||||
q = reshape(q, {B, n_kv_heads, n_repeats, L, -1}, s);
|
||||
k = expand_dims(k, 2, s);
|
||||
k_scales = expand_dims(k_scales, 2, s);
|
||||
k_biases = expand_dims(k_biases, 2, s);
|
||||
v = expand_dims(v, 2, s);
|
||||
v_scales = expand_dims(v_scales, 2, s);
|
||||
v_biases = expand_dims(v_biases, 2, s);
|
||||
}
|
||||
|
||||
array scores = quantized_matmul(
|
||||
q,
|
||||
k,
|
||||
k_scales,
|
||||
k_biases,
|
||||
/*transpose=*/true,
|
||||
/*group_size=*/group_size,
|
||||
/*bits=*/bits,
|
||||
s);
|
||||
if (needs_mask) {
|
||||
scores = add(scores, inputs[7], s);
|
||||
}
|
||||
scores = softmax(scores, std::vector<int>{-1}, true, s);
|
||||
array out = quantized_matmul(
|
||||
scores,
|
||||
v,
|
||||
v_scales,
|
||||
v_biases,
|
||||
/*transpose=*/false,
|
||||
/*group_size=*/group_size,
|
||||
/*bits=*/bits,
|
||||
s);
|
||||
if (n_repeats > 1) {
|
||||
out = reshape(out, {B, n_q_heads, L, -1}, s);
|
||||
}
|
||||
return std::vector<array>{out};
|
||||
};
|
||||
|
||||
int L = queries.shape(2);
|
||||
if (L > 1) {
|
||||
if (needs_mask) {
|
||||
return fallback(
|
||||
{queries,
|
||||
keys,
|
||||
key_scales,
|
||||
key_biases,
|
||||
values,
|
||||
value_scales,
|
||||
value_biases,
|
||||
mask.value()})[0];
|
||||
} else {
|
||||
return fallback(
|
||||
{queries,
|
||||
keys,
|
||||
key_scales,
|
||||
key_biases,
|
||||
values,
|
||||
value_scales,
|
||||
value_biases})[0];
|
||||
}
|
||||
} else {
|
||||
return array(
|
||||
std::move(out_shape),
|
||||
queries.dtype(),
|
||||
std::make_shared<ScaledDotProductAttention>(
|
||||
stream,
|
||||
fallback,
|
||||
scale,
|
||||
/*needs_mask=*/false,
|
||||
/*quantized=*/true,
|
||||
group_size,
|
||||
bits),
|
||||
{queries,
|
||||
keys,
|
||||
key_scales,
|
||||
key_biases,
|
||||
values,
|
||||
value_scales,
|
||||
value_biases});
|
||||
}
|
||||
}
|
||||
|
||||
array pack_and_quantize(
|
||||
@@ -907,6 +1030,51 @@ array affine_dequantize(
|
||||
return fallback({w, scales, biases})[0];
|
||||
}
|
||||
|
||||
std::vector<array> kv_update(
|
||||
const array& new_keys,
|
||||
const array& new_values,
|
||||
const array& keys,
|
||||
const array& key_scales,
|
||||
const array& key_biases,
|
||||
const array& values,
|
||||
const array& value_scales,
|
||||
const array& value_biases,
|
||||
int offset,
|
||||
int group_size,
|
||||
int bits,
|
||||
StreamOrDevice s_) {
|
||||
auto s = to_stream(s_);
|
||||
|
||||
int el_per_int = 32 / bits;
|
||||
auto out_shape = keys.shape();
|
||||
out_shape.back() = keys.shape(-1) / el_per_int;
|
||||
auto fallback = [](const std::vector<array>& inputs) -> std::vector<array> {
|
||||
return {inputs[0], inputs[1]};
|
||||
};
|
||||
return array::make_arrays(
|
||||
{keys.shape(),
|
||||
key_scales.shape(),
|
||||
key_biases.shape(),
|
||||
values.shape(),
|
||||
value_scales.shape(),
|
||||
value_biases.shape()},
|
||||
{keys.dtype(),
|
||||
key_scales.dtype(),
|
||||
key_biases.dtype(),
|
||||
values.dtype(),
|
||||
value_scales.dtype(),
|
||||
value_biases.dtype()},
|
||||
std::make_shared<KVUpdate>(s, fallback, offset, group_size, bits),
|
||||
{new_keys,
|
||||
new_values,
|
||||
keys,
|
||||
key_scales,
|
||||
key_biases,
|
||||
values,
|
||||
value_scales,
|
||||
value_biases});
|
||||
}
|
||||
|
||||
std::string write_signature(
|
||||
std::string func_name,
|
||||
const std::string& header,
|
||||
|
29
mlx/fast.h
29
mlx/fast.h
@@ -41,6 +41,21 @@ array scaled_dot_product_attention(
|
||||
const std::optional<int> memory_efficient_threshold = std::nullopt,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Computes: `O = softmax(Q @ K.T) @ V` where K and V are quantized. **/
|
||||
array quantized_scaled_dot_product_attention(
|
||||
const array& queries,
|
||||
const array& keys,
|
||||
const array& key_scales,
|
||||
const array& key_biases,
|
||||
const array& values,
|
||||
const array& value_scales,
|
||||
const array& value_biases,
|
||||
const float scale,
|
||||
const std::optional<array>& mask = std::nullopt,
|
||||
const int group_size = 64,
|
||||
const int bits = 4,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
std::tuple<array, array, array> affine_quantize(
|
||||
const array& w,
|
||||
int group_size = 64,
|
||||
@@ -63,6 +78,20 @@ array affine_dequantize(
|
||||
int bits = 4,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
std::vector<array> kv_update(
|
||||
const array& new_keys,
|
||||
const array& new_values,
|
||||
const array& keys,
|
||||
const array& key_scales,
|
||||
const array& key_biases,
|
||||
const array& values,
|
||||
const array& value_scales,
|
||||
const array& value_biases,
|
||||
int offset,
|
||||
int group_size = 64,
|
||||
int bits = 4,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
typedef std::variant<int, bool, Dtype> TemplateArg;
|
||||
|
||||
typedef std::function<std::vector<array>(
|
||||
|
@@ -190,8 +190,16 @@ class ScaledDotProductAttention : public Custom {
|
||||
Stream stream,
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback,
|
||||
const float scale,
|
||||
const bool needs_mask)
|
||||
: Custom(stream, fallback), scale_(scale), needs_mask_(needs_mask) {}
|
||||
const bool needs_mask,
|
||||
const bool quantized,
|
||||
const int group_size = 64,
|
||||
const int bits = 4)
|
||||
: Custom(stream, fallback),
|
||||
scale_(scale),
|
||||
needs_mask_(needs_mask),
|
||||
quantized_(quantized),
|
||||
group_size_(group_size),
|
||||
bits_(bits) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override {
|
||||
@@ -212,6 +220,9 @@ class ScaledDotProductAttention : public Custom {
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
||||
float scale_;
|
||||
bool needs_mask_;
|
||||
bool quantized_;
|
||||
int group_size_;
|
||||
int bits_;
|
||||
};
|
||||
|
||||
class AffineQuantize : public Custom {
|
||||
@@ -244,6 +255,36 @@ class AffineQuantize : public Custom {
|
||||
bool dequantize_;
|
||||
};
|
||||
|
||||
class KVUpdate : public Custom {
|
||||
public:
|
||||
explicit KVUpdate(
|
||||
Stream stream,
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback,
|
||||
int offset,
|
||||
int group_size,
|
||||
int bits)
|
||||
: Custom(stream, fallback),
|
||||
offset_(offset),
|
||||
group_size_(group_size),
|
||||
bits_(bits) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override {
|
||||
throw std::runtime_error("NYI");
|
||||
}
|
||||
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
DEFINE_PRINT(KVUpdate);
|
||||
|
||||
private:
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
||||
int offset_;
|
||||
int group_size_;
|
||||
int bits_;
|
||||
};
|
||||
|
||||
struct CustomKernelShapeInfo {
|
||||
bool shape = false;
|
||||
bool strides = false;
|
||||
|
@@ -150,6 +150,45 @@ void init_fast(nb::module_& parent_module) {
|
||||
array: The output array.
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
"quantized_scaled_dot_product_attention",
|
||||
&fast::quantized_scaled_dot_product_attention,
|
||||
"q"_a,
|
||||
"k"_a,
|
||||
"k_scales"_a,
|
||||
"k_biases"_a,
|
||||
"v"_a,
|
||||
"v_scales"_a,
|
||||
"v_biases"_a,
|
||||
nb::kw_only(),
|
||||
"scale"_a,
|
||||
"mask"_a = nb::none(),
|
||||
"group_size"_a = 64,
|
||||
"bits"_a = 4,
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def quantized_scaled_dot_product_attention(q: array, k: array, k_scales: array, k_biases: array, v: array, v_scales: array, v_biases: array, *, scale: float, mask: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
A fast implementation of multi-head attention where the keys and values are quantized.
|
||||
|
||||
see :func:`scaled_dot_product_attention` for more details.
|
||||
|
||||
Args:
|
||||
q (array): Input query array.
|
||||
k (array): Input keys array.
|
||||
k_scales (array): Scales for the quantized keys array.
|
||||
k_biases (array): Biases for the quantized keys array.
|
||||
v (array): Input values array.
|
||||
v_scales (array): Scales for the quantized values array.
|
||||
v_biases (array): Biases for the quantized values array.
|
||||
scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``)
|
||||
mask (array, optional): An additive mask to apply to the query-key scores.
|
||||
group_size (int): The group size used in the KV quantization.
|
||||
bits (int): The bits used in the KV quantization.
|
||||
Returns:
|
||||
array: The output array.
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
"affine_quantize",
|
||||
nb::overload_cast<
|
||||
@@ -193,6 +232,44 @@ void init_fast(nb::module_& parent_module) {
|
||||
array: The quantized version of ``w``
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
"quantized_kv_update",
|
||||
&fast::kv_update,
|
||||
"new_keys"_a,
|
||||
"new_values"_a,
|
||||
"keys"_a,
|
||||
"key_scales"_a,
|
||||
"key_biases"_a,
|
||||
"values"_a,
|
||||
"value_scales"_a,
|
||||
"value_biases"_a,
|
||||
"offset"_a = 64,
|
||||
"group_size"_a = 64,
|
||||
"bits"_a = 4,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def quantized_kv_update(new_keys: array, new_values: array, key_scales: array, key_biases: array, values: array, value_scales: array, value_biases: array, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Fused update for a quantized KV cache.
|
||||
|
||||
.. math::
|
||||
|
||||
w_i = s (\hat{w_i} + \beta)
|
||||
|
||||
Args:
|
||||
w (array): Matrix to be quantize
|
||||
scales (array): The scales to use per ``group_size`` elements of ``w``
|
||||
biases (array): The biases to use per ``group_size`` elements of ``w``
|
||||
group_size (int, optional): The size of the group in ``w`` that shares a
|
||||
scale and bias. (default: ``64``)
|
||||
bits (int, optional): The number of bits occupied by each element in
|
||||
``w``. (default: ``4``)
|
||||
|
||||
Returns:
|
||||
array: The quantized version of ``w``
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
"metal_kernel",
|
||||
[](const std::string& name,
|
||||
|
Reference in New Issue
Block a user