mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-03 09:16:44 +08:00
[CUDA] Fix segfault on exit (#2424)
* fix cuda segfault on exit * comment
This commit is contained in:
parent
4ad53414dd
commit
b9e88fb976
@ -1,6 +1,7 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/jit_module.h"
|
||||||
#include "mlx/backend/cuda/worker.h"
|
#include "mlx/backend/cuda/worker.h"
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
@ -54,6 +55,10 @@ Device::Device(int device) : device_(device) {
|
|||||||
CHECK_CUBLAS_ERROR(cublasLtCreate(<_));
|
CHECK_CUBLAS_ERROR(cublasLtCreate(<_));
|
||||||
// The cudnn handle is used by Convolution.
|
// The cudnn handle is used by Convolution.
|
||||||
CHECK_CUDNN_ERROR(cudnnCreate(&cudnn_));
|
CHECK_CUDNN_ERROR(cudnnCreate(&cudnn_));
|
||||||
|
|
||||||
|
// Initialize the jit module cache here ensures it is not
|
||||||
|
// unloaded before any evaluation is done
|
||||||
|
get_jit_module_cache();
|
||||||
}
|
}
|
||||||
|
|
||||||
Device::~Device() {
|
Device::~Device() {
|
||||||
|
@ -9,7 +9,6 @@
|
|||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
#include <filesystem>
|
#include <filesystem>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <unordered_map>
|
|
||||||
|
|
||||||
#include <fmt/format.h>
|
#include <fmt/format.h>
|
||||||
#include <nvrtc.h>
|
#include <nvrtc.h>
|
||||||
@ -330,11 +329,16 @@ CUfunction JitModule::get_kernel(const std::string& kernel_name) {
|
|||||||
return it->second;
|
return it->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::unordered_map<std::string, JitModule>& get_jit_module_cache() {
|
||||||
|
static std::unordered_map<std::string, JitModule> map;
|
||||||
|
return map;
|
||||||
|
}
|
||||||
|
|
||||||
JitModule& get_jit_module(
|
JitModule& get_jit_module(
|
||||||
const mlx::core::Device& device,
|
const mlx::core::Device& device,
|
||||||
const std::string& name,
|
const std::string& name,
|
||||||
const KernelBuilder& builder) {
|
const KernelBuilder& builder) {
|
||||||
static std::unordered_map<std::string, JitModule> map;
|
auto& map = get_jit_module_cache();
|
||||||
auto it = map.find(name);
|
auto it = map.find(name);
|
||||||
if (it == map.end()) {
|
if (it == map.end()) {
|
||||||
it = map.try_emplace(name, cu::device(device), name, builder).first;
|
it = map.try_emplace(name, cu::device(device), name, builder).first;
|
||||||
|
@ -99,6 +99,8 @@ class JitModule {
|
|||||||
std::unordered_map<std::string, CUfunction> kernels_;
|
std::unordered_map<std::string, CUfunction> kernels_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
std::unordered_map<std::string, JitModule>& get_jit_module_cache();
|
||||||
|
|
||||||
JitModule& get_jit_module(
|
JitModule& get_jit_module(
|
||||||
const mlx::core::Device& device,
|
const mlx::core::Device& device,
|
||||||
const std::string& name,
|
const std::string& name,
|
||||||
|
Loading…
Reference in New Issue
Block a user