3#include <metal_simdgroup> 
    8#define MLX_MTL_CONST static constant constexpr const 
   13template <
typename T, 
typename U, 
int values_per_thread, 
int bits>
 
   16      bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
 
   17      "Template undefined for bits not in {2, 3, 4, 6, 8}");
 
   22    for (
int i = 0; i < values_per_thread; i += 4) {
 
   23      sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
 
   25      x_thread[i + 1] = x[i + 1] / 4.0f;
 
   26      x_thread[i + 2] = x[i + 2] / 16.0f;
 
   27      x_thread[i + 3] = x[i + 3] / 64.0f;
 
   32    for (
int i = 0; i < values_per_thread; i += 8) {
 
   33      sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] +
 
   36      x_thread[i + 1] = x[i + 1] / 8.0f;
 
   37      x_thread[i + 2] = x[i + 2] / 64.0f;
 
   38      x_thread[i + 3] = x[i + 3] / 2.0f;
 
   39      x_thread[i + 4] = x[i + 4] / 16.0f;
 
   40      x_thread[i + 5] = x[i + 5] / 128.0f;
 
   41      x_thread[i + 6] = x[i + 6] / 4.0f;
 
   42      x_thread[i + 7] = x[i + 7] / 32.0f;
 
   47    for (
int i = 0; i < values_per_thread; i += 4) {
 
   48      sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
 
   50      x_thread[i + 1] = x[i + 1] / 16.0f;
 
   51      x_thread[i + 2] = x[i + 2] / 256.0f;
 
   52      x_thread[i + 3] = x[i + 3] / 4096.0f;
 
   57    for (
int i = 0; i < values_per_thread; i += 4) {
 
   58      sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
 
   60      x_thread[i + 1] = x[i + 1] / 64.0f;
 
   61      x_thread[i + 2] = x[i + 2] / 16.0f;
 
   62      x_thread[i + 3] = x[i + 3] / 4.0f;
 
   67    for (
int i = 0; i < values_per_thread; i++) {
 
 
   76template <
typename T, 
typename U, 
int values_per_thread, 
int bits>
 
   79      bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
 
   80      "Template undefined for bits not in {2, 3, 4, 6, 8}");
 
   85    for (
int i = 0; i < N; i += 4) {
 
   86      sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
 
   88      x_thread[i + 1] = x[i + 1] / 4.0f;
 
   89      x_thread[i + 2] = x[i + 2] / 16.0f;
 
   90      x_thread[i + 3] = x[i + 3] / 64.0f;
 
   95    for (
int i = 0; i < N; i += 8) {
 
   96      sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] +
 
  100      x_thread[i + 1] = x[i + 1] / 8.0f;
 
  101      x_thread[i + 2] = x[i + 2] / 64.0f;
 
  102      x_thread[i + 3] = x[i + 3] / 2.0f;
 
  103      x_thread[i + 4] = x[i + 4] / 16.0f;
 
  104      x_thread[i + 5] = x[i + 5] / 128.0f;
 
  105      x_thread[i + 6] = x[i + 6] / 4.0f;
 
  106      x_thread[i + 7] = x[i + 7] / 32.0f;
 
  110  else if (bits == 4) {
 
  111    for (
int i = 0; i < N; i += 4) {
 
  112      sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
 
  114      x_thread[i + 1] = x[i + 1] / 16.0f;
 
  115      x_thread[i + 2] = x[i + 2] / 256.0f;
 
  116      x_thread[i + 3] = x[i + 3] / 4096.0f;
 
  120  else if (bits == 6) {
 
  121    for (
int i = 0; i < N; i += 4) {
 
  122      sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
 
  124      x_thread[i + 1] = x[i + 1] / 64.0f;
 
  125      x_thread[i + 2] = x[i + 2] / 16.0f;
 
  126      x_thread[i + 3] = x[i + 3] / 4.0f;
 
  130  else if (bits == 8) {
 
  131    for (
int i = 0; i < N; i++) {
 
  137  for (
int i = N; i < values_per_thread; i++) {
 
 
  144template <
typename U, 
int values_per_thread, 
int bits>
 
  146    const device uint8_t* w,
 
  147    const thread U* x_thread,
 
  152      bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
 
  153      "Template undefined for bits not in {2, 3, 4, 6, 8}");
 
  158    for (
int i = 0; i < (values_per_thread / 4); i++) {
 
  160          (x_thread[4 * i] * (w[i] & 0x03) +
 
  161           x_thread[4 * i + 1] * (w[i] & 0x0c) +
 
  162           x_thread[4 * i + 2] * (w[i] & 0x30) +
 
  163           x_thread[4 * i + 3] * (w[i] & 0xc0));
 
  167  else if (bits == 3) {
 
  168    for (
int i = 0; i < (values_per_thread / 8); i++) {
 
  172      accum += (w[0] & 0x07) * x_thread[0];
 
  173      accum += (w[0] & 0x38) * x_thread[1];
 
  174      accum += (w[0] & 0xc0) * x_thread[2];
 
  175      accum += (w[1] & 0x01) * (x_thread[2] * 256.0f);
 
  177      accum += (w[1] & 0x0e) * x_thread[3];
 
  178      accum += (w[1] & 0x70) * x_thread[4];
 
  179      accum += (w[1] & 0x80) * x_thread[5];
 
  180      accum += (w[2] & 0x03) * (x_thread[5] * 256.0f);
 
  182      accum += (w[2] & 0x1c) * x_thread[6];
 
  183      accum += (w[2] & 0xe0) * x_thread[7];
 
  187  else if (bits == 4) {
 
  188    const device uint16_t* ws = (
const device uint16_t*)w;
 
  189    for (
int i = 0; i < (values_per_thread / 4); i++) {
 
  191          (x_thread[4 * i] * (ws[i] & 0x000f) +
 
  192           x_thread[4 * i + 1] * (ws[i] & 0x00f0) +
 
  193           x_thread[4 * i + 2] * (ws[i] & 0x0f00) +
 
  194           x_thread[4 * i + 3] * (ws[i] & 0xf000));
 
  198  else if (bits == 6) {
 
  199    for (
int i = 0; i < (values_per_thread / 4); i++) {
 
  203      accum += (w[0] & 0x3f) * x_thread[0];
 
  205      accum += (w[0] & 0xc0) * x_thread[1];
 
  206      accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f);
 
  208      accum += (w[1] & 0xf0) * x_thread[2];
 
  209      accum += (w[2] & 0x03) * (x_thread[2] * 256.0f);
 
  211      accum += (w[2] & 0xfc) * x_thread[3];
 
  215  else if (bits == 8) {
 
  216    for (
int i = 0; i < values_per_thread; i++) {
 
  217      accum += x_thread[i] * w[i];
 
  221  return scale * accum + sum * bias;
 
 
  224template <
typename U, 
int values_per_thread, 
int bits>
 
  226    const device uint8_t* w,
 
  227    const thread U* x_thread,
 
  233      bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
 
  234      "Template undefined for bits not in {2, 3, 4, 6, 8}");
 
  239    for (
int i = 0; i < (N / 4); i++) {
 
  241          (x_thread[4 * i] * (w[i] & 0x03) +
 
  242           x_thread[4 * i + 1] * (w[i] & 0x0c) +
 
  243           x_thread[4 * i + 2] * (w[i] & 0x30) +
 
  244           x_thread[4 * i + 3] * (w[i] & 0xc0));
 
  248  else if (bits == 3) {
 
  249    for (
int i = 0; i < (N / 8); i++) {
 
  253      accum += (w[0] & 0x07) * x_thread[0];
 
  254      accum += (w[0] & 0x38) * x_thread[1];
 
  255      accum += (w[0] & 0xc0) * x_thread[2];
 
  256      accum += (w[1] & 0x01) * (x_thread[2] * 256.0f);
 
  258      accum += (w[1] & 0x0e) * x_thread[3];
 
  259      accum += (w[1] & 0x70) * x_thread[4];
 
  260      accum += (w[1] & 0x80) * x_thread[5];
 
  261      accum += (w[2] & 0x03) * (x_thread[5] * 256.0f);
 
  263      accum += (w[2] & 0x1c) * x_thread[6];
 
  264      accum += (w[2] & 0xe0) * x_thread[7];
 
  268  else if (bits == 4) {
 
  269    const device uint16_t* ws = (
const device uint16_t*)w;
 
  270    for (
int i = 0; i < (N / 4); i++) {
 
  272          (x_thread[4 * i] * (ws[i] & 0x000f) +
 
  273           x_thread[4 * i + 1] * (ws[i] & 0x00f0) +
 
  274           x_thread[4 * i + 2] * (ws[i] & 0x0f00) +
 
  275           x_thread[4 * i + 3] * (ws[i] & 0xf000));
 
  279  else if (bits == 6) {
 
  280    for (
int i = 0; i < (N / 4); i++) {
 
  284      accum += (w[0] & 0x3f) * x_thread[0];
 
  286      accum += (w[0] & 0xc0) * x_thread[1];
 
  287      accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f);
 
  289      accum += (w[1] & 0xf0) * x_thread[2];
 
  290      accum += (w[2] & 0x03) * (x_thread[2] * 256.0f);
 
  292      accum += (w[2] & 0xfc) * x_thread[3];
 
  296  else if (bits == 8) {
 
  297    for (
int i = 0; i < N; i++) {
 
  298      accum += x_thread[i] * w[i];
 
  302  return scale * accum + sum * bias;
 
 
  305template <
typename U, 
int values_per_thread, 
int bits>
 
  307qouter(
const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
 
  309      bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
 
  310      "Template undefined for bits not in {2, 3, 4, 6, 8}");
 
  313    U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f};
 
  314    for (
int i = 0; i < (values_per_thread / 4); i++) {
 
  315      result[4 * i] += x * (s[0] * (w[i] & 0x03) + bias);
 
  316      result[4 * i + 1] += x * (s[1] * (w[i] & 0x0c) + bias);
 
  317      result[4 * i + 2] += x * (s[2] * (w[i] & 0x30) + bias);
 
  318      result[4 * i + 3] += x * (s[3] * (w[i] & 0xc0) + bias);
 
  322  else if (bits == 3) {
 
  323    for (
int i = 0; i < (values_per_thread / 8); i++) {
 
  324      uint8_t w0 = w[3 * i];
 
  325      uint8_t w1 = w[3 * i + 1];
 
  326      uint8_t w2 = w[3 * i + 2];
 
  328      result[8 * i] += x * ((w0 & 0x7) * scale + bias);
 
  329      result[8 * i + 1] += x * (((w0 & 0x38) >> 3) * scale + bias);
 
  331          x * ((((w0 & 0xc0) >> 6) + ((w1 & 0x1) << 2)) * scale + bias);
 
  332      result[8 * i + 3] += x * (((w1 & 0xe) >> 1) * scale + bias);
 
  333      result[8 * i + 4] += x * (((w1 & 0x70) >> 4) * scale + bias);
 
  335          x * ((((w1 & 0x80) >> 7) + ((w2 & 0x3) << 1)) * scale + bias);
 
  336      result[8 * i + 6] += x * (((w2 & 0x1c) >> 2) * scale + bias);
 
  337      result[8 * i + 7] += x * (((w2 & 0xe0) >> 5) * scale + bias);
 
  341  else if (bits == 4) {
 
  342    U s[2] = {scale, scale / 16.0f};
 
  343    for (
int i = 0; i < (values_per_thread / 2); i++) {
 
  344      result[2 * i] += x * (s[0] * (w[i] & 0x0f) + bias);
 
  345      result[2 * i + 1] += x * (s[1] * (w[i] & 0xf0) + bias);
 
  348  } 
else if (bits == 6) {
 
  349    for (
int i = 0; i < (values_per_thread / 4); i++) {
 
  350      uint8_t w0 = w[3 * i];
 
  351      uint8_t w1 = w[3 * i + 1];
 
  352      uint8_t w2 = w[3 * i + 2];
 
  354      result[4 * i] += x * ((w0 & 0x3f) * scale + bias);
 
  356          x * ((((w0 >> 6) & 0x03) + ((w1 & 0x0f) << 2)) * scale + bias);
 
  358          x * ((((w1 >> 4) & 0x0f) + ((w2 & 0x03) << 4)) * scale + bias);
 
  359      result[4 * i + 3] += x * (((w2 >> 2) & 0x3f) * scale + bias);
 
  363  else if (bits == 8) {
 
  364    for (
int i = 0; i < values_per_thread; i++) {
 
  365      result[i] += x * (scale * w[i] + bias);
 
 
  370template <
typename U, 
int N, 
int bits>
 
  372dequantize(
const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
 
  374      bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
 
  375      "Template undefined for bits not in {2, 3, 4, 6, 8}");
 
  380        scale / 
static_cast<U
>(4.0f),
 
  381        scale / 
static_cast<U
>(16.0f),
 
  382        scale / 
static_cast<U
>(64.0f)};
 
  383    for (
int i = 0; i < (N / 4); i++) {
 
  384      w_local[4 * i] = s[0] * (w[i] & 0x03) + bias;
 
  385      w_local[4 * i + 1] = s[1] * (w[i] & 0x0c) + bias;
 
  386      w_local[4 * i + 2] = s[2] * (w[i] & 0x30) + bias;
 
  387      w_local[4 * i + 3] = s[3] * (w[i] & 0xc0) + bias;
 
  391  else if (bits == 3) {
 
  392    for (
int i = 0; i < (N / 8); i++) {
 
  396      w_local[0] = (w[0] & 0x7) * scale + bias;
 
  397      w_local[1] = ((w[0] & 0x38) >> 3) * scale + bias;
 
  398      w_local[2] = (((w[0] & 0xc0) >> 6) + ((w[1] & 0x1) << 2)) * scale + bias;
 
  399      w_local[3] = ((w[1] & 0xe) >> 1) * scale + bias;
 
  400      w_local[4] = ((w[1] & 0x70) >> 4) * scale + bias;
 
  401      w_local[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias;
 
  402      w_local[6] = ((w[2] & 0x1c) >> 2) * scale + bias;
 
  403      w_local[7] = ((w[2] & 0xe0) >> 5) * scale + bias;
 
  407  else if (bits == 4) {
 
  408    U s[2] = {scale, scale / 
static_cast<U
>(16.0f)};
 
  409    for (
int i = 0; i < (N / 2); i++) {
 
  410      w_local[2 * i] = s[0] * (w[i] & 0x0f) + bias;
 
  411      w_local[2 * i + 1] = s[1] * (w[i] & 0xf0) + bias;
 
  415  else if (bits == 6) {
 
  416    for (
int i = 0; i < (N / 4); i++) {
 
  420      w_local[0] = (w[0] & 0x3f) * scale + bias;
 
  421      w_local[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias;
 
  422      w_local[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias;
 
  423      w_local[3] = ((w[2] >> 2) & 0x3f) * scale + bias;
 
  427  else if (bits == 8) {
 
  428    for (
int i = 0; i < N; i++) {
 
  429      w_local[i] = scale * w[i] + bias;
 
 
  446      "The group size should be larger than the columns");
 
  448      group_size % BCOLS == 0,
 
  449      "The group size should be divisible by the columns");
 
  451      bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
 
  452      "Template undefined for bits not in {2, 3, 4, 6, 8}");
 
  471  const device uint8_t* 
src;
 
  476      const device uint8_t* src_,
 
  477      const device T* scales_,
 
  478      const device T* biases_,
 
  481      ushort simd_group_id [[simdgroup_index_in_threadgroup]],
 
  482      ushort simd_lane_id [[thread_index_in_simdgroup]])
 
  489        thread_idx(simd_group_id * 32 + simd_lane_id),
 
 
  505    for (
int i = 0; i < 
n_reads; i++) {
 
 
  516    if (reduction_dim == 1 && 
bi >= src_tile_dim.y) {
 
  523    if (reduction_dim == 0 && 
bi >= src_tile_dim.x) {
 
  532    for (
int i = 0; i < 
n_reads; i++) {
 
 
  543    if (reduction_dim == 1) {
 
 
 
  562template <
typename T, 
int group_size, 
int bits, 
int D>
 
  564    const device uint32_t* w,
 
  565    const device T* scales,
 
  566    const device T* biases,
 
  569    constant 
int& in_vec_size,
 
  570    const constant 
int& out_vec_size,
 
  571    uint3 tid [[threadgroup_position_in_grid]],
 
  572    uint quad_gid [[quadgroup_index_in_threadgroup]],
 
  573    uint quad_lid [[thread_index_in_quadgroup]]) {
 
  575  constexpr int pack_factor = 32 / bits;
 
  576  constexpr int values_per_thread = D / 
QUAD_SIZE;
 
  577  constexpr int packs_per_thread = values_per_thread / pack_factor;
 
  578  constexpr int scale_step_per_thread = group_size / values_per_thread;
 
  579  constexpr int results_per_quadgroup = 8;
 
  583  thread U x_thread[values_per_thread];
 
  584  thread U result[results_per_quadgroup] = {0};
 
  587  const int in_vec_size_w = in_vec_size / pack_factor;
 
  588  const int in_vec_size_g = in_vec_size / group_size;
 
  589  const int out_row = tid.x * quads_per_simd * results_per_quadgroup + quad_gid;
 
  591  w += out_row * in_vec_size_w + quad_lid * packs_per_thread;
 
  592  scales += out_row * in_vec_size_g + quad_lid / scale_step_per_thread;
 
  593  biases += out_row * in_vec_size_g + quad_lid / scale_step_per_thread;
 
  594  x += tid.y * in_vec_size + quad_lid * values_per_thread;
 
  595  y += tid.y * out_vec_size + out_row;
 
  599  for (
int row = 0; row < results_per_quadgroup; row++) {
 
  600    auto wl = (
const device uint8_t*)(w + row * in_vec_size_w * quads_per_simd);
 
  601    const device T* sl = scales + row * in_vec_size_g * quads_per_simd;
 
  602    const device T* bl = biases + row * in_vec_size_g * quads_per_simd;
 
  606    if (row * quads_per_simd + out_row < out_vec_size) {
 
  611  for (
int row = 0; row < results_per_quadgroup; row++) {
 
  612    result[row] = quad_sum(result[row]);
 
  613    if (quad_lid == 0 && row * quads_per_simd + out_row < out_vec_size) {
 
  614      y[row * quads_per_simd] = 
static_cast<T
>(result[row]);
 
 
  619template <
typename T, 
int group_size, 
int bits>
 
  621    const device uint32_t* w,
 
  622    const device T* scales,
 
  623    const device T* biases,
 
  626    const constant 
int& in_vec_size,
 
  627    const constant 
int& out_vec_size,
 
  628    uint3 tid [[threadgroup_position_in_grid]],
 
  629    uint simd_gid [[simdgroup_index_in_threadgroup]],
 
  630    uint simd_lid [[thread_index_in_simdgroup]]) {
 
  631  constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
 
  632  constexpr int packs_per_thread = bits == 2 ? 1 : 2;
 
  633  constexpr int num_simdgroups = 2;
 
  634  constexpr int results_per_simdgroup = 4;
 
  635  constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
 
  636  constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3;
 
  637  constexpr int values_per_thread = pack_factor * packs_per_thread;
 
  638  constexpr int block_size = values_per_thread * 
SIMD_SIZE;
 
  639  constexpr int scale_step_per_thread = group_size / values_per_thread;
 
  641  const device uint8_t* ws = (
const device uint8_t*)w;
 
  645  thread U x_thread[values_per_thread];
 
  646  thread U result[results_per_simdgroup] = {0};
 
  649  const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
 
  650  const int in_vec_size_g = in_vec_size / group_size;
 
  651  const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) +
 
  652      simd_gid * results_per_simdgroup;
 
  654  ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack;
 
  655  scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
 
  656  biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
 
  657  x += tid.x * in_vec_size + simd_lid * values_per_thread;
 
  658  y += tid.x * out_vec_size + out_row;
 
  660  for (
int k = 0; k < in_vec_size; k += block_size) {
 
  663    for (
int row = 0; row < results_per_simdgroup; row++) {
 
  664      auto wl = (
const device uint8_t*)(ws + row * in_vec_size_w);
 
  665      const device T* sl = scales + row * in_vec_size_g;
 
  666      const device T* bl = biases + row * in_vec_size_g;
 
  673    ws += block_size * bytes_per_pack / pack_factor;
 
  674    scales += block_size / group_size;
 
  675    biases += block_size / group_size;
 
  679  for (
int row = 0; row < results_per_simdgroup; row++) {
 
  680    result[row] = 
simd_sum(result[row]);
 
  682      y[row] = 
static_cast<T
>(result[row]);
 
 
  687template <
typename T, 
int group_size, 
int bits>
 
  689    const device uint32_t* w,
 
  690    const device T* scales,
 
  691    const device T* biases,
 
  694    const constant 
int& in_vec_size,
 
  695    const constant 
int& out_vec_size,
 
  696    uint3 tid [[threadgroup_position_in_grid]],
 
  697    uint simd_gid [[simdgroup_index_in_threadgroup]],
 
  698    uint simd_lid [[thread_index_in_simdgroup]]) {
 
  699  constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
 
  700  constexpr int num_simdgroups = 2;
 
  701  constexpr int results_per_simdgroup = 4;
 
  702  constexpr int packs_per_thread = 1;
 
  703  constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
 
  704  constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3;
 
  705  constexpr int values_per_thread = pack_factor * packs_per_thread;
 
  706  constexpr int block_size = values_per_thread * 
SIMD_SIZE;
 
  707  constexpr int scale_step_per_thread = group_size / values_per_thread;
 
  709  const device uint8_t* ws = (
const device uint8_t*)w;
 
  713  thread U x_thread[values_per_thread];
 
  714  thread U result[results_per_simdgroup] = {0};
 
  717  const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
 
  718  const int in_vec_size_g = in_vec_size / group_size;
 
  719  const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) +
 
  720      simd_gid * results_per_simdgroup;
 
  721  const int used_out_row = 
min(out_vec_size - results_per_simdgroup, out_row);
 
  723  if (out_row >= out_vec_size) {
 
  729  if (out_vec_size < (num_simdgroups * results_per_simdgroup)) {
 
  731        out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack;
 
  732    scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
 
  733    biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
 
  734    x += tid.x * in_vec_size + simd_lid * values_per_thread;
 
  735    y += tid.x * out_vec_size + out_row;
 
  738    for (; k < in_vec_size - block_size; k += block_size) {
 
  741      for (
int row = 0; out_row + row < out_vec_size; row++) {
 
  742        auto wl = (
const device uint8_t*)(ws + row * in_vec_size_w);
 
  743        const device T* sl = scales + row * in_vec_size_g;
 
  744        const device T* bl = biases + row * in_vec_size_g;
 
  752      ws += block_size * bytes_per_pack / pack_factor;
 
  753      scales += block_size / group_size;
 
  754      biases += block_size / group_size;
 
  757    const int remaining = clamp(
 
  758        static_cast<int>(in_vec_size - k - simd_lid * values_per_thread),
 
  763          x, x_thread, remaining);
 
  765      for (
int row = 0; out_row + row < out_vec_size; row++) {
 
  766        auto wl = (
const device uint8_t*)(ws + row * in_vec_size_w);
 
  767        const device T* sl = scales + row * in_vec_size_g;
 
  768        const device T* bl = biases + row * in_vec_size_g;
 
  777    for (
int row = 0; out_row + row < out_vec_size; row++) {
 
  778      result[row] = 
simd_sum(result[row]);
 
  780        y[row] = 
static_cast<T
>(result[row]);
 
  787    ws += used_out_row * in_vec_size_w +
 
  788        simd_lid * packs_per_thread * bytes_per_pack;
 
  789    scales += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
 
  790    biases += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
 
  791    x += tid.x * in_vec_size + simd_lid * values_per_thread;
 
  792    y += tid.x * out_vec_size + used_out_row;
 
  795    for (; k < in_vec_size - block_size; k += block_size) {
 
  798      for (
int row = 0; row < results_per_simdgroup; row++) {
 
  799        auto wl = (
const device uint8_t*)(ws + row * in_vec_size_w);
 
  800        const device T* sl = scales + row * in_vec_size_g;
 
  801        const device T* bl = biases + row * in_vec_size_g;
 
  809      ws += block_size * bytes_per_pack / pack_factor;
 
  810      scales += block_size / group_size;
 
  811      biases += block_size / group_size;
 
  814    const int remaining = clamp(
 
  815        static_cast<int>(in_vec_size - k - simd_lid * values_per_thread),
 
  820          x, x_thread, remaining);
 
  822      for (
int row = 0; row < results_per_simdgroup; row++) {
 
  823        auto wl = (
const device uint8_t*)(ws + row * in_vec_size_w);
 
  824        const device T* sl = scales + row * in_vec_size_g;
 
  825        const device T* bl = biases + row * in_vec_size_g;
 
  830            wl, x_thread, s, b, sum, remaining);
 
  833    for (
int row = 0; row < results_per_simdgroup; row++) {
 
  834      result[row] = 
simd_sum(result[row]);
 
  836        y[row] = 
static_cast<T
>(result[row]);
 
 
  842template <
typename T, const 
int group_size, const 
int bits>
 
  844    const device uint32_t* w,
 
  845    const device T* scales,
 
  846    const device T* biases,
 
  849    const int in_vec_size,
 
  850    const int out_vec_size,
 
  851    uint3 tid [[threadgroup_position_in_grid]],
 
  852    uint simd_gid [[simdgroup_index_in_threadgroup]],
 
  853    uint simd_lid [[thread_index_in_simdgroup]]) {
 
  854  constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
 
  855  constexpr int num_simdgroups = 2;
 
  856  constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
 
  857  constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
 
  858  constexpr int tn = 32 / pack_factor;
 
  863  const device W_T* ws = (
const device W_T*)w;
 
  867    W_T wi[tn * bytes_per_pack];
 
  870  thread vec_w w_local;
 
  871  thread U result[tn * pack_factor] = {0};
 
  874  thread U x_local = 0;
 
  877  const int out_vec_size_w = out_vec_size * bytes_per_pack / pack_factor;
 
  878  const int out_vec_size_g = out_vec_size / group_size;
 
  879  int out_col = pack_factor * tn * (tid.y * num_simdgroups + simd_gid);
 
  880  ws += out_col * bytes_per_pack / pack_factor + simd_lid * out_vec_size_w;
 
  881  scales += out_col / group_size + simd_lid * out_vec_size_g;
 
  882  biases += out_col / group_size + simd_lid * out_vec_size_g;
 
  883  x += tid.x * in_vec_size + simd_lid;
 
  884  y += tid.x * out_vec_size + out_col;
 
  886  if (out_col >= out_vec_size) {
 
  891  int remaining = in_vec_size % block_size;
 
  892  if (remaining == 0) {
 
  893    for (
int i = 0; i < in_vec_size; i += block_size) {
 
  897      w_local = *((device vec_w*)ws);
 
  899          (thread uint8_t*)&w_local, x_local, scale, bias, result);
 
  902      scales += block_size * out_vec_size_g;
 
  903      biases += block_size * out_vec_size_g;
 
  904      ws += block_size * out_vec_size_w;
 
  907    for (
int i = block_size; i < in_vec_size; i += block_size) {
 
  911      w_local = *((device vec_w*)ws);
 
  914          (thread uint8_t*)&w_local, x_local, scale, bias, result);
 
  917      scales += block_size * out_vec_size_g;
 
  918      biases += block_size * out_vec_size_g;
 
  919      ws += block_size * out_vec_size_w;
 
  921    if (
static_cast<int>(simd_lid) < remaining) {
 
  925      w_local = *((device vec_w*)ws);
 
  932        (thread uint8_t*)&w_local, x_local, scale, bias, result);
 
  936#pragma clang loop unroll(full) 
  937  for (
int k = 0; k < tn * pack_factor; k++) {
 
  943#pragma clang loop unroll(full) 
  944    for (
int k = 0; k < tn * pack_factor; k++) {
 
  945      y[k] = 
static_cast<T
>(result[k]);
 
 
  952    const int group_size,
 
  954    const bool aligned_N,
 
  959    const device uint32_t* w,
 
  960    const device T* scales,
 
  961    const device T* biases,
 
  966    const constant 
int& K,
 
  967    const constant 
int& N,
 
  968    const constant 
int& M,
 
  969    uint3 tid [[threadgroup_position_in_grid]],
 
  970    uint lid [[thread_index_in_threadgroup]],
 
  971    uint simd_gid [[simdgroup_index_in_threadgroup]],
 
  972    uint simd_lid [[thread_index_in_simdgroup]]) {
 
  973  static_assert(BK >= 
SIMD_SIZE, 
"BK should be larger than SIMD_SIZE");
 
  974  static_assert(BK % 
SIMD_SIZE == 0, 
"BK should be divisible by SIMD_SIZE");
 
  978  constexpr int WM = 2;
 
  979  constexpr int WN = 2;
 
  980  constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
 
  981  constexpr int BK_padded = (BK + 16 / 
sizeof(T));
 
  982  constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
 
  985  using mma_t = mlx::steel::
 
  986      BlockMMA<T, T, BM, BN, BK, WM, WN, false, true, BK_padded, BK_padded>;
 
 1000  const int K_w = K * bytes_per_pack / pack_factor;
 
 1001  const int K_g = K / group_size;
 
 1002  const int y_row = tid.y * BM;
 
 1003  const int y_col = tid.x * BN;
 
 1005  auto wl = (
const device uint8_t*)w;
 
 1009  scales += y_col * K_g;
 
 1010  biases += y_col * K_g;
 
 1011  y += y_row * N + y_col;
 
 1014  const short num_els = 
min(BM, M - y_row);
 
 1015  const short num_outs = 
min(BN, N - y_col);
 
 1016  loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
 
 1017  loader_w_t loader_w(wl, scales, biases, K, Ws, simd_gid, simd_lid);
 
 1018  mma_t mma_op(simd_gid, simd_lid);
 
 1021    if (!aligned_N && num_outs < BN) {
 
 1022      for (
int k = 0; k < K; k += BK) {
 
 1023        threadgroup_barrier(mem_flags::mem_threadgroup);
 
 1024        loader_x.load_safe(short2(BK, num_els));
 
 1025        loader_w.load_safe(short2(BK, num_outs));
 
 1026        threadgroup_barrier(mem_flags::mem_threadgroup);
 
 1032      for (
int k = 0; k < K; k += BK) {
 
 1033        threadgroup_barrier(mem_flags::mem_threadgroup);
 
 1034        loader_x.load_safe(short2(BK, num_els));
 
 1035        loader_w.load_unsafe();
 
 1036        threadgroup_barrier(mem_flags::mem_threadgroup);
 
 1043    if (!aligned_N && num_outs < BN) {
 
 1044      for (
int k = 0; k < K; k += BK) {
 
 1045        threadgroup_barrier(mem_flags::mem_threadgroup);
 
 1046        loader_x.load_unsafe();
 
 1047        loader_w.load_safe(short2(BK, num_outs));
 
 1048        threadgroup_barrier(mem_flags::mem_threadgroup);
 
 1054      for (
int k = 0; k < K; k += BK) {
 
 1055        threadgroup_barrier(mem_flags::mem_threadgroup);
 
 1056        loader_x.load_unsafe();
 
 1057        loader_w.load_unsafe();
 
 1058        threadgroup_barrier(mem_flags::mem_threadgroup);
 
 1068  threadgroup_barrier(mem_flags::mem_threadgroup);
 
 1069  if (num_els < BM || num_outs < BN) {
 
 1070    mma_op.store_result_safe(y, N, short2(num_outs, num_els));
 
 1072    mma_op.store_result(y, N);
 
 
 1078    const int group_size,
 
 1084    const device uint32_t* w,
 
 1085    const device T* scales,
 
 1086    const device T* biases,
 
 1091    const constant 
int& K,
 
 1092    const constant 
int& N,
 
 1093    const constant 
int& M,
 
 1094    uint3 tid [[threadgroup_position_in_grid]],
 
 1095    uint lid [[thread_index_in_threadgroup]],
 
 1096    uint simd_gid [[simdgroup_index_in_threadgroup]],
 
 1097    uint simd_lid [[thread_index_in_simdgroup]]) {
 
 1098  static_assert(BK >= 
SIMD_SIZE, 
"BK should be larger than SIMD_SIZE");
 
 1099  static_assert(BK % 
SIMD_SIZE == 0, 
"BK should be divisible by SIMD_SIZE");
 
 1103  constexpr int WM = 2;
 
 1104  constexpr int WN = 2;
 
 1105  constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
 
 1106  constexpr int BK_padded = (BK + 16 / 
sizeof(T));
 
 1107  constexpr int BN_padded = (BN + 16 / 
sizeof(T));
 
 1108  constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
 
 1109  constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
 
 1112  using mma_t = mlx::steel::
 
 1113      BlockMMA<T, T, BM, BN, BK, WM, WN, false, false, BK_padded, BN_padded>;
 
 1114  using loader_x_t = mlx::steel::
 
 1115      BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE, 1, 4>;
 
 1126  auto wl = (
const device uint8_t*)w;
 
 1129  const int y_row = tid.y * BM;
 
 1130  const int y_col = tid.x * BN;
 
 1132  wl += y_col * bytes_per_pack / pack_factor;
 
 1133  scales += y_col / group_size;
 
 1134  biases += y_col / group_size;
 
 1135  y += y_row * N + y_col;
 
 1138  const short num_els = 
min(BM, M - y_row);
 
 1139  loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
 
 1140  loader_w_t loader_w(wl, scales, biases, N, Ws, simd_gid, simd_lid);
 
 1141  mma_t mma_op(simd_gid, simd_lid);
 
 1144    if ((K % BK) != 0) {
 
 1145      const int k_blocks = K / BK;
 
 1146      for (
int k = 0; k < k_blocks; k++) {
 
 1147        threadgroup_barrier(mem_flags::mem_threadgroup);
 
 1148        loader_x.load_safe(short2(BK, num_els));
 
 1149        loader_w.load_unsafe();
 
 1150        threadgroup_barrier(mem_flags::mem_threadgroup);
 
 1155      const short num_k = K - k_blocks * BK;
 
 1156      threadgroup_barrier(mem_flags::mem_threadgroup);
 
 1157      loader_x.load_safe(short2(num_k, num_els));
 
 1158      loader_w.load_safe(short2(BN, num_k));
 
 1159      threadgroup_barrier(mem_flags::mem_threadgroup);
 
 1162      for (
int k = 0; k < K; k += BK) {
 
 1163        threadgroup_barrier(mem_flags::mem_threadgroup);
 
 1164        loader_x.load_safe(short2(BK, num_els));
 
 1165        loader_w.load_unsafe();
 
 1166        threadgroup_barrier(mem_flags::mem_threadgroup);
 
 1173    if ((K % BK) != 0) {
 
 1174      const int k_blocks = K / BK;
 
 1175      for (
int k = 0; k < k_blocks; k++) {
 
 1176        threadgroup_barrier(mem_flags::mem_threadgroup);
 
 1177        loader_x.load_unsafe();
 
 1178        loader_w.load_unsafe();
 
 1179        threadgroup_barrier(mem_flags::mem_threadgroup);
 
 1184      const short num_k = K - k_blocks * BK;
 
 1185      threadgroup_barrier(mem_flags::mem_threadgroup);
 
 1186      loader_x.load_safe(short2(num_k, BM));
 
 1187      loader_w.load_safe(short2(BN, num_k));
 
 1188      threadgroup_barrier(mem_flags::mem_threadgroup);
 
 1191      for (
int k = 0; k < K; k += BK) {
 
 1192        threadgroup_barrier(mem_flags::mem_threadgroup);
 
 1193        loader_x.load_unsafe();
 
 1194        loader_w.load_unsafe();
 
 1195        threadgroup_barrier(mem_flags::mem_threadgroup);
 
 1204  threadgroup_barrier(mem_flags::mem_threadgroup);
 
 1206    mma_op.store_result_safe(y, N, short2(BN, num_els));
 
 1208    mma_op.store_result(y, N);
 
 
 1212template <
typename T>
 
 1215    const device uint32_t*& w,
 
 1216    const device T*& scales,
 
 1217    const device T*& biases,
 
 1220    const constant 
int& x_batch_ndims,
 
 1221    const constant 
int* x_shape,
 
 1222    const constant int64_t* x_strides,
 
 1223    const constant 
int& w_batch_ndims,
 
 1224    const constant 
int* w_shape,
 
 1225    const constant int64_t* w_strides,
 
 1226    const constant int64_t* s_strides,
 
 1227    const constant int64_t* b_strides,
 
 1228    uint3 tid [[threadgroup_position_in_grid]]) {
 
 1230  uint32_t x_idx = tid.z;
 
 1231  uint32_t w_idx = tid.z;
 
 1232  if (x_batch_ndims == 1) {
 
 1233    x += x_idx * x_strides[0];
 
 1235    x += 
elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims);
 
 1237  if (w_batch_ndims == 1) {
 
 1238    w += w_idx * w_strides[0];
 
 1239    scales += w_idx * s_strides[0];
 
 1240    biases += w_idx * b_strides[0];
 
 1243        w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims);
 
 1248  y += tid.z * output_stride;
 
 
 1251template <
typename T>
 
 1254    const device uint32_t*& w,
 
 1255    const device T*& scales,
 
 1256    const device T*& biases,
 
 1257    const device uint32_t* lhs_indices,
 
 1258    const device uint32_t* rhs_indices,
 
 1261    const constant 
int& batch_ndims,
 
 1262    const constant 
int* batch_shape,
 
 1263    const constant int64_t* lhs_strides,
 
 1264    const constant int64_t* rhs_strides,
 
 1265    const constant 
int& x_batch_ndims,
 
 1266    const constant 
int* x_shape,
 
 1267    const constant int64_t* x_strides,
 
 1268    const constant 
int& w_batch_ndims,
 
 1269    const constant 
int* w_shape,
 
 1270    const constant int64_t* w_strides,
 
 1271    const constant int64_t* s_strides,
 
 1272    const constant int64_t* b_strides,
 
 1273    uint3 tid [[threadgroup_position_in_grid]]) {
 
 1277  if (batch_ndims == 1) {
 
 1278    x_idx = lhs_indices[tid.z * lhs_strides[0]];
 
 1279    w_idx = rhs_indices[tid.z * rhs_strides[0]];
 
 1282        tid.z, batch_shape, lhs_strides, rhs_strides, batch_ndims);
 
 1283    x_idx = lhs_indices[idx.x];
 
 1284    w_idx = rhs_indices[idx.y];
 
 1286  if (x_batch_ndims == 1) {
 
 1287    x += x_idx * x_strides[0];
 
 1289    x += 
elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims);
 
 1291  if (w_batch_ndims == 1) {
 
 1292    w += w_idx * w_strides[0];
 
 1293    scales += w_idx * s_strides[0];
 
 1294    biases += w_idx * b_strides[0];
 
 1297        w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims);
 
 1302  y += tid.z * output_stride;
 
 
 1305template <
typename T, 
int group_size, 
int bits, 
int D, 
bool batched>
 
 1307    const device uint32_t* w [[buffer(0)]],
 
 1308    const device T* scales [[buffer(1)]],
 
 1309    const device T* biases [[buffer(2)]],
 
 1310    const device T* x [[buffer(3)]],
 
 1311    device T* y [[buffer(4)]],
 
 1312    const constant 
int& in_vec_size [[buffer(5)]],
 
 1313    const constant 
int& out_vec_size [[buffer(6)]],
 
 1314    const constant 
int& x_batch_ndims [[buffer(7)]],
 
 1315    const constant 
int* x_shape [[buffer(8)]],
 
 1316    const constant int64_t* x_strides [[buffer(9)]],
 
 1317    const constant 
int& w_batch_ndims [[buffer(10)]],
 
 1318    const constant 
int* w_shape [[buffer(11)]],
 
 1319    const constant int64_t* w_strides [[buffer(12)]],
 
 1320    const constant int64_t* s_strides [[buffer(13)]],
 
 1321    const constant int64_t* b_strides [[buffer(14)]],
 
 1322    uint3 tid [[threadgroup_position_in_grid]],
 
 1323    uint quad_gid [[quadgroup_index_in_threadgroup]],
 
 1324    uint quad_lid [[thread_index_in_quadgroup]]) {
 
 1326    int M = x_shape[x_batch_ndims];
 
 
 1357template <
typename T, 
int group_size, 
int bits, 
bool batched>
 
 1359    const device uint32_t* w [[buffer(0)]],
 
 1360    const device T* scales [[buffer(1)]],
 
 1361    const device T* biases [[buffer(2)]],
 
 1362    const device T* x [[buffer(3)]],
 
 1363    device T* y [[buffer(4)]],
 
 1364    const constant 
int& in_vec_size [[buffer(5)]],
 
 1365    const constant 
int& out_vec_size [[buffer(6)]],
 
 1366    const constant 
int& x_batch_ndims [[buffer(7)]],
 
 1367    const constant 
int* x_shape [[buffer(8)]],
 
 1368    const constant int64_t* x_strides [[buffer(9)]],
 
 1369    const constant 
int& w_batch_ndims [[buffer(10)]],
 
 1370    const constant 
int* w_shape [[buffer(11)]],
 
 1371    const constant int64_t* w_strides [[buffer(12)]],
 
 1372    const constant int64_t* s_strides [[buffer(13)]],
 
 1373    const constant int64_t* b_strides [[buffer(14)]],
 
 1374    uint3 tid [[threadgroup_position_in_grid]],
 
 1375    uint simd_gid [[simdgroup_index_in_threadgroup]],
 
 1376    uint simd_lid [[thread_index_in_simdgroup]]) {
 
 1378    int M = x_shape[x_batch_ndims];
 
 
 1409template <
typename T, const 
int group_size, const 
int bits, 
bool batched>
 
 1411    const device uint32_t* w [[buffer(0)]],
 
 1412    const device T* scales [[buffer(1)]],
 
 1413    const device T* biases [[buffer(2)]],
 
 1414    const device T* x [[buffer(3)]],
 
 1415    device T* y [[buffer(4)]],
 
 1416    const constant 
int& in_vec_size [[buffer(5)]],
 
 1417    const constant 
int& out_vec_size [[buffer(6)]],
 
 1418    const constant 
int& x_batch_ndims [[buffer(7)]],
 
 1419    const constant 
int* x_shape [[buffer(8)]],
 
 1420    const constant int64_t* x_strides [[buffer(9)]],
 
 1421    const constant 
int& w_batch_ndims [[buffer(10)]],
 
 1422    const constant 
int* w_shape [[buffer(11)]],
 
 1423    const constant int64_t* w_strides [[buffer(12)]],
 
 1424    const constant int64_t* s_strides [[buffer(13)]],
 
 1425    const constant int64_t* b_strides [[buffer(14)]],
 
 1426    uint3 tid [[threadgroup_position_in_grid]],
 
 1427    uint simd_gid [[simdgroup_index_in_threadgroup]],
 
 1428    uint simd_lid [[thread_index_in_simdgroup]]) {
 
 1430    int M = x_shape[x_batch_ndims];
 
 
 1461template <
typename T, const 
int group_size, const 
int bits, 
bool batched>
 
 1463    const device uint32_t* w [[buffer(0)]],
 
 1464    const device T* scales [[buffer(1)]],
 
 1465    const device T* biases [[buffer(2)]],
 
 1466    const device T* x [[buffer(3)]],
 
 1467    device T* y [[buffer(4)]],
 
 1468    const constant 
int& in_vec_size [[buffer(5)]],
 
 1469    const constant 
int& out_vec_size [[buffer(6)]],
 
 1470    const constant 
int& x_batch_ndims [[buffer(7)]],
 
 1471    const constant 
int* x_shape [[buffer(8)]],
 
 1472    const constant int64_t* x_strides [[buffer(9)]],
 
 1473    const constant 
int& w_batch_ndims [[buffer(10)]],
 
 1474    const constant 
int* w_shape [[buffer(11)]],
 
 1475    const constant int64_t* w_strides [[buffer(12)]],
 
 1476    const constant int64_t* s_strides [[buffer(13)]],
 
 1477    const constant int64_t* b_strides [[buffer(14)]],
 
 1478    uint3 tid [[threadgroup_position_in_grid]],
 
 1479    uint simd_gid [[simdgroup_index_in_threadgroup]],
 
 1480    uint simd_lid [[thread_index_in_simdgroup]]) {
 
 1482    int M = x_shape[x_batch_ndims];
 
 
 1513template <
typename T, const 
int group_size, const 
int bits, 
int split_k = 32>
 
 1515    const device uint32_t* w [[buffer(0)]],
 
 1516    const device T* scales [[buffer(1)]],
 
 1517    const device T* biases [[buffer(2)]],
 
 1518    const device T* x [[buffer(3)]],
 
 1519    device T* y [[buffer(4)]],
 
 1520    const constant 
int& in_vec_size [[buffer(5)]],
 
 1521    const constant 
int& out_vec_size [[buffer(6)]],
 
 1522    const constant 
int& x_batch_ndims [[buffer(7)]],
 
 1523    const constant 
int* x_shape [[buffer(8)]],
 
 1524    const constant int64_t* x_strides [[buffer(9)]],
 
 1525    const constant 
int& w_batch_ndims [[buffer(10)]],
 
 1526    const constant 
int* w_shape [[buffer(11)]],
 
 1527    const constant int64_t* w_strides [[buffer(12)]],
 
 1528    const constant int64_t* s_strides [[buffer(13)]],
 
 1529    const constant int64_t* b_strides [[buffer(14)]],
 
 1530    const constant 
int& final_block_size [[buffer(15)]],
 
 1531    uint3 tid [[threadgroup_position_in_grid]],
 
 1532    uint simd_gid [[simdgroup_index_in_threadgroup]],
 
 1533    uint simd_lid [[thread_index_in_simdgroup]]) {
 
 1534  int M = x_shape[x_batch_ndims];
 
 1553  int in_vec_size_adj =
 
 1554      tid.z % split_k == split_k - 1 ? final_block_size : in_vec_size;
 
 
 1571    const int group_size,
 
 1573    const bool aligned_N,
 
 1579    const device uint32_t* w [[buffer(0)]],
 
 1580    const device T* scales [[buffer(1)]],
 
 1581    const device T* biases [[buffer(2)]],
 
 1582    const device T* x [[buffer(3)]],
 
 1583    device T* y [[buffer(4)]],
 
 1584    const constant 
int& K [[buffer(5)]],
 
 1585    const constant 
int& N [[buffer(6)]],
 
 1586    const constant 
int& M [[buffer(7)]],
 
 1587    const constant 
int& x_batch_ndims [[buffer(8)]],
 
 1588    const constant 
int* x_shape [[buffer(9)]],
 
 1589    const constant int64_t* x_strides [[buffer(10)]],
 
 1590    const constant 
int& w_batch_ndims [[buffer(11)]],
 
 1591    const constant 
int* w_shape [[buffer(12)]],
 
 1592    const constant int64_t* w_strides [[buffer(13)]],
 
 1593    const constant int64_t* s_strides [[buffer(14)]],
 
 1594    const constant int64_t* b_strides [[buffer(15)]],
 
 1595    uint3 tid [[threadgroup_position_in_grid]],
 
 1596    uint lid [[thread_index_in_threadgroup]],
 
 1597    uint simd_gid [[simdgroup_index_in_threadgroup]],
 
 1598    uint simd_lid [[thread_index_in_simdgroup]]) {
 
 1601  constexpr int BK_padded = (BK + 16 / 
sizeof(T));
 
 1603  threadgroup T Xs[BM * BK_padded];
 
 1604  threadgroup T Ws[BN * BK_padded];
 
 1625      w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
 
 
 1630    const int group_size,
 
 1637    const device uint32_t* w [[buffer(0)]],
 
 1638    const device T* scales [[buffer(1)]],
 
 1639    const device T* biases [[buffer(2)]],
 
 1640    const device T* x [[buffer(3)]],
 
 1641    device T* y [[buffer(4)]],
 
 1642    const constant 
int& K [[buffer(5)]],
 
 1643    const constant 
int& N [[buffer(6)]],
 
 1644    const constant 
int& M [[buffer(7)]],
 
 1645    const constant 
int& x_batch_ndims [[buffer(8)]],
 
 1646    const constant 
int* x_shape [[buffer(9)]],
 
 1647    const constant int64_t* x_strides [[buffer(10)]],
 
 1648    const constant 
int& w_batch_ndims [[buffer(11)]],
 
 1649    const constant 
int* w_shape [[buffer(12)]],
 
 1650    const constant int64_t* w_strides [[buffer(13)]],
 
 1651    const constant int64_t* s_strides [[buffer(14)]],
 
 1652    const constant int64_t* b_strides [[buffer(15)]],
 
 1653    uint3 tid [[threadgroup_position_in_grid]],
 
 1654    uint lid [[thread_index_in_threadgroup]],
 
 1655    uint simd_gid [[simdgroup_index_in_threadgroup]],
 
 1656    uint simd_lid [[thread_index_in_simdgroup]]) {
 
 1659  constexpr int BK_padded = (BK + 16 / 
sizeof(T));
 
 1660  constexpr int BN_padded = (BN + 16 / 
sizeof(T));
 
 1662  threadgroup T Xs[BM * BK_padded];
 
 1663  threadgroup T Ws[BK * BN_padded];
 
 1685      w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
 
 
 1688template <
typename T, 
int group_size, 
int bits>
 
 1690    const device uint32_t* w [[buffer(0)]],
 
 1691    const device T* scales [[buffer(1)]],
 
 1692    const device T* biases [[buffer(2)]],
 
 1693    const device T* x [[buffer(3)]],
 
 1694    device T* y [[buffer(4)]],
 
 1695    const constant 
int& in_vec_size [[buffer(5)]],
 
 1696    const constant 
int& out_vec_size [[buffer(6)]],
 
 1697    const constant 
int& x_batch_ndims [[buffer(7)]],
 
 1698    const constant 
int* x_shape [[buffer(8)]],
 
 1699    const constant int64_t* x_strides [[buffer(9)]],
 
 1700    const constant 
int& w_batch_ndims [[buffer(10)]],
 
 1701    const constant 
int* w_shape [[buffer(11)]],
 
 1702    const constant int64_t* w_strides [[buffer(12)]],
 
 1703    const constant int64_t* s_strides [[buffer(13)]],
 
 1704    const constant int64_t* b_strides [[buffer(14)]],
 
 1705    const constant 
int& batch_ndims [[buffer(15)]],
 
 1706    const constant 
int* batch_shape [[buffer(16)]],
 
 1707    const device uint32_t* lhs_indices [[buffer(17)]],
 
 1708    const device uint32_t* rhs_indices [[buffer(18)]],
 
 1709    const constant int64_t* lhs_strides [[buffer(19)]],
 
 1710    const constant int64_t* rhs_strides [[buffer(20)]],
 
 1711    uint3 tid [[threadgroup_position_in_grid]],
 
 1712    uint simd_gid [[simdgroup_index_in_threadgroup]],
 
 1713    uint simd_lid [[thread_index_in_simdgroup]]) {
 
 1714  int M = x_shape[x_batch_ndims];
 
 
 1750template <
typename T, 
int group_size, 
int bits>
 
 1752    const device uint32_t* w [[buffer(0)]],
 
 1753    const device T* scales [[buffer(1)]],
 
 1754    const device T* biases [[buffer(2)]],
 
 1755    const device T* x [[buffer(3)]],
 
 1756    device T* y [[buffer(4)]],
 
 1757    const constant 
int& in_vec_size [[buffer(5)]],
 
 1758    const constant 
int& out_vec_size [[buffer(6)]],
 
 1759    const constant 
int& x_batch_ndims [[buffer(7)]],
 
 1760    const constant 
int* x_shape [[buffer(8)]],
 
 1761    const constant int64_t* x_strides [[buffer(9)]],
 
 1762    const constant 
int& w_batch_ndims [[buffer(10)]],
 
 1763    const constant 
int* w_shape [[buffer(11)]],
 
 1764    const constant int64_t* w_strides [[buffer(12)]],
 
 1765    const constant int64_t* s_strides [[buffer(13)]],
 
 1766    const constant int64_t* b_strides [[buffer(14)]],
 
 1767    const constant 
int& batch_ndims [[buffer(15)]],
 
 1768    const constant 
int* batch_shape [[buffer(16)]],
 
 1769    const device uint32_t* lhs_indices [[buffer(17)]],
 
 1770    const device uint32_t* rhs_indices [[buffer(18)]],
 
 1771    const constant int64_t* lhs_strides [[buffer(19)]],
 
 1772    const constant int64_t* rhs_strides [[buffer(20)]],
 
 1773    uint3 tid [[threadgroup_position_in_grid]],
 
 1774    uint simd_gid [[simdgroup_index_in_threadgroup]],
 
 1775    uint simd_lid [[thread_index_in_simdgroup]]) {
 
 1776  int M = x_shape[x_batch_ndims];
 
 
 1812template <
typename T, 
int group_size, 
int bits>
 
 1814    const device uint32_t* w [[buffer(0)]],
 
 1815    const device T* scales [[buffer(1)]],
 
 1816    const device T* biases [[buffer(2)]],
 
 1817    const device T* x [[buffer(3)]],
 
 1818    device T* y [[buffer(4)]],
 
 1819    const constant 
int& in_vec_size [[buffer(5)]],
 
 1820    const constant 
int& out_vec_size [[buffer(6)]],
 
 1821    const constant 
int& x_batch_ndims [[buffer(7)]],
 
 1822    const constant 
int* x_shape [[buffer(8)]],
 
 1823    const constant int64_t* x_strides [[buffer(9)]],
 
 1824    const constant 
int& w_batch_ndims [[buffer(10)]],
 
 1825    const constant 
int* w_shape [[buffer(11)]],
 
 1826    const constant int64_t* w_strides [[buffer(12)]],
 
 1827    const constant int64_t* s_strides [[buffer(13)]],
 
 1828    const constant int64_t* b_strides [[buffer(14)]],
 
 1829    const constant 
int& batch_ndims [[buffer(15)]],
 
 1830    const constant 
int* batch_shape [[buffer(16)]],
 
 1831    const device uint32_t* lhs_indices [[buffer(17)]],
 
 1832    const device uint32_t* rhs_indices [[buffer(18)]],
 
 1833    const constant int64_t* lhs_strides [[buffer(19)]],
 
 1834    const constant int64_t* rhs_strides [[buffer(20)]],
 
 1835    uint3 tid [[threadgroup_position_in_grid]],
 
 1836    uint simd_gid [[simdgroup_index_in_threadgroup]],
 
 1837    uint simd_lid [[thread_index_in_simdgroup]]) {
 
 1838  int M = x_shape[x_batch_ndims];
 
 
 1876    const int group_size,
 
 1878    const bool aligned_N,
 
 1883    const device uint32_t* w [[buffer(0)]],
 
 1884    const device T* scales [[buffer(1)]],
 
 1885    const device T* biases [[buffer(2)]],
 
 1886    const device T* x [[buffer(3)]],
 
 1887    device T* y [[buffer(4)]],
 
 1888    const constant 
int& K [[buffer(5)]],
 
 1889    const constant 
int& N [[buffer(6)]],
 
 1890    const constant 
int& M [[buffer(7)]],
 
 1891    const constant 
int& x_batch_ndims [[buffer(8)]],
 
 1892    const constant 
int* x_shape [[buffer(9)]],
 
 1893    const constant int64_t* x_strides [[buffer(10)]],
 
 1894    const constant 
int& w_batch_ndims [[buffer(11)]],
 
 1895    const constant 
int* w_shape [[buffer(12)]],
 
 1896    const constant int64_t* w_strides [[buffer(13)]],
 
 1897    const constant int64_t* s_strides [[buffer(14)]],
 
 1898    const constant int64_t* b_strides [[buffer(15)]],
 
 1899    const constant 
int& batch_ndims [[buffer(16)]],
 
 1900    const constant 
int* batch_shape [[buffer(17)]],
 
 1901    const device uint32_t* lhs_indices [[buffer(18)]],
 
 1902    const device uint32_t* rhs_indices [[buffer(19)]],
 
 1903    const constant int64_t* lhs_strides [[buffer(20)]],
 
 1904    const constant int64_t* rhs_strides [[buffer(21)]],
 
 1905    uint3 tid [[threadgroup_position_in_grid]],
 
 1906    uint lid [[thread_index_in_threadgroup]],
 
 1907    uint simd_gid [[simdgroup_index_in_threadgroup]],
 
 1908    uint simd_lid [[thread_index_in_simdgroup]]) {
 
 1911  constexpr int BK_padded = (BK + 16 / 
sizeof(T));
 
 1913  threadgroup T Xs[BM * BK_padded];
 
 1914  threadgroup T Ws[BN * BK_padded];
 
 1939      w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
 
 
 1944    const int group_size,
 
 1950    const device uint32_t* w [[buffer(0)]],
 
 1951    const device T* scales [[buffer(1)]],
 
 1952    const device T* biases [[buffer(2)]],
 
 1953    const device T* x [[buffer(3)]],
 
 1954    device T* y [[buffer(4)]],
 
 1955    const constant 
int& K [[buffer(5)]],
 
 1956    const constant 
int& N [[buffer(6)]],
 
 1957    const constant 
int& M [[buffer(7)]],
 
 1958    const constant 
int& x_batch_ndims [[buffer(8)]],
 
 1959    const constant 
int* x_shape [[buffer(9)]],
 
 1960    const constant int64_t* x_strides [[buffer(10)]],
 
 1961    const constant 
int& w_batch_ndims [[buffer(11)]],
 
 1962    const constant 
int* w_shape [[buffer(12)]],
 
 1963    const constant int64_t* w_strides [[buffer(13)]],
 
 1964    const constant int64_t* s_strides [[buffer(14)]],
 
 1965    const constant int64_t* b_strides [[buffer(15)]],
 
 1966    const constant 
int& batch_ndims [[buffer(16)]],
 
 1967    const constant 
int* batch_shape [[buffer(17)]],
 
 1968    const device uint32_t* lhs_indices [[buffer(18)]],
 
 1969    const device uint32_t* rhs_indices [[buffer(19)]],
 
 1970    const constant int64_t* lhs_strides [[buffer(20)]],
 
 1971    const constant int64_t* rhs_strides [[buffer(21)]],
 
 1972    uint3 tid [[threadgroup_position_in_grid]],
 
 1973    uint lid [[thread_index_in_threadgroup]],
 
 1974    uint simd_gid [[simdgroup_index_in_threadgroup]],
 
 1975    uint simd_lid [[thread_index_in_simdgroup]]) {
 
 1978  constexpr int BK_padded = (BK + 16 / 
sizeof(T));
 
 1979  constexpr int BN_padded = (BN + 16 / 
sizeof(T));
 
 1981  threadgroup T Xs[BM * BK_padded];
 
 1982  threadgroup T Ws[BK * BN_padded];
 
 2007      w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
 
 
 2010template <
typename T, const 
int group_size, const 
int bits>
 
 2012    const device T* w [[buffer(0)]],
 
 2013    device uint8_t* out [[buffer(1)]],
 
 2014    device T* scales [[buffer(2)]],
 
 2015    device T* biases [[buffer(3)]],
 
 2016    uint2 index [[thread_position_in_grid]],
 
 2017    uint2 grid_dim [[threads_per_grid]]) {
 
 2018  constexpr T eps = T(1e-7);
 
 2020  constexpr T n_bins = (1 << bits) - 1;
 
 2021  constexpr int packs_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
 
 2022  constexpr int values_per_reduce = group_size / 
simd_size;
 
 2023  constexpr int writes_per_reduce = packs_per_int / values_per_reduce;
 
 2024  constexpr int writes_per_pack =
 
 2025      writes_per_reduce > 1 ? 1 : values_per_reduce / packs_per_int;
 
 2026  constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
 
 2027  constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
 
 2031      "Group size must be divisible by simd size.");
 
 2033  size_t offset = index.x + grid_dim.x * size_t(index.y);
 
 2034  size_t in_index = offset * values_per_reduce;
 
 2035  size_t out_index = power_of_2_bits
 
 2036      ? offset * writes_per_pack
 
 2037      : offset * bytes_per_pack / writes_per_reduce;
 
 2039  T w_thread[values_per_reduce];
 
 2043#pragma clang loop unroll(full) 
 2044  for (
int i = 0; i < values_per_reduce; i++) {
 
 2045    T val = w[in_index + i];
 
 2047    w_min = 
min(w_min, val);
 
 2048    w_max = 
max(w_max, val);
 
 2054  T scale = 
max((w_max - w_min) / n_bins, eps);
 
 2055  bool side = 
abs(w_min) > 
abs(w_max);
 
 2056  scale = side ? scale : -scale;
 
 2057  T edge = side ? w_min : w_max;
 
 2058  T q0 = 
round(edge / scale);
 
 2059  bool at_zero = q0 == 0.0f;
 
 2060  scale = at_zero ? scale : edge / q0;
 
 2061  T bias = at_zero ? T(0) : edge;
 
 2064  size_t gindex = in_index / group_size;
 
 2065  if (in_index % group_size == 0) {
 
 2066    scales[gindex] = scale;
 
 2067    biases[gindex] = bias;
 
 2071  uint32_t output = 0;
 
 2073#pragma clang loop unroll(full) 
 2074  for (
int i = 0; i < values_per_reduce; i++) {
 
 2075    uint8_t val = 
min(
round((w_thread[i] - bias) / scale), n_bins);
 
 2079      output += val << (bits * (i % packs_per_int));
 
 2082    if (packs_per_int < values_per_reduce &&
 
 2083        i % packs_per_int == packs_per_int - 1) {
 
 2084      out[out_index + i / packs_per_int] = output;
 
 2087#pragma clang loop unroll(full) 
 2088      for (
int j = 1; j < writes_per_reduce; j++) {
 
 2090        output += sval << (bits * (j * values_per_reduce + i));
 
 2094  if (bits == 3 || bits == 6) {
 
 2095    if (in_index % packs_per_int == 0 && out_index % bytes_per_pack == 0) {
 
 2096      out[out_index] = output & 0xff;
 
 2097      out[out_index + 1] = (output & 0xff00) >> 8;
 
 2098      out[out_index + 2] = (output & 0xff0000) >> 16;
 
 2101    if (writes_per_reduce > 0 && out_index % writes_per_reduce == 0) {
 
 2102      out[out_index / writes_per_reduce] = output;
 
 
 2107template <
typename T, const 
int group_size, const 
int bits>
 
 2109    const device uint8_t* w [[buffer(0)]],
 
 2110    const device T* scales [[buffer(1)]],
 
 2111    const device T* biases [[buffer(2)]],
 
 2112    device T* out [[buffer(3)]],
 
 2113    uint2 index [[thread_position_in_grid]],
 
 2114    uint2 grid_dim [[threads_per_grid]]) {
 
 2115  constexpr int packs_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
 
 2116  constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
 
 2117  constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
 
 2119  size_t offset = index.x + grid_dim.x * size_t(index.y);
 
 2120  size_t oindex = offset * packs_per_int;
 
 2121  size_t gindex = oindex / group_size;
 
 2122  T scale = scales[gindex];
 
 2123  T bias = biases[gindex];
 
 2128    w += offset * bytes_per_pack;
 
 2129    out[0] = (w[0] & 0x7) * scale + bias;
 
 2130    out[1] = ((w[0] & 0x38) >> 3) * scale + bias;
 
 2131    out[2] = (((w[0] & 0xc0) >> 6) + ((w[1] & 0x1) << 2)) * scale + bias;
 
 2132    out[3] = ((w[1] & 0xe) >> 1) * scale + bias;
 
 2133    out[4] = ((w[1] & 0x70) >> 4) * scale + bias;
 
 2134    out[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias;
 
 2135    out[6] = ((w[2] & 0x1c) >> 2) * scale + bias;
 
 2136    out[7] = ((w[2] & 0xe0) >> 5) * scale + bias;
 
 2138  } 
else if (bits == 6) {
 
 2139    w += offset * bytes_per_pack;
 
 2140    out[0] = (w[0] & 0x3f) * scale + bias;
 
 2141    out[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias;
 
 2142    out[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias;
 
 2143    out[3] = ((w[2] >> 2) & 0x3f) * scale + bias;
 
 2145    uint val = w[offset];
 
 2146#pragma clang loop unroll(full) 
 2147    for (
int i = 0; i < packs_per_int; i++) {
 
 2150        d = (val >> (bits * i)) & 0x03;
 
 2151      } 
else if (bits == 4) {
 
 2152        d = (val >> (bits * i)) & 0x0f;
 
 2153      } 
else if (bits == 8) {
 
 2156      out[i] = scale * d + bias;
 
 
#define MLX_MTL_CONST
Definition gemv_masked.h:7
 
array bits(const Shape &shape, int width, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
Generate an array with type uint32 filled with random bits.
 
#define MLX_MTL_CONST
Definition quantized.h:8
 
U qdot_safe(const device uint8_t *w, const thread U *x_thread, U scale, U bias, U sum, int N)
Definition quantized.h:225
 
METAL_FUNC void qmm_n_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, threadgroup T *Xs, threadgroup T *Ws, const constant int &K, const constant int &N, const constant int &M, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:1083
 
METAL_FUNC void qvm_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const int in_vec_size, const int out_vec_size, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:843
 
void bs_qvm(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant int64_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant int64_t *w_strides, const constant int64_t *s_strides, const constant int64_t *b_strides, const constant int &batch_ndims, const constant int *batch_shape, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, const constant int64_t *lhs_strides, const constant int64_t *rhs_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1813
 
void bs_qmm_n(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &K, const constant int &N, const constant int &M, const constant int &x_batch_ndims, const constant int *x_shape, const constant int64_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant int64_t *w_strides, const constant int64_t *s_strides, const constant int64_t *b_strides, const constant int &batch_ndims, const constant int *batch_shape, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, const constant int64_t *lhs_strides, const constant int64_t *rhs_strides, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:1949
 
void qmv_fast(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant int64_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant int64_t *w_strides, const constant int64_t *s_strides, const constant int64_t *b_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1358
 
void bs_qmv_fast(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant int64_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant int64_t *w_strides, const constant int64_t *s_strides, const constant int64_t *b_strides, const constant int &batch_ndims, const constant int *batch_shape, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, const constant int64_t *lhs_strides, const constant int64_t *rhs_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1689
 
METAL_FUNC void adjust_matrix_offsets(const device T *&x, const device uint32_t *&w, const device T *&scales, const device T *&biases, device T *&y, int output_stride, const constant int &x_batch_ndims, const constant int *x_shape, const constant int64_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant int64_t *w_strides, const constant int64_t *s_strides, const constant int64_t *b_strides, uint3 tid)
Definition quantized.h:1213
 
void affine_quantize(const device T *w, device uint8_t *out, device T *scales, device T *biases, uint2 index, uint2 grid_dim)
Definition quantized.h:2011
 
void qvm(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant int64_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant int64_t *w_strides, const constant int64_t *s_strides, const constant int64_t *b_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1462
 
void affine_dequantize(const device uint8_t *w, const device T *scales, const device T *biases, device T *out, uint2 index, uint2 grid_dim)
Definition quantized.h:2108
 
static constant constexpr const int SIMD_SIZE
Definition quantized.h:10
 
void bs_qmm_t(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &K, const constant int &N, const constant int &M, const constant int &x_batch_ndims, const constant int *x_shape, const constant int64_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant int64_t *w_strides, const constant int64_t *s_strides, const constant int64_t *b_strides, const constant int &batch_ndims, const constant int *batch_shape, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, const constant int64_t *lhs_strides, const constant int64_t *rhs_strides, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:1882
 
void qmm_n(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &K, const constant int &N, const constant int &M, const constant int &x_batch_ndims, const constant int *x_shape, const constant int64_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant int64_t *w_strides, const constant int64_t *s_strides, const constant int64_t *b_strides, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:1636
 
static constant constexpr const int QUAD_SIZE
Definition quantized.h:11
 
void qmv(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant int64_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant int64_t *w_strides, const constant int64_t *s_strides, const constant int64_t *b_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1410
 
void qmm_t(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &K, const constant int &N, const constant int &M, const constant int &x_batch_ndims, const constant int *x_shape, const constant int64_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant int64_t *w_strides, const constant int64_t *s_strides, const constant int64_t *b_strides, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:1578
 
U load_vector(const device T *x, thread U *x_thread)
Definition quantized.h:14
 
METAL_FUNC void qmv_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:688
 
void qmv_quad(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant int64_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant int64_t *w_strides, const constant int64_t *s_strides, const constant int64_t *b_strides, uint3 tid, uint quad_gid, uint quad_lid)
Definition quantized.h:1306
 
U load_vector_safe(const device T *x, thread U *x_thread, int N)
Definition quantized.h:77
 
void qvm_split_k(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant int64_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant int64_t *w_strides, const constant int64_t *s_strides, const constant int64_t *b_strides, const constant int &final_block_size, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1514
 
void bs_qmv(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant int64_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant int64_t *w_strides, const constant int64_t *s_strides, const constant int64_t *b_strides, const constant int &batch_ndims, const constant int *batch_shape, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, const constant int64_t *lhs_strides, const constant int64_t *rhs_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1751
 
U qdot(const device uint8_t *w, const thread U *x_thread, U scale, U bias, U sum)
Definition quantized.h:145
 
METAL_FUNC void qmv_fast_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:620
 
METAL_FUNC void qmv_quad_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, constant int &in_vec_size, const constant int &out_vec_size, uint3 tid, uint quad_gid, uint quad_lid)
Definition quantized.h:563
 
void qouter(const thread uint8_t *w, U x, U scale, U bias, thread U *result)
Definition quantized.h:307
 
void dequantize(const device uint8_t *w, U scale, U bias, threadgroup U *w_local)
Definition quantized.h:372
 
METAL_FUNC void qmm_t_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, threadgroup T *Xs, threadgroup T *Ws, const constant int &K, const constant int &N, const constant int &M, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:958
 
U type
Definition utils.h:417
 
static const constant U max
Definition utils.h:24
 
Definition quantized.h:443
 
const int group_stride
Definition quantized.h:464
 
static constant constexpr const short BCOLS_PACKED
Definition quantized.h:456
 
const device T * biases
Definition quantized.h:473
 
short group_step_cnt
Definition quantized.h:463
 
static constant constexpr const short group_steps
Definition quantized.h:459
 
const short thread_idx
Definition quantized.h:466
 
QuantizedBlockLoader(const device uint8_t *src_, const device T *scales_, const device T *biases_, const int src_ld_, threadgroup T *dst_, ushort simd_group_id, ushort simd_lane_id)
Definition quantized.h:475
 
const device T * scales
Definition quantized.h:472
 
static constant constexpr const short n_reads
Definition quantized.h:457
 
void next()
Definition quantized.h:541
 
void load_safe(short2 src_tile_dim) const
Definition quantized.h:511
 
const int src_ld
Definition quantized.h:461
 
const short bi
Definition quantized.h:467
 
void load_unsafe() const
Definition quantized.h:498
 
static constant constexpr const short pack_factor
Definition quantized.h:454
 
threadgroup T * dst
Definition quantized.h:470
 
const device uint8_t * src
Definition quantized.h:471
 
const int tile_stride
Definition quantized.h:462
 
static constant constexpr const short bytes_per_pack
Definition quantized.h:455
 
const short bj
Definition quantized.h:468