mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
066336b60e |
@@ -712,65 +712,6 @@ void winograd_conv_2D_gpu(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void depthwise_conv_2D_gpu(
|
|
||||||
const Stream& s,
|
|
||||||
metal::Device& d,
|
|
||||||
const array& in,
|
|
||||||
const array& wt,
|
|
||||||
array out,
|
|
||||||
const MLXConvParams<2>& conv_params) {
|
|
||||||
std::ostringstream kname;
|
|
||||||
kname << "depthwise_conv_2d_" << type_to_name(out);
|
|
||||||
std::string base_name = kname.str();
|
|
||||||
|
|
||||||
const int N = conv_params.N;
|
|
||||||
const int ker_h = conv_params.wS[0];
|
|
||||||
const int ker_w = conv_params.wS[1];
|
|
||||||
const int str_h = conv_params.str[0];
|
|
||||||
const int str_w = conv_params.str[1];
|
|
||||||
const int tc = 8;
|
|
||||||
const int tw = 8;
|
|
||||||
const int th = 4;
|
|
||||||
const bool do_flip = conv_params.flip;
|
|
||||||
|
|
||||||
metal::MTLFCList func_consts = {
|
|
||||||
{&ker_h, MTL::DataType::DataTypeInt, 00},
|
|
||||||
{&ker_w, MTL::DataType::DataTypeInt, 01},
|
|
||||||
{&str_h, MTL::DataType::DataTypeInt, 10},
|
|
||||||
{&str_w, MTL::DataType::DataTypeInt, 11},
|
|
||||||
{&th, MTL::DataType::DataTypeInt, 100},
|
|
||||||
{&tw, MTL::DataType::DataTypeInt, 101},
|
|
||||||
{&do_flip, MTL::DataType::DataTypeBool, 200},
|
|
||||||
};
|
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
kname << "_ker_h_" << ker_h
|
|
||||||
<< "_ker_w_" << ker_w
|
|
||||||
<< "_str_h_" << str_h
|
|
||||||
<< "_str_w_" << str_w
|
|
||||||
<< "_tgp_h_" << th
|
|
||||||
<< "_tgp_w_" << tw
|
|
||||||
<< "_do_flip_" << (do_flip ? 't' : 'n'); // clang-format on
|
|
||||||
|
|
||||||
std::string hash_name = kname.str();
|
|
||||||
|
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
|
||||||
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts);
|
|
||||||
compute_encoder.set_compute_pipeline_state(kernel);
|
|
||||||
|
|
||||||
compute_encoder.set_input_array(in, 0);
|
|
||||||
compute_encoder.set_input_array(wt, 1);
|
|
||||||
compute_encoder.set_output_array(out, 2);
|
|
||||||
|
|
||||||
compute_encoder.set_bytes(conv_params, 3);
|
|
||||||
|
|
||||||
MTL::Size group_dims = MTL::Size(tc, tw, th);
|
|
||||||
MTL::Size grid_dims = MTL::Size(
|
|
||||||
conv_params.C / tc, conv_params.oS[1] / tw, (conv_params.oS[0] / th) * N);
|
|
||||||
|
|
||||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
|
||||||
}
|
|
||||||
|
|
||||||
void conv_2D_gpu(
|
void conv_2D_gpu(
|
||||||
const Stream& s,
|
const Stream& s,
|
||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
@@ -813,20 +754,11 @@ void conv_2D_gpu(
|
|||||||
bool is_kdil_one = conv_params.kdil[0] == 1 && conv_params.kdil[1] == 1;
|
bool is_kdil_one = conv_params.kdil[0] == 1 && conv_params.kdil[1] == 1;
|
||||||
bool is_idil_one = conv_params.idil[0] == 1 && conv_params.idil[1] == 1;
|
bool is_idil_one = conv_params.idil[0] == 1 && conv_params.idil[1] == 1;
|
||||||
|
|
||||||
if (is_idil_one && groups > 1) {
|
if (groups > 1) {
|
||||||
const int C_per_group = conv_params.C / groups;
|
const int C_per_group = conv_params.C / groups;
|
||||||
const int O_per_group = conv_params.O / groups;
|
const int O_per_group = conv_params.O / groups;
|
||||||
|
|
||||||
if (C_per_group == 1 && O_per_group == 1 && is_kdil_one &&
|
if (is_idil_one && (C_per_group <= 4 || C_per_group % 16 == 0) &&
|
||||||
conv_params.wS[0] <= 7 && conv_params.wS[1] <= 7 &&
|
|
||||||
conv_params.str[0] <= 2 && conv_params.str[1] <= 2 &&
|
|
||||||
conv_params.oS[0] % 8 == 0 && conv_params.oS[1] % 8 == 0 &&
|
|
||||||
conv_params.wt_strides[1] == conv_params.wS[1] &&
|
|
||||||
conv_params.C % 16 == 0 && conv_params.C == conv_params.O) {
|
|
||||||
return depthwise_conv_2D_gpu(s, d, in, wt, out, conv_params);
|
|
||||||
}
|
|
||||||
|
|
||||||
if ((C_per_group <= 4 || C_per_group % 16 == 0) &&
|
|
||||||
(O_per_group <= 16 || O_per_group % 16 == 0)) {
|
(O_per_group <= 16 || O_per_group % 16 == 0)) {
|
||||||
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
|
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -275,128 +275,6 @@ instantiate_naive_conv_2d_blocks(float32, float);
|
|||||||
instantiate_naive_conv_2d_blocks(float16, half);
|
instantiate_naive_conv_2d_blocks(float16, half);
|
||||||
instantiate_naive_conv_2d_blocks(bfloat16, bfloat16_t);
|
instantiate_naive_conv_2d_blocks(bfloat16, bfloat16_t);
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
/// Depthwise convolution kernels
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
constant int ker_h [[function_constant(00)]];
|
|
||||||
constant int ker_w [[function_constant(01)]];
|
|
||||||
constant int str_h [[function_constant(10)]];
|
|
||||||
constant int str_w [[function_constant(11)]];
|
|
||||||
constant int tgp_h [[function_constant(100)]];
|
|
||||||
constant int tgp_w [[function_constant(101)]];
|
|
||||||
constant bool do_flip [[function_constant(200)]];
|
|
||||||
|
|
||||||
constant int span_h = tgp_h * str_h + ker_h - 1;
|
|
||||||
constant int span_w = tgp_w * str_w + ker_w - 1;
|
|
||||||
constant int span_hw = span_h * span_w;
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
[[kernel]] void depthwise_conv_2d(
|
|
||||||
const device T* in [[buffer(0)]],
|
|
||||||
const device T* wt [[buffer(1)]],
|
|
||||||
device T* out [[buffer(2)]],
|
|
||||||
const constant MLXConvParams<2>& params [[buffer(3)]],
|
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
|
||||||
uint3 lid [[thread_position_in_threadgroup]],
|
|
||||||
uint3 gid [[thread_position_in_grid]],
|
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
||||||
constexpr int tc = 8;
|
|
||||||
constexpr int tw = 8;
|
|
||||||
constexpr int th = 4;
|
|
||||||
|
|
||||||
constexpr int c_per_thr = 8;
|
|
||||||
|
|
||||||
constexpr int TGH = th * 2 + 6;
|
|
||||||
constexpr int TGW = tw * 2 + 6;
|
|
||||||
constexpr int TGC = tc;
|
|
||||||
|
|
||||||
threadgroup T ins[TGH * TGW * TGC];
|
|
||||||
|
|
||||||
const int n_tgblocks_h = params.oS[0] / th;
|
|
||||||
const int n = tid.z / n_tgblocks_h;
|
|
||||||
const int tghid = tid.z % n_tgblocks_h;
|
|
||||||
const int oh = tghid * th + lid.z;
|
|
||||||
const int ow = gid.y;
|
|
||||||
const int c = gid.x;
|
|
||||||
|
|
||||||
in += n * params.in_strides[0];
|
|
||||||
|
|
||||||
// Load in
|
|
||||||
{
|
|
||||||
constexpr int n_threads = th * tw * tc;
|
|
||||||
const int tg_oh = (tghid * th) * str_h - params.pad[0];
|
|
||||||
const int tg_ow = (tid.y * tw) * str_w - params.pad[1];
|
|
||||||
const int tg_c = tid.x * tc;
|
|
||||||
|
|
||||||
const int thread_idx = simd_gid * 32 + simd_lid;
|
|
||||||
constexpr int thr_per_hw = tc / c_per_thr;
|
|
||||||
constexpr int hw_per_group = n_threads / thr_per_hw;
|
|
||||||
|
|
||||||
const int thr_c = thread_idx % thr_per_hw;
|
|
||||||
const int thr_hw = thread_idx / thr_per_hw;
|
|
||||||
|
|
||||||
for (int hw = thr_hw; hw < span_hw; hw += hw_per_group) {
|
|
||||||
const int h = hw / span_w;
|
|
||||||
const int w = hw % span_w;
|
|
||||||
|
|
||||||
const int ih = tg_oh + h;
|
|
||||||
const int iw = tg_ow + w;
|
|
||||||
|
|
||||||
const int in_s_offset = h * span_w * TGC + w * TGC;
|
|
||||||
|
|
||||||
if (ih >= 0 && ih < params.iS[0] && iw >= 0 && iw < params.iS[1]) {
|
|
||||||
const auto in_load =
|
|
||||||
in + ih * params.in_strides[1] + iw * params.in_strides[2] + tg_c;
|
|
||||||
|
|
||||||
MLX_MTL_PRAGMA_UNROLL
|
|
||||||
for (int cc = 0; cc < c_per_thr; ++cc) {
|
|
||||||
ins[in_s_offset + c_per_thr * thr_c + cc] =
|
|
||||||
in_load[c_per_thr * thr_c + cc];
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
MLX_MTL_PRAGMA_UNROLL
|
|
||||||
for (int cc = 0; cc < c_per_thr; ++cc) {
|
|
||||||
ins[in_s_offset + c_per_thr * thr_c + cc] = T(0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
wt += c * params.wt_strides[0];
|
|
||||||
|
|
||||||
const auto ins_ptr =
|
|
||||||
&ins[lid.z * str_h * span_w * TGC + lid.y * str_w * TGC + lid.x];
|
|
||||||
float o = 0.;
|
|
||||||
for (int h = 0; h < ker_h; ++h) {
|
|
||||||
for (int w = 0; w < ker_w; ++w) {
|
|
||||||
int wt_h = h;
|
|
||||||
int wt_w = w;
|
|
||||||
if (do_flip) {
|
|
||||||
wt_h = ker_h - h - 1;
|
|
||||||
wt_w = ker_w - w - 1;
|
|
||||||
}
|
|
||||||
auto inv = ins_ptr[h * span_w * TGC + w * TGC];
|
|
||||||
auto wtv = wt[wt_h * ker_w + wt_w];
|
|
||||||
o += inv * wtv;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
threadgroup_barrier(mem_flags::mem_none);
|
|
||||||
|
|
||||||
out += n * params.out_strides[0] + oh * params.out_strides[1] +
|
|
||||||
ow * params.out_strides[2];
|
|
||||||
out[c] = static_cast<T>(o);
|
|
||||||
}
|
|
||||||
|
|
||||||
#define instantiate_depthconv2d(iname, itype) \
|
|
||||||
instantiate_kernel("depthwise_conv_2d_" #iname, depthwise_conv_2d, itype)
|
|
||||||
|
|
||||||
instantiate_depthconv2d(float32, float);
|
|
||||||
instantiate_depthconv2d(float16, half);
|
|
||||||
instantiate_depthconv2d(bfloat16, bfloat16_t);
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
/// Winograd kernels
|
/// Winograd kernels
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|||||||
51
mlx/fast.cpp
51
mlx/fast.cpp
@@ -567,9 +567,8 @@ array scaled_dot_product_attention(
|
|||||||
const array& keys,
|
const array& keys,
|
||||||
const array& values,
|
const array& values,
|
||||||
const float scale,
|
const float scale,
|
||||||
const std::string& mask_mode /* = "" */,
|
const std::variant<std::monostate, std::string, array>& mask /* = {}*/,
|
||||||
const std::vector<array>& mask_arrs /* = {} */,
|
StreamOrDevice s) {
|
||||||
StreamOrDevice s /* = {}*/) {
|
|
||||||
for (const auto& tensor : {queries, keys, values}) {
|
for (const auto& tensor : {queries, keys, values}) {
|
||||||
if (tensor.ndim() != 4) {
|
if (tensor.ndim() != 4) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
@@ -578,49 +577,29 @@ array scaled_dot_product_attention(
|
|||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Check valid mask
|
|
||||||
if (mask_mode != "" && mask_mode != "causal" && mask_mode != "array") {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[scaled_dot_product_attention] Invalid mask_mode " << mask_mode
|
|
||||||
<< ". mask_mode must be 'causal', 'array' or ''.";
|
|
||||||
throw std::invalid_argument(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
bool do_causal = false;
|
bool do_causal = false;
|
||||||
bool has_mask = false;
|
bool has_mask = !std::holds_alternative<std::monostate>(mask);
|
||||||
bool has_arr_mask = false;
|
bool has_str_mask = has_mask && std::holds_alternative<std::string>(mask);
|
||||||
|
bool has_arr_mask = has_mask && std::holds_alternative<array>(mask);
|
||||||
bool has_bool_mask = false;
|
bool has_bool_mask = false;
|
||||||
|
|
||||||
if (mask_mode == "causal") {
|
if (has_str_mask) {
|
||||||
has_mask = true;
|
if (std::get<std::string>(mask) != "causal") {
|
||||||
do_causal = true;
|
|
||||||
|
|
||||||
if (!mask_arrs.empty()) {
|
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[scaled_dot_product_attention] Invalid mask_arrs for mask_mode "
|
msg << "[scaled_dot_product_attention] invalid mask option '"
|
||||||
<< "'casusal'. No array masks supported.";
|
<< std::get<std::string>(mask) << "'. Must be 'causal', or an array.";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
|
} else {
|
||||||
|
do_causal = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (mask_mode == "array" || (mask_mode == "" && !mask_arrs.empty())) {
|
if (has_arr_mask && (std::get<array>(mask)).ndim() > 4) {
|
||||||
if (mask_arrs.size() != 1) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[scaled_dot_product_attention] Invalid mask_arrs for mask_mode "
|
|
||||||
<< "'" << mask_mode << "'. Only 1 mask array is supported, got "
|
|
||||||
<< mask_arrs.size() << "arrays.";
|
|
||||||
throw std::invalid_argument(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
has_mask = true;
|
|
||||||
has_arr_mask = true;
|
|
||||||
has_bool_mask = mask_arrs[0].dtype() == bool_;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (has_arr_mask && (mask_arrs[0]).ndim() > 4) {
|
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[scaled_dot_product_attention] the mask with shape "
|
msg << "[scaled_dot_product_attention] the mask with shape "
|
||||||
<< mask_arrs[0].shape() << " expected to have at most rank 4.";
|
<< (std::get<array>(mask)).shape()
|
||||||
|
<< " expected to have at most rank 4";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -757,7 +736,7 @@ array scaled_dot_product_attention(
|
|||||||
std::vector<array> inputs = {q, k, v};
|
std::vector<array> inputs = {q, k, v};
|
||||||
if (has_arr_mask) {
|
if (has_arr_mask) {
|
||||||
// Check type
|
// Check type
|
||||||
auto mask_arr = mask_arrs[0];
|
auto mask_arr = std::get<array>(mask);
|
||||||
has_bool_mask = mask_arr.dtype() == bool_;
|
has_bool_mask = mask_arr.dtype() == bool_;
|
||||||
if (promote_types(mask_arr.dtype(), final_type) != final_type) {
|
if (promote_types(mask_arr.dtype(), final_type) != final_type) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
|
|||||||
@@ -48,8 +48,7 @@ array scaled_dot_product_attention(
|
|||||||
const array& keys,
|
const array& keys,
|
||||||
const array& values,
|
const array& values,
|
||||||
const float scale,
|
const float scale,
|
||||||
const std::string& mask_mode = "",
|
const std::variant<std::monostate, std::string, array>& mask = {},
|
||||||
const std::vector<array>& mask_arrs = {},
|
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
std::tuple<array, array, array> affine_quantize(
|
std::tuple<array, array, array> affine_quantize(
|
||||||
|
|||||||
@@ -219,7 +219,7 @@ std::unordered_map<std::string, array> load_arrays(gguf_ctx* ctx) {
|
|||||||
|
|
||||||
while (gguf_get_tensor(ctx, &tensor)) {
|
while (gguf_get_tensor(ctx, &tensor)) {
|
||||||
if (tensor.type == GGUF_TYPE_Q4_0 || tensor.type == GGUF_TYPE_Q4_1 ||
|
if (tensor.type == GGUF_TYPE_Q4_0 || tensor.type == GGUF_TYPE_Q4_1 ||
|
||||||
tensor.type == GGUF_TYPE_Q8_0) {
|
tensor.type == GGUF_TYPE_Q4_K || tensor.type == GGUF_TYPE_Q8_0) {
|
||||||
gguf_load_quantized(array_map, tensor);
|
gguf_load_quantized(array_map, tensor);
|
||||||
} else {
|
} else {
|
||||||
std::string name(tensor.name, tensor.namelen);
|
std::string name(tensor.name, tensor.namelen);
|
||||||
|
|||||||
@@ -70,6 +70,65 @@ void extract_q4_1_data(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Extracts (weight, scales, biases) from Q4_K tensors.
|
||||||
|
// Data layout is:
|
||||||
|
// * |FP16 s_of_scales | +
|
||||||
|
// * |FP16 s_of_mins | +
|
||||||
|
// * |16 6 bit integers d,m pairs, one per sub-block of 32 ele | +
|
||||||
|
// * |256 x 4bit weights|
|
||||||
|
void extract_q4_k_data(
|
||||||
|
const gguf_tensor& tensor,
|
||||||
|
array& weights_arr,
|
||||||
|
array& scales_arr,
|
||||||
|
array& biases_arr) {
|
||||||
|
auto data = static_cast<uint8_t*>(tensor.weights_data);
|
||||||
|
auto weights = weights_arr.data<int8_t>();
|
||||||
|
auto scales = scales_arr.data<float16_t>();
|
||||||
|
auto biases = biases_arr.data<float16_t>();
|
||||||
|
for (int64_t g = 0; g < scales_arr.size() / 8; ++g) {
|
||||||
|
auto scales_scale = *((float16_t*)data);
|
||||||
|
auto mins_scale = *((float16_t*)data + 1);
|
||||||
|
data += 4;
|
||||||
|
|
||||||
|
/* Scale scales/mins. */
|
||||||
|
for (int j = 0; j < 8; j++) {
|
||||||
|
uint8_t d, m;
|
||||||
|
if (j < 4) {
|
||||||
|
d = data[j] & 63;
|
||||||
|
m = data[j + 4] & 63;
|
||||||
|
} else {
|
||||||
|
d = (data[j + 4] & 0xF) | ((data[j - 4] >> 6) << 4);
|
||||||
|
m = (data[j + 4] >> 4) | ((data[j - 0] >> 6) << 4);
|
||||||
|
}
|
||||||
|
scales[g * 8 + j] = d * scales_scale;
|
||||||
|
biases[g * 8 + j] = -(m * mins_scale);
|
||||||
|
}
|
||||||
|
data += 12;
|
||||||
|
for (int i = 0; i < 8; i += 2) {
|
||||||
|
std::fill_n(weights, 32, 0);
|
||||||
|
|
||||||
|
// First 32 weights are in the lower bits
|
||||||
|
for (int j = 0; j < 32; ++j) {
|
||||||
|
uint8_t x = (data[j] & 0x0F);
|
||||||
|
if (j % 2 != 0) {
|
||||||
|
x <<= 4;
|
||||||
|
}
|
||||||
|
weights[j / 2] += x;
|
||||||
|
}
|
||||||
|
// Last 32 weights are in the higher bits
|
||||||
|
for (int j = 0; j < 32; ++j) {
|
||||||
|
uint8_t x = (data[j] >> 4);
|
||||||
|
if (j % 2 != 0) {
|
||||||
|
x <<= 4;
|
||||||
|
}
|
||||||
|
weights[16 + j / 2] += x;
|
||||||
|
}
|
||||||
|
weights += 32;
|
||||||
|
data += 32;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Extracts (weight, scales, biases) from Q8_0 tensors.
|
// Extracts (weight, scales, biases) from Q8_0 tensors.
|
||||||
// Data layout is: |16 bit scale|32 x 8bit weights|.
|
// Data layout is: |16 bit scale|32 x 8bit weights|.
|
||||||
void extract_q8_0_data(
|
void extract_q8_0_data(
|
||||||
@@ -100,11 +159,12 @@ void extract_q8_0_data(
|
|||||||
void gguf_load_quantized(
|
void gguf_load_quantized(
|
||||||
std::unordered_map<std::string, array>& a,
|
std::unordered_map<std::string, array>& a,
|
||||||
const gguf_tensor& tensor) {
|
const gguf_tensor& tensor) {
|
||||||
uint64_t weights_per_byte;
|
int bits;
|
||||||
if (tensor.type == GGUF_TYPE_Q4_0 || tensor.type == GGUF_TYPE_Q4_1) {
|
if (tensor.type == GGUF_TYPE_Q4_0 || tensor.type == GGUF_TYPE_Q4_1 ||
|
||||||
weights_per_byte = 2;
|
tensor.type == GGUF_TYPE_Q4_K) {
|
||||||
|
bits = 4;
|
||||||
} else { // tensor.type == GGUF_TYPE_Q8_0
|
} else { // tensor.type == GGUF_TYPE_Q8_0
|
||||||
weights_per_byte = 1;
|
bits = 8;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string name(tensor.name, tensor.namelen);
|
std::string name(tensor.name, tensor.namelen);
|
||||||
@@ -119,7 +179,7 @@ void gguf_load_quantized(
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto weights_shape = shape;
|
auto weights_shape = shape;
|
||||||
weights_shape.back() /= (weights_per_byte * 4);
|
weights_shape.back() = weights_shape.back() * bits / 32;
|
||||||
auto w_nbytes = uint32.size() *
|
auto w_nbytes = uint32.size() *
|
||||||
std::accumulate(weights_shape.begin(),
|
std::accumulate(weights_shape.begin(),
|
||||||
weights_shape.end(),
|
weights_shape.end(),
|
||||||
@@ -139,6 +199,8 @@ void gguf_load_quantized(
|
|||||||
extract_q4_0_data(tensor, weights, scales, biases);
|
extract_q4_0_data(tensor, weights, scales, biases);
|
||||||
} else if (tensor.type == GGUF_TYPE_Q4_1) {
|
} else if (tensor.type == GGUF_TYPE_Q4_1) {
|
||||||
extract_q4_1_data(tensor, weights, scales, biases);
|
extract_q4_1_data(tensor, weights, scales, biases);
|
||||||
|
} else if (tensor.type == GGUF_TYPE_Q4_K) {
|
||||||
|
extract_q4_k_data(tensor, weights, scales, biases);
|
||||||
} else if (tensor.type == GGUF_TYPE_Q8_0) {
|
} else if (tensor.type == GGUF_TYPE_Q8_0) {
|
||||||
extract_q8_0_data(tensor, weights, scales, biases);
|
extract_q8_0_data(tensor, weights, scales, biases);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,7 +4,7 @@
|
|||||||
|
|
||||||
#define MLX_VERSION_MAJOR 0
|
#define MLX_VERSION_MAJOR 0
|
||||||
#define MLX_VERSION_MINOR 24
|
#define MLX_VERSION_MINOR 24
|
||||||
#define MLX_VERSION_PATCH 2
|
#define MLX_VERSION_PATCH 1
|
||||||
#define MLX_VERSION_NUMERIC \
|
#define MLX_VERSION_NUMERIC \
|
||||||
(100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH)
|
(100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH)
|
||||||
|
|
||||||
|
|||||||
@@ -124,39 +124,7 @@ void init_fast(nb::module_& parent_module) {
|
|||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"scaled_dot_product_attention",
|
"scaled_dot_product_attention",
|
||||||
[](const mx::array& queries,
|
&mx::fast::scaled_dot_product_attention,
|
||||||
const mx::array& keys,
|
|
||||||
const mx::array& values,
|
|
||||||
const float scale,
|
|
||||||
const std::variant<std::monostate, std::string, mx::array>& mask,
|
|
||||||
mx::StreamOrDevice s) {
|
|
||||||
bool has_mask = !std::holds_alternative<std::monostate>(mask);
|
|
||||||
bool has_str_mask =
|
|
||||||
has_mask && std::holds_alternative<std::string>(mask);
|
|
||||||
bool has_arr_mask = has_mask && std::holds_alternative<mx::array>(mask);
|
|
||||||
|
|
||||||
if (has_mask) {
|
|
||||||
if (has_str_mask) {
|
|
||||||
auto mask_str = std::get<std::string>(mask);
|
|
||||||
if (mask_str != "causal") {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[scaled_dot_product_attention] invalid mask option '"
|
|
||||||
<< mask_str << "'. Must be 'causal', or an array.";
|
|
||||||
throw std::invalid_argument(msg.str());
|
|
||||||
}
|
|
||||||
return mx::fast::scaled_dot_product_attention(
|
|
||||||
queries, keys, values, scale, mask_str, {}, s);
|
|
||||||
} else {
|
|
||||||
auto mask_arr = std::get<mx::array>(mask);
|
|
||||||
return mx::fast::scaled_dot_product_attention(
|
|
||||||
queries, keys, values, scale, "", {mask_arr}, s);
|
|
||||||
}
|
|
||||||
|
|
||||||
} else {
|
|
||||||
return mx::fast::scaled_dot_product_attention(
|
|
||||||
queries, keys, values, scale, "", {}, s);
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"q"_a,
|
"q"_a,
|
||||||
"k"_a,
|
"k"_a,
|
||||||
"v"_a,
|
"v"_a,
|
||||||
|
|||||||
@@ -707,11 +707,9 @@ class TestConv(mlx_tests.MLXTestCase):
|
|||||||
flip=flip,
|
flip=flip,
|
||||||
np_dtype=np_dtype,
|
np_dtype=np_dtype,
|
||||||
):
|
):
|
||||||
np.random.seed(0)
|
|
||||||
scale = 1.0 / math.sqrt(np.prod(wt_shape[1:]))
|
scale = 1.0 / math.sqrt(np.prod(wt_shape[1:]))
|
||||||
scale = min(0.3, scale)
|
in_np = np.random.normal(0.0, scale, in_shape).astype(np_dtype)
|
||||||
in_np = np.random.normal(0, scale, in_shape).astype(np_dtype)
|
wt_np = np.random.normal(0.0, scale, wt_shape).astype(np_dtype)
|
||||||
wt_np = np.random.normal(0, scale, wt_shape).astype(np_dtype)
|
|
||||||
|
|
||||||
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
|
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
|
||||||
|
|
||||||
@@ -1052,42 +1050,6 @@ class TestConv(mlx_tests.MLXTestCase):
|
|||||||
y2 = mx.conv2d(x, w, (1, 1), (1, 1), (1, 1), 1)
|
y2 = mx.conv2d(x, w, (1, 1), (1, 1), (1, 1), 1)
|
||||||
self.assertTrue(mx.allclose(y1, y2))
|
self.assertTrue(mx.allclose(y1, y2))
|
||||||
|
|
||||||
@unittest.skipIf(not has_torch, "requires Torch")
|
|
||||||
def test_torch_conv_depthwise(self):
|
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
shapes = (
|
|
||||||
# N, H, W, C kH, kW, O, strides, padding, groups
|
|
||||||
( 2, 16, 16, 32, 1, 1, 32, (2, 2), (1, 1), 32),
|
|
||||||
( 1, 16, 16, 32, 3, 3, 32, (2, 2), (1, 1), 32),
|
|
||||||
( 1, 32, 32, 32, 7, 7, 32, (1, 1), (3, 3), 32),
|
|
||||||
( 3, 32, 32, 32, 5, 5, 32, (1, 2), (0, 0), 32),
|
|
||||||
( 1, 32, 32, 32, 7, 7, 32, (2, 1), (1, 3), 32),
|
|
||||||
)
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
dtypes = [np.float32]
|
|
||||||
if mx.default_device() == mx.gpu:
|
|
||||||
dtypes += [np.float16]
|
|
||||||
|
|
||||||
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
|
|
||||||
for dtype in dtypes:
|
|
||||||
for flip in [False, True]:
|
|
||||||
Cw = C // groups
|
|
||||||
|
|
||||||
self.__conv_general_test(
|
|
||||||
(N, H, W, C),
|
|
||||||
(O, kH, kW, Cw),
|
|
||||||
strides,
|
|
||||||
padding,
|
|
||||||
kernel_dilation=1,
|
|
||||||
input_dilation=1,
|
|
||||||
groups=groups,
|
|
||||||
flip=flip,
|
|
||||||
np_dtype=dtype,
|
|
||||||
atol=2e-5 if dtype == np.float32 else 5e-4,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user