mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Add gemv masked to JIT plus some fixes (#1310)
* add gemv masked to JIT plus some fixes * some cleanup * add utils * fix * fix 2 * more cleaning * fix * remove unused mps matmul support * one more nit * revert
This commit is contained in:
parent
635ccd9e25
commit
30bbea2f08
@ -486,9 +486,8 @@ below.
|
||||
std::ostringstream kname;
|
||||
kname << "axpby_" << "general_" << type_to_name(out);
|
||||
|
||||
// Make sure the metal library is available and look for it
|
||||
// in the same folder as this executable if needed
|
||||
d.register_library("mlx_ext", metal::get_colocated_mtllib_path);
|
||||
// Make sure the metal library is available
|
||||
d.register_library("mlx_ext");
|
||||
|
||||
// Make a kernel from this metal library
|
||||
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
|
||||
|
@ -249,9 +249,8 @@ void Axpby::eval_gpu(
|
||||
kname << (contiguous_kernel ? "contiguous_" : "general_");
|
||||
kname << type_to_name(out);
|
||||
|
||||
// Make sure the metal library is available and look for it
|
||||
// in the same folder as this executable if needed
|
||||
d.register_library("mlx_ext", metal::get_colocated_mtllib_path);
|
||||
// Make sure the metal library is available
|
||||
d.register_library("mlx_ext");
|
||||
|
||||
// Make a kernel from this metal library
|
||||
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
|
||||
|
@ -114,6 +114,7 @@ if (MLX_METAL_JIT)
|
||||
kernels/steel/conv/loaders/loader_general.h
|
||||
)
|
||||
make_jit_source(quantized)
|
||||
make_jit_source(gemv_masked)
|
||||
else()
|
||||
target_sources(
|
||||
mlx
|
||||
@ -149,6 +150,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||
)
|
||||
|
||||
if (NOT MLX_METAL_PATH)
|
||||
|
@ -14,7 +14,6 @@
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
#include "mlx/backend/metal/metal_impl.h"
|
||||
#include "mlx/backend/metal/mps/gemm.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
@ -39,6 +38,20 @@ constexpr auto get_metal_version() {
|
||||
#endif
|
||||
}
|
||||
|
||||
std::string get_colocated_mtllib_path(const std::string& lib_name) {
|
||||
Dl_info info;
|
||||
std::string mtllib_path;
|
||||
std::string lib_ext = lib_name + ".metallib";
|
||||
|
||||
int success = dladdr((void*)get_colocated_mtllib_path, &info);
|
||||
if (success) {
|
||||
auto mtllib = fs::path(info.dli_fname).remove_filename() / lib_ext;
|
||||
mtllib_path = mtllib.c_str();
|
||||
}
|
||||
|
||||
return mtllib_path;
|
||||
}
|
||||
|
||||
auto load_device() {
|
||||
auto devices = MTL::CopyAllDevices();
|
||||
auto device = static_cast<MTL::Device*>(devices->object(0))
|
||||
@ -126,6 +139,49 @@ MTL::Library* load_library(
|
||||
|
||||
} // namespace
|
||||
|
||||
CommandEncoder::CommandEncoder(MTL::CommandBuffer* cbuf) : cbuf(cbuf) {
|
||||
enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
|
||||
enc->retain();
|
||||
}
|
||||
|
||||
CommandEncoder::~CommandEncoder() {
|
||||
enc->endEncoding();
|
||||
enc->release();
|
||||
}
|
||||
|
||||
void CommandEncoder::set_input_array(
|
||||
const array& a,
|
||||
int idx,
|
||||
int64_t offset /* = 0 */) {
|
||||
auto r_buf = static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
|
||||
if (auto it = outputs.find(r_buf); it != outputs.end()) {
|
||||
// Insert a barrier
|
||||
enc->memoryBarrier(&r_buf, 1);
|
||||
|
||||
// Remove the output
|
||||
outputs.erase(it);
|
||||
}
|
||||
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
||||
auto base_offset = a.data<char>() -
|
||||
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
|
||||
base_offset += offset;
|
||||
enc->setBuffer(a_buf, base_offset, idx);
|
||||
}
|
||||
|
||||
void CommandEncoder::set_output_array(
|
||||
array& a,
|
||||
int idx,
|
||||
int64_t offset /* = 0 */) {
|
||||
// Add barriers before adding the output to the output set
|
||||
set_input_array(a, idx, offset);
|
||||
auto buf = static_cast<MTL::Resource*>(a.buffer().ptr());
|
||||
if (concurrent) {
|
||||
concurrent_outputs.insert(buf);
|
||||
} else {
|
||||
outputs.insert(buf);
|
||||
}
|
||||
}
|
||||
|
||||
void CommandEncoder::dispatchThreadgroups(
|
||||
MTL::Size grid_dims,
|
||||
MTL::Size group_dims) {
|
||||
@ -255,13 +311,9 @@ void Device::register_library(
|
||||
}
|
||||
}
|
||||
|
||||
void Device::register_library(
|
||||
const std::string& lib_name,
|
||||
const std::function<std::string(const std::string&)>& lib_path_func) {
|
||||
void Device::register_library(const std::string& lib_name) {
|
||||
if (auto it = library_map_.find(lib_name); it == library_map_.end()) {
|
||||
std::string new_lib_path = lib_path_func(lib_name);
|
||||
auto new_lib = load_library(device_, lib_name, new_lib_path.c_str());
|
||||
library_map_.insert({lib_name, new_lib});
|
||||
register_library(lib_name, get_colocated_mtllib_path(lib_name));
|
||||
}
|
||||
}
|
||||
|
||||
@ -271,7 +323,7 @@ MTL::Library* Device::get_library_cache_(const std::string& lib_name) {
|
||||
if (auto it = library_map_.find(lib_name); it != library_map_.end()) {
|
||||
mtl_lib = it->second;
|
||||
} else { // Look for metallib alongside library
|
||||
register_library(lib_name);
|
||||
register_library(lib_name, get_colocated_mtllib_path(lib_name));
|
||||
mtl_lib = library_map_[lib_name];
|
||||
}
|
||||
|
||||
|
@ -9,38 +9,16 @@
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
#include <dlfcn.h>
|
||||
#include <filesystem>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/device.h"
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
|
||||
namespace mlx::core::metal {
|
||||
|
||||
inline std::string get_colocated_mtllib_path(const std::string& lib_name) {
|
||||
Dl_info info;
|
||||
std::string mtllib_path;
|
||||
std::string lib_ext = lib_name + ".metallib";
|
||||
|
||||
int success = dladdr((void*)get_colocated_mtllib_path, &info);
|
||||
if (success) {
|
||||
auto mtllib = fs::path(info.dli_fname).remove_filename() / lib_ext;
|
||||
mtllib_path = mtllib.c_str();
|
||||
}
|
||||
|
||||
return mtllib_path;
|
||||
}
|
||||
|
||||
using MTLFCList =
|
||||
std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;
|
||||
|
||||
struct CommandEncoder {
|
||||
CommandEncoder(MTL::CommandBuffer* cbuf) : cbuf(cbuf) {
|
||||
enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
|
||||
enc->retain();
|
||||
};
|
||||
CommandEncoder(MTL::CommandBuffer* cbuf);
|
||||
CommandEncoder(const CommandEncoder&) = delete;
|
||||
CommandEncoder& operator=(const CommandEncoder&) = delete;
|
||||
|
||||
@ -63,34 +41,8 @@ struct CommandEncoder {
|
||||
return enc;
|
||||
}
|
||||
|
||||
void set_input_array(const array& a, int idx, int64_t offset = 0) {
|
||||
auto r_buf =
|
||||
static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
|
||||
if (auto it = outputs.find(r_buf); it != outputs.end()) {
|
||||
// Insert a barrier
|
||||
enc->memoryBarrier(&r_buf, 1);
|
||||
|
||||
// Remove the output
|
||||
outputs.erase(it);
|
||||
}
|
||||
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
||||
auto base_offset = a.data<char>() -
|
||||
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
|
||||
base_offset += offset;
|
||||
enc->setBuffer(a_buf, base_offset, idx);
|
||||
}
|
||||
|
||||
void set_output_array(array& a, int idx, int64_t offset = 0) {
|
||||
// Add barriers before adding the output to the output set
|
||||
set_input_array(a, idx, offset);
|
||||
auto buf = static_cast<MTL::Resource*>(a.buffer().ptr());
|
||||
if (concurrent) {
|
||||
concurrent_outputs.insert(buf);
|
||||
} else {
|
||||
outputs.insert(buf);
|
||||
}
|
||||
}
|
||||
|
||||
void set_input_array(const array& a, int idx, int64_t offset = 0);
|
||||
void set_output_array(array& a, int idx, int64_t offset = 0);
|
||||
void dispatchThreadgroups(MTL::Size grid_dims, MTL::Size group_dims);
|
||||
void dispatchThreads(MTL::Size grid_dims, MTL::Size group_dims);
|
||||
|
||||
@ -98,10 +50,7 @@ struct CommandEncoder {
|
||||
return ConcurrentContext(*this);
|
||||
}
|
||||
|
||||
~CommandEncoder() {
|
||||
enc->endEncoding();
|
||||
enc->release();
|
||||
}
|
||||
~CommandEncoder();
|
||||
|
||||
private:
|
||||
void maybe_split();
|
||||
@ -136,10 +85,8 @@ class Device {
|
||||
void register_library(
|
||||
const std::string& lib_name,
|
||||
const std::string& lib_path);
|
||||
void register_library(
|
||||
const std::string& lib_name,
|
||||
const std::function<std::string(const std::string&)>& lib_path_func =
|
||||
get_colocated_mtllib_path);
|
||||
|
||||
void register_library(const std::string& lib_name);
|
||||
|
||||
MTL::Library* get_library(const std::string& name);
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2024 Apple Inc.
|
||||
#include <cassert>
|
||||
#include <complex>
|
||||
#include <map>
|
||||
@ -12,8 +12,6 @@
|
||||
#include "mlx/backend/metal/slicing.h"
|
||||
#include "mlx/backend/metal/unary.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/mlx.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
@ -786,10 +784,9 @@ void nd_fft_op(
|
||||
fft_op(in_arr, out_arr, axis, inverse, step_real, inplace, s);
|
||||
}
|
||||
|
||||
std::vector<array> copies = {temp1, temp2};
|
||||
auto& d = metal::device(s.device);
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
[temp_arrs](MTL::CommandBuffer*) mutable { temp_arrs.clear(); });
|
||||
}
|
||||
|
||||
void FFT::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
25
mlx/backend/metal/jit/gemv_masked.h
Normal file
25
mlx/backend/metal/jit/gemv_masked.h
Normal file
@ -0,0 +1,25 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view gemv_masked_kernel = R"(
|
||||
template [[host_name("{name}")]] [[kernel]] void
|
||||
gemv_{trans}masked<{itype}, {outm_t}, {opm_t}, {bm}, {bn}, {sm}, {sn}, {tm}, {tn}, {nc}>(
|
||||
const device {itype}* mat [[buffer(0)]],
|
||||
const device {itype}* in_vec [[buffer(1)]],
|
||||
device {itype}* out_vec [[buffer(3)]],
|
||||
const constant int& in_vec_size [[buffer(4)]],
|
||||
const constant int& out_vec_size [[buffer(5)]],
|
||||
const constant int& marix_ld [[buffer(6)]],
|
||||
const constant int& batch_ndim [[buffer(9)]],
|
||||
const constant int* batch_shape [[buffer(10)]],
|
||||
const constant size_t* vector_batch_stride [[buffer(11)]],
|
||||
const constant size_t* matrix_batch_stride [[buffer(12)]],
|
||||
const device {outm_t}* out_mask [[buffer(20)]],
|
||||
const device {opm_t}* mat_mask [[buffer(21)]],
|
||||
const device {opm_t}* vec_mask [[buffer(22)]],
|
||||
const constant int* mask_strides [[buffer(23)]],
|
||||
const constant size_t* mask_batch_strides [[buffer(24)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
)";
|
@ -33,5 +33,6 @@ const char* steel_gemm_splitk();
|
||||
const char* conv();
|
||||
const char* steel_conv();
|
||||
const char* steel_conv_general();
|
||||
const char* gemv_masked();
|
||||
|
||||
} // namespace mlx::core::metal
|
||||
|
@ -4,6 +4,7 @@
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/metal/jit/arange.h"
|
||||
#include "mlx/backend/metal/jit/copy.h"
|
||||
#include "mlx/backend/metal/jit/gemv_masked.h"
|
||||
#include "mlx/backend/metal/jit/includes.h"
|
||||
#include "mlx/backend/metal/jit/reduce.h"
|
||||
#include "mlx/backend/metal/jit/scan.h"
|
||||
@ -50,10 +51,12 @@ MTL::ComputePipelineState* get_unary_kernel(
|
||||
std::ostringstream kernel_source;
|
||||
auto u_def = get_template_definition(
|
||||
"v" + lib_name, "unary_v", get_type_string(out_type), op);
|
||||
auto u2_def = get_template_definition(
|
||||
"v2" + lib_name, "unary_v2", get_type_string(out_type), op);
|
||||
auto g_def = get_template_definition(
|
||||
"g" + lib_name, "unary_g", get_type_string(out_type), op);
|
||||
kernel_source << metal::utils() << metal::unary_ops() << metal::unary()
|
||||
<< u_def << g_def;
|
||||
<< u_def << u2_def << g_def;
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
@ -70,6 +73,9 @@ void add_binary_kernels(
|
||||
{"vs", "binary_vs"},
|
||||
{"sv", "binary_sv"},
|
||||
{"vv", "binary_vv"},
|
||||
{"vs2", "binary_vs2"},
|
||||
{"sv2", "binary_sv2"},
|
||||
{"vv2", "binary_vv2"},
|
||||
{"g1", "binary_g_nd1"},
|
||||
{"g2", "binary_g_nd2"},
|
||||
{"g3", "binary_g_nd3"},
|
||||
@ -146,6 +152,7 @@ MTL::ComputePipelineState* get_ternary_kernel(
|
||||
std::ostringstream kernel_source;
|
||||
const std::map<std::string, std::string> kernel_types = {
|
||||
{"v", "ternary_v"},
|
||||
{"v2", "ternary_v2"},
|
||||
{"g", "ternary_g"},
|
||||
{"g1", "ternary_g_nd1"},
|
||||
{"g2", "ternary_g_nd2"},
|
||||
@ -496,6 +503,49 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_gemv_masked_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out,
|
||||
const std::optional<array>& mask_out,
|
||||
const std::optional<array>& mask_op,
|
||||
bool transpose_mat,
|
||||
int bm,
|
||||
int bn,
|
||||
int sm,
|
||||
int sn,
|
||||
int tm,
|
||||
int tn,
|
||||
bool contiguous) {
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
auto out_mask_type = mask_out.has_value()
|
||||
? get_type_string((*mask_out).dtype())
|
||||
: "nomask_t";
|
||||
auto op_mask_type =
|
||||
mask_op.has_value() ? get_type_string((*mask_op).dtype()) : "nomask_t";
|
||||
kernel_source << metal::utils() << metal::gemv_masked()
|
||||
<< fmt::format(
|
||||
gemv_masked_kernel,
|
||||
"name"_a = lib_name,
|
||||
"itype"_a = get_type_string(out.dtype()),
|
||||
"outm_t"_a = out_mask_type,
|
||||
"opm_t"_a = op_mask_type,
|
||||
"bm"_a = bm,
|
||||
"bn"_a = bn,
|
||||
"sm"_a = sm,
|
||||
"sn"_a = sn,
|
||||
"tm"_a = tm,
|
||||
"tn"_a = tn,
|
||||
"trans"_a = transpose_mat ? "t_" : "",
|
||||
"nc"_a = contiguous ? "0" : "1");
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_steel_conv_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
|
@ -151,6 +151,21 @@ MTL::ComputePipelineState* get_steel_conv_kernel(
|
||||
int n_channel_specialization,
|
||||
bool small_filter);
|
||||
|
||||
MTL::ComputePipelineState* get_gemv_masked_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out,
|
||||
const std::optional<array>& mask_out,
|
||||
const std::optional<array>& mask_op,
|
||||
bool transpose_mat,
|
||||
int bm,
|
||||
int bn,
|
||||
int sm,
|
||||
int sn,
|
||||
int tm,
|
||||
int tn,
|
||||
bool contiguous);
|
||||
|
||||
MTL::ComputePipelineState* get_steel_conv_general_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
|
@ -38,7 +38,6 @@ endfunction(build_kernel)
|
||||
build_kernel(arg_reduce)
|
||||
build_kernel(conv steel/conv/params.h)
|
||||
build_kernel(gemv steel/utils.h)
|
||||
build_kernel(gemv_masked steel/utils.h)
|
||||
build_kernel(layer_norm)
|
||||
build_kernel(random)
|
||||
build_kernel(rms_norm)
|
||||
@ -121,6 +120,7 @@ build_kernel(
|
||||
steel/gemm/kernels/steel_gemm_splitk
|
||||
${STEEL_HEADERS}
|
||||
)
|
||||
build_kernel(gemv_masked steel/utils.h)
|
||||
endif()
|
||||
|
||||
|
||||
|
819
mlx/backend/metal/kernels/gemv_masked.h
Normal file
819
mlx/backend/metal/kernels/gemv_masked.h
Normal file
@ -0,0 +1,819 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
#define MLX_MTL_CONST static constant constexpr const
|
||||
#define MLX_MTL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
|
||||
|
||||
struct _NoMask {
|
||||
char x;
|
||||
|
||||
constexpr METAL_FUNC operator bool() {
|
||||
return true;
|
||||
}
|
||||
constexpr METAL_FUNC operator bool() const threadgroup {
|
||||
return true;
|
||||
}
|
||||
constexpr METAL_FUNC operator bool() const device {
|
||||
return true;
|
||||
}
|
||||
constexpr METAL_FUNC operator bool() const constant {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
typedef struct _NoMask nomask_t;
|
||||
|
||||
template <typename OutT, typename InT = OutT>
|
||||
struct ScaleOp {
|
||||
OutT scale;
|
||||
|
||||
METAL_FUNC OutT apply(InT x) const {
|
||||
return static_cast<OutT>(x) * scale;
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename out_mask_t,
|
||||
typename op_mask_t,
|
||||
const int BM, /* Threadgroup rows (in simdgroups) */
|
||||
const int BN, /* Threadgroup cols (in simdgroups) */
|
||||
const int SM, /* Simdgroup rows (in threads) */
|
||||
const int SN, /* Simdgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN> /* Thread cols (in elements) */
|
||||
struct GEMVKernel {
|
||||
MLX_MTL_CONST int threadsM = BM * SM;
|
||||
MLX_MTL_CONST int threadsN = BN * SN;
|
||||
|
||||
MLX_MTL_CONST int blockM = threadsM * TM;
|
||||
MLX_MTL_CONST int blockN = threadsN * TN;
|
||||
|
||||
static_assert(SM * SN == 32, "simdgroup can only have 32 threads");
|
||||
|
||||
static_assert(
|
||||
SN == 8 || SN == 16 || SN == 32,
|
||||
"gemv block must have a width of 8, 16, or 32");
|
||||
|
||||
static_assert(blockN >= blockM, "Masked gemv must have blockN >= blockM");
|
||||
|
||||
MLX_MTL_CONST bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
|
||||
MLX_MTL_CONST bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
|
||||
|
||||
MLX_MTL_CONST bool has_mul_operand_mask =
|
||||
has_operand_mask && !metal::is_same_v<op_mask_t, bool>;
|
||||
MLX_MTL_CONST bool has_mul_output_mask =
|
||||
has_output_mask && !metal::is_same_v<out_mask_t, bool>;
|
||||
|
||||
// - The matrix of size (M = out_vec_size, K = in_vec_size) is divided up
|
||||
// into blocks of (blockM, blockN) divided among threadgroups
|
||||
// - Every thread works on a block of (TM, TN)
|
||||
// - We assume each threadgroup has (threadsN, threadsM, 1) threads
|
||||
//
|
||||
// 1. A thread loads TN elements each from mat along TM rows
|
||||
// and the corresponding scalar from the vector
|
||||
// 2. The thread then multiplies and adds to accumulate its local result for
|
||||
// the block
|
||||
// 3. At the end, each thread has accumulated results over all blocks across
|
||||
// the rows. These are then summed up across the threadgroup
|
||||
// 4. Each threadgroup writes its accumulated blockM outputs
|
||||
//
|
||||
// Edge case handling:
|
||||
// - The threadgroup with the largest tid has blocks that exceed the matrix
|
||||
// * The blocks that start outside the matrix are never read (thread results
|
||||
// remain zero)
|
||||
// * The last thread that partially overlaps with the matrix is shifted
|
||||
// inwards such that the thread block fits exactly in the matrix
|
||||
|
||||
MLX_MTL_CONST short tgp_mem_size = BN > 1 ? BN*(blockM + TM) : 0;
|
||||
MLX_MTL_CONST bool needs_tgp_reduction = BN > 1;
|
||||
|
||||
static METAL_FUNC void
|
||||
load_unsafe(const device T* src, thread T dst[TN], const int src_offset = 0) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
dst[tn] = src[src_offset + tn];
|
||||
}
|
||||
}
|
||||
|
||||
static METAL_FUNC void load_safe(
|
||||
const device T* src,
|
||||
thread T dst[TN],
|
||||
const int src_offset = 0,
|
||||
const int src_size = TN) {
|
||||
if (src_offset + TN <= src_size) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
dst[tn] = src[src_offset + tn];
|
||||
}
|
||||
} else { // Edgecase
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
dst[tn] = src_offset + tn < src_size ? src[src_offset + tn] : 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static METAL_FUNC void run(
|
||||
const device T* mat [[buffer(0)]],
|
||||
const device T* in_vec [[buffer(1)]],
|
||||
device T* out_vec [[buffer(3)]],
|
||||
const constant int& in_vec_size [[buffer(4)]],
|
||||
const constant int& out_vec_size [[buffer(5)]],
|
||||
const constant int& matrix_ld [[buffer(6)]],
|
||||
const device out_mask_t* out_mask [[buffer(20)]],
|
||||
const device op_mask_t* mat_mask [[buffer(21)]],
|
||||
const device op_mask_t* vec_mask [[buffer(22)]],
|
||||
const constant int* mask_strides [[buffer(23)]],
|
||||
threadgroup T* tgp_memory [[threadgroup(0)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
// Appease compiler
|
||||
(void)lid;
|
||||
|
||||
// Thread local accumulation results
|
||||
thread T result[TM] = {0};
|
||||
thread T inter[TN];
|
||||
thread T v_coeff[TN];
|
||||
|
||||
const int thrM = SN != 32 ? simd_lid / SN : 0;
|
||||
const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid);
|
||||
|
||||
const int sgN = BN != 1 ? (simd_gid % BN) : 0;
|
||||
|
||||
const int simdM = BN != 1 ? SM * (simd_gid / BN) : int(SM * simd_gid);
|
||||
const int simdN = BN != 1 ? SN * (simd_gid % BN) : 0;
|
||||
|
||||
int bm = (simdM + thrM) * TM;
|
||||
int bn = (simdN + thrN) * TN;
|
||||
|
||||
// Block position
|
||||
int out_row = tid.x * blockM + bm;
|
||||
|
||||
// Exit simdgroup if rows out of bound
|
||||
if (out_row >= out_vec_size)
|
||||
return;
|
||||
|
||||
// Adjust tail simdgroup to ensure in bound reads
|
||||
out_row = out_row + TM <= out_vec_size ? out_row : out_vec_size - TM;
|
||||
|
||||
// Prepare mask offsets
|
||||
const constant int* out_mask_strides = mask_strides;
|
||||
const constant int* mat_mask_strides =
|
||||
mask_strides + (has_output_mask ? 2 : 0);
|
||||
const constant int* vec_mask_strides =
|
||||
mat_mask_strides + (has_operand_mask ? 2 : 0);
|
||||
|
||||
const int m_block_idx = blockN > blockM ? out_row / blockN : int(tid.x);
|
||||
|
||||
const int out_mask_offset =
|
||||
!has_output_mask ? 0 : m_block_idx * out_mask_strides[1];
|
||||
|
||||
int mat_mask_offset =
|
||||
!has_operand_mask ? 0 : m_block_idx * mat_mask_strides[1];
|
||||
int vec_mask_offset = 0;
|
||||
const int mat_mask_step = !has_operand_mask ? 0 : mat_mask_strides[0];
|
||||
const int vec_mask_step = !has_operand_mask ? 0 : vec_mask_strides[1];
|
||||
|
||||
T out_scale{1};
|
||||
|
||||
// Check output mask
|
||||
if (has_output_mask) {
|
||||
auto mask_out = out_mask[out_mask_offset];
|
||||
|
||||
// Write zeros and return if mask is 0
|
||||
if (!mask_out) {
|
||||
if (simdN == 0 && thrN == 0) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
out_vec[out_row + tm] = T(0.);
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
// Store scalar if multiplicative mask
|
||||
if (has_mul_output_mask) {
|
||||
out_scale = T(mask_out);
|
||||
}
|
||||
}
|
||||
|
||||
// Advance matrix
|
||||
mat += out_row * matrix_ld;
|
||||
|
||||
// Prepare for loop
|
||||
constexpr const uniform<int> loop_stride = make_uniform(blockN);
|
||||
const uniform<int> in_size = make_uniform(in_vec_size);
|
||||
const uniform<int> n_iter = in_size / loop_stride;
|
||||
const uniform<int> last_iter = loop_stride * n_iter;
|
||||
const uniform<int> leftover = in_size - last_iter;
|
||||
|
||||
// Loop over in_vec in blocks of blockN
|
||||
for (int i = 0; i < n_iter; ++i) {
|
||||
if (!has_operand_mask ||
|
||||
(bool(mat_mask[mat_mask_offset]) &&
|
||||
bool(vec_mask[vec_mask_offset]))) {
|
||||
T block_scale{1};
|
||||
if (has_mul_operand_mask) {
|
||||
block_scale =
|
||||
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
|
||||
}
|
||||
|
||||
load_unsafe(in_vec, v_coeff, bn);
|
||||
|
||||
// Apply scale
|
||||
if (has_mul_operand_mask) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
v_coeff[tn] *= block_scale;
|
||||
}
|
||||
}
|
||||
|
||||
// Per thread work loop
|
||||
int mat_offset = 0;
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
// Load for the row
|
||||
load_unsafe(mat, inter, mat_offset + bn);
|
||||
|
||||
// Accumulate results
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
result[tm] += inter[tn] * v_coeff[tn];
|
||||
}
|
||||
|
||||
mat_offset += matrix_ld;
|
||||
}
|
||||
}
|
||||
|
||||
bn += blockN;
|
||||
mat_mask_offset += mat_mask_step;
|
||||
vec_mask_offset += vec_mask_step;
|
||||
}
|
||||
|
||||
if (leftover > 0 &&
|
||||
(!has_operand_mask ||
|
||||
(bool(mat_mask[mat_mask_offset]) &&
|
||||
bool(vec_mask[vec_mask_offset])))) {
|
||||
T block_scale{1};
|
||||
if (has_mul_operand_mask) {
|
||||
block_scale =
|
||||
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
|
||||
}
|
||||
|
||||
load_safe(in_vec, v_coeff, bn, in_size);
|
||||
|
||||
// Apply scale
|
||||
if (has_mul_operand_mask) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
v_coeff[tn] *= block_scale;
|
||||
}
|
||||
}
|
||||
|
||||
// Per thread work loop
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
// Load for the row
|
||||
load_safe(&mat[tm * matrix_ld], inter, bn, in_size);
|
||||
|
||||
// Accumulate results
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
result[tm] += inter[tn] * v_coeff[tn];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply out scale
|
||||
if (has_mul_output_mask) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
result[tm] *= out_scale;
|
||||
}
|
||||
}
|
||||
|
||||
// Simdgroup accumulations
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (ushort sn = (SN / 2); sn >= 1; sn >>= 1) {
|
||||
result[tm] += simd_shuffle_down(result[tm], sn);
|
||||
}
|
||||
}
|
||||
|
||||
// Threadgroup accumulation results
|
||||
if (needs_tgp_reduction) {
|
||||
threadgroup T* tgp_results = tgp_memory + sgN * (blockM + TM) + bm;
|
||||
if (thrN == 0) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
tgp_results[tm] = result[tm];
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
if (sgN == 0) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int sgn = 1; sgn < BN; sgn++) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
result[tm] += tgp_results[sgn * (blockM + TM) + tm];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Write outputs
|
||||
if (simdN == 0 && thrN == 0) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
out_vec[out_row + tm] = result[tm];
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
/// Vector matrix multiplication
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename out_mask_t,
|
||||
typename op_mask_t,
|
||||
const int BM, /* Threadgroup rows (in simdgroups) */
|
||||
const int BN, /* Threadgroup cols (in simdgroups) */
|
||||
const int SM, /* Simdgroup rows (in threads) */
|
||||
const int SN, /* Simdgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN> /* Thread cols (in elements) */
|
||||
struct GEMVTKernel {
|
||||
MLX_MTL_CONST int threadsM = BM * SM;
|
||||
MLX_MTL_CONST int threadsN = BN * SN;
|
||||
|
||||
MLX_MTL_CONST int blockM = threadsM * TM;
|
||||
MLX_MTL_CONST int blockN = threadsN * TN;
|
||||
|
||||
static_assert(SM * SN == 32, "simdgroup can only have 32 threads");
|
||||
|
||||
MLX_MTL_CONST bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
|
||||
MLX_MTL_CONST bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
|
||||
|
||||
MLX_MTL_CONST bool has_mul_operand_mask =
|
||||
has_operand_mask && !metal::is_same_v<op_mask_t, bool>;
|
||||
MLX_MTL_CONST bool has_mul_output_mask =
|
||||
has_output_mask && !metal::is_same_v<out_mask_t, bool>;
|
||||
|
||||
// - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up
|
||||
// into blocks of (blockM, blockN) divided among threadgroups
|
||||
// - Every thread works on a block of (TM, TN)
|
||||
// - We assume each threadgroup has (threadsN, threadsM, 1) threads
|
||||
//
|
||||
// 1. A thread loads TN elements each from mat along TM contiguous rows
|
||||
// and the corresponding scalar from the vector
|
||||
// 2. The thread then accumulates its local result for the block
|
||||
// 3. At the end, each thread has accumulated results over all blocks across
|
||||
// the rows. These are then summed up across the threadgroup
|
||||
// 4. Each threadgroup writes its accumulated BN * TN outputs
|
||||
//
|
||||
// Edge case handling:
|
||||
// - The threadgroup with the largest tid has blocks that exceed the matrix
|
||||
// * The blocks that start outside the matrix are never read (thread results
|
||||
// remain zero)
|
||||
// * The last thread that partially overlaps with the matrix is shifted
|
||||
// inwards such that the thread block fits exactly in the matrix
|
||||
|
||||
MLX_MTL_CONST short tgp_mem_size = BM > 1 ? BM*(blockN + TN) : 0;
|
||||
MLX_MTL_CONST bool needs_tgp_reduction = BM > 1;
|
||||
|
||||
static METAL_FUNC void run(
|
||||
const device T* mat [[buffer(0)]],
|
||||
const device T* in_vec [[buffer(1)]],
|
||||
device T* out_vec [[buffer(3)]],
|
||||
const constant int& in_vec_size [[buffer(4)]],
|
||||
const constant int& out_vec_size [[buffer(5)]],
|
||||
const constant int& marix_ld [[buffer(6)]],
|
||||
const device out_mask_t* out_mask [[buffer(20)]],
|
||||
const device op_mask_t* mat_mask [[buffer(21)]],
|
||||
const device op_mask_t* vec_mask [[buffer(22)]],
|
||||
const constant int* mask_strides [[buffer(23)]],
|
||||
threadgroup T* tgp_memory [[threadgroup(0)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
// Appease compiler
|
||||
(void)lid;
|
||||
|
||||
// Thread local accumulation results
|
||||
T result[TN] = {0};
|
||||
T inter[TN];
|
||||
T v_coeff[TM];
|
||||
|
||||
const int thrM = SN != 32 ? simd_lid / SN : 0;
|
||||
const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid);
|
||||
|
||||
const int sgM = BN != 1 ? (simd_gid / BN) : int(simd_gid);
|
||||
const int sgN = BN != 1 ? (simd_gid % BN) : 0;
|
||||
|
||||
const int simdM = SM * sgM;
|
||||
const int simdN = SN * sgN;
|
||||
|
||||
int cm = (simdM + thrM);
|
||||
int cn = (simdN + thrN);
|
||||
|
||||
int bm = cm * TM;
|
||||
int bn = cn * TN;
|
||||
|
||||
int out_col = tid.x * blockN + bn;
|
||||
|
||||
// Prepare mask offsets
|
||||
const constant int* out_mask_strides = mask_strides;
|
||||
const constant int* mat_mask_strides =
|
||||
out_mask_strides + (has_output_mask ? 2 : 0);
|
||||
const constant int* vec_mask_strides =
|
||||
mat_mask_strides + (has_operand_mask ? 2 : 0);
|
||||
|
||||
const int n_block_idx = blockM > blockN ? out_col / blockM : int(tid.x);
|
||||
|
||||
const int out_mask_offset =
|
||||
!has_output_mask ? 0 : n_block_idx; // * out_mask_strides[0];
|
||||
|
||||
int mat_mask_offset =
|
||||
!has_operand_mask ? 0 : n_block_idx * mat_mask_strides[0];
|
||||
int vec_mask_offset = 0;
|
||||
const int mat_mask_step = !has_operand_mask ? 0 : mat_mask_strides[1];
|
||||
const int vec_mask_step = !has_operand_mask ? 0 : vec_mask_strides[0];
|
||||
|
||||
T out_scale{1};
|
||||
|
||||
// Check output mask
|
||||
if (has_output_mask) {
|
||||
auto mask_out = out_mask[out_mask_offset];
|
||||
|
||||
// Write zeros and return if mask is 0
|
||||
if (!mask_out) {
|
||||
if (cm == 0 && out_col < out_vec_size) {
|
||||
if (out_col + TN <= out_vec_size) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
out_vec[out_col + tn] = T(0.);
|
||||
}
|
||||
} else {
|
||||
for (int tn = 0; tn < TN && (out_col + tn) < out_vec_size; tn++) {
|
||||
out_vec[out_col + tn] = T(0.);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
// Store scalar if multiplicative mask
|
||||
if (has_mul_output_mask) {
|
||||
out_scale = T(mask_out);
|
||||
}
|
||||
}
|
||||
|
||||
// Prepare for loop
|
||||
constexpr const uniform<int> loop_stride = make_uniform(blockM);
|
||||
const uniform<int> in_size = make_uniform(in_vec_size);
|
||||
const uniform<int> n_iter = in_size / loop_stride;
|
||||
const uniform<int> last_iter = loop_stride * n_iter;
|
||||
const uniform<int> leftover = in_size - last_iter;
|
||||
|
||||
// Edgecase handling
|
||||
if (out_col < out_vec_size) {
|
||||
out_col = (out_col + TN) <= out_vec_size ? out_col : out_vec_size - TN;
|
||||
|
||||
// Per thread accumulation main loop
|
||||
for (int i = 0; i < n_iter; ++i) {
|
||||
// Adding a threadgroup_barrier improves performance slightly
|
||||
// This is possibly it may help exploit cache better
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
if (!has_operand_mask ||
|
||||
(bool(mat_mask[mat_mask_offset]) &&
|
||||
bool(vec_mask[vec_mask_offset]))) {
|
||||
T block_scale{1};
|
||||
if (has_mul_operand_mask) {
|
||||
block_scale =
|
||||
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
|
||||
}
|
||||
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
v_coeff[tm] = in_vec[bm + tm];
|
||||
}
|
||||
|
||||
// Apply scale
|
||||
if (has_mul_operand_mask) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
v_coeff[tm] *= block_scale;
|
||||
}
|
||||
}
|
||||
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
|
||||
}
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
result[tn] += v_coeff[tm] * inter[tn];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bm += blockM;
|
||||
mat_mask_offset += mat_mask_step;
|
||||
vec_mask_offset += vec_mask_step;
|
||||
}
|
||||
|
||||
if (leftover > 0 &&
|
||||
(!has_operand_mask ||
|
||||
(bool(mat_mask[mat_mask_offset]) &&
|
||||
bool(vec_mask[vec_mask_offset])))) {
|
||||
T block_scale{1};
|
||||
if (has_mul_operand_mask) {
|
||||
block_scale =
|
||||
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
|
||||
}
|
||||
|
||||
for (int tm = 0; tm < TM && bm + tm < in_vec_size; tm++) {
|
||||
v_coeff[tm] = in_vec[bm + tm];
|
||||
|
||||
if (has_mul_operand_mask) {
|
||||
v_coeff[tm] *= block_scale;
|
||||
}
|
||||
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
|
||||
}
|
||||
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
result[tn] += v_coeff[tm] * inter[tn];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply out scale
|
||||
if (has_mul_output_mask) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
result[tn] *= out_scale;
|
||||
}
|
||||
}
|
||||
|
||||
// Simdgroup accumulations
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (ushort sm = (SM / 2); sm >= 1; sm >>= 1) {
|
||||
result[tn] += simd_shuffle_down(result[tn], SN * sm);
|
||||
}
|
||||
}
|
||||
|
||||
// Threadgroup accumulation results
|
||||
if (needs_tgp_reduction) {
|
||||
threadgroup T* tgp_results = tgp_memory + sgM * (blockN + TN) + bn;
|
||||
if (thrM == 0) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
tgp_results[tn] = result[tn];
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
if (sgM == 0) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int sgm = 1; sgm < BM; sgm++) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
result[tn] += tgp_results[sgm * (blockN + TN) + tn];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Threadgroup accumulation and writing out results
|
||||
if (cm == 0 && out_col < out_vec_size) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int j = 0; j < TN; j++) {
|
||||
out_vec[out_col + j] = result[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
/// Matrix vector multiplication
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename out_mask_t,
|
||||
typename op_mask_t,
|
||||
const int BM, /* Threadgroup rows (in simdgroups) */
|
||||
const int BN, /* Threadgroup cols (in simdgroups) */
|
||||
const int SM, /* Simdgroup rows (in threads) */
|
||||
const int SN, /* Simdgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN, /* Thread cols (in elements) */
|
||||
const bool kDoNCBatch> /* Batch ndim > 1 */
|
||||
[[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]] void gemv_masked(
|
||||
const device T* mat [[buffer(0)]],
|
||||
const device T* in_vec [[buffer(1)]],
|
||||
device T* out_vec [[buffer(3)]],
|
||||
const constant int& in_vec_size [[buffer(4)]],
|
||||
const constant int& out_vec_size [[buffer(5)]],
|
||||
const constant int& marix_ld [[buffer(6)]],
|
||||
const constant int& batch_ndim [[buffer(9)]],
|
||||
const constant int* batch_shape [[buffer(10)]],
|
||||
const constant size_t* vector_batch_stride [[buffer(11)]],
|
||||
const constant size_t* matrix_batch_stride [[buffer(12)]],
|
||||
const device out_mask_t* out_mask [[buffer(20)]],
|
||||
const device op_mask_t* mat_mask [[buffer(21)]],
|
||||
const device op_mask_t* vec_mask [[buffer(22)]],
|
||||
const constant int* mask_strides [[buffer(23)]],
|
||||
const constant size_t* mask_batch_strides [[buffer(24)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
using gemv_kernel =
|
||||
GEMVKernel<T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN>;
|
||||
threadgroup T tgp_memory
|
||||
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
|
||||
|
||||
constexpr bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
|
||||
constexpr bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
|
||||
|
||||
// Update batch offsets
|
||||
if (kDoNCBatch) {
|
||||
in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim);
|
||||
mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim);
|
||||
|
||||
if (has_output_mask) {
|
||||
out_mask +=
|
||||
elem_to_loc(tid.z, batch_shape, mask_batch_strides, batch_ndim);
|
||||
mask_batch_strides += batch_ndim;
|
||||
}
|
||||
|
||||
if (has_operand_mask) {
|
||||
const constant size_t* mask_strides_mat = mask_batch_strides;
|
||||
const constant size_t* mask_strides_vec = mask_strides_mat + batch_ndim;
|
||||
|
||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||
tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim);
|
||||
|
||||
mat_mask += batch_offsets.x;
|
||||
vec_mask += batch_offsets.y;
|
||||
}
|
||||
|
||||
} else {
|
||||
in_vec += tid.z * vector_batch_stride[0];
|
||||
mat += tid.z * matrix_batch_stride[0];
|
||||
|
||||
if (has_output_mask) {
|
||||
out_mask += tid.z * mask_batch_strides[0];
|
||||
mask_batch_strides += batch_ndim;
|
||||
}
|
||||
|
||||
if (has_operand_mask) {
|
||||
mat_mask += tid.z * mask_batch_strides[0];
|
||||
vec_mask += tid.z * mask_batch_strides[batch_ndim];
|
||||
}
|
||||
}
|
||||
|
||||
out_vec += tid.z * out_vec_size;
|
||||
|
||||
gemv_kernel::run(
|
||||
mat,
|
||||
in_vec,
|
||||
out_vec,
|
||||
in_vec_size,
|
||||
out_vec_size,
|
||||
marix_ld,
|
||||
out_mask,
|
||||
mat_mask,
|
||||
vec_mask,
|
||||
mask_strides,
|
||||
gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory,
|
||||
tid,
|
||||
lid,
|
||||
simd_gid,
|
||||
simd_lid);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
/// Vector matrix multiplication
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename out_mask_t,
|
||||
typename op_mask_t,
|
||||
const int BM, /* Threadgroup rows (in simdgroups) */
|
||||
const int BN, /* Threadgroup cols (in simdgroups) */
|
||||
const int SM, /* Simdgroup rows (in threads) */
|
||||
const int SN, /* Simdgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN, /* Thread cols (in elements) */
|
||||
const bool kDoNCBatch> /* Batch ndim > 1 */
|
||||
[[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]] void gemv_t_masked(
|
||||
const device T* mat [[buffer(0)]],
|
||||
const device T* in_vec [[buffer(1)]],
|
||||
device T* out_vec [[buffer(3)]],
|
||||
const constant int& in_vec_size [[buffer(4)]],
|
||||
const constant int& out_vec_size [[buffer(5)]],
|
||||
const constant int& marix_ld [[buffer(6)]],
|
||||
const constant int& batch_ndim [[buffer(9)]],
|
||||
const constant int* batch_shape [[buffer(10)]],
|
||||
const constant size_t* vector_batch_stride [[buffer(11)]],
|
||||
const constant size_t* matrix_batch_stride [[buffer(12)]],
|
||||
const device out_mask_t* out_mask [[buffer(20)]],
|
||||
const device op_mask_t* mat_mask [[buffer(21)]],
|
||||
const device op_mask_t* vec_mask [[buffer(22)]],
|
||||
const constant int* mask_strides [[buffer(23)]],
|
||||
const constant size_t* mask_batch_strides [[buffer(24)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
using gemv_kernel =
|
||||
GEMVTKernel<T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN>;
|
||||
threadgroup T tgp_memory
|
||||
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
|
||||
|
||||
constexpr bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
|
||||
constexpr bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
|
||||
|
||||
// Update batch offsets
|
||||
if (kDoNCBatch) {
|
||||
in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim);
|
||||
mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim);
|
||||
|
||||
if (has_output_mask) {
|
||||
out_mask +=
|
||||
elem_to_loc(tid.z, batch_shape, mask_batch_strides, batch_ndim);
|
||||
mask_batch_strides += batch_ndim;
|
||||
}
|
||||
|
||||
if (has_operand_mask) {
|
||||
const constant size_t* mask_strides_mat = mask_batch_strides;
|
||||
const constant size_t* mask_strides_vec = mask_strides_mat + batch_ndim;
|
||||
|
||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||
tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim);
|
||||
|
||||
mat_mask += batch_offsets.x;
|
||||
vec_mask += batch_offsets.y;
|
||||
}
|
||||
|
||||
} else {
|
||||
in_vec += tid.z * vector_batch_stride[0];
|
||||
mat += tid.z * matrix_batch_stride[0];
|
||||
|
||||
if (has_output_mask) {
|
||||
out_mask += tid.z * mask_batch_strides[0];
|
||||
mask_batch_strides += batch_ndim;
|
||||
}
|
||||
|
||||
if (has_operand_mask) {
|
||||
mat_mask += tid.z * mask_batch_strides[0];
|
||||
vec_mask += tid.z * mask_batch_strides[batch_ndim];
|
||||
}
|
||||
}
|
||||
|
||||
out_vec += tid.z * out_vec_size;
|
||||
|
||||
gemv_kernel::run(
|
||||
mat,
|
||||
in_vec,
|
||||
out_vec,
|
||||
in_vec_size,
|
||||
out_vec_size,
|
||||
marix_ld,
|
||||
out_mask,
|
||||
mat_mask,
|
||||
vec_mask,
|
||||
mask_strides,
|
||||
gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory,
|
||||
tid,
|
||||
lid,
|
||||
simd_gid,
|
||||
simd_lid);
|
||||
}
|
@ -1,5 +1,6 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
// clang-format off
|
||||
#include <metal_simdgroup>
|
||||
#include <metal_stdlib>
|
||||
|
||||
@ -7,726 +8,7 @@
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
/// Matrix vector multiplication
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define MLX_MTL_CONST static constant constexpr const
|
||||
|
||||
struct _NoMask {
|
||||
char x;
|
||||
|
||||
constexpr METAL_FUNC operator bool() {
|
||||
return true;
|
||||
}
|
||||
constexpr METAL_FUNC operator bool() const threadgroup {
|
||||
return true;
|
||||
}
|
||||
constexpr METAL_FUNC operator bool() const device {
|
||||
return true;
|
||||
}
|
||||
constexpr METAL_FUNC operator bool() const constant {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
typedef struct _NoMask nomask_t;
|
||||
|
||||
template <typename OutT, typename InT = OutT>
|
||||
struct ScaleOp {
|
||||
OutT scale;
|
||||
|
||||
METAL_FUNC OutT apply(InT x) const {
|
||||
return static_cast<OutT>(x) * scale;
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename out_mask_t,
|
||||
typename op_mask_t,
|
||||
const int BM, /* Threadgroup rows (in simdgroups) */
|
||||
const int BN, /* Threadgroup cols (in simdgroups) */
|
||||
const int SM, /* Simdgroup rows (in threads) */
|
||||
const int SN, /* Simdgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN> /* Thread cols (in elements) */
|
||||
struct GEMVKernel {
|
||||
MLX_MTL_CONST int threadsM = BM * SM;
|
||||
MLX_MTL_CONST int threadsN = BN * SN;
|
||||
|
||||
MLX_MTL_CONST int blockM = threadsM * TM;
|
||||
MLX_MTL_CONST int blockN = threadsN * TN;
|
||||
|
||||
static_assert(SM * SN == 32, "simdgroup can only have 32 threads");
|
||||
|
||||
static_assert(
|
||||
SN == 8 || SN == 16 || SN == 32,
|
||||
"gemv block must have a width of 8, 16, or 32");
|
||||
|
||||
static_assert(blockN >= blockM, "Masked gemv must have blockN >= blockM");
|
||||
|
||||
MLX_MTL_CONST bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
|
||||
MLX_MTL_CONST bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
|
||||
|
||||
MLX_MTL_CONST bool has_mul_operand_mask =
|
||||
has_operand_mask && !metal::is_same_v<op_mask_t, bool>;
|
||||
MLX_MTL_CONST bool has_mul_output_mask =
|
||||
has_output_mask && !metal::is_same_v<out_mask_t, bool>;
|
||||
|
||||
// - The matrix of size (M = out_vec_size, K = in_vec_size) is divided up
|
||||
// into blocks of (blockM, blockN) divided among threadgroups
|
||||
// - Every thread works on a block of (TM, TN)
|
||||
// - We assume each threadgroup has (threadsN, threadsM, 1) threads
|
||||
//
|
||||
// 1. A thread loads TN elements each from mat along TM rows
|
||||
// and the corresponding scalar from the vector
|
||||
// 2. The thread then multiplies and adds to accumulate its local result for
|
||||
// the block
|
||||
// 3. At the end, each thread has accumulated results over all blocks across
|
||||
// the rows. These are then summed up across the threadgroup
|
||||
// 4. Each threadgroup writes its accumulated blockM outputs
|
||||
//
|
||||
// Edge case handling:
|
||||
// - The threadgroup with the largest tid has blocks that exceed the matrix
|
||||
// * The blocks that start outside the matrix are never read (thread results
|
||||
// remain zero)
|
||||
// * The last thread that partially overlaps with the matrix is shifted
|
||||
// inwards such that the thread block fits exactly in the matrix
|
||||
|
||||
MLX_MTL_CONST short tgp_mem_size = BN > 1 ? BN*(blockM + TM) : 0;
|
||||
MLX_MTL_CONST bool needs_tgp_reduction = BN > 1;
|
||||
|
||||
static METAL_FUNC void
|
||||
load_unsafe(const device T* src, thread T dst[TN], const int src_offset = 0) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
dst[tn] = src[src_offset + tn];
|
||||
}
|
||||
}
|
||||
|
||||
static METAL_FUNC void load_safe(
|
||||
const device T* src,
|
||||
thread T dst[TN],
|
||||
const int src_offset = 0,
|
||||
const int src_size = TN) {
|
||||
if (src_offset + TN <= src_size) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
dst[tn] = src[src_offset + tn];
|
||||
}
|
||||
} else { // Edgecase
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
dst[tn] = src_offset + tn < src_size ? src[src_offset + tn] : 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static METAL_FUNC void run(
|
||||
const device T* mat [[buffer(0)]],
|
||||
const device T* in_vec [[buffer(1)]],
|
||||
device T* out_vec [[buffer(3)]],
|
||||
const constant int& in_vec_size [[buffer(4)]],
|
||||
const constant int& out_vec_size [[buffer(5)]],
|
||||
const constant int& matrix_ld [[buffer(6)]],
|
||||
const device out_mask_t* out_mask [[buffer(20)]],
|
||||
const device op_mask_t* mat_mask [[buffer(21)]],
|
||||
const device op_mask_t* vec_mask [[buffer(22)]],
|
||||
const constant int* mask_strides [[buffer(23)]],
|
||||
threadgroup T* tgp_memory [[threadgroup(0)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
// Appease compiler
|
||||
(void)lid;
|
||||
|
||||
// Thread local accumulation results
|
||||
thread T result[TM] = {0};
|
||||
thread T inter[TN];
|
||||
thread T v_coeff[TN];
|
||||
|
||||
const int thrM = SN != 32 ? simd_lid / SN : 0;
|
||||
const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid);
|
||||
|
||||
const int sgN = BN != 1 ? (simd_gid % BN) : 0;
|
||||
|
||||
const int simdM = BN != 1 ? SM * (simd_gid / BN) : int(SM * simd_gid);
|
||||
const int simdN = BN != 1 ? SN * (simd_gid % BN) : 0;
|
||||
|
||||
int bm = (simdM + thrM) * TM;
|
||||
int bn = (simdN + thrN) * TN;
|
||||
|
||||
// Block position
|
||||
int out_row = tid.x * blockM + bm;
|
||||
|
||||
// Exit simdgroup if rows out of bound
|
||||
if (out_row >= out_vec_size)
|
||||
return;
|
||||
|
||||
// Adjust tail simdgroup to ensure in bound reads
|
||||
out_row = out_row + TM <= out_vec_size ? out_row : out_vec_size - TM;
|
||||
|
||||
// Prepare mask offsets
|
||||
const constant int* out_mask_strides = mask_strides;
|
||||
const constant int* mat_mask_strides =
|
||||
mask_strides + (has_output_mask ? 2 : 0);
|
||||
const constant int* vec_mask_strides =
|
||||
mat_mask_strides + (has_operand_mask ? 2 : 0);
|
||||
|
||||
const int m_block_idx = blockN > blockM ? out_row / blockN : int(tid.x);
|
||||
|
||||
const int out_mask_offset =
|
||||
!has_output_mask ? 0 : m_block_idx * out_mask_strides[1];
|
||||
|
||||
int mat_mask_offset =
|
||||
!has_operand_mask ? 0 : m_block_idx * mat_mask_strides[1];
|
||||
int vec_mask_offset = 0;
|
||||
const int mat_mask_step = !has_operand_mask ? 0 : mat_mask_strides[0];
|
||||
const int vec_mask_step = !has_operand_mask ? 0 : vec_mask_strides[1];
|
||||
|
||||
T out_scale{1};
|
||||
|
||||
// Check output mask
|
||||
if (has_output_mask) {
|
||||
auto mask_out = out_mask[out_mask_offset];
|
||||
|
||||
// Write zeros and return if mask is 0
|
||||
if (!mask_out) {
|
||||
if (simdN == 0 && thrN == 0) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
out_vec[out_row + tm] = T(0.);
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
// Store scalar if multiplicative mask
|
||||
if (has_mul_output_mask) {
|
||||
out_scale = T(mask_out);
|
||||
}
|
||||
}
|
||||
|
||||
// Advance matrix
|
||||
mat += out_row * matrix_ld;
|
||||
|
||||
// Prepare for loop
|
||||
constexpr const uniform<int> loop_stride = make_uniform(blockN);
|
||||
const uniform<int> in_size = make_uniform(in_vec_size);
|
||||
const uniform<int> n_iter = in_size / loop_stride;
|
||||
const uniform<int> last_iter = loop_stride * n_iter;
|
||||
const uniform<int> leftover = in_size - last_iter;
|
||||
|
||||
// Loop over in_vec in blocks of blockN
|
||||
for (int i = 0; i < n_iter; ++i) {
|
||||
if (!has_operand_mask ||
|
||||
(bool(mat_mask[mat_mask_offset]) &&
|
||||
bool(vec_mask[vec_mask_offset]))) {
|
||||
T block_scale{1};
|
||||
if (has_mul_operand_mask) {
|
||||
block_scale =
|
||||
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
|
||||
}
|
||||
|
||||
load_unsafe(in_vec, v_coeff, bn);
|
||||
|
||||
// Apply scale
|
||||
if (has_mul_operand_mask) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
v_coeff[tn] *= block_scale;
|
||||
}
|
||||
}
|
||||
|
||||
// Per thread work loop
|
||||
int mat_offset = 0;
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
// Load for the row
|
||||
load_unsafe(mat, inter, mat_offset + bn);
|
||||
|
||||
// Accumulate results
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
result[tm] += inter[tn] * v_coeff[tn];
|
||||
}
|
||||
|
||||
mat_offset += matrix_ld;
|
||||
}
|
||||
}
|
||||
|
||||
bn += blockN;
|
||||
mat_mask_offset += mat_mask_step;
|
||||
vec_mask_offset += vec_mask_step;
|
||||
}
|
||||
|
||||
if (leftover > 0 &&
|
||||
(!has_operand_mask ||
|
||||
(bool(mat_mask[mat_mask_offset]) &&
|
||||
bool(vec_mask[vec_mask_offset])))) {
|
||||
T block_scale{1};
|
||||
if (has_mul_operand_mask) {
|
||||
block_scale =
|
||||
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
|
||||
}
|
||||
|
||||
load_safe(in_vec, v_coeff, bn, in_size);
|
||||
|
||||
// Apply scale
|
||||
if (has_mul_operand_mask) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
v_coeff[tn] *= block_scale;
|
||||
}
|
||||
}
|
||||
|
||||
// Per thread work loop
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
// Load for the row
|
||||
load_safe(&mat[tm * matrix_ld], inter, bn, in_size);
|
||||
|
||||
// Accumulate results
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
result[tm] += inter[tn] * v_coeff[tn];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply out scale
|
||||
if (has_mul_output_mask) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
result[tm] *= out_scale;
|
||||
}
|
||||
}
|
||||
|
||||
// Simdgroup accumulations
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (ushort sn = (SN / 2); sn >= 1; sn >>= 1) {
|
||||
result[tm] += simd_shuffle_down(result[tm], sn);
|
||||
}
|
||||
}
|
||||
|
||||
// Threadgroup accumulation results
|
||||
if (needs_tgp_reduction) {
|
||||
threadgroup T* tgp_results = tgp_memory + sgN * (blockM + TM) + bm;
|
||||
if (thrN == 0) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
tgp_results[tm] = result[tm];
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
if (sgN == 0) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int sgn = 1; sgn < BN; sgn++) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
result[tm] += tgp_results[sgn * (blockM + TM) + tm];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Write outputs
|
||||
if (simdN == 0 && thrN == 0) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
out_vec[out_row + tm] = result[tm];
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
/// Vector matrix multiplication
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename out_mask_t,
|
||||
typename op_mask_t,
|
||||
const int BM, /* Threadgroup rows (in simdgroups) */
|
||||
const int BN, /* Threadgroup cols (in simdgroups) */
|
||||
const int SM, /* Simdgroup rows (in threads) */
|
||||
const int SN, /* Simdgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN> /* Thread cols (in elements) */
|
||||
struct GEMVTKernel {
|
||||
MLX_MTL_CONST int threadsM = BM * SM;
|
||||
MLX_MTL_CONST int threadsN = BN * SN;
|
||||
|
||||
MLX_MTL_CONST int blockM = threadsM * TM;
|
||||
MLX_MTL_CONST int blockN = threadsN * TN;
|
||||
|
||||
static_assert(SM * SN == 32, "simdgroup can only have 32 threads");
|
||||
|
||||
MLX_MTL_CONST bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
|
||||
MLX_MTL_CONST bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
|
||||
|
||||
MLX_MTL_CONST bool has_mul_operand_mask =
|
||||
has_operand_mask && !metal::is_same_v<op_mask_t, bool>;
|
||||
MLX_MTL_CONST bool has_mul_output_mask =
|
||||
has_output_mask && !metal::is_same_v<out_mask_t, bool>;
|
||||
|
||||
// - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up
|
||||
// into blocks of (blockM, blockN) divided among threadgroups
|
||||
// - Every thread works on a block of (TM, TN)
|
||||
// - We assume each threadgroup has (threadsN, threadsM, 1) threads
|
||||
//
|
||||
// 1. A thread loads TN elements each from mat along TM contiguous rows
|
||||
// and the corresponding scalar from the vector
|
||||
// 2. The thread then accumulates its local result for the block
|
||||
// 3. At the end, each thread has accumulated results over all blocks across
|
||||
// the rows. These are then summed up across the threadgroup
|
||||
// 4. Each threadgroup writes its accumulated BN * TN outputs
|
||||
//
|
||||
// Edge case handling:
|
||||
// - The threadgroup with the largest tid has blocks that exceed the matrix
|
||||
// * The blocks that start outside the matrix are never read (thread results
|
||||
// remain zero)
|
||||
// * The last thread that partially overlaps with the matrix is shifted
|
||||
// inwards such that the thread block fits exactly in the matrix
|
||||
|
||||
MLX_MTL_CONST short tgp_mem_size = BM > 1 ? BM*(blockN + TN) : 0;
|
||||
MLX_MTL_CONST bool needs_tgp_reduction = BM > 1;
|
||||
|
||||
static METAL_FUNC void run(
|
||||
const device T* mat [[buffer(0)]],
|
||||
const device T* in_vec [[buffer(1)]],
|
||||
device T* out_vec [[buffer(3)]],
|
||||
const constant int& in_vec_size [[buffer(4)]],
|
||||
const constant int& out_vec_size [[buffer(5)]],
|
||||
const constant int& marix_ld [[buffer(6)]],
|
||||
const device out_mask_t* out_mask [[buffer(20)]],
|
||||
const device op_mask_t* mat_mask [[buffer(21)]],
|
||||
const device op_mask_t* vec_mask [[buffer(22)]],
|
||||
const constant int* mask_strides [[buffer(23)]],
|
||||
threadgroup T* tgp_memory [[threadgroup(0)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
// Appease compiler
|
||||
(void)lid;
|
||||
|
||||
// Thread local accumulation results
|
||||
T result[TN] = {0};
|
||||
T inter[TN];
|
||||
T v_coeff[TM];
|
||||
|
||||
const int thrM = SN != 32 ? simd_lid / SN : 0;
|
||||
const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid);
|
||||
|
||||
const int sgM = BN != 1 ? (simd_gid / BN) : int(simd_gid);
|
||||
const int sgN = BN != 1 ? (simd_gid % BN) : 0;
|
||||
|
||||
const int simdM = SM * sgM;
|
||||
const int simdN = SN * sgN;
|
||||
|
||||
int cm = (simdM + thrM);
|
||||
int cn = (simdN + thrN);
|
||||
|
||||
int bm = cm * TM;
|
||||
int bn = cn * TN;
|
||||
|
||||
int out_col = tid.x * blockN + bn;
|
||||
|
||||
// Prepare mask offsets
|
||||
const constant int* out_mask_strides = mask_strides;
|
||||
const constant int* mat_mask_strides =
|
||||
out_mask_strides + (has_output_mask ? 2 : 0);
|
||||
const constant int* vec_mask_strides =
|
||||
mat_mask_strides + (has_operand_mask ? 2 : 0);
|
||||
|
||||
const int n_block_idx = blockM > blockN ? out_col / blockM : int(tid.x);
|
||||
|
||||
const int out_mask_offset =
|
||||
!has_output_mask ? 0 : n_block_idx; // * out_mask_strides[0];
|
||||
|
||||
int mat_mask_offset =
|
||||
!has_operand_mask ? 0 : n_block_idx * mat_mask_strides[0];
|
||||
int vec_mask_offset = 0;
|
||||
const int mat_mask_step = !has_operand_mask ? 0 : mat_mask_strides[1];
|
||||
const int vec_mask_step = !has_operand_mask ? 0 : vec_mask_strides[0];
|
||||
|
||||
T out_scale{1};
|
||||
|
||||
// Check output mask
|
||||
if (has_output_mask) {
|
||||
auto mask_out = out_mask[out_mask_offset];
|
||||
|
||||
// Write zeros and return if mask is 0
|
||||
if (!mask_out) {
|
||||
if (cm == 0 && out_col < out_vec_size) {
|
||||
if (out_col + TN <= out_vec_size) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
out_vec[out_col + tn] = T(0.);
|
||||
}
|
||||
} else {
|
||||
for (int tn = 0; tn < TN && (out_col + tn) < out_vec_size; tn++) {
|
||||
out_vec[out_col + tn] = T(0.);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
// Store scalar if multiplicative mask
|
||||
if (has_mul_output_mask) {
|
||||
out_scale = T(mask_out);
|
||||
}
|
||||
}
|
||||
|
||||
// Prepare for loop
|
||||
constexpr const uniform<int> loop_stride = make_uniform(blockM);
|
||||
const uniform<int> in_size = make_uniform(in_vec_size);
|
||||
const uniform<int> n_iter = in_size / loop_stride;
|
||||
const uniform<int> last_iter = loop_stride * n_iter;
|
||||
const uniform<int> leftover = in_size - last_iter;
|
||||
|
||||
// Edgecase handling
|
||||
if (out_col < out_vec_size) {
|
||||
out_col = (out_col + TN) <= out_vec_size ? out_col : out_vec_size - TN;
|
||||
|
||||
// Per thread accumulation main loop
|
||||
for (int i = 0; i < n_iter; ++i) {
|
||||
// Adding a threadgroup_barrier improves performance slightly
|
||||
// This is possibly it may help exploit cache better
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
if (!has_operand_mask ||
|
||||
(bool(mat_mask[mat_mask_offset]) &&
|
||||
bool(vec_mask[vec_mask_offset]))) {
|
||||
T block_scale{1};
|
||||
if (has_mul_operand_mask) {
|
||||
block_scale =
|
||||
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
|
||||
}
|
||||
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
v_coeff[tm] = in_vec[bm + tm];
|
||||
}
|
||||
|
||||
// Apply scale
|
||||
if (has_mul_operand_mask) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
v_coeff[tm] *= block_scale;
|
||||
}
|
||||
}
|
||||
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
|
||||
}
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
result[tn] += v_coeff[tm] * inter[tn];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bm += blockM;
|
||||
mat_mask_offset += mat_mask_step;
|
||||
vec_mask_offset += vec_mask_step;
|
||||
}
|
||||
|
||||
if (leftover > 0 &&
|
||||
(!has_operand_mask ||
|
||||
(bool(mat_mask[mat_mask_offset]) &&
|
||||
bool(vec_mask[vec_mask_offset])))) {
|
||||
T block_scale{1};
|
||||
if (has_mul_operand_mask) {
|
||||
block_scale =
|
||||
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
|
||||
}
|
||||
|
||||
for (int tm = 0; tm < TM && bm + tm < in_vec_size; tm++) {
|
||||
v_coeff[tm] = in_vec[bm + tm];
|
||||
|
||||
if (has_mul_operand_mask) {
|
||||
v_coeff[tm] *= block_scale;
|
||||
}
|
||||
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
|
||||
}
|
||||
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
result[tn] += v_coeff[tm] * inter[tn];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply out scale
|
||||
if (has_mul_output_mask) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
result[tn] *= out_scale;
|
||||
}
|
||||
}
|
||||
|
||||
// Simdgroup accumulations
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (ushort sm = (SM / 2); sm >= 1; sm >>= 1) {
|
||||
result[tn] += simd_shuffle_down(result[tn], SN * sm);
|
||||
}
|
||||
}
|
||||
|
||||
// Threadgroup accumulation results
|
||||
if (needs_tgp_reduction) {
|
||||
threadgroup T* tgp_results = tgp_memory + sgM * (blockN + TN) + bn;
|
||||
if (thrM == 0) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
tgp_results[tn] = result[tn];
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
if (sgM == 0) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int sgm = 1; sgm < BM; sgm++) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
result[tn] += tgp_results[sgm * (blockN + TN) + tn];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Threadgroup accumulation and writing out results
|
||||
if (cm == 0 && out_col < out_vec_size) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int j = 0; j < TN; j++) {
|
||||
out_vec[out_col + j] = result[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
/// Matrix vector multiplication
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename out_mask_t,
|
||||
typename op_mask_t,
|
||||
const int BM, /* Threadgroup rows (in simdgroups) */
|
||||
const int BN, /* Threadgroup cols (in simdgroups) */
|
||||
const int SM, /* Simdgroup rows (in threads) */
|
||||
const int SN, /* Simdgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN, /* Thread cols (in elements) */
|
||||
const bool kDoNCBatch> /* Batch ndim > 1 */
|
||||
[[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]] void gemv_masked(
|
||||
const device T* mat [[buffer(0)]],
|
||||
const device T* in_vec [[buffer(1)]],
|
||||
device T* out_vec [[buffer(3)]],
|
||||
const constant int& in_vec_size [[buffer(4)]],
|
||||
const constant int& out_vec_size [[buffer(5)]],
|
||||
const constant int& marix_ld [[buffer(6)]],
|
||||
const constant int& batch_ndim [[buffer(9)]],
|
||||
const constant int* batch_shape [[buffer(10)]],
|
||||
const constant size_t* vector_batch_stride [[buffer(11)]],
|
||||
const constant size_t* matrix_batch_stride [[buffer(12)]],
|
||||
const device out_mask_t* out_mask [[buffer(20)]],
|
||||
const device op_mask_t* mat_mask [[buffer(21)]],
|
||||
const device op_mask_t* vec_mask [[buffer(22)]],
|
||||
const constant int* mask_strides [[buffer(23)]],
|
||||
const constant size_t* mask_batch_strides [[buffer(24)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
using gemv_kernel =
|
||||
GEMVKernel<T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN>;
|
||||
threadgroup T tgp_memory
|
||||
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
|
||||
|
||||
constexpr bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
|
||||
constexpr bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
|
||||
|
||||
// Update batch offsets
|
||||
if (kDoNCBatch) {
|
||||
in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim);
|
||||
mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim);
|
||||
|
||||
if (has_output_mask) {
|
||||
out_mask +=
|
||||
elem_to_loc(tid.z, batch_shape, mask_batch_strides, batch_ndim);
|
||||
mask_batch_strides += batch_ndim;
|
||||
}
|
||||
|
||||
if (has_operand_mask) {
|
||||
const constant size_t* mask_strides_mat = mask_batch_strides;
|
||||
const constant size_t* mask_strides_vec = mask_strides_mat + batch_ndim;
|
||||
|
||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||
tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim);
|
||||
|
||||
mat_mask += batch_offsets.x;
|
||||
vec_mask += batch_offsets.y;
|
||||
}
|
||||
|
||||
} else {
|
||||
in_vec += tid.z * vector_batch_stride[0];
|
||||
mat += tid.z * matrix_batch_stride[0];
|
||||
|
||||
if (has_output_mask) {
|
||||
out_mask += tid.z * mask_batch_strides[0];
|
||||
mask_batch_strides += batch_ndim;
|
||||
}
|
||||
|
||||
if (has_operand_mask) {
|
||||
mat_mask += tid.z * mask_batch_strides[0];
|
||||
vec_mask += tid.z * mask_batch_strides[batch_ndim];
|
||||
}
|
||||
}
|
||||
|
||||
out_vec += tid.z * out_vec_size;
|
||||
|
||||
gemv_kernel::run(
|
||||
mat,
|
||||
in_vec,
|
||||
out_vec,
|
||||
in_vec_size,
|
||||
out_vec_size,
|
||||
marix_ld,
|
||||
out_mask,
|
||||
mat_mask,
|
||||
vec_mask,
|
||||
mask_strides,
|
||||
gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory,
|
||||
tid,
|
||||
lid,
|
||||
simd_gid,
|
||||
simd_lid);
|
||||
}
|
||||
#include "mlx/backend/metal/kernels/gemv_masked.h"
|
||||
|
||||
#define instantiate_gemv_helper( \
|
||||
outm_n, outm_t, opm_n, opm_t, name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
||||
@ -754,7 +36,6 @@ template <
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemv_base(name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
||||
instantiate_gemv_helper(bool_, bool, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
||||
instantiate_gemv_helper(name, itype, name, itype, name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
||||
@ -763,125 +44,23 @@ template <
|
||||
instantiate_gemv_helper(nomask, nomask_t, name, itype, name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
||||
instantiate_gemv_helper(nomask, nomask_t, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
||||
instantiate_gemv_helper(bool_, bool, nomask, nomask_t, name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
||||
instantiate_gemv_helper(name, itype, nomask, nomask_t, name, itype, bm, bn, sm, sn, tm, tn, nc) // clang-format on
|
||||
instantiate_gemv_helper(name, itype, nomask, nomask_t, name, itype, bm, bn, sm, sn, tm, tn, nc)
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemv(name, itype, bm, bn, sm, sn, tm, tn) \
|
||||
#define instantiate_gemv(name, itype, bm, bn, sm, sn, tm, tn) \
|
||||
instantiate_gemv_base(name, itype, bm, bn, sm, sn, tm, tn, 0) \
|
||||
instantiate_gemv_base(name, itype, bm, bn, sm, sn, tm, tn, 1) // clang-format on
|
||||
instantiate_gemv_base(name, itype, bm, bn, sm, sn, tm, tn, 1)
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemv_blocks(name, itype) \
|
||||
instantiate_gemv(name, itype, 2, 1, 4, 8, 1, 4) \
|
||||
instantiate_gemv(name, itype, 2, 1, 4, 8, 4, 4) \
|
||||
instantiate_gemv(name, itype, 2, 1, 2, 16, 1, 4) \
|
||||
instantiate_gemv(name, itype, 2, 1, 2, 16, 4, 4) \
|
||||
instantiate_gemv(name, itype, 4, 1, 2, 16, 4, 4) // clang-format on
|
||||
instantiate_gemv(name, itype, 4, 1, 2, 16, 4, 4)
|
||||
|
||||
instantiate_gemv_blocks(float32, float);
|
||||
instantiate_gemv_blocks(float16, half);
|
||||
instantiate_gemv_blocks(bfloat16, bfloat16_t);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
/// Vector matrix multiplication
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename out_mask_t,
|
||||
typename op_mask_t,
|
||||
const int BM, /* Threadgroup rows (in simdgroups) */
|
||||
const int BN, /* Threadgroup cols (in simdgroups) */
|
||||
const int SM, /* Simdgroup rows (in threads) */
|
||||
const int SN, /* Simdgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN, /* Thread cols (in elements) */
|
||||
const bool kDoNCBatch> /* Batch ndim > 1 */
|
||||
[[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]] void gemv_t_masked(
|
||||
const device T* mat [[buffer(0)]],
|
||||
const device T* in_vec [[buffer(1)]],
|
||||
device T* out_vec [[buffer(3)]],
|
||||
const constant int& in_vec_size [[buffer(4)]],
|
||||
const constant int& out_vec_size [[buffer(5)]],
|
||||
const constant int& marix_ld [[buffer(6)]],
|
||||
const constant int& batch_ndim [[buffer(9)]],
|
||||
const constant int* batch_shape [[buffer(10)]],
|
||||
const constant size_t* vector_batch_stride [[buffer(11)]],
|
||||
const constant size_t* matrix_batch_stride [[buffer(12)]],
|
||||
const device out_mask_t* out_mask [[buffer(20)]],
|
||||
const device op_mask_t* mat_mask [[buffer(21)]],
|
||||
const device op_mask_t* vec_mask [[buffer(22)]],
|
||||
const constant int* mask_strides [[buffer(23)]],
|
||||
const constant size_t* mask_batch_strides [[buffer(24)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
using gemv_kernel =
|
||||
GEMVTKernel<T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN>;
|
||||
threadgroup T tgp_memory
|
||||
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
|
||||
|
||||
constexpr bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
|
||||
constexpr bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
|
||||
|
||||
// Update batch offsets
|
||||
if (kDoNCBatch) {
|
||||
in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim);
|
||||
mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim);
|
||||
|
||||
if (has_output_mask) {
|
||||
out_mask +=
|
||||
elem_to_loc(tid.z, batch_shape, mask_batch_strides, batch_ndim);
|
||||
mask_batch_strides += batch_ndim;
|
||||
}
|
||||
|
||||
if (has_operand_mask) {
|
||||
const constant size_t* mask_strides_mat = mask_batch_strides;
|
||||
const constant size_t* mask_strides_vec = mask_strides_mat + batch_ndim;
|
||||
|
||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||
tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim);
|
||||
|
||||
mat_mask += batch_offsets.x;
|
||||
vec_mask += batch_offsets.y;
|
||||
}
|
||||
|
||||
} else {
|
||||
in_vec += tid.z * vector_batch_stride[0];
|
||||
mat += tid.z * matrix_batch_stride[0];
|
||||
|
||||
if (has_output_mask) {
|
||||
out_mask += tid.z * mask_batch_strides[0];
|
||||
mask_batch_strides += batch_ndim;
|
||||
}
|
||||
|
||||
if (has_operand_mask) {
|
||||
mat_mask += tid.z * mask_batch_strides[0];
|
||||
vec_mask += tid.z * mask_batch_strides[batch_ndim];
|
||||
}
|
||||
}
|
||||
|
||||
out_vec += tid.z * out_vec_size;
|
||||
|
||||
gemv_kernel::run(
|
||||
mat,
|
||||
in_vec,
|
||||
out_vec,
|
||||
in_vec_size,
|
||||
out_vec_size,
|
||||
marix_ld,
|
||||
out_mask,
|
||||
mat_mask,
|
||||
vec_mask,
|
||||
mask_strides,
|
||||
gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory,
|
||||
tid,
|
||||
lid,
|
||||
simd_gid,
|
||||
simd_lid);
|
||||
}
|
||||
|
||||
#define instantiate_gemv_t_helper( \
|
||||
outm_n, outm_t, opm_n, opm_t, name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
||||
template [[host_name("gemv_t_outmask_" #outm_n "_opmask_" #opm_n "_" #name \
|
||||
@ -908,7 +87,6 @@ template <
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemv_t_base(name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
||||
instantiate_gemv_t_helper(bool_, bool, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
||||
instantiate_gemv_t_helper(name, itype, name, itype, name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
||||
@ -917,23 +95,20 @@ template <
|
||||
instantiate_gemv_t_helper(nomask, nomask_t, name, itype, name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
||||
instantiate_gemv_t_helper(nomask, nomask_t, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
||||
instantiate_gemv_t_helper(bool_, bool, nomask, nomask_t, name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
||||
instantiate_gemv_t_helper(name, itype, nomask, nomask_t, name, itype, bm, bn, sm, sn, tm, tn, nc) // clang-format on
|
||||
instantiate_gemv_t_helper(name, itype, nomask, nomask_t, name, itype, bm, bn, sm, sn, tm, tn, nc)
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemv_t(name, itype, bm, bn, sm, sn, tm, tn) \
|
||||
instantiate_gemv_t_base(name, itype, bm, bn, sm, sn, tm, tn, 0) \
|
||||
instantiate_gemv_t_base(name, itype, bm, bn, sm, sn, tm, tn, 1) // clang-format on
|
||||
instantiate_gemv_t_base(name, itype, bm, bn, sm, sn, tm, tn, 1)
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemv_t_blocks(name, itype) \
|
||||
instantiate_gemv_t(name, itype, 1, 1, 8, 4, 4, 1) \
|
||||
instantiate_gemv_t(name, itype, 1, 2, 8, 4, 4, 4) \
|
||||
instantiate_gemv_t(name, itype, 1, 1, 8, 4, 8, 1) \
|
||||
instantiate_gemv_t(name, itype, 1, 1, 8, 4, 8, 4) \
|
||||
instantiate_gemv_t(name, itype, 1, 2, 8, 4, 8, 4) \
|
||||
instantiate_gemv_t(name, itype, 1, 4, 8, 4, 8, 4) // clang-format on
|
||||
instantiate_gemv_t(name, itype, 1, 4, 8, 4, 8, 4)
|
||||
|
||||
// clang-format off
|
||||
instantiate_gemv_t_blocks(float32, float);
|
||||
instantiate_gemv_t_blocks(float16, half);
|
||||
instantiate_gemv_t_blocks(bfloat16, bfloat16_t); // clang-format on
|
||||
instantiate_gemv_t_blocks(bfloat16, bfloat16_t); // clang-format on
|
||||
|
@ -11,187 +11,14 @@
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/params.h"
|
||||
#include "mlx/backend/metal/matmul.h"
|
||||
#include "mlx/backend/metal/mps/gemm.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MPS Matmul fallback
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace {
|
||||
|
||||
bool use_mps() {
|
||||
auto get_val = []() {
|
||||
if (const char* buff_str = std::getenv("MLX_USE_MPS")) {
|
||||
return std::string(buff_str) != "OFF";
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
static bool use_mps_ = get_val();
|
||||
return use_mps_;
|
||||
}
|
||||
|
||||
#define MAX_OPS_PER_BUFFER max_ops_per_buffer()
|
||||
|
||||
inline void mps_matmul(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int batch_size_out,
|
||||
int lda,
|
||||
int ldb,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
std::vector<array>& copies,
|
||||
float alpha = 1.0f,
|
||||
float beta = 0.0f) {
|
||||
MPS::DataType mps_dtype = MPS::DataTypeFloat32;
|
||||
|
||||
if (out.dtype() == float16) {
|
||||
mps_dtype = MPS::DataTypeFloat16;
|
||||
} else if (out.dtype() == bfloat16) {
|
||||
mps_dtype = MPS::DataTypeBFloat16;
|
||||
}
|
||||
|
||||
// Used batched MPSMatrixMultiplication if batch_size_out > 1
|
||||
// We only accept the following cases:
|
||||
// 1. Both a, b have batch_size_out matrices worth of data
|
||||
// 2. Only one of a or b has batch_size_out matrices worth of data and
|
||||
// the other has matrix worth of data
|
||||
|
||||
// The matrix dimensions of a and b are sure to be regularly strided
|
||||
if (batch_size_out > 1) {
|
||||
// No broadcasting defaults
|
||||
auto batch_size_a = a.data_size() / (M * K);
|
||||
auto batch_size_b = b.data_size() / (K * N);
|
||||
|
||||
auto matrix_stride_a = M * K;
|
||||
auto matrix_stride_b = K * N;
|
||||
auto matrix_stride_out = M * N;
|
||||
|
||||
// At this point, batch_size_a, batch_size_b show the number of matrices
|
||||
// in data, no broadcasted strides considered
|
||||
if (batch_size_out == std::max(batch_size_a, batch_size_b)) {
|
||||
// Handle simple broadcasting
|
||||
if (std::min(batch_size_a, batch_size_b) == 1) {
|
||||
matrix_stride_a = (batch_size_a == 1) ? 0 : matrix_stride_a;
|
||||
matrix_stride_b = (batch_size_b == 1) ? 0 : matrix_stride_b;
|
||||
|
||||
batch_size_a = batch_size_out;
|
||||
batch_size_b = batch_size_out;
|
||||
}
|
||||
|
||||
// Only proceed if broadcasting between a and b is simple
|
||||
// At this point, batch_size_a, batch_size_b show the number of matrices
|
||||
// after broadcasting
|
||||
if (batch_size_a == batch_size_b) {
|
||||
auto a_desc = MPS::MatrixDescriptor::matrixDescriptor(
|
||||
(M * K) / lda,
|
||||
lda,
|
||||
batch_size_a,
|
||||
lda * a.itemsize(),
|
||||
(matrix_stride_a * a.itemsize()),
|
||||
mps_dtype);
|
||||
|
||||
auto b_desc = MPS::MatrixDescriptor::matrixDescriptor(
|
||||
(K * N) / ldb,
|
||||
ldb,
|
||||
batch_size_b,
|
||||
ldb * b.itemsize(),
|
||||
(matrix_stride_b * b.itemsize()),
|
||||
mps_dtype);
|
||||
|
||||
auto out_desc = MPS::MatrixDescriptor::matrixDescriptor(
|
||||
M,
|
||||
N,
|
||||
batch_size_out,
|
||||
N * out.itemsize(),
|
||||
matrix_stride_out * out.itemsize(),
|
||||
mps_dtype);
|
||||
|
||||
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
||||
auto a_mat = MPS::Matrix::alloc()->init(a_buf, a_desc);
|
||||
|
||||
auto b_buf = static_cast<const MTL::Buffer*>(b.buffer().ptr());
|
||||
auto b_mat = MPS::Matrix::alloc()->init(b_buf, b_desc);
|
||||
|
||||
auto out_buf = static_cast<MTL::Buffer*>(out.buffer().ptr());
|
||||
auto out_mat = MPS::Matrix::alloc()->init(out_buf, out_desc);
|
||||
|
||||
auto kernel = MPS::MatrixMultiplication::alloc()->init(
|
||||
d.mtl_device(), transpose_a, transpose_b, M, N, K, alpha, beta);
|
||||
|
||||
auto command_buffer = d.get_command_buffer(s.index);
|
||||
kernel->setBatchSize(batch_size_out);
|
||||
kernel->setBatchStart(0);
|
||||
kernel->encodeToCommandBuffer(command_buffer, a_mat, b_mat, out_mat);
|
||||
command_buffer->addCompletedHandler(
|
||||
[a_mat, b_mat, out_mat, kernel, copies](
|
||||
MTL::CommandBuffer*) mutable {
|
||||
a_mat->release();
|
||||
b_mat->release();
|
||||
out_mat->release();
|
||||
kernel->release();
|
||||
copies.clear();
|
||||
});
|
||||
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Schedule as many calls to MPSMatrixMultiplication as needed otherwise
|
||||
auto a_desc = MPS::MatrixDescriptor::matrixDescriptor(
|
||||
a.data_size() / lda, lda, lda * a.itemsize(), mps_dtype);
|
||||
|
||||
auto b_desc = MPS::MatrixDescriptor::matrixDescriptor(
|
||||
b.data_size() / ldb, ldb, ldb * b.itemsize(), mps_dtype);
|
||||
|
||||
auto out_desc = MPS::MatrixDescriptor::matrixDescriptor(
|
||||
batch_size_out * M, N, N * out.itemsize(), mps_dtype);
|
||||
|
||||
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
||||
auto a_mat = MPS::Matrix::alloc()->init(a_buf, a_desc);
|
||||
|
||||
auto b_buf = static_cast<const MTL::Buffer*>(b.buffer().ptr());
|
||||
auto b_mat = MPS::Matrix::alloc()->init(b_buf, b_desc);
|
||||
|
||||
auto out_buf = static_cast<MTL::Buffer*>(out.buffer().ptr());
|
||||
auto out_mat = MPS::Matrix::alloc()->init(out_buf, out_desc);
|
||||
|
||||
auto kernel = MPS::MatrixMultiplication::alloc()->init(
|
||||
d.mtl_device(), transpose_a, transpose_b, M, N, K, alpha, beta);
|
||||
|
||||
auto command_buffer = d.get_command_buffer(s.index);
|
||||
for (int i = 0; i < batch_size_out; ++i) {
|
||||
auto a_row = elem_to_loc(M * K * i, a.shape(), a.strides()) / lda;
|
||||
auto b_row = elem_to_loc(K * N * i, b.shape(), b.strides()) / ldb;
|
||||
kernel->setLeftMatrixOrigin({a_row, 0, 0});
|
||||
kernel->setRightMatrixOrigin({b_row, 0, 0});
|
||||
kernel->setResultMatrixOrigin({i * static_cast<size_t>(M), 0, 0});
|
||||
kernel->encodeToCommandBuffer(command_buffer, a_mat, b_mat, out_mat);
|
||||
}
|
||||
|
||||
command_buffer->addCompletedHandler(
|
||||
[a_mat, b_mat, out_mat, kernel, copies](MTL::CommandBuffer*) mutable {
|
||||
a_mat->release();
|
||||
b_mat->release();
|
||||
out_mat->release();
|
||||
kernel->release();
|
||||
copies.clear();
|
||||
});
|
||||
}
|
||||
|
||||
inline auto collapse_batches(const array& a, const array& b) {
|
||||
// Get and check the shape for the batched dims
|
||||
std::vector<int> A_bshape{a.shape().begin(), a.shape().end() - 2};
|
||||
@ -860,26 +687,6 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Gemm specialization
|
||||
|
||||
if (use_mps()) {
|
||||
d.end_encoding(s.index);
|
||||
|
||||
return mps_matmul(
|
||||
s,
|
||||
d,
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
batch_size_out,
|
||||
a_cols,
|
||||
b_cols,
|
||||
a_transposed,
|
||||
b_transposed,
|
||||
copies);
|
||||
}
|
||||
|
||||
return steel_matmul(
|
||||
/* const Stream& s = */ s,
|
||||
/* metal::Device& d = */ d,
|
||||
@ -1529,8 +1336,22 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
kname << "_nc" << !contiguous_kernel;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto kernel = get_gemv_masked_kernel(
|
||||
d,
|
||||
kname.str(),
|
||||
out,
|
||||
has_out_mask ? std::optional<array>{inputs[2]} : std::nullopt,
|
||||
has_op_mask ? std::optional<array>{inputs.back()} : std::nullopt,
|
||||
transpose_mat,
|
||||
bm,
|
||||
bn,
|
||||
sm,
|
||||
sn,
|
||||
tm,
|
||||
tn,
|
||||
contiguous_kernel);
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
|
||||
|
@ -1,14 +1,6 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/mps/gemm.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@ -48,4 +40,4 @@ void steel_matmul(
|
||||
std::vector<size_t> A_batch_stride = {},
|
||||
std::vector<size_t> B_batch_stride = {});
|
||||
|
||||
} // namespace mlx::core
|
||||
} // namespace mlx::core
|
||||
|
@ -1,370 +0,0 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <Metal/Metal.hpp>
|
||||
|
||||
#define _MPS_PRIVATE_CLS(symbol) (MTL::Private::Class::s_k##symbol)
|
||||
#define _MPS_PRIVATE_SEL(accessor) (MTL::Private::Selector::s_k##accessor)
|
||||
|
||||
namespace MTL::Private::Class {
|
||||
_MTL_PRIVATE_DEF_CLS(MPSMatrixDescriptor);
|
||||
_MTL_PRIVATE_DEF_CLS(MPSMatrix);
|
||||
_MTL_PRIVATE_DEF_CLS(MPSVectorDescriptor);
|
||||
_MTL_PRIVATE_DEF_CLS(MPSVector);
|
||||
_MTL_PRIVATE_DEF_CLS(MPSKernel);
|
||||
_MTL_PRIVATE_DEF_CLS(MPSMatrixMultiplication);
|
||||
_MTL_PRIVATE_DEF_CLS(MPSMatrixVectorMultiplication);
|
||||
} // namespace MTL::Private::Class
|
||||
|
||||
namespace MTL::Private::Selector {
|
||||
_MTL_PRIVATE_DEF_SEL(
|
||||
matrixDescriptorWithRows_columns_rowBytes_dataType,
|
||||
"matrixDescriptorWithRows:columns:rowBytes:dataType:");
|
||||
_MTL_PRIVATE_DEF_SEL(
|
||||
matrixDescriptorWithRows_columns_matrices_rowBytes_matrixBytes_dataType,
|
||||
"matrixDescriptorWithRows:columns:matrices:rowBytes:matrixBytes:dataType:");
|
||||
_MTL_PRIVATE_DEF_SEL(rows, "rows");
|
||||
_MTL_PRIVATE_DEF_SEL(initWithBuffer_descriptor, "initWithBuffer:descriptor:");
|
||||
_MTL_PRIVATE_DEF_SEL(
|
||||
initWithDevice_,
|
||||
"initWithDevice:transposeLeft:transposeRight:"
|
||||
"resultRows:resultColumns:interiorColumns:alpha:beta:");
|
||||
_MTL_PRIVATE_DEF_SEL(
|
||||
encodeToCommandBuffer_leftMatrix_rightMatrix_resultMatrix,
|
||||
"encodeToCommandBuffer:leftMatrix:rightMatrix:resultMatrix:");
|
||||
_MTL_PRIVATE_DEF_SEL(setLeftMatrixOrigin_, "setLeftMatrixOrigin:");
|
||||
_MTL_PRIVATE_DEF_SEL(setRightMatrixOrigin_, "setRightMatrixOrigin:");
|
||||
_MTL_PRIVATE_DEF_SEL(setResultMatrixOrigin_, "setResultMatrixOrigin:");
|
||||
_MTL_PRIVATE_DEF_SEL(setBatchStart_, "setBatchStart:");
|
||||
_MTL_PRIVATE_DEF_SEL(setBatchSize_, "setBatchSize:");
|
||||
_MTL_PRIVATE_DEF_SEL(
|
||||
vectorDescriptorWithLength_dataType,
|
||||
"vectorDescriptorWithLength:dataType:");
|
||||
_MTL_PRIVATE_DEF_SEL(
|
||||
vectorDescriptorWithLength_vectors_vectorBytes_dataType,
|
||||
"vectorDescriptorWithLength:vectors:vectorBytes:dataType:");
|
||||
_MTL_PRIVATE_DEF_SEL(
|
||||
initWithDevice_transpose_rows_columns_alpha_beta,
|
||||
"initWithDevice:transpose:rows:columns:alpha:beta:");
|
||||
_MTL_PRIVATE_DEF_SEL(
|
||||
encodeToCommandBuffer_inputMatrix_inputVector_resultVector,
|
||||
"encodeToCommandBuffer:inputMatrix:inputVector:resultVector:");
|
||||
} // namespace MTL::Private::Selector
|
||||
|
||||
namespace MPS {
|
||||
|
||||
typedef enum DataType : uint32_t {
|
||||
DataTypeFloatBit = 0x10000000,
|
||||
DataTypeAlternateEncodingBit = 0x80000000,
|
||||
DataTypeFloat16 = DataTypeFloatBit | 16,
|
||||
DataTypeFloat32 = DataTypeFloatBit | 32,
|
||||
DataTypeBFloat16 = DataTypeAlternateEncodingBit | DataTypeFloat16
|
||||
} DataType;
|
||||
|
||||
class MatrixDescriptor : public NS::Copying<MatrixDescriptor> {
|
||||
public:
|
||||
static class MatrixDescriptor* matrixDescriptor(
|
||||
NS::UInteger rows,
|
||||
NS::UInteger columns,
|
||||
NS::UInteger rowBytes,
|
||||
NS::UInteger dataType);
|
||||
static class MatrixDescriptor* matrixDescriptor(
|
||||
NS::UInteger rows,
|
||||
NS::UInteger columns,
|
||||
NS::UInteger matrices,
|
||||
NS::UInteger rowBytes,
|
||||
NS::UInteger matrixBytes,
|
||||
NS::UInteger dataType);
|
||||
NS::UInteger rows() const;
|
||||
};
|
||||
|
||||
class Matrix : public NS::Referencing<Matrix> {
|
||||
public:
|
||||
static class Matrix* alloc();
|
||||
Matrix* init(MTL::Buffer* buffer, MatrixDescriptor* descriptor);
|
||||
Matrix* init(const MTL::Buffer* buffer, MatrixDescriptor* descriptor);
|
||||
};
|
||||
|
||||
class Kernel : public NS::Referencing<Kernel> {
|
||||
public:
|
||||
NS::String* label() const;
|
||||
MTL::Device* device() const;
|
||||
};
|
||||
|
||||
class MatrixMultiplication
|
||||
: public NS::Referencing<MatrixMultiplication, Kernel> {
|
||||
public:
|
||||
static class MatrixMultiplication* alloc();
|
||||
|
||||
MatrixMultiplication* init(
|
||||
MTL::Device* device,
|
||||
bool transposeLeft,
|
||||
bool transposeRight,
|
||||
NS::UInteger resultRows,
|
||||
NS::UInteger resultColumns,
|
||||
NS::UInteger interiorColumns,
|
||||
double alpha,
|
||||
double beta);
|
||||
|
||||
void encodeToCommandBuffer(
|
||||
MTL::CommandBuffer* commandBuffer,
|
||||
Matrix* leftMatrix,
|
||||
Matrix* rightMatrix,
|
||||
Matrix* resultMatrix);
|
||||
|
||||
void setLeftMatrixOrigin(MTL::Origin origin);
|
||||
void setRightMatrixOrigin(MTL::Origin origin);
|
||||
void setResultMatrixOrigin(MTL::Origin origin);
|
||||
void setBatchStart(NS::UInteger batchStart);
|
||||
void setBatchSize(NS::UInteger batchSize);
|
||||
};
|
||||
|
||||
class VectorDescriptor : public NS::Copying<VectorDescriptor> {
|
||||
public:
|
||||
static class VectorDescriptor* vectorDescriptor(
|
||||
NS::UInteger length,
|
||||
NS::UInteger dataType);
|
||||
static class VectorDescriptor* vectorDescriptor(
|
||||
NS::UInteger length,
|
||||
NS::UInteger vectors,
|
||||
NS::UInteger vectorBytes,
|
||||
NS::UInteger dataType);
|
||||
};
|
||||
|
||||
class Vector : public NS::Referencing<Vector> {
|
||||
public:
|
||||
static class Vector* alloc();
|
||||
Vector* init(MTL::Buffer* buffer, VectorDescriptor* descriptor);
|
||||
Vector* init(const MTL::Buffer* buffer, VectorDescriptor* descriptor);
|
||||
};
|
||||
|
||||
class MatrixVectorMultiplication
|
||||
: public NS::Referencing<MatrixVectorMultiplication, Kernel> {
|
||||
public:
|
||||
static class MatrixVectorMultiplication* alloc();
|
||||
|
||||
MatrixVectorMultiplication* init(
|
||||
MTL::Device* device,
|
||||
bool transpose,
|
||||
NS::UInteger rows,
|
||||
NS::UInteger columns,
|
||||
double alpha,
|
||||
double beta);
|
||||
|
||||
void encodeToCommandBuffer(
|
||||
MTL::CommandBuffer* commandBuffer,
|
||||
Matrix* inputMatrix,
|
||||
Vector* inputVector,
|
||||
Vector* resultVector);
|
||||
};
|
||||
|
||||
_MTL_INLINE MatrixDescriptor* MatrixDescriptor::matrixDescriptor(
|
||||
NS::UInteger rows,
|
||||
NS::UInteger columns,
|
||||
NS::UInteger rowBytes,
|
||||
NS::UInteger dataType) {
|
||||
return Object::sendMessage<MatrixDescriptor*>(
|
||||
_MPS_PRIVATE_CLS(MPSMatrixDescriptor),
|
||||
_MPS_PRIVATE_SEL(matrixDescriptorWithRows_columns_rowBytes_dataType),
|
||||
rows,
|
||||
columns,
|
||||
rowBytes,
|
||||
dataType);
|
||||
}
|
||||
|
||||
_MTL_INLINE MatrixDescriptor* MatrixDescriptor::matrixDescriptor(
|
||||
NS::UInteger rows,
|
||||
NS::UInteger columns,
|
||||
NS::UInteger matrices,
|
||||
NS::UInteger rowBytes,
|
||||
NS::UInteger matrixBytes,
|
||||
NS::UInteger dataType) {
|
||||
return Object::sendMessage<MatrixDescriptor*>(
|
||||
_MPS_PRIVATE_CLS(MPSMatrixDescriptor),
|
||||
_MPS_PRIVATE_SEL(
|
||||
matrixDescriptorWithRows_columns_matrices_rowBytes_matrixBytes_dataType),
|
||||
rows,
|
||||
columns,
|
||||
matrices,
|
||||
rowBytes,
|
||||
matrixBytes,
|
||||
dataType);
|
||||
}
|
||||
|
||||
_MTL_INLINE NS::UInteger MatrixDescriptor::rows() const {
|
||||
return Object::sendMessage<NS::UInteger>(this, _MPS_PRIVATE_SEL(rows));
|
||||
}
|
||||
|
||||
_MTL_INLINE Matrix* Matrix::alloc() {
|
||||
return NS::Object::alloc<Matrix>(_MPS_PRIVATE_CLS(MPSMatrix));
|
||||
}
|
||||
|
||||
_MTL_INLINE Matrix* Matrix::init(
|
||||
MTL::Buffer* buffer,
|
||||
MatrixDescriptor* descriptor) {
|
||||
return Object::sendMessage<Matrix*>(
|
||||
this, _MPS_PRIVATE_SEL(initWithBuffer_descriptor), buffer, descriptor);
|
||||
}
|
||||
|
||||
_MTL_INLINE Matrix* Matrix::init(
|
||||
const MTL::Buffer* buffer,
|
||||
MatrixDescriptor* descriptor) {
|
||||
return init(const_cast<MTL::Buffer*>(buffer), descriptor);
|
||||
}
|
||||
|
||||
_MTL_INLINE NS::String* Kernel::label() const {
|
||||
return Object::sendMessage<NS::String*>(this, _MPS_PRIVATE_SEL(label));
|
||||
}
|
||||
|
||||
_MTL_INLINE MTL::Device* Kernel::device() const {
|
||||
return Object::sendMessage<MTL::Device*>(this, _MPS_PRIVATE_SEL(device));
|
||||
}
|
||||
|
||||
_MTL_INLINE MatrixMultiplication* MatrixMultiplication::alloc() {
|
||||
return NS::Object::alloc<MatrixMultiplication>(
|
||||
_MPS_PRIVATE_CLS(MPSMatrixMultiplication));
|
||||
}
|
||||
|
||||
_MTL_INLINE MatrixMultiplication* MatrixMultiplication::init(
|
||||
MTL::Device* device,
|
||||
bool transposeLeft,
|
||||
bool transposeRight,
|
||||
NS::UInteger resultRows,
|
||||
NS::UInteger resultColumns,
|
||||
NS::UInteger interiorColumns,
|
||||
double alpha,
|
||||
double beta) {
|
||||
return Object::sendMessage<MatrixMultiplication*>(
|
||||
this,
|
||||
_MPS_PRIVATE_SEL(initWithDevice_),
|
||||
device,
|
||||
transposeLeft,
|
||||
transposeRight,
|
||||
resultRows,
|
||||
resultColumns,
|
||||
interiorColumns,
|
||||
alpha,
|
||||
beta);
|
||||
}
|
||||
|
||||
_MTL_INLINE void MatrixMultiplication::encodeToCommandBuffer(
|
||||
MTL::CommandBuffer* commandBuffer,
|
||||
Matrix* leftMatrix,
|
||||
Matrix* rightMatrix,
|
||||
Matrix* resultMatrix) {
|
||||
return Object::sendMessage<void>(
|
||||
this,
|
||||
_MPS_PRIVATE_SEL(
|
||||
encodeToCommandBuffer_leftMatrix_rightMatrix_resultMatrix),
|
||||
commandBuffer,
|
||||
leftMatrix,
|
||||
rightMatrix,
|
||||
resultMatrix);
|
||||
}
|
||||
|
||||
_MTL_INLINE void MatrixMultiplication::setLeftMatrixOrigin(MTL::Origin origin) {
|
||||
Object::sendMessage<void>(
|
||||
this, _MPS_PRIVATE_SEL(setLeftMatrixOrigin_), origin);
|
||||
}
|
||||
|
||||
_MTL_INLINE void MatrixMultiplication::setRightMatrixOrigin(
|
||||
MTL::Origin origin) {
|
||||
Object::sendMessage<void>(
|
||||
this, _MPS_PRIVATE_SEL(setRightMatrixOrigin_), origin);
|
||||
}
|
||||
|
||||
_MTL_INLINE void MatrixMultiplication::setResultMatrixOrigin(
|
||||
MTL::Origin origin) {
|
||||
Object::sendMessage<void>(
|
||||
this, _MPS_PRIVATE_SEL(setResultMatrixOrigin_), origin);
|
||||
}
|
||||
|
||||
_MTL_INLINE void MatrixMultiplication::setBatchStart(NS::UInteger batchStart) {
|
||||
Object::sendMessage<void>(this, _MPS_PRIVATE_SEL(setBatchStart_), batchStart);
|
||||
}
|
||||
|
||||
_MTL_INLINE void MatrixMultiplication::setBatchSize(NS::UInteger batchSize) {
|
||||
Object::sendMessage<void>(this, _MPS_PRIVATE_SEL(setBatchSize_), batchSize);
|
||||
}
|
||||
|
||||
_MTL_INLINE VectorDescriptor* VectorDescriptor::vectorDescriptor(
|
||||
NS::UInteger length,
|
||||
NS::UInteger dataType) {
|
||||
return Object::sendMessage<VectorDescriptor*>(
|
||||
_MPS_PRIVATE_CLS(MPSVectorDescriptor),
|
||||
_MPS_PRIVATE_SEL(vectorDescriptorWithLength_dataType),
|
||||
length,
|
||||
dataType);
|
||||
}
|
||||
|
||||
_MTL_INLINE VectorDescriptor* VectorDescriptor::vectorDescriptor(
|
||||
NS::UInteger length,
|
||||
NS::UInteger vectors,
|
||||
NS::UInteger vectorBytes,
|
||||
NS::UInteger dataType) {
|
||||
return Object::sendMessage<VectorDescriptor*>(
|
||||
_MPS_PRIVATE_CLS(MPSVectorDescriptor),
|
||||
_MPS_PRIVATE_SEL(vectorDescriptorWithLength_vectors_vectorBytes_dataType),
|
||||
length,
|
||||
vectors,
|
||||
vectorBytes,
|
||||
dataType);
|
||||
}
|
||||
|
||||
_MTL_INLINE Vector* Vector::alloc() {
|
||||
return NS::Object::alloc<Vector>(_MPS_PRIVATE_CLS(MPSVector));
|
||||
}
|
||||
|
||||
_MTL_INLINE Vector* Vector::init(
|
||||
MTL::Buffer* buffer,
|
||||
VectorDescriptor* descriptor) {
|
||||
return Object::sendMessage<Vector*>(
|
||||
this, _MPS_PRIVATE_SEL(initWithBuffer_descriptor), buffer, descriptor);
|
||||
}
|
||||
|
||||
_MTL_INLINE Vector* Vector::init(
|
||||
const MTL::Buffer* buffer,
|
||||
VectorDescriptor* descriptor) {
|
||||
return init(const_cast<MTL::Buffer*>(buffer), descriptor);
|
||||
}
|
||||
|
||||
_MTL_INLINE MatrixVectorMultiplication* MatrixVectorMultiplication::alloc() {
|
||||
return NS::Object::alloc<MatrixVectorMultiplication>(
|
||||
_MPS_PRIVATE_CLS(MPSMatrixVectorMultiplication));
|
||||
}
|
||||
|
||||
_MTL_INLINE MatrixVectorMultiplication* MatrixVectorMultiplication::init(
|
||||
MTL::Device* device,
|
||||
bool transpose,
|
||||
NS::UInteger rows,
|
||||
NS::UInteger columns,
|
||||
double alpha,
|
||||
double beta) {
|
||||
return Object::sendMessage<MatrixVectorMultiplication*>(
|
||||
this,
|
||||
_MPS_PRIVATE_SEL(initWithDevice_transpose_rows_columns_alpha_beta),
|
||||
device,
|
||||
transpose,
|
||||
rows,
|
||||
columns,
|
||||
alpha,
|
||||
beta);
|
||||
}
|
||||
|
||||
_MTL_INLINE void MatrixVectorMultiplication::encodeToCommandBuffer(
|
||||
MTL::CommandBuffer* commandBuffer,
|
||||
Matrix* inputMatrix,
|
||||
Vector* inputVector,
|
||||
Vector* resultVector) {
|
||||
return Object::sendMessage<void>(
|
||||
this,
|
||||
_MPS_PRIVATE_SEL(
|
||||
encodeToCommandBuffer_inputMatrix_inputVector_resultVector),
|
||||
commandBuffer,
|
||||
inputMatrix,
|
||||
inputVector,
|
||||
resultVector);
|
||||
}
|
||||
|
||||
} // namespace MPS
|
@ -169,6 +169,23 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
|
||||
return d.get_kernel(kernel_name);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_gemv_masked_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array&,
|
||||
const std::optional<array>&,
|
||||
const std::optional<array>&,
|
||||
bool,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
bool) {
|
||||
return d.get_kernel(kernel_name);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_steel_conv_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
|
@ -32,7 +32,7 @@ void ternary_op_gpu_inplace(
|
||||
auto& strides_c = strides[2];
|
||||
auto& strides_out = strides[3];
|
||||
|
||||
bool use_2d = out.data_size();
|
||||
bool use_2d = out.data_size() > UINT_MAX;
|
||||
std::string kernel_name;
|
||||
{
|
||||
std::ostringstream kname;
|
||||
|
116
mlx/backend/metal/utils.cpp
Normal file
116
mlx/backend/metal/utils.cpp
Normal file
@ -0,0 +1,116 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
|
||||
using namespace mlx;
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
std::string type_to_name(const array& a) {
|
||||
std::string tname;
|
||||
switch (a.dtype()) {
|
||||
case bool_:
|
||||
tname = "bool_";
|
||||
break;
|
||||
case uint8:
|
||||
tname = "uint8";
|
||||
break;
|
||||
case uint16:
|
||||
tname = "uint16";
|
||||
break;
|
||||
case uint32:
|
||||
tname = "uint32";
|
||||
break;
|
||||
case uint64:
|
||||
tname = "uint64";
|
||||
break;
|
||||
case int8:
|
||||
tname = "int8";
|
||||
break;
|
||||
case int16:
|
||||
tname = "int16";
|
||||
break;
|
||||
case int32:
|
||||
tname = "int32";
|
||||
break;
|
||||
case int64:
|
||||
tname = "int64";
|
||||
break;
|
||||
case float16:
|
||||
tname = "float16";
|
||||
break;
|
||||
case float32:
|
||||
tname = "float32";
|
||||
break;
|
||||
case bfloat16:
|
||||
tname = "bfloat16";
|
||||
break;
|
||||
case complex64:
|
||||
tname = "complex64";
|
||||
break;
|
||||
}
|
||||
return tname;
|
||||
}
|
||||
|
||||
MTL::Size get_block_dims(int dim0, int dim1, int dim2) {
|
||||
int pows[3] = {0, 0, 0};
|
||||
int sum = 0;
|
||||
while (true) {
|
||||
int presum = sum;
|
||||
// Check all the pows
|
||||
if (dim0 >= (1 << (pows[0] + 1))) {
|
||||
pows[0]++;
|
||||
sum++;
|
||||
}
|
||||
if (sum == 10) {
|
||||
break;
|
||||
}
|
||||
if (dim1 >= (1 << (pows[1] + 1))) {
|
||||
pows[1]++;
|
||||
sum++;
|
||||
}
|
||||
if (sum == 10) {
|
||||
break;
|
||||
}
|
||||
if (dim2 >= (1 << (pows[2] + 1))) {
|
||||
pows[2]++;
|
||||
sum++;
|
||||
}
|
||||
if (sum == presum || sum == 10) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return MTL::Size{1ul << pows[0], 1ul << pows[1], 1ul << pows[2]};
|
||||
}
|
||||
|
||||
MTL::Size get_2d_grid_dims(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& strides) {
|
||||
// Dims with strides of 0 are ignored as they
|
||||
// correspond to broadcasted dimensions
|
||||
size_t grid_x = 1;
|
||||
size_t grid_y = 1;
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
if (strides[i] == 0) {
|
||||
continue;
|
||||
}
|
||||
if (grid_x * shape[i] < UINT32_MAX) {
|
||||
grid_x *= shape[i];
|
||||
} else {
|
||||
grid_y *= shape[i];
|
||||
}
|
||||
}
|
||||
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) {
|
||||
throw std::runtime_error("Unable to safely factor shape.");
|
||||
}
|
||||
return MTL::Size(
|
||||
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
|
||||
}
|
||||
|
||||
std::string get_primitive_string(Primitive* primitive) {
|
||||
std::ostringstream op_t;
|
||||
primitive->print(op_t);
|
||||
return op_t.str();
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@ -8,8 +8,6 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
using metal::CommandEncoder;
|
||||
|
||||
template <typename T>
|
||||
@ -27,82 +25,13 @@ set_vector_bytes(CommandEncoder& enc, const std::vector<T>& vec, int idx) {
|
||||
return set_vector_bytes(enc, vec, vec.size(), idx);
|
||||
}
|
||||
|
||||
std::string type_to_name(const array& a) {
|
||||
std::string tname;
|
||||
switch (a.dtype()) {
|
||||
case bool_:
|
||||
tname = "bool_";
|
||||
break;
|
||||
case uint8:
|
||||
tname = "uint8";
|
||||
break;
|
||||
case uint16:
|
||||
tname = "uint16";
|
||||
break;
|
||||
case uint32:
|
||||
tname = "uint32";
|
||||
break;
|
||||
case uint64:
|
||||
tname = "uint64";
|
||||
break;
|
||||
case int8:
|
||||
tname = "int8";
|
||||
break;
|
||||
case int16:
|
||||
tname = "int16";
|
||||
break;
|
||||
case int32:
|
||||
tname = "int32";
|
||||
break;
|
||||
case int64:
|
||||
tname = "int64";
|
||||
break;
|
||||
case float16:
|
||||
tname = "float16";
|
||||
break;
|
||||
case float32:
|
||||
tname = "float32";
|
||||
break;
|
||||
case bfloat16:
|
||||
tname = "bfloat16";
|
||||
break;
|
||||
case complex64:
|
||||
tname = "complex64";
|
||||
break;
|
||||
}
|
||||
return tname;
|
||||
}
|
||||
std::string type_to_name(const array& a);
|
||||
|
||||
MTL::Size get_block_dims(int dim0, int dim1, int dim2) {
|
||||
int pows[3] = {0, 0, 0};
|
||||
int sum = 0;
|
||||
while (true) {
|
||||
int presum = sum;
|
||||
// Check all the pows
|
||||
if (dim0 >= (1 << (pows[0] + 1))) {
|
||||
pows[0]++;
|
||||
sum++;
|
||||
}
|
||||
if (sum == 10) {
|
||||
break;
|
||||
}
|
||||
if (dim1 >= (1 << (pows[1] + 1))) {
|
||||
pows[1]++;
|
||||
sum++;
|
||||
}
|
||||
if (sum == 10) {
|
||||
break;
|
||||
}
|
||||
if (dim2 >= (1 << (pows[2] + 1))) {
|
||||
pows[2]++;
|
||||
sum++;
|
||||
}
|
||||
if (sum == presum || sum == 10) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return MTL::Size{1ul << pows[0], 1ul << pows[1], 1ul << pows[2]};
|
||||
}
|
||||
// Compute the thread block dimensions which fit the given
|
||||
// input dimensions.
|
||||
// - The thread block dimensions will be powers of two
|
||||
// - The thread block size will be less than 1024
|
||||
MTL::Size get_block_dims(int dim0, int dim1, int dim2);
|
||||
|
||||
// Computes a 2D grid where each element is < UINT_MAX
|
||||
// Assumes:
|
||||
@ -111,27 +40,7 @@ MTL::Size get_block_dims(int dim0, int dim1, int dim2) {
|
||||
// possibly broadcasted array
|
||||
MTL::Size get_2d_grid_dims(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& strides) {
|
||||
// Dims with strides of 0 are ignored as they
|
||||
// correspond to broadcasted dimensions
|
||||
size_t grid_x = 1;
|
||||
size_t grid_y = 1;
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
if (strides[i] == 0) {
|
||||
continue;
|
||||
}
|
||||
if (grid_x * shape[i] < UINT32_MAX) {
|
||||
grid_x *= shape[i];
|
||||
} else {
|
||||
grid_y *= shape[i];
|
||||
}
|
||||
}
|
||||
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) {
|
||||
throw std::runtime_error("Unable to safely factor shape.");
|
||||
}
|
||||
return MTL::Size(
|
||||
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
|
||||
}
|
||||
const std::vector<size_t>& strides);
|
||||
|
||||
inline NS::String* make_string(std::ostringstream& os) {
|
||||
std::string string = os.str();
|
||||
@ -159,12 +68,6 @@ inline void debug_set_primitive_buffer_label(
|
||||
#endif
|
||||
}
|
||||
|
||||
std::string get_primitive_string(Primitive* primitive) {
|
||||
std::ostringstream op_t;
|
||||
primitive->print(op_t);
|
||||
return op_t.str();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
std::string get_primitive_string(Primitive* primitive);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -1,11 +1,8 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cstdint>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
|
||||
#include "mlx/dtype.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@ -178,67 +175,4 @@ bool issubdtype(const Dtype::Category& a, const Dtype::Category& b) {
|
||||
[static_cast<uint32_t>(b)];
|
||||
}
|
||||
|
||||
// Array protocol typestring for Dtype
|
||||
std::string dtype_to_array_protocol(const Dtype& t) {
|
||||
std::ostringstream r;
|
||||
if (size_of(t) > 1)
|
||||
r << (is_big_endian() ? ">" : "<");
|
||||
else
|
||||
r << "|";
|
||||
r << kindof(t) << (int)size_of(t);
|
||||
return r.str();
|
||||
}
|
||||
|
||||
// Dtype from array protocol type string
|
||||
Dtype dtype_from_array_protocol(std::string_view t) {
|
||||
if (t.length() == 2 || t.length() == 3) {
|
||||
std::string_view r = t.length() == 3 ? t.substr(1, 2) : t;
|
||||
|
||||
if (r == "V2") {
|
||||
return bfloat16;
|
||||
}
|
||||
|
||||
uint8_t size = r[1] - '0';
|
||||
|
||||
switch (r[0]) {
|
||||
case 'b': {
|
||||
if (size == 1)
|
||||
return bool_;
|
||||
}
|
||||
case 'i': {
|
||||
if (size == 1)
|
||||
return int8;
|
||||
else if (size == 2)
|
||||
return int16;
|
||||
else if (size == 4)
|
||||
return int32;
|
||||
else if (size == 8)
|
||||
return int64;
|
||||
}
|
||||
case 'u': {
|
||||
if (size == 1)
|
||||
return uint8;
|
||||
else if (size == 2)
|
||||
return uint16;
|
||||
else if (size == 4)
|
||||
return uint32;
|
||||
else if (size == 8)
|
||||
return uint64;
|
||||
}
|
||||
case 'f': {
|
||||
if (size == 2)
|
||||
return float16;
|
||||
else if (size == 4)
|
||||
return float32;
|
||||
}
|
||||
case 'c': {
|
||||
return complex64;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
throw std::invalid_argument(
|
||||
"[from_str] Invalid array protocol type-string: " + std::string(t));
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -4,8 +4,6 @@
|
||||
|
||||
#include <complex>
|
||||
#include <cstdint>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
|
||||
#include "mlx/types/complex.h"
|
||||
#include "mlx/types/half_types.h"
|
||||
@ -103,9 +101,4 @@ struct TypeToDtype {
|
||||
operator Dtype();
|
||||
};
|
||||
|
||||
// Array protocol typestring for Dtype
|
||||
std::string dtype_to_array_protocol(const Dtype& t);
|
||||
// Dtype from array protocol type string
|
||||
Dtype dtype_from_array_protocol(std::string_view t);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -26,6 +26,80 @@ constexpr uint8_t MAGIC[] = {
|
||||
0x59,
|
||||
};
|
||||
|
||||
inline bool is_big_endian() {
|
||||
union ByteOrder {
|
||||
int32_t i;
|
||||
uint8_t c[4];
|
||||
};
|
||||
ByteOrder b = {0x01234567};
|
||||
|
||||
return b.c[0] == 0x01;
|
||||
}
|
||||
|
||||
// Array protocol typestring for Dtype
|
||||
std::string dtype_to_array_protocol(const Dtype& t) {
|
||||
std::ostringstream r;
|
||||
if (size_of(t) > 1) {
|
||||
r << (is_big_endian() ? ">" : "<");
|
||||
} else {
|
||||
r << "|";
|
||||
}
|
||||
r << kindof(t) << (int)size_of(t);
|
||||
return r.str();
|
||||
}
|
||||
|
||||
// Dtype from array protocol type string
|
||||
Dtype dtype_from_array_protocol(std::string_view t) {
|
||||
if (t.length() == 2 || t.length() == 3) {
|
||||
std::string_view r = t.length() == 3 ? t.substr(1, 2) : t;
|
||||
|
||||
if (r == "V2") {
|
||||
return bfloat16;
|
||||
}
|
||||
|
||||
uint8_t size = r[1] - '0';
|
||||
|
||||
switch (r[0]) {
|
||||
case 'b': {
|
||||
if (size == 1)
|
||||
return bool_;
|
||||
}
|
||||
case 'i': {
|
||||
if (size == 1)
|
||||
return int8;
|
||||
else if (size == 2)
|
||||
return int16;
|
||||
else if (size == 4)
|
||||
return int32;
|
||||
else if (size == 8)
|
||||
return int64;
|
||||
}
|
||||
case 'u': {
|
||||
if (size == 1)
|
||||
return uint8;
|
||||
else if (size == 2)
|
||||
return uint16;
|
||||
else if (size == 4)
|
||||
return uint32;
|
||||
else if (size == 8)
|
||||
return uint64;
|
||||
}
|
||||
case 'f': {
|
||||
if (size == 2)
|
||||
return float16;
|
||||
else if (size == 4)
|
||||
return float32;
|
||||
}
|
||||
case 'c': {
|
||||
return complex64;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
throw std::invalid_argument(
|
||||
"[from_str] Invalid array protocol type-string: " + std::string(t));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
/** Save array to out stream in .npy format */
|
||||
|
10
mlx/utils.h
10
mlx/utils.h
@ -84,16 +84,6 @@ int check_shape_dim(const T dim) {
|
||||
return static_cast<int>(dim);
|
||||
}
|
||||
|
||||
inline bool is_big_endian() {
|
||||
union ByteOrder {
|
||||
int32_t i;
|
||||
uint8_t c[4];
|
||||
};
|
||||
ByteOrder b = {0x01234567};
|
||||
|
||||
return b.c[0] == 0x01;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the axis normalized to be in the range [0, ndim).
|
||||
* Based on numpy's normalize_axis_index. See
|
||||
|
@ -161,6 +161,8 @@ TEST_CASE("test array types") {
|
||||
// bfloat16
|
||||
{ basic_dtype_test(bfloat16_t, bfloat16); }
|
||||
|
||||
#undef basic_dtype_test
|
||||
|
||||
// uint32
|
||||
{
|
||||
uint32_t val = UINT_MAX;
|
||||
@ -233,31 +235,6 @@ TEST_CASE("test array types") {
|
||||
CHECK_EQ(x.dtype(), complex64);
|
||||
CHECK_EQ(x.item<complex64_t>(), v);
|
||||
}
|
||||
|
||||
#undef basic_dtype_test
|
||||
|
||||
#define basic_dtype_str_test(s, dtype) \
|
||||
CHECK_EQ(s, dtype_to_array_protocol(dtype)); \
|
||||
CHECK_EQ(dtype_from_array_protocol(s), dtype);
|
||||
|
||||
// To and from str
|
||||
{
|
||||
basic_dtype_str_test("|b1", bool_);
|
||||
basic_dtype_str_test("|u1", uint8);
|
||||
basic_dtype_str_test("<u2", uint16);
|
||||
basic_dtype_str_test("<u4", uint32);
|
||||
basic_dtype_str_test("<u8", uint64);
|
||||
basic_dtype_str_test("|i1", int8);
|
||||
basic_dtype_str_test("<i2", int16);
|
||||
basic_dtype_str_test("<i4", int32);
|
||||
basic_dtype_str_test("<i8", int64);
|
||||
basic_dtype_str_test("<f2", float16);
|
||||
basic_dtype_str_test("<f4", float32);
|
||||
basic_dtype_str_test("<V2", bfloat16);
|
||||
basic_dtype_str_test("<c8", complex64);
|
||||
}
|
||||
|
||||
#undef basic_dtype_str_test
|
||||
}
|
||||
|
||||
TEST_CASE("test array metadata") {
|
||||
|
Loading…
Reference in New Issue
Block a user