Add gemv masked to JIT plus some fixes (#1310)

* add gemv masked to JIT plus some fixes

* some cleanup

* add utils

* fix

* fix 2

* more cleaning

* fix

* remove unused mps matmul support

* one more nit

* revert
This commit is contained in:
Awni Hannun
2024-08-07 13:38:07 -07:00
committed by GitHub
parent 635ccd9e25
commit 30bbea2f08
25 changed files with 1230 additions and 1702 deletions

View File

@@ -14,7 +14,6 @@
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/metal.h"
#include "mlx/backend/metal/metal_impl.h"
#include "mlx/backend/metal/mps/gemm.h"
#include "mlx/backend/metal/utils.h"
namespace fs = std::filesystem;
@@ -39,6 +38,20 @@ constexpr auto get_metal_version() {
#endif
}
std::string get_colocated_mtllib_path(const std::string& lib_name) {
Dl_info info;
std::string mtllib_path;
std::string lib_ext = lib_name + ".metallib";
int success = dladdr((void*)get_colocated_mtllib_path, &info);
if (success) {
auto mtllib = fs::path(info.dli_fname).remove_filename() / lib_ext;
mtllib_path = mtllib.c_str();
}
return mtllib_path;
}
auto load_device() {
auto devices = MTL::CopyAllDevices();
auto device = static_cast<MTL::Device*>(devices->object(0))
@@ -126,6 +139,49 @@ MTL::Library* load_library(
} // namespace
CommandEncoder::CommandEncoder(MTL::CommandBuffer* cbuf) : cbuf(cbuf) {
enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
enc->retain();
}
CommandEncoder::~CommandEncoder() {
enc->endEncoding();
enc->release();
}
void CommandEncoder::set_input_array(
const array& a,
int idx,
int64_t offset /* = 0 */) {
auto r_buf = static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
if (auto it = outputs.find(r_buf); it != outputs.end()) {
// Insert a barrier
enc->memoryBarrier(&r_buf, 1);
// Remove the output
outputs.erase(it);
}
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
auto base_offset = a.data<char>() -
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
base_offset += offset;
enc->setBuffer(a_buf, base_offset, idx);
}
void CommandEncoder::set_output_array(
array& a,
int idx,
int64_t offset /* = 0 */) {
// Add barriers before adding the output to the output set
set_input_array(a, idx, offset);
auto buf = static_cast<MTL::Resource*>(a.buffer().ptr());
if (concurrent) {
concurrent_outputs.insert(buf);
} else {
outputs.insert(buf);
}
}
void CommandEncoder::dispatchThreadgroups(
MTL::Size grid_dims,
MTL::Size group_dims) {
@@ -255,13 +311,9 @@ void Device::register_library(
}
}
void Device::register_library(
const std::string& lib_name,
const std::function<std::string(const std::string&)>& lib_path_func) {
void Device::register_library(const std::string& lib_name) {
if (auto it = library_map_.find(lib_name); it == library_map_.end()) {
std::string new_lib_path = lib_path_func(lib_name);
auto new_lib = load_library(device_, lib_name, new_lib_path.c_str());
library_map_.insert({lib_name, new_lib});
register_library(lib_name, get_colocated_mtllib_path(lib_name));
}
}
@@ -271,7 +323,7 @@ MTL::Library* Device::get_library_cache_(const std::string& lib_name) {
if (auto it = library_map_.find(lib_name); it != library_map_.end()) {
mtl_lib = it->second;
} else { // Look for metallib alongside library
register_library(lib_name);
register_library(lib_name, get_colocated_mtllib_path(lib_name));
mtl_lib = library_map_[lib_name];
}