mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Add gemv masked to JIT plus some fixes (#1310)
* add gemv masked to JIT plus some fixes * some cleanup * add utils * fix * fix 2 * more cleaning * fix * remove unused mps matmul support * one more nit * revert
This commit is contained in:
parent
635ccd9e25
commit
30bbea2f08
@ -486,9 +486,8 @@ below.
|
|||||||
std::ostringstream kname;
|
std::ostringstream kname;
|
||||||
kname << "axpby_" << "general_" << type_to_name(out);
|
kname << "axpby_" << "general_" << type_to_name(out);
|
||||||
|
|
||||||
// Make sure the metal library is available and look for it
|
// Make sure the metal library is available
|
||||||
// in the same folder as this executable if needed
|
d.register_library("mlx_ext");
|
||||||
d.register_library("mlx_ext", metal::get_colocated_mtllib_path);
|
|
||||||
|
|
||||||
// Make a kernel from this metal library
|
// Make a kernel from this metal library
|
||||||
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
|
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
|
||||||
|
@ -249,9 +249,8 @@ void Axpby::eval_gpu(
|
|||||||
kname << (contiguous_kernel ? "contiguous_" : "general_");
|
kname << (contiguous_kernel ? "contiguous_" : "general_");
|
||||||
kname << type_to_name(out);
|
kname << type_to_name(out);
|
||||||
|
|
||||||
// Make sure the metal library is available and look for it
|
// Make sure the metal library is available
|
||||||
// in the same folder as this executable if needed
|
d.register_library("mlx_ext");
|
||||||
d.register_library("mlx_ext", metal::get_colocated_mtllib_path);
|
|
||||||
|
|
||||||
// Make a kernel from this metal library
|
// Make a kernel from this metal library
|
||||||
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
|
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
|
||||||
|
@ -114,6 +114,7 @@ if (MLX_METAL_JIT)
|
|||||||
kernels/steel/conv/loaders/loader_general.h
|
kernels/steel/conv/loaders/loader_general.h
|
||||||
)
|
)
|
||||||
make_jit_source(quantized)
|
make_jit_source(quantized)
|
||||||
|
make_jit_source(gemv_masked)
|
||||||
else()
|
else()
|
||||||
target_sources(
|
target_sources(
|
||||||
mlx
|
mlx
|
||||||
@ -149,6 +150,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||||
)
|
)
|
||||||
|
|
||||||
if (NOT MLX_METAL_PATH)
|
if (NOT MLX_METAL_PATH)
|
||||||
|
@ -14,7 +14,6 @@
|
|||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
#include "mlx/backend/metal/metal.h"
|
#include "mlx/backend/metal/metal.h"
|
||||||
#include "mlx/backend/metal/metal_impl.h"
|
#include "mlx/backend/metal/metal_impl.h"
|
||||||
#include "mlx/backend/metal/mps/gemm.h"
|
|
||||||
#include "mlx/backend/metal/utils.h"
|
#include "mlx/backend/metal/utils.h"
|
||||||
|
|
||||||
namespace fs = std::filesystem;
|
namespace fs = std::filesystem;
|
||||||
@ -39,6 +38,20 @@ constexpr auto get_metal_version() {
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string get_colocated_mtllib_path(const std::string& lib_name) {
|
||||||
|
Dl_info info;
|
||||||
|
std::string mtllib_path;
|
||||||
|
std::string lib_ext = lib_name + ".metallib";
|
||||||
|
|
||||||
|
int success = dladdr((void*)get_colocated_mtllib_path, &info);
|
||||||
|
if (success) {
|
||||||
|
auto mtllib = fs::path(info.dli_fname).remove_filename() / lib_ext;
|
||||||
|
mtllib_path = mtllib.c_str();
|
||||||
|
}
|
||||||
|
|
||||||
|
return mtllib_path;
|
||||||
|
}
|
||||||
|
|
||||||
auto load_device() {
|
auto load_device() {
|
||||||
auto devices = MTL::CopyAllDevices();
|
auto devices = MTL::CopyAllDevices();
|
||||||
auto device = static_cast<MTL::Device*>(devices->object(0))
|
auto device = static_cast<MTL::Device*>(devices->object(0))
|
||||||
@ -126,6 +139,49 @@ MTL::Library* load_library(
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
CommandEncoder::CommandEncoder(MTL::CommandBuffer* cbuf) : cbuf(cbuf) {
|
||||||
|
enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
|
||||||
|
enc->retain();
|
||||||
|
}
|
||||||
|
|
||||||
|
CommandEncoder::~CommandEncoder() {
|
||||||
|
enc->endEncoding();
|
||||||
|
enc->release();
|
||||||
|
}
|
||||||
|
|
||||||
|
void CommandEncoder::set_input_array(
|
||||||
|
const array& a,
|
||||||
|
int idx,
|
||||||
|
int64_t offset /* = 0 */) {
|
||||||
|
auto r_buf = static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
|
||||||
|
if (auto it = outputs.find(r_buf); it != outputs.end()) {
|
||||||
|
// Insert a barrier
|
||||||
|
enc->memoryBarrier(&r_buf, 1);
|
||||||
|
|
||||||
|
// Remove the output
|
||||||
|
outputs.erase(it);
|
||||||
|
}
|
||||||
|
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
||||||
|
auto base_offset = a.data<char>() -
|
||||||
|
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
|
||||||
|
base_offset += offset;
|
||||||
|
enc->setBuffer(a_buf, base_offset, idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
void CommandEncoder::set_output_array(
|
||||||
|
array& a,
|
||||||
|
int idx,
|
||||||
|
int64_t offset /* = 0 */) {
|
||||||
|
// Add barriers before adding the output to the output set
|
||||||
|
set_input_array(a, idx, offset);
|
||||||
|
auto buf = static_cast<MTL::Resource*>(a.buffer().ptr());
|
||||||
|
if (concurrent) {
|
||||||
|
concurrent_outputs.insert(buf);
|
||||||
|
} else {
|
||||||
|
outputs.insert(buf);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void CommandEncoder::dispatchThreadgroups(
|
void CommandEncoder::dispatchThreadgroups(
|
||||||
MTL::Size grid_dims,
|
MTL::Size grid_dims,
|
||||||
MTL::Size group_dims) {
|
MTL::Size group_dims) {
|
||||||
@ -255,13 +311,9 @@ void Device::register_library(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Device::register_library(
|
void Device::register_library(const std::string& lib_name) {
|
||||||
const std::string& lib_name,
|
|
||||||
const std::function<std::string(const std::string&)>& lib_path_func) {
|
|
||||||
if (auto it = library_map_.find(lib_name); it == library_map_.end()) {
|
if (auto it = library_map_.find(lib_name); it == library_map_.end()) {
|
||||||
std::string new_lib_path = lib_path_func(lib_name);
|
register_library(lib_name, get_colocated_mtllib_path(lib_name));
|
||||||
auto new_lib = load_library(device_, lib_name, new_lib_path.c_str());
|
|
||||||
library_map_.insert({lib_name, new_lib});
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -271,7 +323,7 @@ MTL::Library* Device::get_library_cache_(const std::string& lib_name) {
|
|||||||
if (auto it = library_map_.find(lib_name); it != library_map_.end()) {
|
if (auto it = library_map_.find(lib_name); it != library_map_.end()) {
|
||||||
mtl_lib = it->second;
|
mtl_lib = it->second;
|
||||||
} else { // Look for metallib alongside library
|
} else { // Look for metallib alongside library
|
||||||
register_library(lib_name);
|
register_library(lib_name, get_colocated_mtllib_path(lib_name));
|
||||||
mtl_lib = library_map_[lib_name];
|
mtl_lib = library_map_[lib_name];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -9,38 +9,16 @@
|
|||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
|
|
||||||
#include <dlfcn.h>
|
|
||||||
#include <filesystem>
|
|
||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
#include "mlx/device.h"
|
#include "mlx/device.h"
|
||||||
|
|
||||||
namespace fs = std::filesystem;
|
|
||||||
|
|
||||||
namespace mlx::core::metal {
|
namespace mlx::core::metal {
|
||||||
|
|
||||||
inline std::string get_colocated_mtllib_path(const std::string& lib_name) {
|
|
||||||
Dl_info info;
|
|
||||||
std::string mtllib_path;
|
|
||||||
std::string lib_ext = lib_name + ".metallib";
|
|
||||||
|
|
||||||
int success = dladdr((void*)get_colocated_mtllib_path, &info);
|
|
||||||
if (success) {
|
|
||||||
auto mtllib = fs::path(info.dli_fname).remove_filename() / lib_ext;
|
|
||||||
mtllib_path = mtllib.c_str();
|
|
||||||
}
|
|
||||||
|
|
||||||
return mtllib_path;
|
|
||||||
}
|
|
||||||
|
|
||||||
using MTLFCList =
|
using MTLFCList =
|
||||||
std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;
|
std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;
|
||||||
|
|
||||||
struct CommandEncoder {
|
struct CommandEncoder {
|
||||||
CommandEncoder(MTL::CommandBuffer* cbuf) : cbuf(cbuf) {
|
CommandEncoder(MTL::CommandBuffer* cbuf);
|
||||||
enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
|
|
||||||
enc->retain();
|
|
||||||
};
|
|
||||||
CommandEncoder(const CommandEncoder&) = delete;
|
CommandEncoder(const CommandEncoder&) = delete;
|
||||||
CommandEncoder& operator=(const CommandEncoder&) = delete;
|
CommandEncoder& operator=(const CommandEncoder&) = delete;
|
||||||
|
|
||||||
@ -63,34 +41,8 @@ struct CommandEncoder {
|
|||||||
return enc;
|
return enc;
|
||||||
}
|
}
|
||||||
|
|
||||||
void set_input_array(const array& a, int idx, int64_t offset = 0) {
|
void set_input_array(const array& a, int idx, int64_t offset = 0);
|
||||||
auto r_buf =
|
void set_output_array(array& a, int idx, int64_t offset = 0);
|
||||||
static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
|
|
||||||
if (auto it = outputs.find(r_buf); it != outputs.end()) {
|
|
||||||
// Insert a barrier
|
|
||||||
enc->memoryBarrier(&r_buf, 1);
|
|
||||||
|
|
||||||
// Remove the output
|
|
||||||
outputs.erase(it);
|
|
||||||
}
|
|
||||||
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
|
||||||
auto base_offset = a.data<char>() -
|
|
||||||
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
|
|
||||||
base_offset += offset;
|
|
||||||
enc->setBuffer(a_buf, base_offset, idx);
|
|
||||||
}
|
|
||||||
|
|
||||||
void set_output_array(array& a, int idx, int64_t offset = 0) {
|
|
||||||
// Add barriers before adding the output to the output set
|
|
||||||
set_input_array(a, idx, offset);
|
|
||||||
auto buf = static_cast<MTL::Resource*>(a.buffer().ptr());
|
|
||||||
if (concurrent) {
|
|
||||||
concurrent_outputs.insert(buf);
|
|
||||||
} else {
|
|
||||||
outputs.insert(buf);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void dispatchThreadgroups(MTL::Size grid_dims, MTL::Size group_dims);
|
void dispatchThreadgroups(MTL::Size grid_dims, MTL::Size group_dims);
|
||||||
void dispatchThreads(MTL::Size grid_dims, MTL::Size group_dims);
|
void dispatchThreads(MTL::Size grid_dims, MTL::Size group_dims);
|
||||||
|
|
||||||
@ -98,10 +50,7 @@ struct CommandEncoder {
|
|||||||
return ConcurrentContext(*this);
|
return ConcurrentContext(*this);
|
||||||
}
|
}
|
||||||
|
|
||||||
~CommandEncoder() {
|
~CommandEncoder();
|
||||||
enc->endEncoding();
|
|
||||||
enc->release();
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void maybe_split();
|
void maybe_split();
|
||||||
@ -136,10 +85,8 @@ class Device {
|
|||||||
void register_library(
|
void register_library(
|
||||||
const std::string& lib_name,
|
const std::string& lib_name,
|
||||||
const std::string& lib_path);
|
const std::string& lib_path);
|
||||||
void register_library(
|
|
||||||
const std::string& lib_name,
|
void register_library(const std::string& lib_name);
|
||||||
const std::function<std::string(const std::string&)>& lib_path_func =
|
|
||||||
get_colocated_mtllib_path);
|
|
||||||
|
|
||||||
MTL::Library* get_library(const std::string& name);
|
MTL::Library* get_library(const std::string& name);
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <complex>
|
#include <complex>
|
||||||
#include <map>
|
#include <map>
|
||||||
@ -12,8 +12,6 @@
|
|||||||
#include "mlx/backend/metal/slicing.h"
|
#include "mlx/backend/metal/slicing.h"
|
||||||
#include "mlx/backend/metal/unary.h"
|
#include "mlx/backend/metal/unary.h"
|
||||||
#include "mlx/backend/metal/utils.h"
|
#include "mlx/backend/metal/utils.h"
|
||||||
#include "mlx/mlx.h"
|
|
||||||
#include "mlx/primitives.h"
|
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
@ -786,10 +784,9 @@ void nd_fft_op(
|
|||||||
fft_op(in_arr, out_arr, axis, inverse, step_real, inplace, s);
|
fft_op(in_arr, out_arr, axis, inverse, step_real, inplace, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<array> copies = {temp1, temp2};
|
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
[temp_arrs](MTL::CommandBuffer*) mutable { temp_arrs.clear(); });
|
||||||
}
|
}
|
||||||
|
|
||||||
void FFT::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void FFT::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
25
mlx/backend/metal/jit/gemv_masked.h
Normal file
25
mlx/backend/metal/jit/gemv_masked.h
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
constexpr std::string_view gemv_masked_kernel = R"(
|
||||||
|
template [[host_name("{name}")]] [[kernel]] void
|
||||||
|
gemv_{trans}masked<{itype}, {outm_t}, {opm_t}, {bm}, {bn}, {sm}, {sn}, {tm}, {tn}, {nc}>(
|
||||||
|
const device {itype}* mat [[buffer(0)]],
|
||||||
|
const device {itype}* in_vec [[buffer(1)]],
|
||||||
|
device {itype}* out_vec [[buffer(3)]],
|
||||||
|
const constant int& in_vec_size [[buffer(4)]],
|
||||||
|
const constant int& out_vec_size [[buffer(5)]],
|
||||||
|
const constant int& marix_ld [[buffer(6)]],
|
||||||
|
const constant int& batch_ndim [[buffer(9)]],
|
||||||
|
const constant int* batch_shape [[buffer(10)]],
|
||||||
|
const constant size_t* vector_batch_stride [[buffer(11)]],
|
||||||
|
const constant size_t* matrix_batch_stride [[buffer(12)]],
|
||||||
|
const device {outm_t}* out_mask [[buffer(20)]],
|
||||||
|
const device {opm_t}* mat_mask [[buffer(21)]],
|
||||||
|
const device {opm_t}* vec_mask [[buffer(22)]],
|
||||||
|
const constant int* mask_strides [[buffer(23)]],
|
||||||
|
const constant size_t* mask_batch_strides [[buffer(24)]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]],
|
||||||
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||||
|
)";
|
@ -33,5 +33,6 @@ const char* steel_gemm_splitk();
|
|||||||
const char* conv();
|
const char* conv();
|
||||||
const char* steel_conv();
|
const char* steel_conv();
|
||||||
const char* steel_conv_general();
|
const char* steel_conv_general();
|
||||||
|
const char* gemv_masked();
|
||||||
|
|
||||||
} // namespace mlx::core::metal
|
} // namespace mlx::core::metal
|
||||||
|
@ -4,6 +4,7 @@
|
|||||||
#include "mlx/backend/common/compiled.h"
|
#include "mlx/backend/common/compiled.h"
|
||||||
#include "mlx/backend/metal/jit/arange.h"
|
#include "mlx/backend/metal/jit/arange.h"
|
||||||
#include "mlx/backend/metal/jit/copy.h"
|
#include "mlx/backend/metal/jit/copy.h"
|
||||||
|
#include "mlx/backend/metal/jit/gemv_masked.h"
|
||||||
#include "mlx/backend/metal/jit/includes.h"
|
#include "mlx/backend/metal/jit/includes.h"
|
||||||
#include "mlx/backend/metal/jit/reduce.h"
|
#include "mlx/backend/metal/jit/reduce.h"
|
||||||
#include "mlx/backend/metal/jit/scan.h"
|
#include "mlx/backend/metal/jit/scan.h"
|
||||||
@ -50,10 +51,12 @@ MTL::ComputePipelineState* get_unary_kernel(
|
|||||||
std::ostringstream kernel_source;
|
std::ostringstream kernel_source;
|
||||||
auto u_def = get_template_definition(
|
auto u_def = get_template_definition(
|
||||||
"v" + lib_name, "unary_v", get_type_string(out_type), op);
|
"v" + lib_name, "unary_v", get_type_string(out_type), op);
|
||||||
|
auto u2_def = get_template_definition(
|
||||||
|
"v2" + lib_name, "unary_v2", get_type_string(out_type), op);
|
||||||
auto g_def = get_template_definition(
|
auto g_def = get_template_definition(
|
||||||
"g" + lib_name, "unary_g", get_type_string(out_type), op);
|
"g" + lib_name, "unary_g", get_type_string(out_type), op);
|
||||||
kernel_source << metal::utils() << metal::unary_ops() << metal::unary()
|
kernel_source << metal::utils() << metal::unary_ops() << metal::unary()
|
||||||
<< u_def << g_def;
|
<< u_def << u2_def << g_def;
|
||||||
lib = d.get_library(lib_name, kernel_source.str());
|
lib = d.get_library(lib_name, kernel_source.str());
|
||||||
}
|
}
|
||||||
return d.get_kernel(kernel_name, lib);
|
return d.get_kernel(kernel_name, lib);
|
||||||
@ -70,6 +73,9 @@ void add_binary_kernels(
|
|||||||
{"vs", "binary_vs"},
|
{"vs", "binary_vs"},
|
||||||
{"sv", "binary_sv"},
|
{"sv", "binary_sv"},
|
||||||
{"vv", "binary_vv"},
|
{"vv", "binary_vv"},
|
||||||
|
{"vs2", "binary_vs2"},
|
||||||
|
{"sv2", "binary_sv2"},
|
||||||
|
{"vv2", "binary_vv2"},
|
||||||
{"g1", "binary_g_nd1"},
|
{"g1", "binary_g_nd1"},
|
||||||
{"g2", "binary_g_nd2"},
|
{"g2", "binary_g_nd2"},
|
||||||
{"g3", "binary_g_nd3"},
|
{"g3", "binary_g_nd3"},
|
||||||
@ -146,6 +152,7 @@ MTL::ComputePipelineState* get_ternary_kernel(
|
|||||||
std::ostringstream kernel_source;
|
std::ostringstream kernel_source;
|
||||||
const std::map<std::string, std::string> kernel_types = {
|
const std::map<std::string, std::string> kernel_types = {
|
||||||
{"v", "ternary_v"},
|
{"v", "ternary_v"},
|
||||||
|
{"v2", "ternary_v2"},
|
||||||
{"g", "ternary_g"},
|
{"g", "ternary_g"},
|
||||||
{"g1", "ternary_g_nd1"},
|
{"g1", "ternary_g_nd1"},
|
||||||
{"g2", "ternary_g_nd2"},
|
{"g2", "ternary_g_nd2"},
|
||||||
@ -496,6 +503,49 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
|
|||||||
return d.get_kernel(kernel_name, lib);
|
return d.get_kernel(kernel_name, lib);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MTL::ComputePipelineState* get_gemv_masked_kernel(
|
||||||
|
metal::Device& d,
|
||||||
|
const std::string& kernel_name,
|
||||||
|
const array& out,
|
||||||
|
const std::optional<array>& mask_out,
|
||||||
|
const std::optional<array>& mask_op,
|
||||||
|
bool transpose_mat,
|
||||||
|
int bm,
|
||||||
|
int bn,
|
||||||
|
int sm,
|
||||||
|
int sn,
|
||||||
|
int tm,
|
||||||
|
int tn,
|
||||||
|
bool contiguous) {
|
||||||
|
const auto& lib_name = kernel_name;
|
||||||
|
auto lib = d.get_library(lib_name);
|
||||||
|
if (lib == nullptr) {
|
||||||
|
std::ostringstream kernel_source;
|
||||||
|
auto out_mask_type = mask_out.has_value()
|
||||||
|
? get_type_string((*mask_out).dtype())
|
||||||
|
: "nomask_t";
|
||||||
|
auto op_mask_type =
|
||||||
|
mask_op.has_value() ? get_type_string((*mask_op).dtype()) : "nomask_t";
|
||||||
|
kernel_source << metal::utils() << metal::gemv_masked()
|
||||||
|
<< fmt::format(
|
||||||
|
gemv_masked_kernel,
|
||||||
|
"name"_a = lib_name,
|
||||||
|
"itype"_a = get_type_string(out.dtype()),
|
||||||
|
"outm_t"_a = out_mask_type,
|
||||||
|
"opm_t"_a = op_mask_type,
|
||||||
|
"bm"_a = bm,
|
||||||
|
"bn"_a = bn,
|
||||||
|
"sm"_a = sm,
|
||||||
|
"sn"_a = sn,
|
||||||
|
"tm"_a = tm,
|
||||||
|
"tn"_a = tn,
|
||||||
|
"trans"_a = transpose_mat ? "t_" : "",
|
||||||
|
"nc"_a = contiguous ? "0" : "1");
|
||||||
|
lib = d.get_library(lib_name, kernel_source.str());
|
||||||
|
}
|
||||||
|
return d.get_kernel(kernel_name, lib);
|
||||||
|
}
|
||||||
|
|
||||||
MTL::ComputePipelineState* get_steel_conv_kernel(
|
MTL::ComputePipelineState* get_steel_conv_kernel(
|
||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
const std::string& kernel_name,
|
const std::string& kernel_name,
|
||||||
|
@ -151,6 +151,21 @@ MTL::ComputePipelineState* get_steel_conv_kernel(
|
|||||||
int n_channel_specialization,
|
int n_channel_specialization,
|
||||||
bool small_filter);
|
bool small_filter);
|
||||||
|
|
||||||
|
MTL::ComputePipelineState* get_gemv_masked_kernel(
|
||||||
|
metal::Device& d,
|
||||||
|
const std::string& kernel_name,
|
||||||
|
const array& out,
|
||||||
|
const std::optional<array>& mask_out,
|
||||||
|
const std::optional<array>& mask_op,
|
||||||
|
bool transpose_mat,
|
||||||
|
int bm,
|
||||||
|
int bn,
|
||||||
|
int sm,
|
||||||
|
int sn,
|
||||||
|
int tm,
|
||||||
|
int tn,
|
||||||
|
bool contiguous);
|
||||||
|
|
||||||
MTL::ComputePipelineState* get_steel_conv_general_kernel(
|
MTL::ComputePipelineState* get_steel_conv_general_kernel(
|
||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
const std::string& kernel_name,
|
const std::string& kernel_name,
|
||||||
|
@ -38,7 +38,6 @@ endfunction(build_kernel)
|
|||||||
build_kernel(arg_reduce)
|
build_kernel(arg_reduce)
|
||||||
build_kernel(conv steel/conv/params.h)
|
build_kernel(conv steel/conv/params.h)
|
||||||
build_kernel(gemv steel/utils.h)
|
build_kernel(gemv steel/utils.h)
|
||||||
build_kernel(gemv_masked steel/utils.h)
|
|
||||||
build_kernel(layer_norm)
|
build_kernel(layer_norm)
|
||||||
build_kernel(random)
|
build_kernel(random)
|
||||||
build_kernel(rms_norm)
|
build_kernel(rms_norm)
|
||||||
@ -121,6 +120,7 @@ build_kernel(
|
|||||||
steel/gemm/kernels/steel_gemm_splitk
|
steel/gemm/kernels/steel_gemm_splitk
|
||||||
${STEEL_HEADERS}
|
${STEEL_HEADERS}
|
||||||
)
|
)
|
||||||
|
build_kernel(gemv_masked steel/utils.h)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
|
||||||
|
819
mlx/backend/metal/kernels/gemv_masked.h
Normal file
819
mlx/backend/metal/kernels/gemv_masked.h
Normal file
@ -0,0 +1,819 @@
|
|||||||
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/kernels/steel/utils.h"
|
||||||
|
|
||||||
|
using namespace metal;
|
||||||
|
|
||||||
|
#define MLX_MTL_CONST static constant constexpr const
|
||||||
|
#define MLX_MTL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
|
||||||
|
|
||||||
|
struct _NoMask {
|
||||||
|
char x;
|
||||||
|
|
||||||
|
constexpr METAL_FUNC operator bool() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
constexpr METAL_FUNC operator bool() const threadgroup {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
constexpr METAL_FUNC operator bool() const device {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
constexpr METAL_FUNC operator bool() const constant {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef struct _NoMask nomask_t;
|
||||||
|
|
||||||
|
template <typename OutT, typename InT = OutT>
|
||||||
|
struct ScaleOp {
|
||||||
|
OutT scale;
|
||||||
|
|
||||||
|
METAL_FUNC OutT apply(InT x) const {
|
||||||
|
return static_cast<OutT>(x) * scale;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
typename out_mask_t,
|
||||||
|
typename op_mask_t,
|
||||||
|
const int BM, /* Threadgroup rows (in simdgroups) */
|
||||||
|
const int BN, /* Threadgroup cols (in simdgroups) */
|
||||||
|
const int SM, /* Simdgroup rows (in threads) */
|
||||||
|
const int SN, /* Simdgroup cols (in threads) */
|
||||||
|
const int TM, /* Thread rows (in elements) */
|
||||||
|
const int TN> /* Thread cols (in elements) */
|
||||||
|
struct GEMVKernel {
|
||||||
|
MLX_MTL_CONST int threadsM = BM * SM;
|
||||||
|
MLX_MTL_CONST int threadsN = BN * SN;
|
||||||
|
|
||||||
|
MLX_MTL_CONST int blockM = threadsM * TM;
|
||||||
|
MLX_MTL_CONST int blockN = threadsN * TN;
|
||||||
|
|
||||||
|
static_assert(SM * SN == 32, "simdgroup can only have 32 threads");
|
||||||
|
|
||||||
|
static_assert(
|
||||||
|
SN == 8 || SN == 16 || SN == 32,
|
||||||
|
"gemv block must have a width of 8, 16, or 32");
|
||||||
|
|
||||||
|
static_assert(blockN >= blockM, "Masked gemv must have blockN >= blockM");
|
||||||
|
|
||||||
|
MLX_MTL_CONST bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
|
||||||
|
MLX_MTL_CONST bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
|
||||||
|
|
||||||
|
MLX_MTL_CONST bool has_mul_operand_mask =
|
||||||
|
has_operand_mask && !metal::is_same_v<op_mask_t, bool>;
|
||||||
|
MLX_MTL_CONST bool has_mul_output_mask =
|
||||||
|
has_output_mask && !metal::is_same_v<out_mask_t, bool>;
|
||||||
|
|
||||||
|
// - The matrix of size (M = out_vec_size, K = in_vec_size) is divided up
|
||||||
|
// into blocks of (blockM, blockN) divided among threadgroups
|
||||||
|
// - Every thread works on a block of (TM, TN)
|
||||||
|
// - We assume each threadgroup has (threadsN, threadsM, 1) threads
|
||||||
|
//
|
||||||
|
// 1. A thread loads TN elements each from mat along TM rows
|
||||||
|
// and the corresponding scalar from the vector
|
||||||
|
// 2. The thread then multiplies and adds to accumulate its local result for
|
||||||
|
// the block
|
||||||
|
// 3. At the end, each thread has accumulated results over all blocks across
|
||||||
|
// the rows. These are then summed up across the threadgroup
|
||||||
|
// 4. Each threadgroup writes its accumulated blockM outputs
|
||||||
|
//
|
||||||
|
// Edge case handling:
|
||||||
|
// - The threadgroup with the largest tid has blocks that exceed the matrix
|
||||||
|
// * The blocks that start outside the matrix are never read (thread results
|
||||||
|
// remain zero)
|
||||||
|
// * The last thread that partially overlaps with the matrix is shifted
|
||||||
|
// inwards such that the thread block fits exactly in the matrix
|
||||||
|
|
||||||
|
MLX_MTL_CONST short tgp_mem_size = BN > 1 ? BN*(blockM + TM) : 0;
|
||||||
|
MLX_MTL_CONST bool needs_tgp_reduction = BN > 1;
|
||||||
|
|
||||||
|
static METAL_FUNC void
|
||||||
|
load_unsafe(const device T* src, thread T dst[TN], const int src_offset = 0) {
|
||||||
|
MLX_MTL_PRAGMA_UNROLL
|
||||||
|
for (int tn = 0; tn < TN; tn++) {
|
||||||
|
dst[tn] = src[src_offset + tn];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static METAL_FUNC void load_safe(
|
||||||
|
const device T* src,
|
||||||
|
thread T dst[TN],
|
||||||
|
const int src_offset = 0,
|
||||||
|
const int src_size = TN) {
|
||||||
|
if (src_offset + TN <= src_size) {
|
||||||
|
MLX_MTL_PRAGMA_UNROLL
|
||||||
|
for (int tn = 0; tn < TN; tn++) {
|
||||||
|
dst[tn] = src[src_offset + tn];
|
||||||
|
}
|
||||||
|
} else { // Edgecase
|
||||||
|
MLX_MTL_PRAGMA_UNROLL
|
||||||
|
for (int tn = 0; tn < TN; tn++) {
|
||||||
|
dst[tn] = src_offset + tn < src_size ? src[src_offset + tn] : 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static METAL_FUNC void run(
|
||||||
|
const device T* mat [[buffer(0)]],
|
||||||
|
const device T* in_vec [[buffer(1)]],
|
||||||
|
device T* out_vec [[buffer(3)]],
|
||||||
|
const constant int& in_vec_size [[buffer(4)]],
|
||||||
|
const constant int& out_vec_size [[buffer(5)]],
|
||||||
|
const constant int& matrix_ld [[buffer(6)]],
|
||||||
|
const device out_mask_t* out_mask [[buffer(20)]],
|
||||||
|
const device op_mask_t* mat_mask [[buffer(21)]],
|
||||||
|
const device op_mask_t* vec_mask [[buffer(22)]],
|
||||||
|
const constant int* mask_strides [[buffer(23)]],
|
||||||
|
threadgroup T* tgp_memory [[threadgroup(0)]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]],
|
||||||
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
|
// Appease compiler
|
||||||
|
(void)lid;
|
||||||
|
|
||||||
|
// Thread local accumulation results
|
||||||
|
thread T result[TM] = {0};
|
||||||
|
thread T inter[TN];
|
||||||
|
thread T v_coeff[TN];
|
||||||
|
|
||||||
|
const int thrM = SN != 32 ? simd_lid / SN : 0;
|
||||||
|
const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid);
|
||||||
|
|
||||||
|
const int sgN = BN != 1 ? (simd_gid % BN) : 0;
|
||||||
|
|
||||||
|
const int simdM = BN != 1 ? SM * (simd_gid / BN) : int(SM * simd_gid);
|
||||||
|
const int simdN = BN != 1 ? SN * (simd_gid % BN) : 0;
|
||||||
|
|
||||||
|
int bm = (simdM + thrM) * TM;
|
||||||
|
int bn = (simdN + thrN) * TN;
|
||||||
|
|
||||||
|
// Block position
|
||||||
|
int out_row = tid.x * blockM + bm;
|
||||||
|
|
||||||
|
// Exit simdgroup if rows out of bound
|
||||||
|
if (out_row >= out_vec_size)
|
||||||
|
return;
|
||||||
|
|
||||||
|
// Adjust tail simdgroup to ensure in bound reads
|
||||||
|
out_row = out_row + TM <= out_vec_size ? out_row : out_vec_size - TM;
|
||||||
|
|
||||||
|
// Prepare mask offsets
|
||||||
|
const constant int* out_mask_strides = mask_strides;
|
||||||
|
const constant int* mat_mask_strides =
|
||||||
|
mask_strides + (has_output_mask ? 2 : 0);
|
||||||
|
const constant int* vec_mask_strides =
|
||||||
|
mat_mask_strides + (has_operand_mask ? 2 : 0);
|
||||||
|
|
||||||
|
const int m_block_idx = blockN > blockM ? out_row / blockN : int(tid.x);
|
||||||
|
|
||||||
|
const int out_mask_offset =
|
||||||
|
!has_output_mask ? 0 : m_block_idx * out_mask_strides[1];
|
||||||
|
|
||||||
|
int mat_mask_offset =
|
||||||
|
!has_operand_mask ? 0 : m_block_idx * mat_mask_strides[1];
|
||||||
|
int vec_mask_offset = 0;
|
||||||
|
const int mat_mask_step = !has_operand_mask ? 0 : mat_mask_strides[0];
|
||||||
|
const int vec_mask_step = !has_operand_mask ? 0 : vec_mask_strides[1];
|
||||||
|
|
||||||
|
T out_scale{1};
|
||||||
|
|
||||||
|
// Check output mask
|
||||||
|
if (has_output_mask) {
|
||||||
|
auto mask_out = out_mask[out_mask_offset];
|
||||||
|
|
||||||
|
// Write zeros and return if mask is 0
|
||||||
|
if (!mask_out) {
|
||||||
|
if (simdN == 0 && thrN == 0) {
|
||||||
|
MLX_MTL_PRAGMA_UNROLL
|
||||||
|
for (int tm = 0; tm < TM; tm++) {
|
||||||
|
out_vec[out_row + tm] = T(0.);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store scalar if multiplicative mask
|
||||||
|
if (has_mul_output_mask) {
|
||||||
|
out_scale = T(mask_out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Advance matrix
|
||||||
|
mat += out_row * matrix_ld;
|
||||||
|
|
||||||
|
// Prepare for loop
|
||||||
|
constexpr const uniform<int> loop_stride = make_uniform(blockN);
|
||||||
|
const uniform<int> in_size = make_uniform(in_vec_size);
|
||||||
|
const uniform<int> n_iter = in_size / loop_stride;
|
||||||
|
const uniform<int> last_iter = loop_stride * n_iter;
|
||||||
|
const uniform<int> leftover = in_size - last_iter;
|
||||||
|
|
||||||
|
// Loop over in_vec in blocks of blockN
|
||||||
|
for (int i = 0; i < n_iter; ++i) {
|
||||||
|
if (!has_operand_mask ||
|
||||||
|
(bool(mat_mask[mat_mask_offset]) &&
|
||||||
|
bool(vec_mask[vec_mask_offset]))) {
|
||||||
|
T block_scale{1};
|
||||||
|
if (has_mul_operand_mask) {
|
||||||
|
block_scale =
|
||||||
|
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
|
||||||
|
}
|
||||||
|
|
||||||
|
load_unsafe(in_vec, v_coeff, bn);
|
||||||
|
|
||||||
|
// Apply scale
|
||||||
|
if (has_mul_operand_mask) {
|
||||||
|
MLX_MTL_PRAGMA_UNROLL
|
||||||
|
for (int tn = 0; tn < TN; tn++) {
|
||||||
|
v_coeff[tn] *= block_scale;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Per thread work loop
|
||||||
|
int mat_offset = 0;
|
||||||
|
MLX_MTL_PRAGMA_UNROLL
|
||||||
|
for (int tm = 0; tm < TM; tm++) {
|
||||||
|
// Load for the row
|
||||||
|
load_unsafe(mat, inter, mat_offset + bn);
|
||||||
|
|
||||||
|
// Accumulate results
|
||||||
|
MLX_MTL_PRAGMA_UNROLL
|
||||||
|
for (int tn = 0; tn < TN; tn++) {
|
||||||
|
result[tm] += inter[tn] * v_coeff[tn];
|
||||||
|
}
|
||||||
|
|
||||||
|
mat_offset += matrix_ld;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bn += blockN;
|
||||||
|
mat_mask_offset += mat_mask_step;
|
||||||
|
vec_mask_offset += vec_mask_step;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (leftover > 0 &&
|
||||||
|
(!has_operand_mask ||
|
||||||
|
(bool(mat_mask[mat_mask_offset]) &&
|
||||||
|
bool(vec_mask[vec_mask_offset])))) {
|
||||||
|
T block_scale{1};
|
||||||
|
if (has_mul_operand_mask) {
|
||||||
|
block_scale =
|
||||||
|
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
|
||||||
|
}
|
||||||
|
|
||||||
|
load_safe(in_vec, v_coeff, bn, in_size);
|
||||||
|
|
||||||
|
// Apply scale
|
||||||
|
if (has_mul_operand_mask) {
|
||||||
|
MLX_MTL_PRAGMA_UNROLL
|
||||||
|
for (int tn = 0; tn < TN; tn++) {
|
||||||
|
v_coeff[tn] *= block_scale;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Per thread work loop
|
||||||
|
MLX_MTL_PRAGMA_UNROLL
|
||||||
|
for (int tm = 0; tm < TM; tm++) {
|
||||||
|
// Load for the row
|
||||||
|
load_safe(&mat[tm * matrix_ld], inter, bn, in_size);
|
||||||
|
|
||||||
|
// Accumulate results
|
||||||
|
MLX_MTL_PRAGMA_UNROLL
|
||||||
|
for (int tn = 0; tn < TN; tn++) {
|
||||||
|
result[tm] += inter[tn] * v_coeff[tn];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply out scale
|
||||||
|
if (has_mul_output_mask) {
|
||||||
|
MLX_MTL_PRAGMA_UNROLL
|
||||||
|
for (int tm = 0; tm < TM; tm++) {
|
||||||
|
result[tm] *= out_scale;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simdgroup accumulations
|
||||||
|
MLX_MTL_PRAGMA_UNROLL
|
||||||
|
for (int tm = 0; tm < TM; tm++) {
|
||||||
|
MLX_MTL_PRAGMA_UNROLL
|
||||||
|
for (ushort sn = (SN / 2); sn >= 1; sn >>= 1) {
|
||||||
|
result[tm] += simd_shuffle_down(result[tm], sn);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Threadgroup accumulation results
|
||||||
|
if (needs_tgp_reduction) {
|
||||||
|
threadgroup T* tgp_results = tgp_memory + sgN * (blockM + TM) + bm;
|
||||||
|
if (thrN == 0) {
|
||||||
|
MLX_MTL_PRAGMA_UNROLL
|
||||||
|
for (int tm = 0; tm < TM; tm++) {
|
||||||
|
tgp_results[tm] = result[tm];
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
|
if (sgN == 0) {
|
||||||
|
MLX_MTL_PRAGMA_UNROLL
|
||||||
|
for (int sgn = 1; sgn < BN; sgn++) {
|
||||||
|
MLX_MTL_PRAGMA_UNROLL
|
||||||
|
for (int tm = 0; tm < TM; tm++) {
|
||||||
|
result[tm] += tgp_results[sgn * (blockM + TM) + tm];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write outputs
|
||||||
|
if (simdN == 0 && thrN == 0) {
|
||||||
|
MLX_MTL_PRAGMA_UNROLL
|
||||||
|
for (int tm = 0; tm < TM; tm++) {
|
||||||
|
out_vec[out_row + tm] = result[tm];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
/// Vector matrix multiplication
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
typename out_mask_t,
|
||||||
|
typename op_mask_t,
|
||||||
|
const int BM, /* Threadgroup rows (in simdgroups) */
|
||||||
|
const int BN, /* Threadgroup cols (in simdgroups) */
|
||||||
|
const int SM, /* Simdgroup rows (in threads) */
|
||||||
|
const int SN, /* Simdgroup cols (in threads) */
|
||||||
|
const int TM, /* Thread rows (in elements) */
|
||||||
|
const int TN> /* Thread cols (in elements) */
|
||||||
|
struct GEMVTKernel {
|
||||||
|
MLX_MTL_CONST int threadsM = BM * SM;
|
||||||
|
MLX_MTL_CONST int threadsN = BN * SN;
|
||||||
|
|
||||||
|
MLX_MTL_CONST int blockM = threadsM * TM;
|
||||||
|
MLX_MTL_CONST int blockN = threadsN * TN;
|
||||||
|
|
||||||
|
static_assert(SM * SN == 32, "simdgroup can only have 32 threads");
|
||||||
|
|
||||||
|
MLX_MTL_CONST bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
|
||||||
|
MLX_MTL_CONST bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
|
||||||
|
|
||||||
|
MLX_MTL_CONST bool has_mul_operand_mask =
|
||||||
|
has_operand_mask && !metal::is_same_v<op_mask_t, bool>;
|
||||||
|
MLX_MTL_CONST bool has_mul_output_mask =
|
||||||
|
has_output_mask && !metal::is_same_v<out_mask_t, bool>;
|
||||||
|
|
||||||
|
// - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up
|
||||||
|
// into blocks of (blockM, blockN) divided among threadgroups
|
||||||
|
// - Every thread works on a block of (TM, TN)
|
||||||
|
// - We assume each threadgroup has (threadsN, threadsM, 1) threads
|
||||||
|
//
|
||||||
|
// 1. A thread loads TN elements each from mat along TM contiguous rows
|
||||||
|
// and the corresponding scalar from the vector
|
||||||
|
// 2. The thread then accumulates its local result for the block
|
||||||
|
// 3. At the end, each thread has accumulated results over all blocks across
|
||||||
|
// the rows. These are then summed up across the threadgroup
|
||||||
|
// 4. Each threadgroup writes its accumulated BN * TN outputs
|
||||||
|
//
|
||||||
|
// Edge case handling:
|
||||||
|
// - The threadgroup with the largest tid has blocks that exceed the matrix
|
||||||
|
// * The blocks that start outside the matrix are never read (thread results
|
||||||
|
// remain zero)
|
||||||
|
// * The last thread that partially overlaps with the matrix is shifted
|
||||||
|
// inwards such that the thread block fits exactly in the matrix
|
||||||
|
|
||||||
|
MLX_MTL_CONST short tgp_mem_size = BM > 1 ? BM*(blockN + TN) : 0;
|
||||||
|
MLX_MTL_CONST bool needs_tgp_reduction = BM > 1;
|
||||||
|
|
||||||
|
static METAL_FUNC void run(
|
||||||
|
const device T* mat [[buffer(0)]],
|
||||||
|
const device T* in_vec [[buffer(1)]],
|
||||||
|
device T* out_vec [[buffer(3)]],
|
||||||
|
const constant int& in_vec_size [[buffer(4)]],
|
||||||
|
const constant int& out_vec_size [[buffer(5)]],
|
||||||
|
const constant int& marix_ld [[buffer(6)]],
|
||||||
|
const device out_mask_t* out_mask [[buffer(20)]],
|
||||||
|
const device op_mask_t* mat_mask [[buffer(21)]],
|
||||||
|
const device op_mask_t* vec_mask [[buffer(22)]],
|
||||||
|
const constant int* mask_strides [[buffer(23)]],
|
||||||
|
threadgroup T* tgp_memory [[threadgroup(0)]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]],
|
||||||
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
|
// Appease compiler
|
||||||
|
(void)lid;
|
||||||
|
|
||||||
|
// Thread local accumulation results
|
||||||
|
T result[TN] = {0};
|
||||||
|
T inter[TN];
|
||||||
|
T v_coeff[TM];
|
||||||
|
|
||||||
|
const int thrM = SN != 32 ? simd_lid / SN : 0;
|
||||||
|
const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid);
|
||||||
|
|
||||||
|
const int sgM = BN != 1 ? (simd_gid / BN) : int(simd_gid);
|
||||||
|
const int sgN = BN != 1 ? (simd_gid % BN) : 0;
|
||||||
|
|
||||||
|
const int simdM = SM * sgM;
|
||||||
|
const int simdN = SN * sgN;
|
||||||
|
|
||||||
|
int cm = (simdM + thrM);
|
||||||
|
int cn = (simdN + thrN);
|
||||||
|
|
||||||
|
int bm = cm * TM;
|
||||||
|
int bn = cn * TN;
|
||||||
|
|
||||||
|
int out_col = tid.x * blockN + bn;
|
||||||
|
|
||||||
|
// Prepare mask offsets
|
||||||
|
const constant int* out_mask_strides = mask_strides;
|
||||||
|
const constant int* mat_mask_strides =
|
||||||
|
out_mask_strides + (has_output_mask ? 2 : 0);
|
||||||
|
const constant int* vec_mask_strides =
|
||||||
|
mat_mask_strides + (has_operand_mask ? 2 : 0);
|
||||||
|
|
||||||
|
const int n_block_idx = blockM > blockN ? out_col / blockM : int(tid.x);
|
||||||
|
|
||||||
|
const int out_mask_offset =
|
||||||
|
!has_output_mask ? 0 : n_block_idx; // * out_mask_strides[0];
|
||||||
|
|
||||||
|
int mat_mask_offset =
|
||||||
|
!has_operand_mask ? 0 : n_block_idx * mat_mask_strides[0];
|
||||||
|
int vec_mask_offset = 0;
|
||||||
|
const int mat_mask_step = !has_operand_mask ? 0 : mat_mask_strides[1];
|
||||||
|
const int vec_mask_step = !has_operand_mask ? 0 : vec_mask_strides[0];
|
||||||
|
|
||||||
|
T out_scale{1};
|
||||||
|
|
||||||
|
// Check output mask
|
||||||
|
if (has_output_mask) {
|
||||||
|
auto mask_out = out_mask[out_mask_offset];
|
||||||
|
|
||||||
|
// Write zeros and return if mask is 0
|
||||||
|
if (!mask_out) {
|
||||||
|
if (cm == 0 && out_col < out_vec_size) {
|
||||||
|
if (out_col + TN <= out_vec_size) {
|
||||||
|
MLX_MTL_PRAGMA_UNROLL
|
||||||
|
for (int tn = 0; tn < TN; tn++) {
|
||||||
|
out_vec[out_col + tn] = T(0.);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int tn = 0; tn < TN && (out_col + tn) < out_vec_size; tn++) {
|
||||||
|
out_vec[out_col + tn] = T(0.);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store scalar if multiplicative mask
|
||||||
|
if (has_mul_output_mask) {
|
||||||
|
out_scale = T(mask_out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare for loop
|
||||||
|
constexpr const uniform<int> loop_stride = make_uniform(blockM);
|
||||||
|
const uniform<int> in_size = make_uniform(in_vec_size);
|
||||||
|
const uniform<int> n_iter = in_size / loop_stride;
|
||||||
|
const uniform<int> last_iter = loop_stride * n_iter;
|
||||||
|
const uniform<int> leftover = in_size - last_iter;
|
||||||
|
|
||||||
|
// Edgecase handling
|
||||||
|
if (out_col < out_vec_size) {
|
||||||
|
out_col = (out_col + TN) <= out_vec_size ? out_col : out_vec_size - TN;
|
||||||
|
|
||||||
|
// Per thread accumulation main loop
|
||||||
|
for (int i = 0; i < n_iter; ++i) {
|
||||||
|
// Adding a threadgroup_barrier improves performance slightly
|
||||||
|
// This is possibly it may help exploit cache better
|
||||||
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
|
if (!has_operand_mask ||
|
||||||
|
(bool(mat_mask[mat_mask_offset]) &&
|
||||||
|
bool(vec_mask[vec_mask_offset]))) {
|
||||||
|
T block_scale{1};
|
||||||
|
if (has_mul_operand_mask) {
|
||||||
|
block_scale =
|
||||||
|
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
|
||||||
|
}
|
||||||
|
|
||||||
|
MLX_MTL_PRAGMA_UNROLL
|
||||||
|
for (int tm = 0; tm < TM; tm++) {
|
||||||
|
v_coeff[tm] = in_vec[bm + tm];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply scale
|
||||||
|
if (has_mul_operand_mask) {
|
||||||
|
MLX_MTL_PRAGMA_UNROLL
|
||||||
|
for (int tm = 0; tm < TM; tm++) {
|
||||||
|
v_coeff[tm] *= block_scale;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
MLX_MTL_PRAGMA_UNROLL
|
||||||
|
for (int tm = 0; tm < TM; tm++) {
|
||||||
|
for (int tn = 0; tn < TN; tn++) {
|
||||||
|
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
|
||||||
|
}
|
||||||
|
for (int tn = 0; tn < TN; tn++) {
|
||||||
|
result[tn] += v_coeff[tm] * inter[tn];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bm += blockM;
|
||||||
|
mat_mask_offset += mat_mask_step;
|
||||||
|
vec_mask_offset += vec_mask_step;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (leftover > 0 &&
|
||||||
|
(!has_operand_mask ||
|
||||||
|
(bool(mat_mask[mat_mask_offset]) &&
|
||||||
|
bool(vec_mask[vec_mask_offset])))) {
|
||||||
|
T block_scale{1};
|
||||||
|
if (has_mul_operand_mask) {
|
||||||
|
block_scale =
|
||||||
|
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int tm = 0; tm < TM && bm + tm < in_vec_size; tm++) {
|
||||||
|
v_coeff[tm] = in_vec[bm + tm];
|
||||||
|
|
||||||
|
if (has_mul_operand_mask) {
|
||||||
|
v_coeff[tm] *= block_scale;
|
||||||
|
}
|
||||||
|
|
||||||
|
MLX_MTL_PRAGMA_UNROLL
|
||||||
|
for (int tn = 0; tn < TN; tn++) {
|
||||||
|
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
|
||||||
|
}
|
||||||
|
|
||||||
|
MLX_MTL_PRAGMA_UNROLL
|
||||||
|
for (int tn = 0; tn < TN; tn++) {
|
||||||
|
result[tn] += v_coeff[tm] * inter[tn];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply out scale
|
||||||
|
if (has_mul_output_mask) {
|
||||||
|
MLX_MTL_PRAGMA_UNROLL
|
||||||
|
for (int tn = 0; tn < TN; tn++) {
|
||||||
|
result[tn] *= out_scale;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simdgroup accumulations
|
||||||
|
MLX_MTL_PRAGMA_UNROLL
|
||||||
|
for (int tn = 0; tn < TN; tn++) {
|
||||||
|
MLX_MTL_PRAGMA_UNROLL
|
||||||
|
for (ushort sm = (SM / 2); sm >= 1; sm >>= 1) {
|
||||||
|
result[tn] += simd_shuffle_down(result[tn], SN * sm);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Threadgroup accumulation results
|
||||||
|
if (needs_tgp_reduction) {
|
||||||
|
threadgroup T* tgp_results = tgp_memory + sgM * (blockN + TN) + bn;
|
||||||
|
if (thrM == 0) {
|
||||||
|
MLX_MTL_PRAGMA_UNROLL
|
||||||
|
for (int tn = 0; tn < TN; tn++) {
|
||||||
|
tgp_results[tn] = result[tn];
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
|
if (sgM == 0) {
|
||||||
|
MLX_MTL_PRAGMA_UNROLL
|
||||||
|
for (int sgm = 1; sgm < BM; sgm++) {
|
||||||
|
MLX_MTL_PRAGMA_UNROLL
|
||||||
|
for (int tn = 0; tn < TN; tn++) {
|
||||||
|
result[tn] += tgp_results[sgm * (blockN + TN) + tn];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Threadgroup accumulation and writing out results
|
||||||
|
if (cm == 0 && out_col < out_vec_size) {
|
||||||
|
MLX_MTL_PRAGMA_UNROLL
|
||||||
|
for (int j = 0; j < TN; j++) {
|
||||||
|
out_vec[out_col + j] = result[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
/// Matrix vector multiplication
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
typename out_mask_t,
|
||||||
|
typename op_mask_t,
|
||||||
|
const int BM, /* Threadgroup rows (in simdgroups) */
|
||||||
|
const int BN, /* Threadgroup cols (in simdgroups) */
|
||||||
|
const int SM, /* Simdgroup rows (in threads) */
|
||||||
|
const int SN, /* Simdgroup cols (in threads) */
|
||||||
|
const int TM, /* Thread rows (in elements) */
|
||||||
|
const int TN, /* Thread cols (in elements) */
|
||||||
|
const bool kDoNCBatch> /* Batch ndim > 1 */
|
||||||
|
[[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]] void gemv_masked(
|
||||||
|
const device T* mat [[buffer(0)]],
|
||||||
|
const device T* in_vec [[buffer(1)]],
|
||||||
|
device T* out_vec [[buffer(3)]],
|
||||||
|
const constant int& in_vec_size [[buffer(4)]],
|
||||||
|
const constant int& out_vec_size [[buffer(5)]],
|
||||||
|
const constant int& marix_ld [[buffer(6)]],
|
||||||
|
const constant int& batch_ndim [[buffer(9)]],
|
||||||
|
const constant int* batch_shape [[buffer(10)]],
|
||||||
|
const constant size_t* vector_batch_stride [[buffer(11)]],
|
||||||
|
const constant size_t* matrix_batch_stride [[buffer(12)]],
|
||||||
|
const device out_mask_t* out_mask [[buffer(20)]],
|
||||||
|
const device op_mask_t* mat_mask [[buffer(21)]],
|
||||||
|
const device op_mask_t* vec_mask [[buffer(22)]],
|
||||||
|
const constant int* mask_strides [[buffer(23)]],
|
||||||
|
const constant size_t* mask_batch_strides [[buffer(24)]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]],
|
||||||
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
|
using gemv_kernel =
|
||||||
|
GEMVKernel<T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN>;
|
||||||
|
threadgroup T tgp_memory
|
||||||
|
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
|
||||||
|
|
||||||
|
constexpr bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
|
||||||
|
constexpr bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
|
||||||
|
|
||||||
|
// Update batch offsets
|
||||||
|
if (kDoNCBatch) {
|
||||||
|
in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim);
|
||||||
|
mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim);
|
||||||
|
|
||||||
|
if (has_output_mask) {
|
||||||
|
out_mask +=
|
||||||
|
elem_to_loc(tid.z, batch_shape, mask_batch_strides, batch_ndim);
|
||||||
|
mask_batch_strides += batch_ndim;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (has_operand_mask) {
|
||||||
|
const constant size_t* mask_strides_mat = mask_batch_strides;
|
||||||
|
const constant size_t* mask_strides_vec = mask_strides_mat + batch_ndim;
|
||||||
|
|
||||||
|
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||||
|
tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim);
|
||||||
|
|
||||||
|
mat_mask += batch_offsets.x;
|
||||||
|
vec_mask += batch_offsets.y;
|
||||||
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
|
in_vec += tid.z * vector_batch_stride[0];
|
||||||
|
mat += tid.z * matrix_batch_stride[0];
|
||||||
|
|
||||||
|
if (has_output_mask) {
|
||||||
|
out_mask += tid.z * mask_batch_strides[0];
|
||||||
|
mask_batch_strides += batch_ndim;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (has_operand_mask) {
|
||||||
|
mat_mask += tid.z * mask_batch_strides[0];
|
||||||
|
vec_mask += tid.z * mask_batch_strides[batch_ndim];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
out_vec += tid.z * out_vec_size;
|
||||||
|
|
||||||
|
gemv_kernel::run(
|
||||||
|
mat,
|
||||||
|
in_vec,
|
||||||
|
out_vec,
|
||||||
|
in_vec_size,
|
||||||
|
out_vec_size,
|
||||||
|
marix_ld,
|
||||||
|
out_mask,
|
||||||
|
mat_mask,
|
||||||
|
vec_mask,
|
||||||
|
mask_strides,
|
||||||
|
gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory,
|
||||||
|
tid,
|
||||||
|
lid,
|
||||||
|
simd_gid,
|
||||||
|
simd_lid);
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
/// Vector matrix multiplication
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
typename out_mask_t,
|
||||||
|
typename op_mask_t,
|
||||||
|
const int BM, /* Threadgroup rows (in simdgroups) */
|
||||||
|
const int BN, /* Threadgroup cols (in simdgroups) */
|
||||||
|
const int SM, /* Simdgroup rows (in threads) */
|
||||||
|
const int SN, /* Simdgroup cols (in threads) */
|
||||||
|
const int TM, /* Thread rows (in elements) */
|
||||||
|
const int TN, /* Thread cols (in elements) */
|
||||||
|
const bool kDoNCBatch> /* Batch ndim > 1 */
|
||||||
|
[[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]] void gemv_t_masked(
|
||||||
|
const device T* mat [[buffer(0)]],
|
||||||
|
const device T* in_vec [[buffer(1)]],
|
||||||
|
device T* out_vec [[buffer(3)]],
|
||||||
|
const constant int& in_vec_size [[buffer(4)]],
|
||||||
|
const constant int& out_vec_size [[buffer(5)]],
|
||||||
|
const constant int& marix_ld [[buffer(6)]],
|
||||||
|
const constant int& batch_ndim [[buffer(9)]],
|
||||||
|
const constant int* batch_shape [[buffer(10)]],
|
||||||
|
const constant size_t* vector_batch_stride [[buffer(11)]],
|
||||||
|
const constant size_t* matrix_batch_stride [[buffer(12)]],
|
||||||
|
const device out_mask_t* out_mask [[buffer(20)]],
|
||||||
|
const device op_mask_t* mat_mask [[buffer(21)]],
|
||||||
|
const device op_mask_t* vec_mask [[buffer(22)]],
|
||||||
|
const constant int* mask_strides [[buffer(23)]],
|
||||||
|
const constant size_t* mask_batch_strides [[buffer(24)]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]],
|
||||||
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
|
using gemv_kernel =
|
||||||
|
GEMVTKernel<T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN>;
|
||||||
|
threadgroup T tgp_memory
|
||||||
|
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
|
||||||
|
|
||||||
|
constexpr bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
|
||||||
|
constexpr bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
|
||||||
|
|
||||||
|
// Update batch offsets
|
||||||
|
if (kDoNCBatch) {
|
||||||
|
in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim);
|
||||||
|
mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim);
|
||||||
|
|
||||||
|
if (has_output_mask) {
|
||||||
|
out_mask +=
|
||||||
|
elem_to_loc(tid.z, batch_shape, mask_batch_strides, batch_ndim);
|
||||||
|
mask_batch_strides += batch_ndim;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (has_operand_mask) {
|
||||||
|
const constant size_t* mask_strides_mat = mask_batch_strides;
|
||||||
|
const constant size_t* mask_strides_vec = mask_strides_mat + batch_ndim;
|
||||||
|
|
||||||
|
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||||
|
tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim);
|
||||||
|
|
||||||
|
mat_mask += batch_offsets.x;
|
||||||
|
vec_mask += batch_offsets.y;
|
||||||
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
|
in_vec += tid.z * vector_batch_stride[0];
|
||||||
|
mat += tid.z * matrix_batch_stride[0];
|
||||||
|
|
||||||
|
if (has_output_mask) {
|
||||||
|
out_mask += tid.z * mask_batch_strides[0];
|
||||||
|
mask_batch_strides += batch_ndim;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (has_operand_mask) {
|
||||||
|
mat_mask += tid.z * mask_batch_strides[0];
|
||||||
|
vec_mask += tid.z * mask_batch_strides[batch_ndim];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
out_vec += tid.z * out_vec_size;
|
||||||
|
|
||||||
|
gemv_kernel::run(
|
||||||
|
mat,
|
||||||
|
in_vec,
|
||||||
|
out_vec,
|
||||||
|
in_vec_size,
|
||||||
|
out_vec_size,
|
||||||
|
marix_ld,
|
||||||
|
out_mask,
|
||||||
|
mat_mask,
|
||||||
|
vec_mask,
|
||||||
|
mask_strides,
|
||||||
|
gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory,
|
||||||
|
tid,
|
||||||
|
lid,
|
||||||
|
simd_gid,
|
||||||
|
simd_lid);
|
||||||
|
}
|
@ -1,5 +1,6 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#include <metal_simdgroup>
|
#include <metal_simdgroup>
|
||||||
#include <metal_stdlib>
|
#include <metal_stdlib>
|
||||||
|
|
||||||
@ -7,726 +8,7 @@
|
|||||||
#include "mlx/backend/metal/kernels/defines.h"
|
#include "mlx/backend/metal/kernels/defines.h"
|
||||||
#include "mlx/backend/metal/kernels/utils.h"
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/steel/utils.h"
|
#include "mlx/backend/metal/kernels/gemv_masked.h"
|
||||||
|
|
||||||
using namespace metal;
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
/// Matrix vector multiplication
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
#define MLX_MTL_CONST static constant constexpr const
|
|
||||||
|
|
||||||
struct _NoMask {
|
|
||||||
char x;
|
|
||||||
|
|
||||||
constexpr METAL_FUNC operator bool() {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
constexpr METAL_FUNC operator bool() const threadgroup {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
constexpr METAL_FUNC operator bool() const device {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
constexpr METAL_FUNC operator bool() const constant {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
typedef struct _NoMask nomask_t;
|
|
||||||
|
|
||||||
template <typename OutT, typename InT = OutT>
|
|
||||||
struct ScaleOp {
|
|
||||||
OutT scale;
|
|
||||||
|
|
||||||
METAL_FUNC OutT apply(InT x) const {
|
|
||||||
return static_cast<OutT>(x) * scale;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <
|
|
||||||
typename T,
|
|
||||||
typename out_mask_t,
|
|
||||||
typename op_mask_t,
|
|
||||||
const int BM, /* Threadgroup rows (in simdgroups) */
|
|
||||||
const int BN, /* Threadgroup cols (in simdgroups) */
|
|
||||||
const int SM, /* Simdgroup rows (in threads) */
|
|
||||||
const int SN, /* Simdgroup cols (in threads) */
|
|
||||||
const int TM, /* Thread rows (in elements) */
|
|
||||||
const int TN> /* Thread cols (in elements) */
|
|
||||||
struct GEMVKernel {
|
|
||||||
MLX_MTL_CONST int threadsM = BM * SM;
|
|
||||||
MLX_MTL_CONST int threadsN = BN * SN;
|
|
||||||
|
|
||||||
MLX_MTL_CONST int blockM = threadsM * TM;
|
|
||||||
MLX_MTL_CONST int blockN = threadsN * TN;
|
|
||||||
|
|
||||||
static_assert(SM * SN == 32, "simdgroup can only have 32 threads");
|
|
||||||
|
|
||||||
static_assert(
|
|
||||||
SN == 8 || SN == 16 || SN == 32,
|
|
||||||
"gemv block must have a width of 8, 16, or 32");
|
|
||||||
|
|
||||||
static_assert(blockN >= blockM, "Masked gemv must have blockN >= blockM");
|
|
||||||
|
|
||||||
MLX_MTL_CONST bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
|
|
||||||
MLX_MTL_CONST bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
|
|
||||||
|
|
||||||
MLX_MTL_CONST bool has_mul_operand_mask =
|
|
||||||
has_operand_mask && !metal::is_same_v<op_mask_t, bool>;
|
|
||||||
MLX_MTL_CONST bool has_mul_output_mask =
|
|
||||||
has_output_mask && !metal::is_same_v<out_mask_t, bool>;
|
|
||||||
|
|
||||||
// - The matrix of size (M = out_vec_size, K = in_vec_size) is divided up
|
|
||||||
// into blocks of (blockM, blockN) divided among threadgroups
|
|
||||||
// - Every thread works on a block of (TM, TN)
|
|
||||||
// - We assume each threadgroup has (threadsN, threadsM, 1) threads
|
|
||||||
//
|
|
||||||
// 1. A thread loads TN elements each from mat along TM rows
|
|
||||||
// and the corresponding scalar from the vector
|
|
||||||
// 2. The thread then multiplies and adds to accumulate its local result for
|
|
||||||
// the block
|
|
||||||
// 3. At the end, each thread has accumulated results over all blocks across
|
|
||||||
// the rows. These are then summed up across the threadgroup
|
|
||||||
// 4. Each threadgroup writes its accumulated blockM outputs
|
|
||||||
//
|
|
||||||
// Edge case handling:
|
|
||||||
// - The threadgroup with the largest tid has blocks that exceed the matrix
|
|
||||||
// * The blocks that start outside the matrix are never read (thread results
|
|
||||||
// remain zero)
|
|
||||||
// * The last thread that partially overlaps with the matrix is shifted
|
|
||||||
// inwards such that the thread block fits exactly in the matrix
|
|
||||||
|
|
||||||
MLX_MTL_CONST short tgp_mem_size = BN > 1 ? BN*(blockM + TM) : 0;
|
|
||||||
MLX_MTL_CONST bool needs_tgp_reduction = BN > 1;
|
|
||||||
|
|
||||||
static METAL_FUNC void
|
|
||||||
load_unsafe(const device T* src, thread T dst[TN], const int src_offset = 0) {
|
|
||||||
MLX_MTL_PRAGMA_UNROLL
|
|
||||||
for (int tn = 0; tn < TN; tn++) {
|
|
||||||
dst[tn] = src[src_offset + tn];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static METAL_FUNC void load_safe(
|
|
||||||
const device T* src,
|
|
||||||
thread T dst[TN],
|
|
||||||
const int src_offset = 0,
|
|
||||||
const int src_size = TN) {
|
|
||||||
if (src_offset + TN <= src_size) {
|
|
||||||
MLX_MTL_PRAGMA_UNROLL
|
|
||||||
for (int tn = 0; tn < TN; tn++) {
|
|
||||||
dst[tn] = src[src_offset + tn];
|
|
||||||
}
|
|
||||||
} else { // Edgecase
|
|
||||||
MLX_MTL_PRAGMA_UNROLL
|
|
||||||
for (int tn = 0; tn < TN; tn++) {
|
|
||||||
dst[tn] = src_offset + tn < src_size ? src[src_offset + tn] : 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static METAL_FUNC void run(
|
|
||||||
const device T* mat [[buffer(0)]],
|
|
||||||
const device T* in_vec [[buffer(1)]],
|
|
||||||
device T* out_vec [[buffer(3)]],
|
|
||||||
const constant int& in_vec_size [[buffer(4)]],
|
|
||||||
const constant int& out_vec_size [[buffer(5)]],
|
|
||||||
const constant int& matrix_ld [[buffer(6)]],
|
|
||||||
const device out_mask_t* out_mask [[buffer(20)]],
|
|
||||||
const device op_mask_t* mat_mask [[buffer(21)]],
|
|
||||||
const device op_mask_t* vec_mask [[buffer(22)]],
|
|
||||||
const constant int* mask_strides [[buffer(23)]],
|
|
||||||
threadgroup T* tgp_memory [[threadgroup(0)]],
|
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
|
||||||
uint3 lid [[thread_position_in_threadgroup]],
|
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
||||||
// Appease compiler
|
|
||||||
(void)lid;
|
|
||||||
|
|
||||||
// Thread local accumulation results
|
|
||||||
thread T result[TM] = {0};
|
|
||||||
thread T inter[TN];
|
|
||||||
thread T v_coeff[TN];
|
|
||||||
|
|
||||||
const int thrM = SN != 32 ? simd_lid / SN : 0;
|
|
||||||
const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid);
|
|
||||||
|
|
||||||
const int sgN = BN != 1 ? (simd_gid % BN) : 0;
|
|
||||||
|
|
||||||
const int simdM = BN != 1 ? SM * (simd_gid / BN) : int(SM * simd_gid);
|
|
||||||
const int simdN = BN != 1 ? SN * (simd_gid % BN) : 0;
|
|
||||||
|
|
||||||
int bm = (simdM + thrM) * TM;
|
|
||||||
int bn = (simdN + thrN) * TN;
|
|
||||||
|
|
||||||
// Block position
|
|
||||||
int out_row = tid.x * blockM + bm;
|
|
||||||
|
|
||||||
// Exit simdgroup if rows out of bound
|
|
||||||
if (out_row >= out_vec_size)
|
|
||||||
return;
|
|
||||||
|
|
||||||
// Adjust tail simdgroup to ensure in bound reads
|
|
||||||
out_row = out_row + TM <= out_vec_size ? out_row : out_vec_size - TM;
|
|
||||||
|
|
||||||
// Prepare mask offsets
|
|
||||||
const constant int* out_mask_strides = mask_strides;
|
|
||||||
const constant int* mat_mask_strides =
|
|
||||||
mask_strides + (has_output_mask ? 2 : 0);
|
|
||||||
const constant int* vec_mask_strides =
|
|
||||||
mat_mask_strides + (has_operand_mask ? 2 : 0);
|
|
||||||
|
|
||||||
const int m_block_idx = blockN > blockM ? out_row / blockN : int(tid.x);
|
|
||||||
|
|
||||||
const int out_mask_offset =
|
|
||||||
!has_output_mask ? 0 : m_block_idx * out_mask_strides[1];
|
|
||||||
|
|
||||||
int mat_mask_offset =
|
|
||||||
!has_operand_mask ? 0 : m_block_idx * mat_mask_strides[1];
|
|
||||||
int vec_mask_offset = 0;
|
|
||||||
const int mat_mask_step = !has_operand_mask ? 0 : mat_mask_strides[0];
|
|
||||||
const int vec_mask_step = !has_operand_mask ? 0 : vec_mask_strides[1];
|
|
||||||
|
|
||||||
T out_scale{1};
|
|
||||||
|
|
||||||
// Check output mask
|
|
||||||
if (has_output_mask) {
|
|
||||||
auto mask_out = out_mask[out_mask_offset];
|
|
||||||
|
|
||||||
// Write zeros and return if mask is 0
|
|
||||||
if (!mask_out) {
|
|
||||||
if (simdN == 0 && thrN == 0) {
|
|
||||||
MLX_MTL_PRAGMA_UNROLL
|
|
||||||
for (int tm = 0; tm < TM; tm++) {
|
|
||||||
out_vec[out_row + tm] = T(0.);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Store scalar if multiplicative mask
|
|
||||||
if (has_mul_output_mask) {
|
|
||||||
out_scale = T(mask_out);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Advance matrix
|
|
||||||
mat += out_row * matrix_ld;
|
|
||||||
|
|
||||||
// Prepare for loop
|
|
||||||
constexpr const uniform<int> loop_stride = make_uniform(blockN);
|
|
||||||
const uniform<int> in_size = make_uniform(in_vec_size);
|
|
||||||
const uniform<int> n_iter = in_size / loop_stride;
|
|
||||||
const uniform<int> last_iter = loop_stride * n_iter;
|
|
||||||
const uniform<int> leftover = in_size - last_iter;
|
|
||||||
|
|
||||||
// Loop over in_vec in blocks of blockN
|
|
||||||
for (int i = 0; i < n_iter; ++i) {
|
|
||||||
if (!has_operand_mask ||
|
|
||||||
(bool(mat_mask[mat_mask_offset]) &&
|
|
||||||
bool(vec_mask[vec_mask_offset]))) {
|
|
||||||
T block_scale{1};
|
|
||||||
if (has_mul_operand_mask) {
|
|
||||||
block_scale =
|
|
||||||
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
|
|
||||||
}
|
|
||||||
|
|
||||||
load_unsafe(in_vec, v_coeff, bn);
|
|
||||||
|
|
||||||
// Apply scale
|
|
||||||
if (has_mul_operand_mask) {
|
|
||||||
MLX_MTL_PRAGMA_UNROLL
|
|
||||||
for (int tn = 0; tn < TN; tn++) {
|
|
||||||
v_coeff[tn] *= block_scale;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Per thread work loop
|
|
||||||
int mat_offset = 0;
|
|
||||||
MLX_MTL_PRAGMA_UNROLL
|
|
||||||
for (int tm = 0; tm < TM; tm++) {
|
|
||||||
// Load for the row
|
|
||||||
load_unsafe(mat, inter, mat_offset + bn);
|
|
||||||
|
|
||||||
// Accumulate results
|
|
||||||
MLX_MTL_PRAGMA_UNROLL
|
|
||||||
for (int tn = 0; tn < TN; tn++) {
|
|
||||||
result[tm] += inter[tn] * v_coeff[tn];
|
|
||||||
}
|
|
||||||
|
|
||||||
mat_offset += matrix_ld;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bn += blockN;
|
|
||||||
mat_mask_offset += mat_mask_step;
|
|
||||||
vec_mask_offset += vec_mask_step;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (leftover > 0 &&
|
|
||||||
(!has_operand_mask ||
|
|
||||||
(bool(mat_mask[mat_mask_offset]) &&
|
|
||||||
bool(vec_mask[vec_mask_offset])))) {
|
|
||||||
T block_scale{1};
|
|
||||||
if (has_mul_operand_mask) {
|
|
||||||
block_scale =
|
|
||||||
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
|
|
||||||
}
|
|
||||||
|
|
||||||
load_safe(in_vec, v_coeff, bn, in_size);
|
|
||||||
|
|
||||||
// Apply scale
|
|
||||||
if (has_mul_operand_mask) {
|
|
||||||
MLX_MTL_PRAGMA_UNROLL
|
|
||||||
for (int tn = 0; tn < TN; tn++) {
|
|
||||||
v_coeff[tn] *= block_scale;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Per thread work loop
|
|
||||||
MLX_MTL_PRAGMA_UNROLL
|
|
||||||
for (int tm = 0; tm < TM; tm++) {
|
|
||||||
// Load for the row
|
|
||||||
load_safe(&mat[tm * matrix_ld], inter, bn, in_size);
|
|
||||||
|
|
||||||
// Accumulate results
|
|
||||||
MLX_MTL_PRAGMA_UNROLL
|
|
||||||
for (int tn = 0; tn < TN; tn++) {
|
|
||||||
result[tm] += inter[tn] * v_coeff[tn];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Apply out scale
|
|
||||||
if (has_mul_output_mask) {
|
|
||||||
MLX_MTL_PRAGMA_UNROLL
|
|
||||||
for (int tm = 0; tm < TM; tm++) {
|
|
||||||
result[tm] *= out_scale;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Simdgroup accumulations
|
|
||||||
MLX_MTL_PRAGMA_UNROLL
|
|
||||||
for (int tm = 0; tm < TM; tm++) {
|
|
||||||
MLX_MTL_PRAGMA_UNROLL
|
|
||||||
for (ushort sn = (SN / 2); sn >= 1; sn >>= 1) {
|
|
||||||
result[tm] += simd_shuffle_down(result[tm], sn);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Threadgroup accumulation results
|
|
||||||
if (needs_tgp_reduction) {
|
|
||||||
threadgroup T* tgp_results = tgp_memory + sgN * (blockM + TM) + bm;
|
|
||||||
if (thrN == 0) {
|
|
||||||
MLX_MTL_PRAGMA_UNROLL
|
|
||||||
for (int tm = 0; tm < TM; tm++) {
|
|
||||||
tgp_results[tm] = result[tm];
|
|
||||||
}
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_none);
|
|
||||||
|
|
||||||
if (sgN == 0) {
|
|
||||||
MLX_MTL_PRAGMA_UNROLL
|
|
||||||
for (int sgn = 1; sgn < BN; sgn++) {
|
|
||||||
MLX_MTL_PRAGMA_UNROLL
|
|
||||||
for (int tm = 0; tm < TM; tm++) {
|
|
||||||
result[tm] += tgp_results[sgn * (blockM + TM) + tm];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write outputs
|
|
||||||
if (simdN == 0 && thrN == 0) {
|
|
||||||
MLX_MTL_PRAGMA_UNROLL
|
|
||||||
for (int tm = 0; tm < TM; tm++) {
|
|
||||||
out_vec[out_row + tm] = result[tm];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
/// Vector matrix multiplication
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
template <
|
|
||||||
typename T,
|
|
||||||
typename out_mask_t,
|
|
||||||
typename op_mask_t,
|
|
||||||
const int BM, /* Threadgroup rows (in simdgroups) */
|
|
||||||
const int BN, /* Threadgroup cols (in simdgroups) */
|
|
||||||
const int SM, /* Simdgroup rows (in threads) */
|
|
||||||
const int SN, /* Simdgroup cols (in threads) */
|
|
||||||
const int TM, /* Thread rows (in elements) */
|
|
||||||
const int TN> /* Thread cols (in elements) */
|
|
||||||
struct GEMVTKernel {
|
|
||||||
MLX_MTL_CONST int threadsM = BM * SM;
|
|
||||||
MLX_MTL_CONST int threadsN = BN * SN;
|
|
||||||
|
|
||||||
MLX_MTL_CONST int blockM = threadsM * TM;
|
|
||||||
MLX_MTL_CONST int blockN = threadsN * TN;
|
|
||||||
|
|
||||||
static_assert(SM * SN == 32, "simdgroup can only have 32 threads");
|
|
||||||
|
|
||||||
MLX_MTL_CONST bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
|
|
||||||
MLX_MTL_CONST bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
|
|
||||||
|
|
||||||
MLX_MTL_CONST bool has_mul_operand_mask =
|
|
||||||
has_operand_mask && !metal::is_same_v<op_mask_t, bool>;
|
|
||||||
MLX_MTL_CONST bool has_mul_output_mask =
|
|
||||||
has_output_mask && !metal::is_same_v<out_mask_t, bool>;
|
|
||||||
|
|
||||||
// - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up
|
|
||||||
// into blocks of (blockM, blockN) divided among threadgroups
|
|
||||||
// - Every thread works on a block of (TM, TN)
|
|
||||||
// - We assume each threadgroup has (threadsN, threadsM, 1) threads
|
|
||||||
//
|
|
||||||
// 1. A thread loads TN elements each from mat along TM contiguous rows
|
|
||||||
// and the corresponding scalar from the vector
|
|
||||||
// 2. The thread then accumulates its local result for the block
|
|
||||||
// 3. At the end, each thread has accumulated results over all blocks across
|
|
||||||
// the rows. These are then summed up across the threadgroup
|
|
||||||
// 4. Each threadgroup writes its accumulated BN * TN outputs
|
|
||||||
//
|
|
||||||
// Edge case handling:
|
|
||||||
// - The threadgroup with the largest tid has blocks that exceed the matrix
|
|
||||||
// * The blocks that start outside the matrix are never read (thread results
|
|
||||||
// remain zero)
|
|
||||||
// * The last thread that partially overlaps with the matrix is shifted
|
|
||||||
// inwards such that the thread block fits exactly in the matrix
|
|
||||||
|
|
||||||
MLX_MTL_CONST short tgp_mem_size = BM > 1 ? BM*(blockN + TN) : 0;
|
|
||||||
MLX_MTL_CONST bool needs_tgp_reduction = BM > 1;
|
|
||||||
|
|
||||||
static METAL_FUNC void run(
|
|
||||||
const device T* mat [[buffer(0)]],
|
|
||||||
const device T* in_vec [[buffer(1)]],
|
|
||||||
device T* out_vec [[buffer(3)]],
|
|
||||||
const constant int& in_vec_size [[buffer(4)]],
|
|
||||||
const constant int& out_vec_size [[buffer(5)]],
|
|
||||||
const constant int& marix_ld [[buffer(6)]],
|
|
||||||
const device out_mask_t* out_mask [[buffer(20)]],
|
|
||||||
const device op_mask_t* mat_mask [[buffer(21)]],
|
|
||||||
const device op_mask_t* vec_mask [[buffer(22)]],
|
|
||||||
const constant int* mask_strides [[buffer(23)]],
|
|
||||||
threadgroup T* tgp_memory [[threadgroup(0)]],
|
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
|
||||||
uint3 lid [[thread_position_in_threadgroup]],
|
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
||||||
// Appease compiler
|
|
||||||
(void)lid;
|
|
||||||
|
|
||||||
// Thread local accumulation results
|
|
||||||
T result[TN] = {0};
|
|
||||||
T inter[TN];
|
|
||||||
T v_coeff[TM];
|
|
||||||
|
|
||||||
const int thrM = SN != 32 ? simd_lid / SN : 0;
|
|
||||||
const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid);
|
|
||||||
|
|
||||||
const int sgM = BN != 1 ? (simd_gid / BN) : int(simd_gid);
|
|
||||||
const int sgN = BN != 1 ? (simd_gid % BN) : 0;
|
|
||||||
|
|
||||||
const int simdM = SM * sgM;
|
|
||||||
const int simdN = SN * sgN;
|
|
||||||
|
|
||||||
int cm = (simdM + thrM);
|
|
||||||
int cn = (simdN + thrN);
|
|
||||||
|
|
||||||
int bm = cm * TM;
|
|
||||||
int bn = cn * TN;
|
|
||||||
|
|
||||||
int out_col = tid.x * blockN + bn;
|
|
||||||
|
|
||||||
// Prepare mask offsets
|
|
||||||
const constant int* out_mask_strides = mask_strides;
|
|
||||||
const constant int* mat_mask_strides =
|
|
||||||
out_mask_strides + (has_output_mask ? 2 : 0);
|
|
||||||
const constant int* vec_mask_strides =
|
|
||||||
mat_mask_strides + (has_operand_mask ? 2 : 0);
|
|
||||||
|
|
||||||
const int n_block_idx = blockM > blockN ? out_col / blockM : int(tid.x);
|
|
||||||
|
|
||||||
const int out_mask_offset =
|
|
||||||
!has_output_mask ? 0 : n_block_idx; // * out_mask_strides[0];
|
|
||||||
|
|
||||||
int mat_mask_offset =
|
|
||||||
!has_operand_mask ? 0 : n_block_idx * mat_mask_strides[0];
|
|
||||||
int vec_mask_offset = 0;
|
|
||||||
const int mat_mask_step = !has_operand_mask ? 0 : mat_mask_strides[1];
|
|
||||||
const int vec_mask_step = !has_operand_mask ? 0 : vec_mask_strides[0];
|
|
||||||
|
|
||||||
T out_scale{1};
|
|
||||||
|
|
||||||
// Check output mask
|
|
||||||
if (has_output_mask) {
|
|
||||||
auto mask_out = out_mask[out_mask_offset];
|
|
||||||
|
|
||||||
// Write zeros and return if mask is 0
|
|
||||||
if (!mask_out) {
|
|
||||||
if (cm == 0 && out_col < out_vec_size) {
|
|
||||||
if (out_col + TN <= out_vec_size) {
|
|
||||||
MLX_MTL_PRAGMA_UNROLL
|
|
||||||
for (int tn = 0; tn < TN; tn++) {
|
|
||||||
out_vec[out_col + tn] = T(0.);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (int tn = 0; tn < TN && (out_col + tn) < out_vec_size; tn++) {
|
|
||||||
out_vec[out_col + tn] = T(0.);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Store scalar if multiplicative mask
|
|
||||||
if (has_mul_output_mask) {
|
|
||||||
out_scale = T(mask_out);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Prepare for loop
|
|
||||||
constexpr const uniform<int> loop_stride = make_uniform(blockM);
|
|
||||||
const uniform<int> in_size = make_uniform(in_vec_size);
|
|
||||||
const uniform<int> n_iter = in_size / loop_stride;
|
|
||||||
const uniform<int> last_iter = loop_stride * n_iter;
|
|
||||||
const uniform<int> leftover = in_size - last_iter;
|
|
||||||
|
|
||||||
// Edgecase handling
|
|
||||||
if (out_col < out_vec_size) {
|
|
||||||
out_col = (out_col + TN) <= out_vec_size ? out_col : out_vec_size - TN;
|
|
||||||
|
|
||||||
// Per thread accumulation main loop
|
|
||||||
for (int i = 0; i < n_iter; ++i) {
|
|
||||||
// Adding a threadgroup_barrier improves performance slightly
|
|
||||||
// This is possibly it may help exploit cache better
|
|
||||||
threadgroup_barrier(mem_flags::mem_none);
|
|
||||||
|
|
||||||
if (!has_operand_mask ||
|
|
||||||
(bool(mat_mask[mat_mask_offset]) &&
|
|
||||||
bool(vec_mask[vec_mask_offset]))) {
|
|
||||||
T block_scale{1};
|
|
||||||
if (has_mul_operand_mask) {
|
|
||||||
block_scale =
|
|
||||||
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
|
|
||||||
}
|
|
||||||
|
|
||||||
MLX_MTL_PRAGMA_UNROLL
|
|
||||||
for (int tm = 0; tm < TM; tm++) {
|
|
||||||
v_coeff[tm] = in_vec[bm + tm];
|
|
||||||
}
|
|
||||||
|
|
||||||
// Apply scale
|
|
||||||
if (has_mul_operand_mask) {
|
|
||||||
MLX_MTL_PRAGMA_UNROLL
|
|
||||||
for (int tm = 0; tm < TM; tm++) {
|
|
||||||
v_coeff[tm] *= block_scale;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
MLX_MTL_PRAGMA_UNROLL
|
|
||||||
for (int tm = 0; tm < TM; tm++) {
|
|
||||||
for (int tn = 0; tn < TN; tn++) {
|
|
||||||
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
|
|
||||||
}
|
|
||||||
for (int tn = 0; tn < TN; tn++) {
|
|
||||||
result[tn] += v_coeff[tm] * inter[tn];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bm += blockM;
|
|
||||||
mat_mask_offset += mat_mask_step;
|
|
||||||
vec_mask_offset += vec_mask_step;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (leftover > 0 &&
|
|
||||||
(!has_operand_mask ||
|
|
||||||
(bool(mat_mask[mat_mask_offset]) &&
|
|
||||||
bool(vec_mask[vec_mask_offset])))) {
|
|
||||||
T block_scale{1};
|
|
||||||
if (has_mul_operand_mask) {
|
|
||||||
block_scale =
|
|
||||||
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int tm = 0; tm < TM && bm + tm < in_vec_size; tm++) {
|
|
||||||
v_coeff[tm] = in_vec[bm + tm];
|
|
||||||
|
|
||||||
if (has_mul_operand_mask) {
|
|
||||||
v_coeff[tm] *= block_scale;
|
|
||||||
}
|
|
||||||
|
|
||||||
MLX_MTL_PRAGMA_UNROLL
|
|
||||||
for (int tn = 0; tn < TN; tn++) {
|
|
||||||
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
|
|
||||||
}
|
|
||||||
|
|
||||||
MLX_MTL_PRAGMA_UNROLL
|
|
||||||
for (int tn = 0; tn < TN; tn++) {
|
|
||||||
result[tn] += v_coeff[tm] * inter[tn];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Apply out scale
|
|
||||||
if (has_mul_output_mask) {
|
|
||||||
MLX_MTL_PRAGMA_UNROLL
|
|
||||||
for (int tn = 0; tn < TN; tn++) {
|
|
||||||
result[tn] *= out_scale;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Simdgroup accumulations
|
|
||||||
MLX_MTL_PRAGMA_UNROLL
|
|
||||||
for (int tn = 0; tn < TN; tn++) {
|
|
||||||
MLX_MTL_PRAGMA_UNROLL
|
|
||||||
for (ushort sm = (SM / 2); sm >= 1; sm >>= 1) {
|
|
||||||
result[tn] += simd_shuffle_down(result[tn], SN * sm);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Threadgroup accumulation results
|
|
||||||
if (needs_tgp_reduction) {
|
|
||||||
threadgroup T* tgp_results = tgp_memory + sgM * (blockN + TN) + bn;
|
|
||||||
if (thrM == 0) {
|
|
||||||
MLX_MTL_PRAGMA_UNROLL
|
|
||||||
for (int tn = 0; tn < TN; tn++) {
|
|
||||||
tgp_results[tn] = result[tn];
|
|
||||||
}
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_none);
|
|
||||||
|
|
||||||
if (sgM == 0) {
|
|
||||||
MLX_MTL_PRAGMA_UNROLL
|
|
||||||
for (int sgm = 1; sgm < BM; sgm++) {
|
|
||||||
MLX_MTL_PRAGMA_UNROLL
|
|
||||||
for (int tn = 0; tn < TN; tn++) {
|
|
||||||
result[tn] += tgp_results[sgm * (blockN + TN) + tn];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Threadgroup accumulation and writing out results
|
|
||||||
if (cm == 0 && out_col < out_vec_size) {
|
|
||||||
MLX_MTL_PRAGMA_UNROLL
|
|
||||||
for (int j = 0; j < TN; j++) {
|
|
||||||
out_vec[out_col + j] = result[j];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
/// Matrix vector multiplication
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
template <
|
|
||||||
typename T,
|
|
||||||
typename out_mask_t,
|
|
||||||
typename op_mask_t,
|
|
||||||
const int BM, /* Threadgroup rows (in simdgroups) */
|
|
||||||
const int BN, /* Threadgroup cols (in simdgroups) */
|
|
||||||
const int SM, /* Simdgroup rows (in threads) */
|
|
||||||
const int SN, /* Simdgroup cols (in threads) */
|
|
||||||
const int TM, /* Thread rows (in elements) */
|
|
||||||
const int TN, /* Thread cols (in elements) */
|
|
||||||
const bool kDoNCBatch> /* Batch ndim > 1 */
|
|
||||||
[[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]] void gemv_masked(
|
|
||||||
const device T* mat [[buffer(0)]],
|
|
||||||
const device T* in_vec [[buffer(1)]],
|
|
||||||
device T* out_vec [[buffer(3)]],
|
|
||||||
const constant int& in_vec_size [[buffer(4)]],
|
|
||||||
const constant int& out_vec_size [[buffer(5)]],
|
|
||||||
const constant int& marix_ld [[buffer(6)]],
|
|
||||||
const constant int& batch_ndim [[buffer(9)]],
|
|
||||||
const constant int* batch_shape [[buffer(10)]],
|
|
||||||
const constant size_t* vector_batch_stride [[buffer(11)]],
|
|
||||||
const constant size_t* matrix_batch_stride [[buffer(12)]],
|
|
||||||
const device out_mask_t* out_mask [[buffer(20)]],
|
|
||||||
const device op_mask_t* mat_mask [[buffer(21)]],
|
|
||||||
const device op_mask_t* vec_mask [[buffer(22)]],
|
|
||||||
const constant int* mask_strides [[buffer(23)]],
|
|
||||||
const constant size_t* mask_batch_strides [[buffer(24)]],
|
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
|
||||||
uint3 lid [[thread_position_in_threadgroup]],
|
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
||||||
using gemv_kernel =
|
|
||||||
GEMVKernel<T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN>;
|
|
||||||
threadgroup T tgp_memory
|
|
||||||
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
|
|
||||||
|
|
||||||
constexpr bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
|
|
||||||
constexpr bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
|
|
||||||
|
|
||||||
// Update batch offsets
|
|
||||||
if (kDoNCBatch) {
|
|
||||||
in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim);
|
|
||||||
mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim);
|
|
||||||
|
|
||||||
if (has_output_mask) {
|
|
||||||
out_mask +=
|
|
||||||
elem_to_loc(tid.z, batch_shape, mask_batch_strides, batch_ndim);
|
|
||||||
mask_batch_strides += batch_ndim;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (has_operand_mask) {
|
|
||||||
const constant size_t* mask_strides_mat = mask_batch_strides;
|
|
||||||
const constant size_t* mask_strides_vec = mask_strides_mat + batch_ndim;
|
|
||||||
|
|
||||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
|
||||||
tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim);
|
|
||||||
|
|
||||||
mat_mask += batch_offsets.x;
|
|
||||||
vec_mask += batch_offsets.y;
|
|
||||||
}
|
|
||||||
|
|
||||||
} else {
|
|
||||||
in_vec += tid.z * vector_batch_stride[0];
|
|
||||||
mat += tid.z * matrix_batch_stride[0];
|
|
||||||
|
|
||||||
if (has_output_mask) {
|
|
||||||
out_mask += tid.z * mask_batch_strides[0];
|
|
||||||
mask_batch_strides += batch_ndim;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (has_operand_mask) {
|
|
||||||
mat_mask += tid.z * mask_batch_strides[0];
|
|
||||||
vec_mask += tid.z * mask_batch_strides[batch_ndim];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
out_vec += tid.z * out_vec_size;
|
|
||||||
|
|
||||||
gemv_kernel::run(
|
|
||||||
mat,
|
|
||||||
in_vec,
|
|
||||||
out_vec,
|
|
||||||
in_vec_size,
|
|
||||||
out_vec_size,
|
|
||||||
marix_ld,
|
|
||||||
out_mask,
|
|
||||||
mat_mask,
|
|
||||||
vec_mask,
|
|
||||||
mask_strides,
|
|
||||||
gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory,
|
|
||||||
tid,
|
|
||||||
lid,
|
|
||||||
simd_gid,
|
|
||||||
simd_lid);
|
|
||||||
}
|
|
||||||
|
|
||||||
#define instantiate_gemv_helper( \
|
#define instantiate_gemv_helper( \
|
||||||
outm_n, outm_t, opm_n, opm_t, name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
outm_n, outm_t, opm_n, opm_t, name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
||||||
@ -754,7 +36,6 @@ template <
|
|||||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
#define instantiate_gemv_base(name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
#define instantiate_gemv_base(name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
||||||
instantiate_gemv_helper(bool_, bool, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
instantiate_gemv_helper(bool_, bool, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
||||||
instantiate_gemv_helper(name, itype, name, itype, name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
instantiate_gemv_helper(name, itype, name, itype, name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
||||||
@ -763,125 +44,23 @@ template <
|
|||||||
instantiate_gemv_helper(nomask, nomask_t, name, itype, name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
instantiate_gemv_helper(nomask, nomask_t, name, itype, name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
||||||
instantiate_gemv_helper(nomask, nomask_t, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
instantiate_gemv_helper(nomask, nomask_t, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
||||||
instantiate_gemv_helper(bool_, bool, nomask, nomask_t, name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
instantiate_gemv_helper(bool_, bool, nomask, nomask_t, name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
||||||
instantiate_gemv_helper(name, itype, nomask, nomask_t, name, itype, bm, bn, sm, sn, tm, tn, nc) // clang-format on
|
instantiate_gemv_helper(name, itype, nomask, nomask_t, name, itype, bm, bn, sm, sn, tm, tn, nc)
|
||||||
|
|
||||||
// clang-format off
|
#define instantiate_gemv(name, itype, bm, bn, sm, sn, tm, tn) \
|
||||||
#define instantiate_gemv(name, itype, bm, bn, sm, sn, tm, tn) \
|
|
||||||
instantiate_gemv_base(name, itype, bm, bn, sm, sn, tm, tn, 0) \
|
instantiate_gemv_base(name, itype, bm, bn, sm, sn, tm, tn, 0) \
|
||||||
instantiate_gemv_base(name, itype, bm, bn, sm, sn, tm, tn, 1) // clang-format on
|
instantiate_gemv_base(name, itype, bm, bn, sm, sn, tm, tn, 1)
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
#define instantiate_gemv_blocks(name, itype) \
|
#define instantiate_gemv_blocks(name, itype) \
|
||||||
instantiate_gemv(name, itype, 2, 1, 4, 8, 1, 4) \
|
instantiate_gemv(name, itype, 2, 1, 4, 8, 1, 4) \
|
||||||
instantiate_gemv(name, itype, 2, 1, 4, 8, 4, 4) \
|
instantiate_gemv(name, itype, 2, 1, 4, 8, 4, 4) \
|
||||||
instantiate_gemv(name, itype, 2, 1, 2, 16, 1, 4) \
|
instantiate_gemv(name, itype, 2, 1, 2, 16, 1, 4) \
|
||||||
instantiate_gemv(name, itype, 2, 1, 2, 16, 4, 4) \
|
instantiate_gemv(name, itype, 2, 1, 2, 16, 4, 4) \
|
||||||
instantiate_gemv(name, itype, 4, 1, 2, 16, 4, 4) // clang-format on
|
instantiate_gemv(name, itype, 4, 1, 2, 16, 4, 4)
|
||||||
|
|
||||||
instantiate_gemv_blocks(float32, float);
|
instantiate_gemv_blocks(float32, float);
|
||||||
instantiate_gemv_blocks(float16, half);
|
instantiate_gemv_blocks(float16, half);
|
||||||
instantiate_gemv_blocks(bfloat16, bfloat16_t);
|
instantiate_gemv_blocks(bfloat16, bfloat16_t);
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
/// Vector matrix multiplication
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
template <
|
|
||||||
typename T,
|
|
||||||
typename out_mask_t,
|
|
||||||
typename op_mask_t,
|
|
||||||
const int BM, /* Threadgroup rows (in simdgroups) */
|
|
||||||
const int BN, /* Threadgroup cols (in simdgroups) */
|
|
||||||
const int SM, /* Simdgroup rows (in threads) */
|
|
||||||
const int SN, /* Simdgroup cols (in threads) */
|
|
||||||
const int TM, /* Thread rows (in elements) */
|
|
||||||
const int TN, /* Thread cols (in elements) */
|
|
||||||
const bool kDoNCBatch> /* Batch ndim > 1 */
|
|
||||||
[[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]] void gemv_t_masked(
|
|
||||||
const device T* mat [[buffer(0)]],
|
|
||||||
const device T* in_vec [[buffer(1)]],
|
|
||||||
device T* out_vec [[buffer(3)]],
|
|
||||||
const constant int& in_vec_size [[buffer(4)]],
|
|
||||||
const constant int& out_vec_size [[buffer(5)]],
|
|
||||||
const constant int& marix_ld [[buffer(6)]],
|
|
||||||
const constant int& batch_ndim [[buffer(9)]],
|
|
||||||
const constant int* batch_shape [[buffer(10)]],
|
|
||||||
const constant size_t* vector_batch_stride [[buffer(11)]],
|
|
||||||
const constant size_t* matrix_batch_stride [[buffer(12)]],
|
|
||||||
const device out_mask_t* out_mask [[buffer(20)]],
|
|
||||||
const device op_mask_t* mat_mask [[buffer(21)]],
|
|
||||||
const device op_mask_t* vec_mask [[buffer(22)]],
|
|
||||||
const constant int* mask_strides [[buffer(23)]],
|
|
||||||
const constant size_t* mask_batch_strides [[buffer(24)]],
|
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
|
||||||
uint3 lid [[thread_position_in_threadgroup]],
|
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
||||||
using gemv_kernel =
|
|
||||||
GEMVTKernel<T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN>;
|
|
||||||
threadgroup T tgp_memory
|
|
||||||
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
|
|
||||||
|
|
||||||
constexpr bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
|
|
||||||
constexpr bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
|
|
||||||
|
|
||||||
// Update batch offsets
|
|
||||||
if (kDoNCBatch) {
|
|
||||||
in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim);
|
|
||||||
mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim);
|
|
||||||
|
|
||||||
if (has_output_mask) {
|
|
||||||
out_mask +=
|
|
||||||
elem_to_loc(tid.z, batch_shape, mask_batch_strides, batch_ndim);
|
|
||||||
mask_batch_strides += batch_ndim;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (has_operand_mask) {
|
|
||||||
const constant size_t* mask_strides_mat = mask_batch_strides;
|
|
||||||
const constant size_t* mask_strides_vec = mask_strides_mat + batch_ndim;
|
|
||||||
|
|
||||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
|
||||||
tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim);
|
|
||||||
|
|
||||||
mat_mask += batch_offsets.x;
|
|
||||||
vec_mask += batch_offsets.y;
|
|
||||||
}
|
|
||||||
|
|
||||||
} else {
|
|
||||||
in_vec += tid.z * vector_batch_stride[0];
|
|
||||||
mat += tid.z * matrix_batch_stride[0];
|
|
||||||
|
|
||||||
if (has_output_mask) {
|
|
||||||
out_mask += tid.z * mask_batch_strides[0];
|
|
||||||
mask_batch_strides += batch_ndim;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (has_operand_mask) {
|
|
||||||
mat_mask += tid.z * mask_batch_strides[0];
|
|
||||||
vec_mask += tid.z * mask_batch_strides[batch_ndim];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
out_vec += tid.z * out_vec_size;
|
|
||||||
|
|
||||||
gemv_kernel::run(
|
|
||||||
mat,
|
|
||||||
in_vec,
|
|
||||||
out_vec,
|
|
||||||
in_vec_size,
|
|
||||||
out_vec_size,
|
|
||||||
marix_ld,
|
|
||||||
out_mask,
|
|
||||||
mat_mask,
|
|
||||||
vec_mask,
|
|
||||||
mask_strides,
|
|
||||||
gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory,
|
|
||||||
tid,
|
|
||||||
lid,
|
|
||||||
simd_gid,
|
|
||||||
simd_lid);
|
|
||||||
}
|
|
||||||
|
|
||||||
#define instantiate_gemv_t_helper( \
|
#define instantiate_gemv_t_helper( \
|
||||||
outm_n, outm_t, opm_n, opm_t, name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
outm_n, outm_t, opm_n, opm_t, name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
||||||
template [[host_name("gemv_t_outmask_" #outm_n "_opmask_" #opm_n "_" #name \
|
template [[host_name("gemv_t_outmask_" #outm_n "_opmask_" #opm_n "_" #name \
|
||||||
@ -908,7 +87,6 @@ template <
|
|||||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
#define instantiate_gemv_t_base(name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
#define instantiate_gemv_t_base(name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
||||||
instantiate_gemv_t_helper(bool_, bool, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
instantiate_gemv_t_helper(bool_, bool, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
||||||
instantiate_gemv_t_helper(name, itype, name, itype, name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
instantiate_gemv_t_helper(name, itype, name, itype, name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
||||||
@ -917,23 +95,20 @@ template <
|
|||||||
instantiate_gemv_t_helper(nomask, nomask_t, name, itype, name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
instantiate_gemv_t_helper(nomask, nomask_t, name, itype, name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
||||||
instantiate_gemv_t_helper(nomask, nomask_t, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
instantiate_gemv_t_helper(nomask, nomask_t, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
||||||
instantiate_gemv_t_helper(bool_, bool, nomask, nomask_t, name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
instantiate_gemv_t_helper(bool_, bool, nomask, nomask_t, name, itype, bm, bn, sm, sn, tm, tn, nc) \
|
||||||
instantiate_gemv_t_helper(name, itype, nomask, nomask_t, name, itype, bm, bn, sm, sn, tm, tn, nc) // clang-format on
|
instantiate_gemv_t_helper(name, itype, nomask, nomask_t, name, itype, bm, bn, sm, sn, tm, tn, nc)
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
#define instantiate_gemv_t(name, itype, bm, bn, sm, sn, tm, tn) \
|
#define instantiate_gemv_t(name, itype, bm, bn, sm, sn, tm, tn) \
|
||||||
instantiate_gemv_t_base(name, itype, bm, bn, sm, sn, tm, tn, 0) \
|
instantiate_gemv_t_base(name, itype, bm, bn, sm, sn, tm, tn, 0) \
|
||||||
instantiate_gemv_t_base(name, itype, bm, bn, sm, sn, tm, tn, 1) // clang-format on
|
instantiate_gemv_t_base(name, itype, bm, bn, sm, sn, tm, tn, 1)
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
#define instantiate_gemv_t_blocks(name, itype) \
|
#define instantiate_gemv_t_blocks(name, itype) \
|
||||||
instantiate_gemv_t(name, itype, 1, 1, 8, 4, 4, 1) \
|
instantiate_gemv_t(name, itype, 1, 1, 8, 4, 4, 1) \
|
||||||
instantiate_gemv_t(name, itype, 1, 2, 8, 4, 4, 4) \
|
instantiate_gemv_t(name, itype, 1, 2, 8, 4, 4, 4) \
|
||||||
instantiate_gemv_t(name, itype, 1, 1, 8, 4, 8, 1) \
|
instantiate_gemv_t(name, itype, 1, 1, 8, 4, 8, 1) \
|
||||||
instantiate_gemv_t(name, itype, 1, 1, 8, 4, 8, 4) \
|
instantiate_gemv_t(name, itype, 1, 1, 8, 4, 8, 4) \
|
||||||
instantiate_gemv_t(name, itype, 1, 2, 8, 4, 8, 4) \
|
instantiate_gemv_t(name, itype, 1, 2, 8, 4, 8, 4) \
|
||||||
instantiate_gemv_t(name, itype, 1, 4, 8, 4, 8, 4) // clang-format on
|
instantiate_gemv_t(name, itype, 1, 4, 8, 4, 8, 4)
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
instantiate_gemv_t_blocks(float32, float);
|
instantiate_gemv_t_blocks(float32, float);
|
||||||
instantiate_gemv_t_blocks(float16, half);
|
instantiate_gemv_t_blocks(float16, half);
|
||||||
instantiate_gemv_t_blocks(bfloat16, bfloat16_t); // clang-format on
|
instantiate_gemv_t_blocks(bfloat16, bfloat16_t); // clang-format on
|
@ -11,187 +11,14 @@
|
|||||||
#include "mlx/backend/metal/kernels/defines.h"
|
#include "mlx/backend/metal/kernels/defines.h"
|
||||||
#include "mlx/backend/metal/kernels/steel/gemm/params.h"
|
#include "mlx/backend/metal/kernels/steel/gemm/params.h"
|
||||||
#include "mlx/backend/metal/matmul.h"
|
#include "mlx/backend/metal/matmul.h"
|
||||||
#include "mlx/backend/metal/mps/gemm.h"
|
|
||||||
#include "mlx/backend/metal/utils.h"
|
#include "mlx/backend/metal/utils.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// MPS Matmul fallback
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
bool use_mps() {
|
|
||||||
auto get_val = []() {
|
|
||||||
if (const char* buff_str = std::getenv("MLX_USE_MPS")) {
|
|
||||||
return std::string(buff_str) != "OFF";
|
|
||||||
} else {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
static bool use_mps_ = get_val();
|
|
||||||
return use_mps_;
|
|
||||||
}
|
|
||||||
|
|
||||||
#define MAX_OPS_PER_BUFFER max_ops_per_buffer()
|
|
||||||
|
|
||||||
inline void mps_matmul(
|
|
||||||
const Stream& s,
|
|
||||||
metal::Device& d,
|
|
||||||
const array& a,
|
|
||||||
const array& b,
|
|
||||||
array& out,
|
|
||||||
int M,
|
|
||||||
int N,
|
|
||||||
int K,
|
|
||||||
int batch_size_out,
|
|
||||||
int lda,
|
|
||||||
int ldb,
|
|
||||||
bool transpose_a,
|
|
||||||
bool transpose_b,
|
|
||||||
std::vector<array>& copies,
|
|
||||||
float alpha = 1.0f,
|
|
||||||
float beta = 0.0f) {
|
|
||||||
MPS::DataType mps_dtype = MPS::DataTypeFloat32;
|
|
||||||
|
|
||||||
if (out.dtype() == float16) {
|
|
||||||
mps_dtype = MPS::DataTypeFloat16;
|
|
||||||
} else if (out.dtype() == bfloat16) {
|
|
||||||
mps_dtype = MPS::DataTypeBFloat16;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Used batched MPSMatrixMultiplication if batch_size_out > 1
|
|
||||||
// We only accept the following cases:
|
|
||||||
// 1. Both a, b have batch_size_out matrices worth of data
|
|
||||||
// 2. Only one of a or b has batch_size_out matrices worth of data and
|
|
||||||
// the other has matrix worth of data
|
|
||||||
|
|
||||||
// The matrix dimensions of a and b are sure to be regularly strided
|
|
||||||
if (batch_size_out > 1) {
|
|
||||||
// No broadcasting defaults
|
|
||||||
auto batch_size_a = a.data_size() / (M * K);
|
|
||||||
auto batch_size_b = b.data_size() / (K * N);
|
|
||||||
|
|
||||||
auto matrix_stride_a = M * K;
|
|
||||||
auto matrix_stride_b = K * N;
|
|
||||||
auto matrix_stride_out = M * N;
|
|
||||||
|
|
||||||
// At this point, batch_size_a, batch_size_b show the number of matrices
|
|
||||||
// in data, no broadcasted strides considered
|
|
||||||
if (batch_size_out == std::max(batch_size_a, batch_size_b)) {
|
|
||||||
// Handle simple broadcasting
|
|
||||||
if (std::min(batch_size_a, batch_size_b) == 1) {
|
|
||||||
matrix_stride_a = (batch_size_a == 1) ? 0 : matrix_stride_a;
|
|
||||||
matrix_stride_b = (batch_size_b == 1) ? 0 : matrix_stride_b;
|
|
||||||
|
|
||||||
batch_size_a = batch_size_out;
|
|
||||||
batch_size_b = batch_size_out;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Only proceed if broadcasting between a and b is simple
|
|
||||||
// At this point, batch_size_a, batch_size_b show the number of matrices
|
|
||||||
// after broadcasting
|
|
||||||
if (batch_size_a == batch_size_b) {
|
|
||||||
auto a_desc = MPS::MatrixDescriptor::matrixDescriptor(
|
|
||||||
(M * K) / lda,
|
|
||||||
lda,
|
|
||||||
batch_size_a,
|
|
||||||
lda * a.itemsize(),
|
|
||||||
(matrix_stride_a * a.itemsize()),
|
|
||||||
mps_dtype);
|
|
||||||
|
|
||||||
auto b_desc = MPS::MatrixDescriptor::matrixDescriptor(
|
|
||||||
(K * N) / ldb,
|
|
||||||
ldb,
|
|
||||||
batch_size_b,
|
|
||||||
ldb * b.itemsize(),
|
|
||||||
(matrix_stride_b * b.itemsize()),
|
|
||||||
mps_dtype);
|
|
||||||
|
|
||||||
auto out_desc = MPS::MatrixDescriptor::matrixDescriptor(
|
|
||||||
M,
|
|
||||||
N,
|
|
||||||
batch_size_out,
|
|
||||||
N * out.itemsize(),
|
|
||||||
matrix_stride_out * out.itemsize(),
|
|
||||||
mps_dtype);
|
|
||||||
|
|
||||||
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
|
||||||
auto a_mat = MPS::Matrix::alloc()->init(a_buf, a_desc);
|
|
||||||
|
|
||||||
auto b_buf = static_cast<const MTL::Buffer*>(b.buffer().ptr());
|
|
||||||
auto b_mat = MPS::Matrix::alloc()->init(b_buf, b_desc);
|
|
||||||
|
|
||||||
auto out_buf = static_cast<MTL::Buffer*>(out.buffer().ptr());
|
|
||||||
auto out_mat = MPS::Matrix::alloc()->init(out_buf, out_desc);
|
|
||||||
|
|
||||||
auto kernel = MPS::MatrixMultiplication::alloc()->init(
|
|
||||||
d.mtl_device(), transpose_a, transpose_b, M, N, K, alpha, beta);
|
|
||||||
|
|
||||||
auto command_buffer = d.get_command_buffer(s.index);
|
|
||||||
kernel->setBatchSize(batch_size_out);
|
|
||||||
kernel->setBatchStart(0);
|
|
||||||
kernel->encodeToCommandBuffer(command_buffer, a_mat, b_mat, out_mat);
|
|
||||||
command_buffer->addCompletedHandler(
|
|
||||||
[a_mat, b_mat, out_mat, kernel, copies](
|
|
||||||
MTL::CommandBuffer*) mutable {
|
|
||||||
a_mat->release();
|
|
||||||
b_mat->release();
|
|
||||||
out_mat->release();
|
|
||||||
kernel->release();
|
|
||||||
copies.clear();
|
|
||||||
});
|
|
||||||
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Schedule as many calls to MPSMatrixMultiplication as needed otherwise
|
|
||||||
auto a_desc = MPS::MatrixDescriptor::matrixDescriptor(
|
|
||||||
a.data_size() / lda, lda, lda * a.itemsize(), mps_dtype);
|
|
||||||
|
|
||||||
auto b_desc = MPS::MatrixDescriptor::matrixDescriptor(
|
|
||||||
b.data_size() / ldb, ldb, ldb * b.itemsize(), mps_dtype);
|
|
||||||
|
|
||||||
auto out_desc = MPS::MatrixDescriptor::matrixDescriptor(
|
|
||||||
batch_size_out * M, N, N * out.itemsize(), mps_dtype);
|
|
||||||
|
|
||||||
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
|
||||||
auto a_mat = MPS::Matrix::alloc()->init(a_buf, a_desc);
|
|
||||||
|
|
||||||
auto b_buf = static_cast<const MTL::Buffer*>(b.buffer().ptr());
|
|
||||||
auto b_mat = MPS::Matrix::alloc()->init(b_buf, b_desc);
|
|
||||||
|
|
||||||
auto out_buf = static_cast<MTL::Buffer*>(out.buffer().ptr());
|
|
||||||
auto out_mat = MPS::Matrix::alloc()->init(out_buf, out_desc);
|
|
||||||
|
|
||||||
auto kernel = MPS::MatrixMultiplication::alloc()->init(
|
|
||||||
d.mtl_device(), transpose_a, transpose_b, M, N, K, alpha, beta);
|
|
||||||
|
|
||||||
auto command_buffer = d.get_command_buffer(s.index);
|
|
||||||
for (int i = 0; i < batch_size_out; ++i) {
|
|
||||||
auto a_row = elem_to_loc(M * K * i, a.shape(), a.strides()) / lda;
|
|
||||||
auto b_row = elem_to_loc(K * N * i, b.shape(), b.strides()) / ldb;
|
|
||||||
kernel->setLeftMatrixOrigin({a_row, 0, 0});
|
|
||||||
kernel->setRightMatrixOrigin({b_row, 0, 0});
|
|
||||||
kernel->setResultMatrixOrigin({i * static_cast<size_t>(M), 0, 0});
|
|
||||||
kernel->encodeToCommandBuffer(command_buffer, a_mat, b_mat, out_mat);
|
|
||||||
}
|
|
||||||
|
|
||||||
command_buffer->addCompletedHandler(
|
|
||||||
[a_mat, b_mat, out_mat, kernel, copies](MTL::CommandBuffer*) mutable {
|
|
||||||
a_mat->release();
|
|
||||||
b_mat->release();
|
|
||||||
out_mat->release();
|
|
||||||
kernel->release();
|
|
||||||
copies.clear();
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
inline auto collapse_batches(const array& a, const array& b) {
|
inline auto collapse_batches(const array& a, const array& b) {
|
||||||
// Get and check the shape for the batched dims
|
// Get and check the shape for the batched dims
|
||||||
std::vector<int> A_bshape{a.shape().begin(), a.shape().end() - 2};
|
std::vector<int> A_bshape{a.shape().begin(), a.shape().end() - 2};
|
||||||
@ -860,26 +687,6 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
/////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
// Gemm specialization
|
// Gemm specialization
|
||||||
|
|
||||||
if (use_mps()) {
|
|
||||||
d.end_encoding(s.index);
|
|
||||||
|
|
||||||
return mps_matmul(
|
|
||||||
s,
|
|
||||||
d,
|
|
||||||
a,
|
|
||||||
b,
|
|
||||||
out,
|
|
||||||
M,
|
|
||||||
N,
|
|
||||||
K,
|
|
||||||
batch_size_out,
|
|
||||||
a_cols,
|
|
||||||
b_cols,
|
|
||||||
a_transposed,
|
|
||||||
b_transposed,
|
|
||||||
copies);
|
|
||||||
}
|
|
||||||
|
|
||||||
return steel_matmul(
|
return steel_matmul(
|
||||||
/* const Stream& s = */ s,
|
/* const Stream& s = */ s,
|
||||||
/* metal::Device& d = */ d,
|
/* metal::Device& d = */ d,
|
||||||
@ -1529,8 +1336,22 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
kname << "_nc" << !contiguous_kernel;
|
kname << "_nc" << !contiguous_kernel;
|
||||||
|
|
||||||
// Encode and dispatch kernel
|
// Encode and dispatch kernel
|
||||||
|
auto kernel = get_gemv_masked_kernel(
|
||||||
|
d,
|
||||||
|
kname.str(),
|
||||||
|
out,
|
||||||
|
has_out_mask ? std::optional<array>{inputs[2]} : std::nullopt,
|
||||||
|
has_op_mask ? std::optional<array>{inputs.back()} : std::nullopt,
|
||||||
|
transpose_mat,
|
||||||
|
bm,
|
||||||
|
bn,
|
||||||
|
sm,
|
||||||
|
sn,
|
||||||
|
tm,
|
||||||
|
tn,
|
||||||
|
contiguous_kernel);
|
||||||
|
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
auto kernel = d.get_kernel(kname.str());
|
|
||||||
compute_encoder->setComputePipelineState(kernel);
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
|
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
|
||||||
|
@ -1,14 +1,6 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
#include <algorithm>
|
|
||||||
#include <cassert>
|
|
||||||
#include <sstream>
|
|
||||||
|
|
||||||
#include "mlx/backend/metal/copy.h"
|
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
#include "mlx/backend/metal/mps/gemm.h"
|
|
||||||
#include "mlx/backend/metal/utils.h"
|
|
||||||
#include "mlx/utils.h"
|
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
|
@ -1,370 +0,0 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <Metal/Metal.hpp>
|
|
||||||
|
|
||||||
#define _MPS_PRIVATE_CLS(symbol) (MTL::Private::Class::s_k##symbol)
|
|
||||||
#define _MPS_PRIVATE_SEL(accessor) (MTL::Private::Selector::s_k##accessor)
|
|
||||||
|
|
||||||
namespace MTL::Private::Class {
|
|
||||||
_MTL_PRIVATE_DEF_CLS(MPSMatrixDescriptor);
|
|
||||||
_MTL_PRIVATE_DEF_CLS(MPSMatrix);
|
|
||||||
_MTL_PRIVATE_DEF_CLS(MPSVectorDescriptor);
|
|
||||||
_MTL_PRIVATE_DEF_CLS(MPSVector);
|
|
||||||
_MTL_PRIVATE_DEF_CLS(MPSKernel);
|
|
||||||
_MTL_PRIVATE_DEF_CLS(MPSMatrixMultiplication);
|
|
||||||
_MTL_PRIVATE_DEF_CLS(MPSMatrixVectorMultiplication);
|
|
||||||
} // namespace MTL::Private::Class
|
|
||||||
|
|
||||||
namespace MTL::Private::Selector {
|
|
||||||
_MTL_PRIVATE_DEF_SEL(
|
|
||||||
matrixDescriptorWithRows_columns_rowBytes_dataType,
|
|
||||||
"matrixDescriptorWithRows:columns:rowBytes:dataType:");
|
|
||||||
_MTL_PRIVATE_DEF_SEL(
|
|
||||||
matrixDescriptorWithRows_columns_matrices_rowBytes_matrixBytes_dataType,
|
|
||||||
"matrixDescriptorWithRows:columns:matrices:rowBytes:matrixBytes:dataType:");
|
|
||||||
_MTL_PRIVATE_DEF_SEL(rows, "rows");
|
|
||||||
_MTL_PRIVATE_DEF_SEL(initWithBuffer_descriptor, "initWithBuffer:descriptor:");
|
|
||||||
_MTL_PRIVATE_DEF_SEL(
|
|
||||||
initWithDevice_,
|
|
||||||
"initWithDevice:transposeLeft:transposeRight:"
|
|
||||||
"resultRows:resultColumns:interiorColumns:alpha:beta:");
|
|
||||||
_MTL_PRIVATE_DEF_SEL(
|
|
||||||
encodeToCommandBuffer_leftMatrix_rightMatrix_resultMatrix,
|
|
||||||
"encodeToCommandBuffer:leftMatrix:rightMatrix:resultMatrix:");
|
|
||||||
_MTL_PRIVATE_DEF_SEL(setLeftMatrixOrigin_, "setLeftMatrixOrigin:");
|
|
||||||
_MTL_PRIVATE_DEF_SEL(setRightMatrixOrigin_, "setRightMatrixOrigin:");
|
|
||||||
_MTL_PRIVATE_DEF_SEL(setResultMatrixOrigin_, "setResultMatrixOrigin:");
|
|
||||||
_MTL_PRIVATE_DEF_SEL(setBatchStart_, "setBatchStart:");
|
|
||||||
_MTL_PRIVATE_DEF_SEL(setBatchSize_, "setBatchSize:");
|
|
||||||
_MTL_PRIVATE_DEF_SEL(
|
|
||||||
vectorDescriptorWithLength_dataType,
|
|
||||||
"vectorDescriptorWithLength:dataType:");
|
|
||||||
_MTL_PRIVATE_DEF_SEL(
|
|
||||||
vectorDescriptorWithLength_vectors_vectorBytes_dataType,
|
|
||||||
"vectorDescriptorWithLength:vectors:vectorBytes:dataType:");
|
|
||||||
_MTL_PRIVATE_DEF_SEL(
|
|
||||||
initWithDevice_transpose_rows_columns_alpha_beta,
|
|
||||||
"initWithDevice:transpose:rows:columns:alpha:beta:");
|
|
||||||
_MTL_PRIVATE_DEF_SEL(
|
|
||||||
encodeToCommandBuffer_inputMatrix_inputVector_resultVector,
|
|
||||||
"encodeToCommandBuffer:inputMatrix:inputVector:resultVector:");
|
|
||||||
} // namespace MTL::Private::Selector
|
|
||||||
|
|
||||||
namespace MPS {
|
|
||||||
|
|
||||||
typedef enum DataType : uint32_t {
|
|
||||||
DataTypeFloatBit = 0x10000000,
|
|
||||||
DataTypeAlternateEncodingBit = 0x80000000,
|
|
||||||
DataTypeFloat16 = DataTypeFloatBit | 16,
|
|
||||||
DataTypeFloat32 = DataTypeFloatBit | 32,
|
|
||||||
DataTypeBFloat16 = DataTypeAlternateEncodingBit | DataTypeFloat16
|
|
||||||
} DataType;
|
|
||||||
|
|
||||||
class MatrixDescriptor : public NS::Copying<MatrixDescriptor> {
|
|
||||||
public:
|
|
||||||
static class MatrixDescriptor* matrixDescriptor(
|
|
||||||
NS::UInteger rows,
|
|
||||||
NS::UInteger columns,
|
|
||||||
NS::UInteger rowBytes,
|
|
||||||
NS::UInteger dataType);
|
|
||||||
static class MatrixDescriptor* matrixDescriptor(
|
|
||||||
NS::UInteger rows,
|
|
||||||
NS::UInteger columns,
|
|
||||||
NS::UInteger matrices,
|
|
||||||
NS::UInteger rowBytes,
|
|
||||||
NS::UInteger matrixBytes,
|
|
||||||
NS::UInteger dataType);
|
|
||||||
NS::UInteger rows() const;
|
|
||||||
};
|
|
||||||
|
|
||||||
class Matrix : public NS::Referencing<Matrix> {
|
|
||||||
public:
|
|
||||||
static class Matrix* alloc();
|
|
||||||
Matrix* init(MTL::Buffer* buffer, MatrixDescriptor* descriptor);
|
|
||||||
Matrix* init(const MTL::Buffer* buffer, MatrixDescriptor* descriptor);
|
|
||||||
};
|
|
||||||
|
|
||||||
class Kernel : public NS::Referencing<Kernel> {
|
|
||||||
public:
|
|
||||||
NS::String* label() const;
|
|
||||||
MTL::Device* device() const;
|
|
||||||
};
|
|
||||||
|
|
||||||
class MatrixMultiplication
|
|
||||||
: public NS::Referencing<MatrixMultiplication, Kernel> {
|
|
||||||
public:
|
|
||||||
static class MatrixMultiplication* alloc();
|
|
||||||
|
|
||||||
MatrixMultiplication* init(
|
|
||||||
MTL::Device* device,
|
|
||||||
bool transposeLeft,
|
|
||||||
bool transposeRight,
|
|
||||||
NS::UInteger resultRows,
|
|
||||||
NS::UInteger resultColumns,
|
|
||||||
NS::UInteger interiorColumns,
|
|
||||||
double alpha,
|
|
||||||
double beta);
|
|
||||||
|
|
||||||
void encodeToCommandBuffer(
|
|
||||||
MTL::CommandBuffer* commandBuffer,
|
|
||||||
Matrix* leftMatrix,
|
|
||||||
Matrix* rightMatrix,
|
|
||||||
Matrix* resultMatrix);
|
|
||||||
|
|
||||||
void setLeftMatrixOrigin(MTL::Origin origin);
|
|
||||||
void setRightMatrixOrigin(MTL::Origin origin);
|
|
||||||
void setResultMatrixOrigin(MTL::Origin origin);
|
|
||||||
void setBatchStart(NS::UInteger batchStart);
|
|
||||||
void setBatchSize(NS::UInteger batchSize);
|
|
||||||
};
|
|
||||||
|
|
||||||
class VectorDescriptor : public NS::Copying<VectorDescriptor> {
|
|
||||||
public:
|
|
||||||
static class VectorDescriptor* vectorDescriptor(
|
|
||||||
NS::UInteger length,
|
|
||||||
NS::UInteger dataType);
|
|
||||||
static class VectorDescriptor* vectorDescriptor(
|
|
||||||
NS::UInteger length,
|
|
||||||
NS::UInteger vectors,
|
|
||||||
NS::UInteger vectorBytes,
|
|
||||||
NS::UInteger dataType);
|
|
||||||
};
|
|
||||||
|
|
||||||
class Vector : public NS::Referencing<Vector> {
|
|
||||||
public:
|
|
||||||
static class Vector* alloc();
|
|
||||||
Vector* init(MTL::Buffer* buffer, VectorDescriptor* descriptor);
|
|
||||||
Vector* init(const MTL::Buffer* buffer, VectorDescriptor* descriptor);
|
|
||||||
};
|
|
||||||
|
|
||||||
class MatrixVectorMultiplication
|
|
||||||
: public NS::Referencing<MatrixVectorMultiplication, Kernel> {
|
|
||||||
public:
|
|
||||||
static class MatrixVectorMultiplication* alloc();
|
|
||||||
|
|
||||||
MatrixVectorMultiplication* init(
|
|
||||||
MTL::Device* device,
|
|
||||||
bool transpose,
|
|
||||||
NS::UInteger rows,
|
|
||||||
NS::UInteger columns,
|
|
||||||
double alpha,
|
|
||||||
double beta);
|
|
||||||
|
|
||||||
void encodeToCommandBuffer(
|
|
||||||
MTL::CommandBuffer* commandBuffer,
|
|
||||||
Matrix* inputMatrix,
|
|
||||||
Vector* inputVector,
|
|
||||||
Vector* resultVector);
|
|
||||||
};
|
|
||||||
|
|
||||||
_MTL_INLINE MatrixDescriptor* MatrixDescriptor::matrixDescriptor(
|
|
||||||
NS::UInteger rows,
|
|
||||||
NS::UInteger columns,
|
|
||||||
NS::UInteger rowBytes,
|
|
||||||
NS::UInteger dataType) {
|
|
||||||
return Object::sendMessage<MatrixDescriptor*>(
|
|
||||||
_MPS_PRIVATE_CLS(MPSMatrixDescriptor),
|
|
||||||
_MPS_PRIVATE_SEL(matrixDescriptorWithRows_columns_rowBytes_dataType),
|
|
||||||
rows,
|
|
||||||
columns,
|
|
||||||
rowBytes,
|
|
||||||
dataType);
|
|
||||||
}
|
|
||||||
|
|
||||||
_MTL_INLINE MatrixDescriptor* MatrixDescriptor::matrixDescriptor(
|
|
||||||
NS::UInteger rows,
|
|
||||||
NS::UInteger columns,
|
|
||||||
NS::UInteger matrices,
|
|
||||||
NS::UInteger rowBytes,
|
|
||||||
NS::UInteger matrixBytes,
|
|
||||||
NS::UInteger dataType) {
|
|
||||||
return Object::sendMessage<MatrixDescriptor*>(
|
|
||||||
_MPS_PRIVATE_CLS(MPSMatrixDescriptor),
|
|
||||||
_MPS_PRIVATE_SEL(
|
|
||||||
matrixDescriptorWithRows_columns_matrices_rowBytes_matrixBytes_dataType),
|
|
||||||
rows,
|
|
||||||
columns,
|
|
||||||
matrices,
|
|
||||||
rowBytes,
|
|
||||||
matrixBytes,
|
|
||||||
dataType);
|
|
||||||
}
|
|
||||||
|
|
||||||
_MTL_INLINE NS::UInteger MatrixDescriptor::rows() const {
|
|
||||||
return Object::sendMessage<NS::UInteger>(this, _MPS_PRIVATE_SEL(rows));
|
|
||||||
}
|
|
||||||
|
|
||||||
_MTL_INLINE Matrix* Matrix::alloc() {
|
|
||||||
return NS::Object::alloc<Matrix>(_MPS_PRIVATE_CLS(MPSMatrix));
|
|
||||||
}
|
|
||||||
|
|
||||||
_MTL_INLINE Matrix* Matrix::init(
|
|
||||||
MTL::Buffer* buffer,
|
|
||||||
MatrixDescriptor* descriptor) {
|
|
||||||
return Object::sendMessage<Matrix*>(
|
|
||||||
this, _MPS_PRIVATE_SEL(initWithBuffer_descriptor), buffer, descriptor);
|
|
||||||
}
|
|
||||||
|
|
||||||
_MTL_INLINE Matrix* Matrix::init(
|
|
||||||
const MTL::Buffer* buffer,
|
|
||||||
MatrixDescriptor* descriptor) {
|
|
||||||
return init(const_cast<MTL::Buffer*>(buffer), descriptor);
|
|
||||||
}
|
|
||||||
|
|
||||||
_MTL_INLINE NS::String* Kernel::label() const {
|
|
||||||
return Object::sendMessage<NS::String*>(this, _MPS_PRIVATE_SEL(label));
|
|
||||||
}
|
|
||||||
|
|
||||||
_MTL_INLINE MTL::Device* Kernel::device() const {
|
|
||||||
return Object::sendMessage<MTL::Device*>(this, _MPS_PRIVATE_SEL(device));
|
|
||||||
}
|
|
||||||
|
|
||||||
_MTL_INLINE MatrixMultiplication* MatrixMultiplication::alloc() {
|
|
||||||
return NS::Object::alloc<MatrixMultiplication>(
|
|
||||||
_MPS_PRIVATE_CLS(MPSMatrixMultiplication));
|
|
||||||
}
|
|
||||||
|
|
||||||
_MTL_INLINE MatrixMultiplication* MatrixMultiplication::init(
|
|
||||||
MTL::Device* device,
|
|
||||||
bool transposeLeft,
|
|
||||||
bool transposeRight,
|
|
||||||
NS::UInteger resultRows,
|
|
||||||
NS::UInteger resultColumns,
|
|
||||||
NS::UInteger interiorColumns,
|
|
||||||
double alpha,
|
|
||||||
double beta) {
|
|
||||||
return Object::sendMessage<MatrixMultiplication*>(
|
|
||||||
this,
|
|
||||||
_MPS_PRIVATE_SEL(initWithDevice_),
|
|
||||||
device,
|
|
||||||
transposeLeft,
|
|
||||||
transposeRight,
|
|
||||||
resultRows,
|
|
||||||
resultColumns,
|
|
||||||
interiorColumns,
|
|
||||||
alpha,
|
|
||||||
beta);
|
|
||||||
}
|
|
||||||
|
|
||||||
_MTL_INLINE void MatrixMultiplication::encodeToCommandBuffer(
|
|
||||||
MTL::CommandBuffer* commandBuffer,
|
|
||||||
Matrix* leftMatrix,
|
|
||||||
Matrix* rightMatrix,
|
|
||||||
Matrix* resultMatrix) {
|
|
||||||
return Object::sendMessage<void>(
|
|
||||||
this,
|
|
||||||
_MPS_PRIVATE_SEL(
|
|
||||||
encodeToCommandBuffer_leftMatrix_rightMatrix_resultMatrix),
|
|
||||||
commandBuffer,
|
|
||||||
leftMatrix,
|
|
||||||
rightMatrix,
|
|
||||||
resultMatrix);
|
|
||||||
}
|
|
||||||
|
|
||||||
_MTL_INLINE void MatrixMultiplication::setLeftMatrixOrigin(MTL::Origin origin) {
|
|
||||||
Object::sendMessage<void>(
|
|
||||||
this, _MPS_PRIVATE_SEL(setLeftMatrixOrigin_), origin);
|
|
||||||
}
|
|
||||||
|
|
||||||
_MTL_INLINE void MatrixMultiplication::setRightMatrixOrigin(
|
|
||||||
MTL::Origin origin) {
|
|
||||||
Object::sendMessage<void>(
|
|
||||||
this, _MPS_PRIVATE_SEL(setRightMatrixOrigin_), origin);
|
|
||||||
}
|
|
||||||
|
|
||||||
_MTL_INLINE void MatrixMultiplication::setResultMatrixOrigin(
|
|
||||||
MTL::Origin origin) {
|
|
||||||
Object::sendMessage<void>(
|
|
||||||
this, _MPS_PRIVATE_SEL(setResultMatrixOrigin_), origin);
|
|
||||||
}
|
|
||||||
|
|
||||||
_MTL_INLINE void MatrixMultiplication::setBatchStart(NS::UInteger batchStart) {
|
|
||||||
Object::sendMessage<void>(this, _MPS_PRIVATE_SEL(setBatchStart_), batchStart);
|
|
||||||
}
|
|
||||||
|
|
||||||
_MTL_INLINE void MatrixMultiplication::setBatchSize(NS::UInteger batchSize) {
|
|
||||||
Object::sendMessage<void>(this, _MPS_PRIVATE_SEL(setBatchSize_), batchSize);
|
|
||||||
}
|
|
||||||
|
|
||||||
_MTL_INLINE VectorDescriptor* VectorDescriptor::vectorDescriptor(
|
|
||||||
NS::UInteger length,
|
|
||||||
NS::UInteger dataType) {
|
|
||||||
return Object::sendMessage<VectorDescriptor*>(
|
|
||||||
_MPS_PRIVATE_CLS(MPSVectorDescriptor),
|
|
||||||
_MPS_PRIVATE_SEL(vectorDescriptorWithLength_dataType),
|
|
||||||
length,
|
|
||||||
dataType);
|
|
||||||
}
|
|
||||||
|
|
||||||
_MTL_INLINE VectorDescriptor* VectorDescriptor::vectorDescriptor(
|
|
||||||
NS::UInteger length,
|
|
||||||
NS::UInteger vectors,
|
|
||||||
NS::UInteger vectorBytes,
|
|
||||||
NS::UInteger dataType) {
|
|
||||||
return Object::sendMessage<VectorDescriptor*>(
|
|
||||||
_MPS_PRIVATE_CLS(MPSVectorDescriptor),
|
|
||||||
_MPS_PRIVATE_SEL(vectorDescriptorWithLength_vectors_vectorBytes_dataType),
|
|
||||||
length,
|
|
||||||
vectors,
|
|
||||||
vectorBytes,
|
|
||||||
dataType);
|
|
||||||
}
|
|
||||||
|
|
||||||
_MTL_INLINE Vector* Vector::alloc() {
|
|
||||||
return NS::Object::alloc<Vector>(_MPS_PRIVATE_CLS(MPSVector));
|
|
||||||
}
|
|
||||||
|
|
||||||
_MTL_INLINE Vector* Vector::init(
|
|
||||||
MTL::Buffer* buffer,
|
|
||||||
VectorDescriptor* descriptor) {
|
|
||||||
return Object::sendMessage<Vector*>(
|
|
||||||
this, _MPS_PRIVATE_SEL(initWithBuffer_descriptor), buffer, descriptor);
|
|
||||||
}
|
|
||||||
|
|
||||||
_MTL_INLINE Vector* Vector::init(
|
|
||||||
const MTL::Buffer* buffer,
|
|
||||||
VectorDescriptor* descriptor) {
|
|
||||||
return init(const_cast<MTL::Buffer*>(buffer), descriptor);
|
|
||||||
}
|
|
||||||
|
|
||||||
_MTL_INLINE MatrixVectorMultiplication* MatrixVectorMultiplication::alloc() {
|
|
||||||
return NS::Object::alloc<MatrixVectorMultiplication>(
|
|
||||||
_MPS_PRIVATE_CLS(MPSMatrixVectorMultiplication));
|
|
||||||
}
|
|
||||||
|
|
||||||
_MTL_INLINE MatrixVectorMultiplication* MatrixVectorMultiplication::init(
|
|
||||||
MTL::Device* device,
|
|
||||||
bool transpose,
|
|
||||||
NS::UInteger rows,
|
|
||||||
NS::UInteger columns,
|
|
||||||
double alpha,
|
|
||||||
double beta) {
|
|
||||||
return Object::sendMessage<MatrixVectorMultiplication*>(
|
|
||||||
this,
|
|
||||||
_MPS_PRIVATE_SEL(initWithDevice_transpose_rows_columns_alpha_beta),
|
|
||||||
device,
|
|
||||||
transpose,
|
|
||||||
rows,
|
|
||||||
columns,
|
|
||||||
alpha,
|
|
||||||
beta);
|
|
||||||
}
|
|
||||||
|
|
||||||
_MTL_INLINE void MatrixVectorMultiplication::encodeToCommandBuffer(
|
|
||||||
MTL::CommandBuffer* commandBuffer,
|
|
||||||
Matrix* inputMatrix,
|
|
||||||
Vector* inputVector,
|
|
||||||
Vector* resultVector) {
|
|
||||||
return Object::sendMessage<void>(
|
|
||||||
this,
|
|
||||||
_MPS_PRIVATE_SEL(
|
|
||||||
encodeToCommandBuffer_inputMatrix_inputVector_resultVector),
|
|
||||||
commandBuffer,
|
|
||||||
inputMatrix,
|
|
||||||
inputVector,
|
|
||||||
resultVector);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace MPS
|
|
@ -169,6 +169,23 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
|
|||||||
return d.get_kernel(kernel_name);
|
return d.get_kernel(kernel_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MTL::ComputePipelineState* get_gemv_masked_kernel(
|
||||||
|
metal::Device& d,
|
||||||
|
const std::string& kernel_name,
|
||||||
|
const array&,
|
||||||
|
const std::optional<array>&,
|
||||||
|
const std::optional<array>&,
|
||||||
|
bool,
|
||||||
|
int,
|
||||||
|
int,
|
||||||
|
int,
|
||||||
|
int,
|
||||||
|
int,
|
||||||
|
int,
|
||||||
|
bool) {
|
||||||
|
return d.get_kernel(kernel_name);
|
||||||
|
}
|
||||||
|
|
||||||
MTL::ComputePipelineState* get_steel_conv_kernel(
|
MTL::ComputePipelineState* get_steel_conv_kernel(
|
||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
const std::string& kernel_name,
|
const std::string& kernel_name,
|
||||||
|
@ -32,7 +32,7 @@ void ternary_op_gpu_inplace(
|
|||||||
auto& strides_c = strides[2];
|
auto& strides_c = strides[2];
|
||||||
auto& strides_out = strides[3];
|
auto& strides_out = strides[3];
|
||||||
|
|
||||||
bool use_2d = out.data_size();
|
bool use_2d = out.data_size() > UINT_MAX;
|
||||||
std::string kernel_name;
|
std::string kernel_name;
|
||||||
{
|
{
|
||||||
std::ostringstream kname;
|
std::ostringstream kname;
|
||||||
|
116
mlx/backend/metal/utils.cpp
Normal file
116
mlx/backend/metal/utils.cpp
Normal file
@ -0,0 +1,116 @@
|
|||||||
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/utils.h"
|
||||||
|
|
||||||
|
using namespace mlx;
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
std::string type_to_name(const array& a) {
|
||||||
|
std::string tname;
|
||||||
|
switch (a.dtype()) {
|
||||||
|
case bool_:
|
||||||
|
tname = "bool_";
|
||||||
|
break;
|
||||||
|
case uint8:
|
||||||
|
tname = "uint8";
|
||||||
|
break;
|
||||||
|
case uint16:
|
||||||
|
tname = "uint16";
|
||||||
|
break;
|
||||||
|
case uint32:
|
||||||
|
tname = "uint32";
|
||||||
|
break;
|
||||||
|
case uint64:
|
||||||
|
tname = "uint64";
|
||||||
|
break;
|
||||||
|
case int8:
|
||||||
|
tname = "int8";
|
||||||
|
break;
|
||||||
|
case int16:
|
||||||
|
tname = "int16";
|
||||||
|
break;
|
||||||
|
case int32:
|
||||||
|
tname = "int32";
|
||||||
|
break;
|
||||||
|
case int64:
|
||||||
|
tname = "int64";
|
||||||
|
break;
|
||||||
|
case float16:
|
||||||
|
tname = "float16";
|
||||||
|
break;
|
||||||
|
case float32:
|
||||||
|
tname = "float32";
|
||||||
|
break;
|
||||||
|
case bfloat16:
|
||||||
|
tname = "bfloat16";
|
||||||
|
break;
|
||||||
|
case complex64:
|
||||||
|
tname = "complex64";
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
return tname;
|
||||||
|
}
|
||||||
|
|
||||||
|
MTL::Size get_block_dims(int dim0, int dim1, int dim2) {
|
||||||
|
int pows[3] = {0, 0, 0};
|
||||||
|
int sum = 0;
|
||||||
|
while (true) {
|
||||||
|
int presum = sum;
|
||||||
|
// Check all the pows
|
||||||
|
if (dim0 >= (1 << (pows[0] + 1))) {
|
||||||
|
pows[0]++;
|
||||||
|
sum++;
|
||||||
|
}
|
||||||
|
if (sum == 10) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (dim1 >= (1 << (pows[1] + 1))) {
|
||||||
|
pows[1]++;
|
||||||
|
sum++;
|
||||||
|
}
|
||||||
|
if (sum == 10) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (dim2 >= (1 << (pows[2] + 1))) {
|
||||||
|
pows[2]++;
|
||||||
|
sum++;
|
||||||
|
}
|
||||||
|
if (sum == presum || sum == 10) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return MTL::Size{1ul << pows[0], 1ul << pows[1], 1ul << pows[2]};
|
||||||
|
}
|
||||||
|
|
||||||
|
MTL::Size get_2d_grid_dims(
|
||||||
|
const std::vector<int>& shape,
|
||||||
|
const std::vector<size_t>& strides) {
|
||||||
|
// Dims with strides of 0 are ignored as they
|
||||||
|
// correspond to broadcasted dimensions
|
||||||
|
size_t grid_x = 1;
|
||||||
|
size_t grid_y = 1;
|
||||||
|
for (int i = 0; i < shape.size(); ++i) {
|
||||||
|
if (strides[i] == 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (grid_x * shape[i] < UINT32_MAX) {
|
||||||
|
grid_x *= shape[i];
|
||||||
|
} else {
|
||||||
|
grid_y *= shape[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) {
|
||||||
|
throw std::runtime_error("Unable to safely factor shape.");
|
||||||
|
}
|
||||||
|
return MTL::Size(
|
||||||
|
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string get_primitive_string(Primitive* primitive) {
|
||||||
|
std::ostringstream op_t;
|
||||||
|
primitive->print(op_t);
|
||||||
|
return op_t.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
@ -8,8 +8,6 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
using metal::CommandEncoder;
|
using metal::CommandEncoder;
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@ -27,82 +25,13 @@ set_vector_bytes(CommandEncoder& enc, const std::vector<T>& vec, int idx) {
|
|||||||
return set_vector_bytes(enc, vec, vec.size(), idx);
|
return set_vector_bytes(enc, vec, vec.size(), idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string type_to_name(const array& a) {
|
std::string type_to_name(const array& a);
|
||||||
std::string tname;
|
|
||||||
switch (a.dtype()) {
|
|
||||||
case bool_:
|
|
||||||
tname = "bool_";
|
|
||||||
break;
|
|
||||||
case uint8:
|
|
||||||
tname = "uint8";
|
|
||||||
break;
|
|
||||||
case uint16:
|
|
||||||
tname = "uint16";
|
|
||||||
break;
|
|
||||||
case uint32:
|
|
||||||
tname = "uint32";
|
|
||||||
break;
|
|
||||||
case uint64:
|
|
||||||
tname = "uint64";
|
|
||||||
break;
|
|
||||||
case int8:
|
|
||||||
tname = "int8";
|
|
||||||
break;
|
|
||||||
case int16:
|
|
||||||
tname = "int16";
|
|
||||||
break;
|
|
||||||
case int32:
|
|
||||||
tname = "int32";
|
|
||||||
break;
|
|
||||||
case int64:
|
|
||||||
tname = "int64";
|
|
||||||
break;
|
|
||||||
case float16:
|
|
||||||
tname = "float16";
|
|
||||||
break;
|
|
||||||
case float32:
|
|
||||||
tname = "float32";
|
|
||||||
break;
|
|
||||||
case bfloat16:
|
|
||||||
tname = "bfloat16";
|
|
||||||
break;
|
|
||||||
case complex64:
|
|
||||||
tname = "complex64";
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
return tname;
|
|
||||||
}
|
|
||||||
|
|
||||||
MTL::Size get_block_dims(int dim0, int dim1, int dim2) {
|
// Compute the thread block dimensions which fit the given
|
||||||
int pows[3] = {0, 0, 0};
|
// input dimensions.
|
||||||
int sum = 0;
|
// - The thread block dimensions will be powers of two
|
||||||
while (true) {
|
// - The thread block size will be less than 1024
|
||||||
int presum = sum;
|
MTL::Size get_block_dims(int dim0, int dim1, int dim2);
|
||||||
// Check all the pows
|
|
||||||
if (dim0 >= (1 << (pows[0] + 1))) {
|
|
||||||
pows[0]++;
|
|
||||||
sum++;
|
|
||||||
}
|
|
||||||
if (sum == 10) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if (dim1 >= (1 << (pows[1] + 1))) {
|
|
||||||
pows[1]++;
|
|
||||||
sum++;
|
|
||||||
}
|
|
||||||
if (sum == 10) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if (dim2 >= (1 << (pows[2] + 1))) {
|
|
||||||
pows[2]++;
|
|
||||||
sum++;
|
|
||||||
}
|
|
||||||
if (sum == presum || sum == 10) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return MTL::Size{1ul << pows[0], 1ul << pows[1], 1ul << pows[2]};
|
|
||||||
}
|
|
||||||
|
|
||||||
// Computes a 2D grid where each element is < UINT_MAX
|
// Computes a 2D grid where each element is < UINT_MAX
|
||||||
// Assumes:
|
// Assumes:
|
||||||
@ -111,27 +40,7 @@ MTL::Size get_block_dims(int dim0, int dim1, int dim2) {
|
|||||||
// possibly broadcasted array
|
// possibly broadcasted array
|
||||||
MTL::Size get_2d_grid_dims(
|
MTL::Size get_2d_grid_dims(
|
||||||
const std::vector<int>& shape,
|
const std::vector<int>& shape,
|
||||||
const std::vector<size_t>& strides) {
|
const std::vector<size_t>& strides);
|
||||||
// Dims with strides of 0 are ignored as they
|
|
||||||
// correspond to broadcasted dimensions
|
|
||||||
size_t grid_x = 1;
|
|
||||||
size_t grid_y = 1;
|
|
||||||
for (int i = 0; i < shape.size(); ++i) {
|
|
||||||
if (strides[i] == 0) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (grid_x * shape[i] < UINT32_MAX) {
|
|
||||||
grid_x *= shape[i];
|
|
||||||
} else {
|
|
||||||
grid_y *= shape[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) {
|
|
||||||
throw std::runtime_error("Unable to safely factor shape.");
|
|
||||||
}
|
|
||||||
return MTL::Size(
|
|
||||||
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
inline NS::String* make_string(std::ostringstream& os) {
|
inline NS::String* make_string(std::ostringstream& os) {
|
||||||
std::string string = os.str();
|
std::string string = os.str();
|
||||||
@ -159,12 +68,6 @@ inline void debug_set_primitive_buffer_label(
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string get_primitive_string(Primitive* primitive) {
|
std::string get_primitive_string(Primitive* primitive);
|
||||||
std::ostringstream op_t;
|
|
||||||
primitive->print(op_t);
|
|
||||||
return op_t.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -1,11 +1,8 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <sstream>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "mlx/dtype.h"
|
#include "mlx/dtype.h"
|
||||||
#include "mlx/utils.h"
|
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
@ -178,67 +175,4 @@ bool issubdtype(const Dtype::Category& a, const Dtype::Category& b) {
|
|||||||
[static_cast<uint32_t>(b)];
|
[static_cast<uint32_t>(b)];
|
||||||
}
|
}
|
||||||
|
|
||||||
// Array protocol typestring for Dtype
|
|
||||||
std::string dtype_to_array_protocol(const Dtype& t) {
|
|
||||||
std::ostringstream r;
|
|
||||||
if (size_of(t) > 1)
|
|
||||||
r << (is_big_endian() ? ">" : "<");
|
|
||||||
else
|
|
||||||
r << "|";
|
|
||||||
r << kindof(t) << (int)size_of(t);
|
|
||||||
return r.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Dtype from array protocol type string
|
|
||||||
Dtype dtype_from_array_protocol(std::string_view t) {
|
|
||||||
if (t.length() == 2 || t.length() == 3) {
|
|
||||||
std::string_view r = t.length() == 3 ? t.substr(1, 2) : t;
|
|
||||||
|
|
||||||
if (r == "V2") {
|
|
||||||
return bfloat16;
|
|
||||||
}
|
|
||||||
|
|
||||||
uint8_t size = r[1] - '0';
|
|
||||||
|
|
||||||
switch (r[0]) {
|
|
||||||
case 'b': {
|
|
||||||
if (size == 1)
|
|
||||||
return bool_;
|
|
||||||
}
|
|
||||||
case 'i': {
|
|
||||||
if (size == 1)
|
|
||||||
return int8;
|
|
||||||
else if (size == 2)
|
|
||||||
return int16;
|
|
||||||
else if (size == 4)
|
|
||||||
return int32;
|
|
||||||
else if (size == 8)
|
|
||||||
return int64;
|
|
||||||
}
|
|
||||||
case 'u': {
|
|
||||||
if (size == 1)
|
|
||||||
return uint8;
|
|
||||||
else if (size == 2)
|
|
||||||
return uint16;
|
|
||||||
else if (size == 4)
|
|
||||||
return uint32;
|
|
||||||
else if (size == 8)
|
|
||||||
return uint64;
|
|
||||||
}
|
|
||||||
case 'f': {
|
|
||||||
if (size == 2)
|
|
||||||
return float16;
|
|
||||||
else if (size == 4)
|
|
||||||
return float32;
|
|
||||||
}
|
|
||||||
case 'c': {
|
|
||||||
return complex64;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
throw std::invalid_argument(
|
|
||||||
"[from_str] Invalid array protocol type-string: " + std::string(t));
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -4,8 +4,6 @@
|
|||||||
|
|
||||||
#include <complex>
|
#include <complex>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <ostream>
|
|
||||||
#include <string>
|
|
||||||
|
|
||||||
#include "mlx/types/complex.h"
|
#include "mlx/types/complex.h"
|
||||||
#include "mlx/types/half_types.h"
|
#include "mlx/types/half_types.h"
|
||||||
@ -103,9 +101,4 @@ struct TypeToDtype {
|
|||||||
operator Dtype();
|
operator Dtype();
|
||||||
};
|
};
|
||||||
|
|
||||||
// Array protocol typestring for Dtype
|
|
||||||
std::string dtype_to_array_protocol(const Dtype& t);
|
|
||||||
// Dtype from array protocol type string
|
|
||||||
Dtype dtype_from_array_protocol(std::string_view t);
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -26,6 +26,80 @@ constexpr uint8_t MAGIC[] = {
|
|||||||
0x59,
|
0x59,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
inline bool is_big_endian() {
|
||||||
|
union ByteOrder {
|
||||||
|
int32_t i;
|
||||||
|
uint8_t c[4];
|
||||||
|
};
|
||||||
|
ByteOrder b = {0x01234567};
|
||||||
|
|
||||||
|
return b.c[0] == 0x01;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Array protocol typestring for Dtype
|
||||||
|
std::string dtype_to_array_protocol(const Dtype& t) {
|
||||||
|
std::ostringstream r;
|
||||||
|
if (size_of(t) > 1) {
|
||||||
|
r << (is_big_endian() ? ">" : "<");
|
||||||
|
} else {
|
||||||
|
r << "|";
|
||||||
|
}
|
||||||
|
r << kindof(t) << (int)size_of(t);
|
||||||
|
return r.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dtype from array protocol type string
|
||||||
|
Dtype dtype_from_array_protocol(std::string_view t) {
|
||||||
|
if (t.length() == 2 || t.length() == 3) {
|
||||||
|
std::string_view r = t.length() == 3 ? t.substr(1, 2) : t;
|
||||||
|
|
||||||
|
if (r == "V2") {
|
||||||
|
return bfloat16;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint8_t size = r[1] - '0';
|
||||||
|
|
||||||
|
switch (r[0]) {
|
||||||
|
case 'b': {
|
||||||
|
if (size == 1)
|
||||||
|
return bool_;
|
||||||
|
}
|
||||||
|
case 'i': {
|
||||||
|
if (size == 1)
|
||||||
|
return int8;
|
||||||
|
else if (size == 2)
|
||||||
|
return int16;
|
||||||
|
else if (size == 4)
|
||||||
|
return int32;
|
||||||
|
else if (size == 8)
|
||||||
|
return int64;
|
||||||
|
}
|
||||||
|
case 'u': {
|
||||||
|
if (size == 1)
|
||||||
|
return uint8;
|
||||||
|
else if (size == 2)
|
||||||
|
return uint16;
|
||||||
|
else if (size == 4)
|
||||||
|
return uint32;
|
||||||
|
else if (size == 8)
|
||||||
|
return uint64;
|
||||||
|
}
|
||||||
|
case 'f': {
|
||||||
|
if (size == 2)
|
||||||
|
return float16;
|
||||||
|
else if (size == 4)
|
||||||
|
return float32;
|
||||||
|
}
|
||||||
|
case 'c': {
|
||||||
|
return complex64;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[from_str] Invalid array protocol type-string: " + std::string(t));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
/** Save array to out stream in .npy format */
|
/** Save array to out stream in .npy format */
|
||||||
|
10
mlx/utils.h
10
mlx/utils.h
@ -84,16 +84,6 @@ int check_shape_dim(const T dim) {
|
|||||||
return static_cast<int>(dim);
|
return static_cast<int>(dim);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline bool is_big_endian() {
|
|
||||||
union ByteOrder {
|
|
||||||
int32_t i;
|
|
||||||
uint8_t c[4];
|
|
||||||
};
|
|
||||||
ByteOrder b = {0x01234567};
|
|
||||||
|
|
||||||
return b.c[0] == 0x01;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns the axis normalized to be in the range [0, ndim).
|
* Returns the axis normalized to be in the range [0, ndim).
|
||||||
* Based on numpy's normalize_axis_index. See
|
* Based on numpy's normalize_axis_index. See
|
||||||
|
@ -161,6 +161,8 @@ TEST_CASE("test array types") {
|
|||||||
// bfloat16
|
// bfloat16
|
||||||
{ basic_dtype_test(bfloat16_t, bfloat16); }
|
{ basic_dtype_test(bfloat16_t, bfloat16); }
|
||||||
|
|
||||||
|
#undef basic_dtype_test
|
||||||
|
|
||||||
// uint32
|
// uint32
|
||||||
{
|
{
|
||||||
uint32_t val = UINT_MAX;
|
uint32_t val = UINT_MAX;
|
||||||
@ -233,31 +235,6 @@ TEST_CASE("test array types") {
|
|||||||
CHECK_EQ(x.dtype(), complex64);
|
CHECK_EQ(x.dtype(), complex64);
|
||||||
CHECK_EQ(x.item<complex64_t>(), v);
|
CHECK_EQ(x.item<complex64_t>(), v);
|
||||||
}
|
}
|
||||||
|
|
||||||
#undef basic_dtype_test
|
|
||||||
|
|
||||||
#define basic_dtype_str_test(s, dtype) \
|
|
||||||
CHECK_EQ(s, dtype_to_array_protocol(dtype)); \
|
|
||||||
CHECK_EQ(dtype_from_array_protocol(s), dtype);
|
|
||||||
|
|
||||||
// To and from str
|
|
||||||
{
|
|
||||||
basic_dtype_str_test("|b1", bool_);
|
|
||||||
basic_dtype_str_test("|u1", uint8);
|
|
||||||
basic_dtype_str_test("<u2", uint16);
|
|
||||||
basic_dtype_str_test("<u4", uint32);
|
|
||||||
basic_dtype_str_test("<u8", uint64);
|
|
||||||
basic_dtype_str_test("|i1", int8);
|
|
||||||
basic_dtype_str_test("<i2", int16);
|
|
||||||
basic_dtype_str_test("<i4", int32);
|
|
||||||
basic_dtype_str_test("<i8", int64);
|
|
||||||
basic_dtype_str_test("<f2", float16);
|
|
||||||
basic_dtype_str_test("<f4", float32);
|
|
||||||
basic_dtype_str_test("<V2", bfloat16);
|
|
||||||
basic_dtype_str_test("<c8", complex64);
|
|
||||||
}
|
|
||||||
|
|
||||||
#undef basic_dtype_str_test
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_CASE("test array metadata") {
|
TEST_CASE("test array metadata") {
|
||||||
|
Loading…
Reference in New Issue
Block a user