mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Detect metal version and propagate correctly for JIT (#1109)
* detect metal version and propagate correctly for JIT * remove softmax * fix versions
This commit is contained in:
@@ -29,6 +29,14 @@ constexpr int MAX_DISPATCHES_PER_ENCODER = 2;
|
||||
|
||||
constexpr const char* default_mtllib_path = METAL_PATH;
|
||||
|
||||
constexpr auto get_metal_version() {
|
||||
#if defined METAL_3_1
|
||||
return MTL::LanguageVersion3_1;
|
||||
#else
|
||||
return MTL::LanguageVersion3_0;
|
||||
#endif
|
||||
}
|
||||
|
||||
auto load_device() {
|
||||
auto devices = MTL::CopyAllDevices();
|
||||
auto device = static_cast<MTL::Device*>(devices->object(0))
|
||||
@@ -275,7 +283,12 @@ MTL::Library* Device::get_library_(const std::string& source_string) {
|
||||
NS::String::string(source_string.c_str(), NS::ASCIIStringEncoding);
|
||||
|
||||
NS::Error* error = nullptr;
|
||||
auto mtl_lib = device_->newLibrary(ns_code, nullptr, &error);
|
||||
auto options = MTL::CompileOptions::alloc()->init();
|
||||
options->setFastMathEnabled(false);
|
||||
|
||||
options->setLanguageVersion(get_metal_version());
|
||||
auto mtl_lib = device_->newLibrary(ns_code, options, &error);
|
||||
options->release();
|
||||
|
||||
// Throw error if unable to compile library
|
||||
if (!mtl_lib) {
|
||||
|
||||
Reference in New Issue
Block a user