mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-27 03:21:19 +08:00
Make the GPU device more thread safe (#1478)
* gpu stream safety * comment * fix
This commit is contained in:
parent
c21331d47f
commit
bf6ec92216
@ -202,10 +202,7 @@ void Compiled::eval_gpu(
|
|||||||
// Get the kernel if someone else built it already
|
// Get the kernel if someone else built it already
|
||||||
auto& s = stream();
|
auto& s = stream();
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
auto lib = d.get_library(kernel_lib_);
|
auto lib = d.get_library(kernel_lib_, [&]() {
|
||||||
|
|
||||||
// If not we have to build it ourselves
|
|
||||||
if (lib == nullptr) {
|
|
||||||
std::ostringstream kernel;
|
std::ostringstream kernel;
|
||||||
kernel << metal::utils() << metal::unary_ops() << metal::binary_ops()
|
kernel << metal::utils() << metal::unary_ops() << metal::binary_ops()
|
||||||
<< metal::ternary_ops();
|
<< metal::ternary_ops();
|
||||||
@ -252,9 +249,8 @@ void Compiled::eval_gpu(
|
|||||||
/* contiguous = */ false,
|
/* contiguous = */ false,
|
||||||
/* ndim = */ 0,
|
/* ndim = */ 0,
|
||||||
/* dynamic_dims = */ true);
|
/* dynamic_dims = */ true);
|
||||||
|
return kernel.str();
|
||||||
lib = d.get_library(kernel_lib_, kernel.str());
|
});
|
||||||
}
|
|
||||||
|
|
||||||
// Figure out which kernel we are using
|
// Figure out which kernel we are using
|
||||||
auto& output_shape = outputs[0].shape();
|
auto& output_shape = outputs[0].shape();
|
||||||
|
@ -39,10 +39,8 @@ void CustomKernel::eval_gpu(
|
|||||||
|
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
const auto& lib_name = name_;
|
const auto& lib_name = name_;
|
||||||
auto lib = d.get_library(lib_name);
|
auto lib =
|
||||||
if (lib == nullptr) {
|
d.get_library(lib_name, [this] { return metal::utils() + source_; });
|
||||||
lib = d.get_library(lib_name, metal::utils() + source_);
|
|
||||||
}
|
|
||||||
auto kernel = d.get_kernel(name_, lib);
|
auto kernel = d.get_kernel(name_, lib);
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
compute_encoder->setComputePipelineState(kernel);
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
@ -219,7 +219,6 @@ void Device::new_queue(int index) {
|
|||||||
|
|
||||||
// Multiple threads can ask the device for queues
|
// Multiple threads can ask the device for queues
|
||||||
// We lock this as a critical section for safety
|
// We lock this as a critical section for safety
|
||||||
const std::lock_guard<std::mutex> lock(mtx_);
|
|
||||||
auto q = device_->newCommandQueue(MAX_BUFFERS_PER_QUEUE);
|
auto q = device_->newCommandQueue(MAX_BUFFERS_PER_QUEUE);
|
||||||
debug_set_stream_queue_label(q, index);
|
debug_set_stream_queue_label(q, index);
|
||||||
if (!q) {
|
if (!q) {
|
||||||
@ -227,21 +226,21 @@ void Device::new_queue(int index) {
|
|||||||
"[metal::Device] Failed to make new command queue.");
|
"[metal::Device] Failed to make new command queue.");
|
||||||
}
|
}
|
||||||
queue_map_.insert({index, q});
|
queue_map_.insert({index, q});
|
||||||
|
buffer_map_.insert({index, {0, nullptr}});
|
||||||
|
encoder_map_.insert({index, nullptr});
|
||||||
}
|
}
|
||||||
|
|
||||||
int Device::get_command_buffer_ops(int index) {
|
int Device::get_command_buffer_ops(int index) {
|
||||||
auto bit = buffer_map_.find(index);
|
return buffer_map_[index].first;
|
||||||
return bit->second.first;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void Device::increment_command_buffer_ops(int index) {
|
void Device::increment_command_buffer_ops(int index) {
|
||||||
auto bit = buffer_map_.find(index);
|
buffer_map_[index].first++;
|
||||||
bit->second.first++;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
MTL::CommandBuffer* Device::get_command_buffer(int index) {
|
MTL::CommandBuffer* Device::get_command_buffer(int index) {
|
||||||
auto bit = buffer_map_.find(index);
|
auto bit = buffer_map_.find(index);
|
||||||
if (bit == buffer_map_.end()) {
|
if (bit->second.second == nullptr) {
|
||||||
auto qit = queue_map_.find(index);
|
auto qit = queue_map_.find(index);
|
||||||
if (qit == queue_map_.end()) {
|
if (qit == queue_map_.end()) {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
@ -258,7 +257,7 @@ MTL::CommandBuffer* Device::get_command_buffer(int index) {
|
|||||||
// Increment ref count so the buffer is not garbage collected
|
// Increment ref count so the buffer is not garbage collected
|
||||||
cb->retain();
|
cb->retain();
|
||||||
|
|
||||||
bit = buffer_map_.insert({index, {0, cb}}).first;
|
bit->second = {0, cb};
|
||||||
}
|
}
|
||||||
return bit->second.second;
|
return bit->second.second;
|
||||||
}
|
}
|
||||||
@ -267,19 +266,18 @@ void Device::commit_command_buffer(int index) {
|
|||||||
auto bit = buffer_map_.find(index);
|
auto bit = buffer_map_.find(index);
|
||||||
bit->second.second->commit();
|
bit->second.second->commit();
|
||||||
bit->second.second->release();
|
bit->second.second->release();
|
||||||
buffer_map_.erase(bit);
|
bit->second = {0, nullptr};
|
||||||
}
|
}
|
||||||
|
|
||||||
void Device::end_encoding(int index) {
|
void Device::end_encoding(int index) {
|
||||||
encoder_map_.erase(index);
|
encoder_map_[index] = nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
CommandEncoder& Device::get_command_encoder(int index) {
|
CommandEncoder& Device::get_command_encoder(int index) {
|
||||||
auto eit = encoder_map_.find(index);
|
auto eit = encoder_map_.find(index);
|
||||||
if (eit == encoder_map_.end()) {
|
if (eit->second == nullptr) {
|
||||||
auto cb = get_command_buffer(index);
|
auto cb = get_command_buffer(index);
|
||||||
eit =
|
eit->second = std::make_unique<CommandEncoder>(cb);
|
||||||
encoder_map_.emplace(index, std::make_unique<CommandEncoder>(cb)).first;
|
|
||||||
}
|
}
|
||||||
return *(eit->second);
|
return *(eit->second);
|
||||||
}
|
}
|
||||||
@ -293,20 +291,7 @@ void Device::register_library(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
MTL::Library* Device::get_library_cache_(const std::string& lib_name) {
|
MTL::Library* Device::build_library_(const std::string& source_string) {
|
||||||
// Search for cached metal lib
|
|
||||||
MTL::Library* mtl_lib;
|
|
||||||
if (auto it = library_map_.find(lib_name); it != library_map_.end()) {
|
|
||||||
mtl_lib = it->second;
|
|
||||||
} else { // Look for metallib alongside library
|
|
||||||
register_library(lib_name, get_colocated_mtllib_path(lib_name));
|
|
||||||
mtl_lib = library_map_[lib_name];
|
|
||||||
}
|
|
||||||
|
|
||||||
return mtl_lib;
|
|
||||||
}
|
|
||||||
|
|
||||||
MTL::Library* Device::get_library_(const std::string& source_string) {
|
|
||||||
auto pool = new_scoped_memory_pool();
|
auto pool = new_scoped_memory_pool();
|
||||||
|
|
||||||
auto ns_code =
|
auto ns_code =
|
||||||
@ -332,25 +317,6 @@ MTL::Library* Device::get_library_(const std::string& source_string) {
|
|||||||
return mtl_lib;
|
return mtl_lib;
|
||||||
}
|
}
|
||||||
|
|
||||||
MTL::Library* Device::get_library_(const MTL::StitchedLibraryDescriptor* desc) {
|
|
||||||
auto pool = new_scoped_memory_pool();
|
|
||||||
|
|
||||||
NS::Error* error = nullptr;
|
|
||||||
auto mtl_lib = device_->newLibrary(desc, &error);
|
|
||||||
|
|
||||||
// Throw error if unable to compile library
|
|
||||||
if (!mtl_lib) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[metal::Device] Unable to build stitched metal library" << "\n";
|
|
||||||
if (error) {
|
|
||||||
msg << error->localizedDescription()->utf8String() << "\n";
|
|
||||||
}
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
return mtl_lib;
|
|
||||||
}
|
|
||||||
|
|
||||||
MTL::Function* Device::get_function_(
|
MTL::Function* Device::get_function_(
|
||||||
const std::string& name,
|
const std::string& name,
|
||||||
MTL::Library* mtl_lib) {
|
MTL::Library* mtl_lib) {
|
||||||
@ -465,68 +431,32 @@ MTL::ComputePipelineState* Device::get_kernel_(
|
|||||||
return kernel;
|
return kernel;
|
||||||
}
|
}
|
||||||
|
|
||||||
MTL::Library* Device::get_library(const std::string& name) {
|
MTL::Library* Device::get_library_(const std::string& name) {
|
||||||
|
std::shared_lock lock(library_mtx_);
|
||||||
auto it = library_map_.find(name);
|
auto it = library_map_.find(name);
|
||||||
return (it != library_map_.end()) ? it->second : nullptr;
|
return (it != library_map_.end()) ? it->second : nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
MTL::Library* Device::get_library(
|
MTL::Library* Device::get_library(
|
||||||
const std::string& name,
|
const std::string& name,
|
||||||
const std::string& source,
|
const std::function<std::string(void)>& builder) {
|
||||||
bool cache /* = true */) {
|
{
|
||||||
if (cache) {
|
std::shared_lock rlock(library_mtx_);
|
||||||
if (auto it = library_map_.find(name); it != library_map_.end()) {
|
if (auto it = library_map_.find(name); it != library_map_.end()) {
|
||||||
return it->second;
|
return it->second;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto mtl_lib = get_library_(source);
|
std::unique_lock wlock(library_mtx_);
|
||||||
|
if (auto it = library_map_.find(name); it != library_map_.end()) {
|
||||||
if (cache) {
|
return it->second;
|
||||||
library_map_.insert({name, mtl_lib});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto mtl_lib = build_library_(builder());
|
||||||
|
library_map_.insert({name, mtl_lib});
|
||||||
return mtl_lib;
|
return mtl_lib;
|
||||||
}
|
}
|
||||||
|
|
||||||
MTL::Library* Device::get_library(
|
|
||||||
const std::string& name,
|
|
||||||
const MTL::StitchedLibraryDescriptor* desc,
|
|
||||||
bool cache /* = true */) {
|
|
||||||
if (cache) {
|
|
||||||
if (auto it = library_map_.find(name); it != library_map_.end()) {
|
|
||||||
return it->second;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
auto mtl_lib = get_library_(desc);
|
|
||||||
|
|
||||||
if (cache) {
|
|
||||||
library_map_.insert({name, mtl_lib});
|
|
||||||
}
|
|
||||||
|
|
||||||
return mtl_lib;
|
|
||||||
}
|
|
||||||
|
|
||||||
MTL::Function* Device::get_function(
|
|
||||||
const std::string& base_name,
|
|
||||||
MTL::Library* mtl_lib,
|
|
||||||
const std::string& specialized_name /* = "" */,
|
|
||||||
const MTLFCList& func_consts /* = {} */) {
|
|
||||||
return get_function_(base_name, specialized_name, func_consts, mtl_lib);
|
|
||||||
}
|
|
||||||
|
|
||||||
MTL::Function* Device::get_function(
|
|
||||||
const std::string& base_name,
|
|
||||||
const std::string& lib_name /* = "mlx" */,
|
|
||||||
const std::string& specialized_name /* = "" */,
|
|
||||||
const MTLFCList& func_consts /* = {} */) {
|
|
||||||
// Search for cached metal lib
|
|
||||||
MTL::Library* mtl_lib = get_library_cache_(lib_name);
|
|
||||||
|
|
||||||
return get_function(base_name, mtl_lib, specialized_name, func_consts);
|
|
||||||
}
|
|
||||||
|
|
||||||
MTL::LinkedFunctions* Device::get_linked_functions_(
|
MTL::LinkedFunctions* Device::get_linked_functions_(
|
||||||
const std::vector<MTL::Function*>& funcs) {
|
const std::vector<MTL::Function*>& funcs) {
|
||||||
if (funcs.empty()) {
|
if (funcs.empty()) {
|
||||||
@ -547,34 +477,55 @@ MTL::LinkedFunctions* Device::get_linked_functions_(
|
|||||||
return lfuncs;
|
return lfuncs;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MTL::ComputePipelineState* Device::get_kernel_(
|
||||||
|
const std::string& base_name,
|
||||||
|
MTL::Library* mtl_lib,
|
||||||
|
const std::string& hash_name,
|
||||||
|
const MTLFCList& func_consts /* = {} */,
|
||||||
|
const std::vector<MTL::Function*>& linked_functions /* = {} */) {
|
||||||
|
// Single writer allowed
|
||||||
|
std::unique_lock wlock(kernel_mtx_);
|
||||||
|
|
||||||
|
// Try loading again to avoid loading twice
|
||||||
|
if (auto it = kernel_map_.find(hash_name); it != kernel_map_.end()) {
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto pool = new_scoped_memory_pool();
|
||||||
|
|
||||||
|
// Pull kernel from library
|
||||||
|
auto mtl_function = get_function_(base_name, hash_name, func_consts, mtl_lib);
|
||||||
|
|
||||||
|
// Compile kernel to compute pipeline
|
||||||
|
auto mtl_linked_funcs = get_linked_functions_(linked_functions);
|
||||||
|
auto kernel = get_kernel_(hash_name, mtl_function, mtl_linked_funcs);
|
||||||
|
|
||||||
|
mtl_function->release();
|
||||||
|
mtl_linked_funcs->release();
|
||||||
|
|
||||||
|
// Add kernel to cache
|
||||||
|
auto inserted = kernel_map_.insert({hash_name, kernel});
|
||||||
|
|
||||||
|
return kernel;
|
||||||
|
}
|
||||||
|
|
||||||
MTL::ComputePipelineState* Device::get_kernel(
|
MTL::ComputePipelineState* Device::get_kernel(
|
||||||
const std::string& base_name,
|
const std::string& base_name,
|
||||||
MTL::Library* mtl_lib,
|
MTL::Library* mtl_lib,
|
||||||
const std::string& hash_name /* = "" */,
|
const std::string& hash_name /* = "" */,
|
||||||
const MTLFCList& func_consts /* = {} */,
|
const MTLFCList& func_consts /* = {} */,
|
||||||
const std::vector<MTL::Function*>& linked_functions /* = {} */) {
|
const std::vector<MTL::Function*>& linked_functions /* = {} */) {
|
||||||
auto pool = new_scoped_memory_pool();
|
|
||||||
|
|
||||||
// Look for cached kernel
|
|
||||||
const auto& kname = hash_name.empty() ? base_name : hash_name;
|
const auto& kname = hash_name.empty() ? base_name : hash_name;
|
||||||
if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) {
|
{
|
||||||
return it->second;
|
// Multiple readers allowed
|
||||||
|
std::shared_lock lock(kernel_mtx_);
|
||||||
|
|
||||||
|
// Look for cached kernel
|
||||||
|
if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) {
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
return get_kernel_(base_name, mtl_lib, kname, func_consts, linked_functions);
|
||||||
// Pull kernel from library
|
|
||||||
auto mtl_function = get_function_(base_name, kname, func_consts, mtl_lib);
|
|
||||||
|
|
||||||
// Compile kernel to compute pipeline
|
|
||||||
auto mtl_linked_funcs = get_linked_functions_(linked_functions);
|
|
||||||
auto kernel = get_kernel_(kname, mtl_function, mtl_linked_funcs);
|
|
||||||
|
|
||||||
mtl_function->release();
|
|
||||||
mtl_linked_funcs->release();
|
|
||||||
|
|
||||||
// Add kernel to cache
|
|
||||||
kernel_map_.insert({kname, kernel});
|
|
||||||
|
|
||||||
return kernel;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
MTL::ComputePipelineState* Device::get_kernel(
|
MTL::ComputePipelineState* Device::get_kernel(
|
||||||
@ -583,16 +534,19 @@ MTL::ComputePipelineState* Device::get_kernel(
|
|||||||
const std::string& hash_name /* = "" */,
|
const std::string& hash_name /* = "" */,
|
||||||
const MTLFCList& func_consts /* = {} */,
|
const MTLFCList& func_consts /* = {} */,
|
||||||
const std::vector<MTL::Function*>& linked_functions /* = {} */) {
|
const std::vector<MTL::Function*>& linked_functions /* = {} */) {
|
||||||
// Look for cached kernel
|
|
||||||
const auto& kname = hash_name.size() == 0 ? base_name : hash_name;
|
const auto& kname = hash_name.size() == 0 ? base_name : hash_name;
|
||||||
if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) {
|
{
|
||||||
return it->second;
|
// Multiple readers allowed
|
||||||
|
std::shared_lock lock(kernel_mtx_);
|
||||||
|
|
||||||
|
// Look for cached kernel
|
||||||
|
if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) {
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Search for cached metal lib
|
// Search for cached metal lib
|
||||||
MTL::Library* mtl_lib = get_library_cache_(lib_name);
|
MTL::Library* mtl_lib = get_library_(lib_name);
|
||||||
|
return get_kernel_(base_name, mtl_lib, kname, func_consts, linked_functions);
|
||||||
return get_kernel(base_name, mtl_lib, kname, func_consts, linked_functions);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Device& device(mlx::core::Device) {
|
Device& device(mlx::core::Device) {
|
||||||
|
@ -7,6 +7,7 @@
|
|||||||
#include <filesystem>
|
#include <filesystem>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
|
#include <shared_mutex>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
@ -114,29 +115,9 @@ class Device {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
MTL::Library* get_library(const std::string& name);
|
|
||||||
|
|
||||||
MTL::Library* get_library(
|
MTL::Library* get_library(
|
||||||
const std::string& name,
|
const std::string& name,
|
||||||
const std::string& source_string,
|
const std::function<std::string(void)>& builder);
|
||||||
bool cache = true);
|
|
||||||
|
|
||||||
MTL::Library* get_library(
|
|
||||||
const std::string& name,
|
|
||||||
const MTL::StitchedLibraryDescriptor* desc,
|
|
||||||
bool cache = true);
|
|
||||||
|
|
||||||
MTL::Function* get_function(
|
|
||||||
const std::string& base_name,
|
|
||||||
MTL::Library* mtl_lib,
|
|
||||||
const std::string& specialized_name = "",
|
|
||||||
const MTLFCList& func_consts = {});
|
|
||||||
|
|
||||||
MTL::Function* get_function(
|
|
||||||
const std::string& base_name,
|
|
||||||
const std::string& lib_name = "mlx",
|
|
||||||
const std::string& specialized_name = "",
|
|
||||||
const MTLFCList& func_consts = {});
|
|
||||||
|
|
||||||
MTL::ComputePipelineState* get_kernel(
|
MTL::ComputePipelineState* get_kernel(
|
||||||
const std::string& base_name,
|
const std::string& base_name,
|
||||||
@ -158,8 +139,8 @@ class Device {
|
|||||||
private:
|
private:
|
||||||
MTL::Library* get_library_cache_(const std::string& name);
|
MTL::Library* get_library_cache_(const std::string& name);
|
||||||
|
|
||||||
MTL::Library* get_library_(const std::string& source_string);
|
MTL::Library* get_library_(const std::string& name);
|
||||||
MTL::Library* get_library_(const MTL::StitchedLibraryDescriptor* desc);
|
MTL::Library* build_library_(const std::string& source_string);
|
||||||
|
|
||||||
MTL::Function* get_function_(const std::string& name, MTL::Library* mtl_lib);
|
MTL::Function* get_function_(const std::string& name, MTL::Library* mtl_lib);
|
||||||
|
|
||||||
@ -181,13 +162,23 @@ class Device {
|
|||||||
const MTL::Function* mtl_function,
|
const MTL::Function* mtl_function,
|
||||||
const MTL::LinkedFunctions* linked_functions);
|
const MTL::LinkedFunctions* linked_functions);
|
||||||
|
|
||||||
|
MTL::ComputePipelineState* get_kernel_(
|
||||||
|
const std::string& base_name,
|
||||||
|
MTL::Library* mtl_lib,
|
||||||
|
const std::string& hash_name,
|
||||||
|
const MTLFCList& func_consts = {},
|
||||||
|
const std::vector<MTL::Function*>& linked_functions = {});
|
||||||
|
|
||||||
MTL::Device* device_;
|
MTL::Device* device_;
|
||||||
std::unordered_map<int32_t, MTL::CommandQueue*> queue_map_;
|
std::unordered_map<int32_t, MTL::CommandQueue*> queue_map_;
|
||||||
std::unordered_map<int32_t, std::pair<int, MTL::CommandBuffer*>> buffer_map_;
|
std::unordered_map<int32_t, std::pair<int, MTL::CommandBuffer*>> buffer_map_;
|
||||||
std::unordered_map<int32_t, std::unique_ptr<CommandEncoder>> encoder_map_;
|
std::unordered_map<int32_t, std::unique_ptr<CommandEncoder>> encoder_map_;
|
||||||
|
|
||||||
|
std::shared_mutex kernel_mtx_;
|
||||||
std::unordered_map<std::string, MTL::ComputePipelineState*> kernel_map_;
|
std::unordered_map<std::string, MTL::ComputePipelineState*> kernel_map_;
|
||||||
|
|
||||||
|
std::shared_mutex library_mtx_;
|
||||||
std::unordered_map<std::string, MTL::Library*> library_map_;
|
std::unordered_map<std::string, MTL::Library*> library_map_;
|
||||||
std::mutex mtx_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
Device& device(mlx::core::Device);
|
Device& device(mlx::core::Device);
|
||||||
|
@ -60,32 +60,6 @@ std::string gen_hadamard_codelet(int m) {
|
|||||||
return source.str();
|
return source.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
void launch_hadamard(
|
|
||||||
const array& in,
|
|
||||||
array& out,
|
|
||||||
int batch_size,
|
|
||||||
int threads_per,
|
|
||||||
const std::string kernel_name,
|
|
||||||
float scale,
|
|
||||||
const Stream& s) {
|
|
||||||
auto& d = metal::device(s.device);
|
|
||||||
|
|
||||||
const auto& lib_name = kernel_name.substr(1);
|
|
||||||
auto lib = d.get_library(lib_name);
|
|
||||||
auto kernel = d.get_kernel(kernel_name, lib);
|
|
||||||
assert(threads_per <= kernel->maxTotalThreadsPerThreadgroup());
|
|
||||||
|
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
|
||||||
compute_encoder->setComputePipelineState(kernel);
|
|
||||||
compute_encoder.set_input_array(in, 0);
|
|
||||||
compute_encoder.set_output_array(out, 1);
|
|
||||||
compute_encoder->setBytes(&scale, sizeof(float), 2);
|
|
||||||
|
|
||||||
MTL::Size group_dims = MTL::Size(1, threads_per, 1);
|
|
||||||
MTL::Size grid_dims = MTL::Size(batch_size, threads_per, 1);
|
|
||||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
|
||||||
}
|
|
||||||
|
|
||||||
void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
auto& s = stream();
|
auto& s = stream();
|
||||||
|
|
||||||
@ -113,7 +87,8 @@ void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
}
|
}
|
||||||
|
|
||||||
auto [n, m] = decompose_hadamard(in.shape(axis));
|
int n, m;
|
||||||
|
std::tie(n, m) = decompose_hadamard(in.shape(axis));
|
||||||
|
|
||||||
if (n * (int)size_of(in.dtype()) > MAX_HADAMARD_BYTES) {
|
if (n * (int)size_of(in.dtype()) > MAX_HADAMARD_BYTES) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
@ -129,8 +104,7 @@ void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto kernel_name = kname.str();
|
auto kernel_name = kname.str();
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
const auto& lib_name = kernel_name;
|
const auto& lib_name = kernel_name;
|
||||||
auto lib = d.get_library(lib_name);
|
auto lib = d.get_library(lib_name, [&]() {
|
||||||
if (lib == nullptr) {
|
|
||||||
std::ostringstream kernel_source;
|
std::ostringstream kernel_source;
|
||||||
auto codelet = gen_hadamard_codelet(m);
|
auto codelet = gen_hadamard_codelet(m);
|
||||||
kernel_source << metal::utils() << codelet << metal::hadamard();
|
kernel_source << metal::utils() << codelet << metal::hadamard();
|
||||||
@ -148,12 +122,31 @@ void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
n,
|
n,
|
||||||
m,
|
m,
|
||||||
read_width);
|
read_width);
|
||||||
lib = d.get_library(lib_name, kernel_source.str());
|
return kernel_source.str();
|
||||||
}
|
});
|
||||||
|
|
||||||
int batch_size = in.size() / n;
|
int batch_size = in.size() / n;
|
||||||
int threads_per = n / max_radix;
|
int threads_per = n / max_radix;
|
||||||
|
|
||||||
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
|
||||||
|
auto launch_hadamard = [&](const array& in,
|
||||||
|
array& out,
|
||||||
|
const std::string& kernel_name,
|
||||||
|
float scale) {
|
||||||
|
auto kernel = d.get_kernel(kernel_name, lib);
|
||||||
|
assert(threads_per <= kernel->maxTotalThreadsPerThreadgroup());
|
||||||
|
|
||||||
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
compute_encoder.set_input_array(in, 0);
|
||||||
|
compute_encoder.set_output_array(out, 1);
|
||||||
|
compute_encoder->setBytes(&scale, sizeof(float), 2);
|
||||||
|
|
||||||
|
MTL::Size group_dims = MTL::Size(1, threads_per, 1);
|
||||||
|
MTL::Size grid_dims = MTL::Size(batch_size, threads_per, 1);
|
||||||
|
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||||
|
};
|
||||||
|
|
||||||
if (m > 1) {
|
if (m > 1) {
|
||||||
// When m is greater than 1, we decompose the
|
// When m is greater than 1, we decompose the
|
||||||
// computation into two uploads to the GPU:
|
// computation into two uploads to the GPU:
|
||||||
@ -171,29 +164,14 @@ void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
temp.set_data(allocator::malloc_or_wait(temp.nbytes()));
|
temp.set_data(allocator::malloc_or_wait(temp.nbytes()));
|
||||||
copies.push_back(temp);
|
copies.push_back(temp);
|
||||||
|
|
||||||
launch_hadamard(
|
launch_hadamard(in_contiguous, temp, "n" + kernel_name, 1.0);
|
||||||
in_contiguous,
|
|
||||||
temp,
|
|
||||||
batch_size,
|
|
||||||
threads_per,
|
|
||||||
"n" + kernel_name,
|
|
||||||
1.0,
|
|
||||||
s);
|
|
||||||
|
|
||||||
// Metal sometimes reports 256 max threads per group for hadamard_m kernel
|
// Metal sometimes reports 256 max threads per group for hadamard_m kernel
|
||||||
threads_per = std::min(n / read_width, MAX_HADAMARD_THREADS_PER_GROUP);
|
threads_per = std::min(n / read_width, MAX_HADAMARD_THREADS_PER_GROUP);
|
||||||
batch_size = in.size() / m / read_width / threads_per;
|
batch_size = in.size() / m / read_width / threads_per;
|
||||||
launch_hadamard(
|
launch_hadamard(temp, out, "m" + kernel_name, scale_);
|
||||||
temp, out, batch_size, threads_per, "m" + kernel_name, scale_, s);
|
|
||||||
} else {
|
} else {
|
||||||
launch_hadamard(
|
launch_hadamard(in_contiguous, out, "n" + kernel_name, scale_);
|
||||||
in_contiguous,
|
|
||||||
out,
|
|
||||||
batch_size,
|
|
||||||
threads_per,
|
|
||||||
"n" + kernel_name,
|
|
||||||
scale_,
|
|
||||||
s);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!copies.empty()) {
|
if (!copies.empty()) {
|
||||||
|
@ -64,8 +64,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
kernel_name = lib_name;
|
kernel_name = lib_name;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto lib = d.get_library(lib_name);
|
auto lib = d.get_library(lib_name, [&]() {
|
||||||
if (lib == nullptr) {
|
|
||||||
std::ostringstream kernel_source;
|
std::ostringstream kernel_source;
|
||||||
kernel_source << metal::utils() << metal::gather();
|
kernel_source << metal::utils() << metal::gather();
|
||||||
std::string out_type_str = get_type_string(out.dtype());
|
std::string out_type_str = get_type_string(out.dtype());
|
||||||
@ -83,8 +82,8 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
idx_args,
|
idx_args,
|
||||||
idx_arr,
|
idx_arr,
|
||||||
idx_ndim);
|
idx_ndim);
|
||||||
lib = d.get_library(lib_name, kernel_source.str());
|
return kernel_source.str();
|
||||||
}
|
});
|
||||||
|
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
auto kernel = d.get_kernel(kernel_name, lib);
|
auto kernel = d.get_kernel(kernel_name, lib);
|
||||||
@ -236,8 +235,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
kernel_name = kname.str();
|
kernel_name = kname.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
auto lib = d.get_library(lib_name);
|
auto lib = d.get_library(lib_name, [&]() {
|
||||||
if (lib == nullptr) {
|
|
||||||
std::ostringstream kernel_source;
|
std::ostringstream kernel_source;
|
||||||
kernel_source << metal::utils() << metal::reduce_utils()
|
kernel_source << metal::utils() << metal::reduce_utils()
|
||||||
<< metal::scatter();
|
<< metal::scatter();
|
||||||
@ -277,8 +275,8 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
nidx,
|
nidx,
|
||||||
idx_args,
|
idx_args,
|
||||||
idx_arr);
|
idx_arr);
|
||||||
lib = d.get_library(lib_name, kernel_source.str());
|
return kernel_source.str();
|
||||||
}
|
});
|
||||||
|
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
auto kernel = d.get_kernel(kernel_name, lib);
|
auto kernel = d.get_kernel(kernel_name, lib);
|
||||||
|
@ -25,15 +25,15 @@ MTL::ComputePipelineState* get_arange_kernel(
|
|||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
const std::string& kernel_name,
|
const std::string& kernel_name,
|
||||||
const array& out) {
|
const array& out) {
|
||||||
const auto& lib_name = kernel_name;
|
auto lib = d.get_library(kernel_name, [&]() {
|
||||||
auto lib = d.get_library(lib_name);
|
|
||||||
if (lib == nullptr) {
|
|
||||||
std::ostringstream kernel_source;
|
std::ostringstream kernel_source;
|
||||||
kernel_source
|
kernel_source << metal::utils() << metal::arange()
|
||||||
<< metal::utils() << metal::arange()
|
<< fmt::format(
|
||||||
<< fmt::format(arange_kernels, lib_name, get_type_string(out.dtype()));
|
arange_kernels,
|
||||||
lib = d.get_library(lib_name, kernel_source.str());
|
kernel_name,
|
||||||
}
|
get_type_string(out.dtype()));
|
||||||
|
return kernel_source.str();
|
||||||
|
});
|
||||||
return d.get_kernel(kernel_name, lib);
|
return d.get_kernel(kernel_name, lib);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -43,8 +43,7 @@ MTL::ComputePipelineState* get_unary_kernel(
|
|||||||
Dtype out_type,
|
Dtype out_type,
|
||||||
const std::string op) {
|
const std::string op) {
|
||||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||||
auto lib = d.get_library(lib_name);
|
auto lib = d.get_library(lib_name, [&]() {
|
||||||
if (lib == nullptr) {
|
|
||||||
std::ostringstream kernel_source;
|
std::ostringstream kernel_source;
|
||||||
kernel_source << metal::utils() << metal::unary_ops() << metal::unary();
|
kernel_source << metal::utils() << metal::unary_ops() << metal::unary();
|
||||||
kernel_source << get_template_definition(
|
kernel_source << get_template_definition(
|
||||||
@ -55,8 +54,8 @@ MTL::ComputePipelineState* get_unary_kernel(
|
|||||||
"g_" + lib_name, "unary_g", get_type_string(out_type), op);
|
"g_" + lib_name, "unary_g", get_type_string(out_type), op);
|
||||||
kernel_source << get_template_definition(
|
kernel_source << get_template_definition(
|
||||||
"gn4_" + lib_name, "unary_g", get_type_string(out_type), op, 4);
|
"gn4_" + lib_name, "unary_g", get_type_string(out_type), op, 4);
|
||||||
lib = d.get_library(lib_name, kernel_source.str());
|
return kernel_source.str();
|
||||||
}
|
});
|
||||||
return d.get_kernel(kernel_name, lib);
|
return d.get_kernel(kernel_name, lib);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -105,13 +104,12 @@ MTL::ComputePipelineState* get_binary_kernel(
|
|||||||
Dtype out_type,
|
Dtype out_type,
|
||||||
const std::string op) {
|
const std::string op) {
|
||||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||||
auto lib = d.get_library(lib_name);
|
auto lib = d.get_library(lib_name, [&]() {
|
||||||
if (lib == nullptr) {
|
|
||||||
std::ostringstream kernel_source;
|
std::ostringstream kernel_source;
|
||||||
kernel_source << metal::utils() << metal::binary_ops() << metal::binary();
|
kernel_source << metal::utils() << metal::binary_ops() << metal::binary();
|
||||||
add_binary_kernels(lib_name, in_type, out_type, op, kernel_source);
|
add_binary_kernels(lib_name, in_type, out_type, op, kernel_source);
|
||||||
lib = d.get_library(lib_name, kernel_source.str());
|
return kernel_source.str();
|
||||||
}
|
});
|
||||||
return d.get_kernel(kernel_name, lib);
|
return d.get_kernel(kernel_name, lib);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -122,14 +120,13 @@ MTL::ComputePipelineState* get_binary_two_kernel(
|
|||||||
Dtype out_type,
|
Dtype out_type,
|
||||||
const std::string op) {
|
const std::string op) {
|
||||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||||
auto lib = d.get_library(lib_name);
|
auto lib = d.get_library(lib_name, [&]() {
|
||||||
if (lib == nullptr) {
|
|
||||||
std::ostringstream kernel_source;
|
std::ostringstream kernel_source;
|
||||||
kernel_source << metal::utils() << metal::binary_ops()
|
kernel_source << metal::utils() << metal::binary_ops()
|
||||||
<< metal::binary_two();
|
<< metal::binary_two();
|
||||||
add_binary_kernels(lib_name, in_type, out_type, op, kernel_source);
|
add_binary_kernels(lib_name, in_type, out_type, op, kernel_source);
|
||||||
lib = d.get_library(lib_name, kernel_source.str());
|
return kernel_source.str();
|
||||||
}
|
});
|
||||||
return d.get_kernel(kernel_name, lib);
|
return d.get_kernel(kernel_name, lib);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -139,8 +136,7 @@ MTL::ComputePipelineState* get_ternary_kernel(
|
|||||||
Dtype type,
|
Dtype type,
|
||||||
const std::string op) {
|
const std::string op) {
|
||||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||||
auto lib = d.get_library(lib_name);
|
auto lib = d.get_library(lib_name, [&]() {
|
||||||
if (lib == nullptr) {
|
|
||||||
std::ostringstream kernel_source;
|
std::ostringstream kernel_source;
|
||||||
const std::array<std::pair<std::string, std::string>, 6> kernel_types = {{
|
const std::array<std::pair<std::string, std::string>, 6> kernel_types = {{
|
||||||
{"v", "ternary_v"},
|
{"v", "ternary_v"},
|
||||||
@ -159,8 +155,8 @@ MTL::ComputePipelineState* get_ternary_kernel(
|
|||||||
}
|
}
|
||||||
kernel_source << get_template_definition(
|
kernel_source << get_template_definition(
|
||||||
"gn4_" + lib_name, "ternary_g", get_type_string(type), op, 4);
|
"gn4_" + lib_name, "ternary_g", get_type_string(type), op, 4);
|
||||||
lib = d.get_library(lib_name, kernel_source.str());
|
return kernel_source.str();
|
||||||
}
|
});
|
||||||
return d.get_kernel(kernel_name, lib);
|
return d.get_kernel(kernel_name, lib);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -170,8 +166,7 @@ MTL::ComputePipelineState* get_copy_kernel(
|
|||||||
const array& in,
|
const array& in,
|
||||||
const array& out) {
|
const array& out) {
|
||||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||||
auto lib = d.get_library(lib_name);
|
auto lib = d.get_library(lib_name, [&]() {
|
||||||
if (lib == nullptr) {
|
|
||||||
std::ostringstream kernel_source;
|
std::ostringstream kernel_source;
|
||||||
auto in_type = get_type_string(in.dtype());
|
auto in_type = get_type_string(in.dtype());
|
||||||
auto out_type = get_type_string(out.dtype());
|
auto out_type = get_type_string(out.dtype());
|
||||||
@ -198,8 +193,8 @@ MTL::ComputePipelineState* get_copy_kernel(
|
|||||||
"gg_" + lib_name, "copy_gg", in_type, out_type)
|
"gg_" + lib_name, "copy_gg", in_type, out_type)
|
||||||
<< get_template_definition(
|
<< get_template_definition(
|
||||||
"ggn4_" + lib_name, "copy_gg", in_type, out_type, 4);
|
"ggn4_" + lib_name, "copy_gg", in_type, out_type, 4);
|
||||||
lib = d.get_library(lib_name, kernel_source.str());
|
return kernel_source.str();
|
||||||
}
|
});
|
||||||
return d.get_kernel(kernel_name, lib);
|
return d.get_kernel(kernel_name, lib);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -209,8 +204,7 @@ MTL::ComputePipelineState* get_softmax_kernel(
|
|||||||
bool precise,
|
bool precise,
|
||||||
const array& out) {
|
const array& out) {
|
||||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||||
auto lib = d.get_library(lib_name);
|
auto lib = d.get_library(lib_name, [&] {
|
||||||
if (lib == nullptr) {
|
|
||||||
std::ostringstream kernel_source;
|
std::ostringstream kernel_source;
|
||||||
kernel_source << metal::utils() << metal::softmax()
|
kernel_source << metal::utils() << metal::softmax()
|
||||||
<< fmt::format(
|
<< fmt::format(
|
||||||
@ -218,8 +212,8 @@ MTL::ComputePipelineState* get_softmax_kernel(
|
|||||||
lib_name,
|
lib_name,
|
||||||
get_type_string(out.dtype()),
|
get_type_string(out.dtype()),
|
||||||
get_type_string(precise ? float32 : out.dtype()));
|
get_type_string(precise ? float32 : out.dtype()));
|
||||||
lib = d.get_library(lib_name, kernel_source.str());
|
return kernel_source.str();
|
||||||
}
|
});
|
||||||
return d.get_kernel(kernel_name, lib);
|
return d.get_kernel(kernel_name, lib);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -232,8 +226,7 @@ MTL::ComputePipelineState* get_scan_kernel(
|
|||||||
const array& in,
|
const array& in,
|
||||||
const array& out) {
|
const array& out) {
|
||||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||||
auto lib = d.get_library(lib_name);
|
auto lib = d.get_library(lib_name, [&]() {
|
||||||
if (lib == nullptr) {
|
|
||||||
std::string op_name = "Cum" + reduce_type;
|
std::string op_name = "Cum" + reduce_type;
|
||||||
op_name[3] = toupper(op_name[3]);
|
op_name[3] = toupper(op_name[3]);
|
||||||
std::ostringstream kernel_source;
|
std::ostringstream kernel_source;
|
||||||
@ -246,8 +239,8 @@ MTL::ComputePipelineState* get_scan_kernel(
|
|||||||
op_name,
|
op_name,
|
||||||
inclusive,
|
inclusive,
|
||||||
reverse);
|
reverse);
|
||||||
lib = d.get_library(lib_name, kernel_source.str());
|
return kernel_source.str();
|
||||||
}
|
});
|
||||||
return d.get_kernel(kernel_name, lib);
|
return d.get_kernel(kernel_name, lib);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -259,8 +252,7 @@ MTL::ComputePipelineState* get_sort_kernel(
|
|||||||
int bn,
|
int bn,
|
||||||
int tn) {
|
int tn) {
|
||||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||||
auto lib = d.get_library(lib_name);
|
auto lib = d.get_library(lib_name, [&]() {
|
||||||
if (lib == nullptr) {
|
|
||||||
std::ostringstream kernel_source;
|
std::ostringstream kernel_source;
|
||||||
auto in_type = get_type_string(in.dtype());
|
auto in_type = get_type_string(in.dtype());
|
||||||
auto out_type = get_type_string(out.dtype());
|
auto out_type = get_type_string(out.dtype());
|
||||||
@ -285,8 +277,8 @@ MTL::ComputePipelineState* get_sort_kernel(
|
|||||||
bn,
|
bn,
|
||||||
tn);
|
tn);
|
||||||
}
|
}
|
||||||
lib = d.get_library(lib_name, kernel_source.str());
|
return kernel_source.str();
|
||||||
}
|
});
|
||||||
return d.get_kernel(kernel_name, lib);
|
return d.get_kernel(kernel_name, lib);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -298,8 +290,7 @@ MTL::ComputePipelineState* get_mb_sort_kernel(
|
|||||||
int bn,
|
int bn,
|
||||||
int tn) {
|
int tn) {
|
||||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||||
auto lib = d.get_library(lib_name);
|
auto lib = d.get_library(lib_name, [&]() {
|
||||||
if (lib == nullptr) {
|
|
||||||
std::ostringstream kernel_source;
|
std::ostringstream kernel_source;
|
||||||
kernel_source << metal::utils() << metal::sort();
|
kernel_source << metal::utils() << metal::sort();
|
||||||
std::array<std::pair<std::string, std::string>, 3> kernel_types = {
|
std::array<std::pair<std::string, std::string>, 3> kernel_types = {
|
||||||
@ -316,8 +307,8 @@ MTL::ComputePipelineState* get_mb_sort_kernel(
|
|||||||
bn,
|
bn,
|
||||||
tn);
|
tn);
|
||||||
}
|
}
|
||||||
lib = d.get_library(lib_name, kernel_source.str());
|
return kernel_source.str();
|
||||||
}
|
});
|
||||||
return d.get_kernel(kernel_name, lib);
|
return d.get_kernel(kernel_name, lib);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -325,8 +316,7 @@ MTL::ComputePipelineState* get_reduce_init_kernel(
|
|||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
const std::string& kernel_name,
|
const std::string& kernel_name,
|
||||||
const array& out) {
|
const array& out) {
|
||||||
auto lib = d.get_library(kernel_name);
|
auto lib = d.get_library(kernel_name, [&]() {
|
||||||
if (lib == nullptr) {
|
|
||||||
std::ostringstream kernel_source;
|
std::ostringstream kernel_source;
|
||||||
std::string op_type = op_name(out);
|
std::string op_type = op_name(out);
|
||||||
op_type[0] = std::toupper(op_name(out)[0]);
|
op_type[0] = std::toupper(op_name(out)[0]);
|
||||||
@ -335,8 +325,8 @@ MTL::ComputePipelineState* get_reduce_init_kernel(
|
|||||||
kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce();
|
kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce();
|
||||||
kernel_source << get_template_definition(
|
kernel_source << get_template_definition(
|
||||||
kernel_name, "init_reduce", out_type, op);
|
kernel_name, "init_reduce", out_type, op);
|
||||||
lib = d.get_library(kernel_name, kernel_source.str());
|
return kernel_source.str();
|
||||||
}
|
});
|
||||||
return d.get_kernel(kernel_name, lib);
|
return d.get_kernel(kernel_name, lib);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -350,8 +340,7 @@ MTL::ComputePipelineState* get_reduce_kernel(
|
|||||||
int ndim /* = -1 */,
|
int ndim /* = -1 */,
|
||||||
int bm /* = -1 */,
|
int bm /* = -1 */,
|
||||||
int bn /* = -1 */) {
|
int bn /* = -1 */) {
|
||||||
auto lib = d.get_library(kernel_name);
|
auto lib = d.get_library(kernel_name, [&]() {
|
||||||
if (lib == nullptr) {
|
|
||||||
std::string op_type = op_name;
|
std::string op_type = op_name;
|
||||||
op_type[0] = std::toupper(op_name[0]);
|
op_type[0] = std::toupper(op_name[0]);
|
||||||
std::ostringstream kernel_source;
|
std::ostringstream kernel_source;
|
||||||
@ -369,8 +358,8 @@ MTL::ComputePipelineState* get_reduce_kernel(
|
|||||||
kernel_source << get_template_definition(
|
kernel_source << get_template_definition(
|
||||||
kernel_name, func_name, in_type, out_type, op);
|
kernel_name, func_name, in_type, out_type, op);
|
||||||
}
|
}
|
||||||
lib = d.get_library(kernel_name, kernel_source.str());
|
return kernel_source.str();
|
||||||
}
|
});
|
||||||
auto st = d.get_kernel(kernel_name, lib);
|
auto st = d.get_kernel(kernel_name, lib);
|
||||||
return st;
|
return st;
|
||||||
}
|
}
|
||||||
@ -389,8 +378,7 @@ MTL::ComputePipelineState* get_steel_gemm_fused_kernel(
|
|||||||
int wm,
|
int wm,
|
||||||
int wn) {
|
int wn) {
|
||||||
const auto& lib_name = kernel_name;
|
const auto& lib_name = kernel_name;
|
||||||
auto lib = d.get_library(lib_name);
|
auto lib = d.get_library(lib_name, [&]() {
|
||||||
if (lib == nullptr) {
|
|
||||||
std::ostringstream kernel_source;
|
std::ostringstream kernel_source;
|
||||||
kernel_source << metal::utils() << metal::gemm()
|
kernel_source << metal::utils() << metal::gemm()
|
||||||
<< metal::steel_gemm_fused()
|
<< metal::steel_gemm_fused()
|
||||||
@ -405,8 +393,8 @@ MTL::ComputePipelineState* get_steel_gemm_fused_kernel(
|
|||||||
"wn"_a = wn,
|
"wn"_a = wn,
|
||||||
"trans_a"_a = transpose_a,
|
"trans_a"_a = transpose_a,
|
||||||
"trans_b"_a = transpose_b);
|
"trans_b"_a = transpose_b);
|
||||||
lib = d.get_library(lib_name, kernel_source.str());
|
return kernel_source.str();
|
||||||
}
|
});
|
||||||
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
|
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -425,8 +413,7 @@ MTL::ComputePipelineState* get_steel_gemm_splitk_kernel(
|
|||||||
bool mn_aligned,
|
bool mn_aligned,
|
||||||
bool k_aligned) {
|
bool k_aligned) {
|
||||||
const auto& lib_name = kernel_name;
|
const auto& lib_name = kernel_name;
|
||||||
auto lib = d.get_library(lib_name);
|
auto lib = d.get_library(lib_name, [&]() {
|
||||||
if (lib == nullptr) {
|
|
||||||
std::ostringstream kernel_source;
|
std::ostringstream kernel_source;
|
||||||
kernel_source << metal::utils() << metal::gemm()
|
kernel_source << metal::utils() << metal::gemm()
|
||||||
<< metal::steel_gemm_splitk()
|
<< metal::steel_gemm_splitk()
|
||||||
@ -444,8 +431,8 @@ MTL::ComputePipelineState* get_steel_gemm_splitk_kernel(
|
|||||||
"trans_b"_a = transpose_b,
|
"trans_b"_a = transpose_b,
|
||||||
"mn_aligned"_a = mn_aligned,
|
"mn_aligned"_a = mn_aligned,
|
||||||
"k_aligned"_a = k_aligned);
|
"k_aligned"_a = k_aligned);
|
||||||
lib = d.get_library(lib_name, kernel_source.str());
|
return kernel_source.str();
|
||||||
}
|
});
|
||||||
return d.get_kernel(kernel_name, lib);
|
return d.get_kernel(kernel_name, lib);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -456,8 +443,7 @@ MTL::ComputePipelineState* get_steel_gemm_splitk_accum_kernel(
|
|||||||
const array& out,
|
const array& out,
|
||||||
bool axbpy) {
|
bool axbpy) {
|
||||||
const auto& lib_name = kernel_name;
|
const auto& lib_name = kernel_name;
|
||||||
auto lib = d.get_library(lib_name);
|
auto lib = d.get_library(lib_name, [&]() {
|
||||||
if (lib == nullptr) {
|
|
||||||
std::ostringstream kernel_source;
|
std::ostringstream kernel_source;
|
||||||
kernel_source << metal::utils() << metal::gemm()
|
kernel_source << metal::utils() << metal::gemm()
|
||||||
<< metal::steel_gemm_splitk()
|
<< metal::steel_gemm_splitk()
|
||||||
@ -467,8 +453,8 @@ MTL::ComputePipelineState* get_steel_gemm_splitk_accum_kernel(
|
|||||||
"name"_a = lib_name,
|
"name"_a = lib_name,
|
||||||
"atype"_a = get_type_string(in.dtype()),
|
"atype"_a = get_type_string(in.dtype()),
|
||||||
"otype"_a = get_type_string(out.dtype()));
|
"otype"_a = get_type_string(out.dtype()));
|
||||||
lib = d.get_library(lib_name, kernel_source.str());
|
return kernel_source.str();
|
||||||
}
|
});
|
||||||
return d.get_kernel(kernel_name, lib);
|
return d.get_kernel(kernel_name, lib);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -488,8 +474,7 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
|
|||||||
bool mn_aligned,
|
bool mn_aligned,
|
||||||
bool k_aligned) {
|
bool k_aligned) {
|
||||||
const auto& lib_name = kernel_name;
|
const auto& lib_name = kernel_name;
|
||||||
auto lib = d.get_library(lib_name);
|
auto lib = d.get_library(lib_name, [&]() {
|
||||||
if (lib == nullptr) {
|
|
||||||
std::ostringstream kernel_source;
|
std::ostringstream kernel_source;
|
||||||
auto out_mask_type = mask_out.has_value()
|
auto out_mask_type = mask_out.has_value()
|
||||||
? get_type_string((*mask_out).dtype())
|
? get_type_string((*mask_out).dtype())
|
||||||
@ -513,8 +498,8 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
|
|||||||
"trans_b"_a = transpose_b,
|
"trans_b"_a = transpose_b,
|
||||||
"mn_aligned"_a = mn_aligned,
|
"mn_aligned"_a = mn_aligned,
|
||||||
"k_aligned"_a = k_aligned);
|
"k_aligned"_a = k_aligned);
|
||||||
lib = d.get_library(lib_name, kernel_source.str());
|
return kernel_source.str();
|
||||||
}
|
});
|
||||||
return d.get_kernel(kernel_name, lib);
|
return d.get_kernel(kernel_name, lib);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -533,8 +518,7 @@ MTL::ComputePipelineState* get_gemv_masked_kernel(
|
|||||||
int tn,
|
int tn,
|
||||||
bool contiguous) {
|
bool contiguous) {
|
||||||
const auto& lib_name = kernel_name;
|
const auto& lib_name = kernel_name;
|
||||||
auto lib = d.get_library(lib_name);
|
auto lib = d.get_library(lib_name, [&]() {
|
||||||
if (lib == nullptr) {
|
|
||||||
std::ostringstream kernel_source;
|
std::ostringstream kernel_source;
|
||||||
auto out_mask_type = mask_out.has_value()
|
auto out_mask_type = mask_out.has_value()
|
||||||
? get_type_string((*mask_out).dtype())
|
? get_type_string((*mask_out).dtype())
|
||||||
@ -556,8 +540,8 @@ MTL::ComputePipelineState* get_gemv_masked_kernel(
|
|||||||
"tn"_a = tn,
|
"tn"_a = tn,
|
||||||
"trans"_a = transpose_mat ? "t_" : "",
|
"trans"_a = transpose_mat ? "t_" : "",
|
||||||
"nc"_a = contiguous ? "0" : "1");
|
"nc"_a = contiguous ? "0" : "1");
|
||||||
lib = d.get_library(lib_name, kernel_source.str());
|
return kernel_source.str();
|
||||||
}
|
});
|
||||||
return d.get_kernel(kernel_name, lib);
|
return d.get_kernel(kernel_name, lib);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -573,8 +557,7 @@ MTL::ComputePipelineState* get_steel_conv_kernel(
|
|||||||
int n_channel_specialization,
|
int n_channel_specialization,
|
||||||
bool small_filter) {
|
bool small_filter) {
|
||||||
const auto& lib_name = kernel_name;
|
const auto& lib_name = kernel_name;
|
||||||
auto lib = d.get_library(lib_name);
|
auto lib = d.get_library(lib_name, [&]() {
|
||||||
if (lib == nullptr) {
|
|
||||||
std::ostringstream kernel_source;
|
std::ostringstream kernel_source;
|
||||||
kernel_source << metal::utils() << metal::conv() << metal::steel_conv()
|
kernel_source << metal::utils() << metal::conv() << metal::steel_conv()
|
||||||
<< fmt::format(
|
<< fmt::format(
|
||||||
@ -588,8 +571,8 @@ MTL::ComputePipelineState* get_steel_conv_kernel(
|
|||||||
"wn"_a = wn,
|
"wn"_a = wn,
|
||||||
"n_channels"_a = n_channel_specialization,
|
"n_channels"_a = n_channel_specialization,
|
||||||
"small_filter"_a = small_filter);
|
"small_filter"_a = small_filter);
|
||||||
lib = d.get_library(lib_name, kernel_source.str());
|
return kernel_source.str();
|
||||||
}
|
});
|
||||||
return d.get_kernel(kernel_name, lib);
|
return d.get_kernel(kernel_name, lib);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -603,8 +586,7 @@ MTL::ComputePipelineState* get_steel_conv_general_kernel(
|
|||||||
int wm,
|
int wm,
|
||||||
int wn) {
|
int wn) {
|
||||||
const auto& lib_name = kernel_name;
|
const auto& lib_name = kernel_name;
|
||||||
auto lib = d.get_library(lib_name);
|
auto lib = d.get_library(lib_name, [&]() {
|
||||||
if (lib == nullptr) {
|
|
||||||
std::ostringstream kernel_source;
|
std::ostringstream kernel_source;
|
||||||
kernel_source << metal::utils() << metal::conv()
|
kernel_source << metal::utils() << metal::conv()
|
||||||
<< metal::steel_conv_general()
|
<< metal::steel_conv_general()
|
||||||
@ -617,8 +599,8 @@ MTL::ComputePipelineState* get_steel_conv_general_kernel(
|
|||||||
"bk"_a = bk,
|
"bk"_a = bk,
|
||||||
"wm"_a = wm,
|
"wm"_a = wm,
|
||||||
"wn"_a = wn);
|
"wn"_a = wn);
|
||||||
lib = d.get_library(lib_name, kernel_source.str());
|
return kernel_source.str();
|
||||||
}
|
});
|
||||||
return d.get_kernel(kernel_name, lib);
|
return d.get_kernel(kernel_name, lib);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -629,13 +611,12 @@ MTL::ComputePipelineState* get_fft_kernel(
|
|||||||
const metal::MTLFCList& func_consts,
|
const metal::MTLFCList& func_consts,
|
||||||
const std::string& template_def) {
|
const std::string& template_def) {
|
||||||
const auto& lib_name = kernel_name;
|
const auto& lib_name = kernel_name;
|
||||||
auto lib = d.get_library(lib_name);
|
auto lib = d.get_library(lib_name, [&]() {
|
||||||
if (lib == nullptr) {
|
|
||||||
std::ostringstream kernel_source;
|
std::ostringstream kernel_source;
|
||||||
std::string kernel_string;
|
std::string kernel_string;
|
||||||
kernel_source << metal::fft() << template_def;
|
kernel_source << metal::fft() << template_def;
|
||||||
lib = d.get_library(lib_name, kernel_source.str());
|
return kernel_source.str();
|
||||||
}
|
});
|
||||||
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
|
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -644,13 +625,12 @@ MTL::ComputePipelineState* get_quantized_kernel(
|
|||||||
const std::string& kernel_name,
|
const std::string& kernel_name,
|
||||||
const std::string& template_def) {
|
const std::string& template_def) {
|
||||||
const auto& lib_name = kernel_name;
|
const auto& lib_name = kernel_name;
|
||||||
auto lib = d.get_library(lib_name);
|
auto lib = d.get_library(lib_name, [&]() {
|
||||||
if (lib == nullptr) {
|
|
||||||
std::ostringstream kernel_source;
|
std::ostringstream kernel_source;
|
||||||
kernel_source << metal::utils() << metal::gemm() << metal::quantized()
|
kernel_source << metal::utils() << metal::gemm() << metal::quantized()
|
||||||
<< template_def;
|
<< template_def;
|
||||||
lib = d.get_library(lib_name, kernel_source.str());
|
return kernel_source.str();
|
||||||
}
|
});
|
||||||
return d.get_kernel(kernel_name, lib);
|
return d.get_kernel(kernel_name, lib);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -122,6 +122,21 @@ class TestEval(mlx_tests.MLXTestCase):
|
|||||||
out = mx.vjp(fn, (x,), (y,))
|
out = mx.vjp(fn, (x,), (y,))
|
||||||
self.assertEqual(peak_mem, mx.metal.get_peak_memory())
|
self.assertEqual(peak_mem, mx.metal.get_peak_memory())
|
||||||
|
|
||||||
|
def test_async_eval_with_multiple_streams(self):
|
||||||
|
x = mx.array([1.0])
|
||||||
|
y = mx.array([1.0])
|
||||||
|
a = mx.array([1.0])
|
||||||
|
b = mx.array([1.0])
|
||||||
|
|
||||||
|
d = mx.default_device()
|
||||||
|
s2 = mx.new_stream(d)
|
||||||
|
|
||||||
|
for _ in range(50):
|
||||||
|
for _ in range(20):
|
||||||
|
x = x + y
|
||||||
|
mx.async_eval(x)
|
||||||
|
mx.eval(a + b)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user