mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-18 23:21:16 +08:00
fix resource leaks in matmul and graph (#2383)
This commit is contained in:
parent
6b1b8ea91b
commit
fbb3f65a1a
@ -66,7 +66,6 @@ CommandEncoder& Device::get_command_encoder(Stream s) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
|
CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
|
||||||
CHECK_CUDA_ERROR(cudaGraphCreate(&graph, 0));
|
|
||||||
CHECK_CUDA_ERROR(
|
CHECK_CUDA_ERROR(
|
||||||
cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal));
|
cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal));
|
||||||
}
|
}
|
||||||
|
@ -27,6 +27,35 @@ void check_cublas_error(const char* name, cublasStatus_t err) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct CublasPreference {
|
||||||
|
CublasPreference(Device& device) {
|
||||||
|
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB
|
||||||
|
// for Hopper+:
|
||||||
|
// https://docs.nvidia.com/cuda/cublas/#cublassetworkspace
|
||||||
|
uint64_t MiB = 1024 * 1024;
|
||||||
|
uint64_t workspace_size =
|
||||||
|
device.compute_capability_major() >= 9 ? 32 * MiB : 4 * MiB;
|
||||||
|
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceCreate(&pref_));
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceSetAttribute(
|
||||||
|
pref_,
|
||||||
|
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
|
||||||
|
&workspace_size,
|
||||||
|
sizeof(uint64_t)));
|
||||||
|
}
|
||||||
|
|
||||||
|
~CublasPreference() {
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceDestroy(pref_));
|
||||||
|
}
|
||||||
|
|
||||||
|
cublasLtMatmulPreference_t pref_{nullptr};
|
||||||
|
};
|
||||||
|
|
||||||
|
cublasLtMatmulPreference_t cublas_preference(Device& device) {
|
||||||
|
static CublasPreference pref(device);
|
||||||
|
return pref.pref_;
|
||||||
|
}
|
||||||
|
|
||||||
class MatMul {
|
class MatMul {
|
||||||
public:
|
public:
|
||||||
MatMul(
|
MatMul(
|
||||||
@ -43,7 +72,7 @@ class MatMul {
|
|||||||
int32_t batch_count,
|
int32_t batch_count,
|
||||||
int64_t a_batch_stride,
|
int64_t a_batch_stride,
|
||||||
int64_t b_batch_stride)
|
int64_t b_batch_stride)
|
||||||
: handle_(device.lt_handle()) {
|
: handle_(device.lt_handle()), pref_(cublas_preference(device)) {
|
||||||
heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED;
|
heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED;
|
||||||
|
|
||||||
auto scale_type = dtype_to_cuda_type(dtype);
|
auto scale_type = dtype_to_cuda_type(dtype);
|
||||||
@ -77,20 +106,6 @@ class MatMul {
|
|||||||
type, b_rows, b_cols, b_transposed, ldb, batch_count, b_batch_stride);
|
type, b_rows, b_cols, b_transposed, ldb, batch_count, b_batch_stride);
|
||||||
out_desc_ = create_matrix_layout(
|
out_desc_ = create_matrix_layout(
|
||||||
type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols);
|
type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols);
|
||||||
|
|
||||||
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB
|
|
||||||
// for Hopper+:
|
|
||||||
// https://docs.nvidia.com/cuda/cublas/#cublassetworkspace
|
|
||||||
uint64_t MiB = 1024 * 1024;
|
|
||||||
uint64_t workspace_size =
|
|
||||||
device.compute_capability_major() >= 9 ? 32 * MiB : 4 * MiB;
|
|
||||||
|
|
||||||
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceCreate(&pref_));
|
|
||||||
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceSetAttribute(
|
|
||||||
pref_,
|
|
||||||
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
|
|
||||||
&workspace_size,
|
|
||||||
sizeof(uint64_t)));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
MatMul(
|
MatMul(
|
||||||
@ -130,11 +145,11 @@ class MatMul {
|
|||||||
}
|
}
|
||||||
|
|
||||||
~MatMul() {
|
~MatMul() {
|
||||||
cublasLtMatrixLayoutDestroy(a_desc_);
|
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(a_desc_));
|
||||||
cublasLtMatrixLayoutDestroy(b_desc_);
|
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(b_desc_));
|
||||||
cublasLtMatrixLayoutDestroy(c_desc_);
|
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(c_desc_));
|
||||||
cublasLtMatrixLayoutDestroy(out_desc_);
|
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_));
|
||||||
cublasLtMatmulDescDestroy(matmul_desc_);
|
CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_));
|
||||||
}
|
}
|
||||||
|
|
||||||
void run(
|
void run(
|
||||||
@ -259,9 +274,9 @@ class MatMul {
|
|||||||
return desc;
|
return desc;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cublasLtMatmulPreference_t pref_{nullptr};
|
||||||
cublasLtHandle_t handle_{nullptr};
|
cublasLtHandle_t handle_{nullptr};
|
||||||
cublasLtMatmulDesc_t matmul_desc_{nullptr};
|
cublasLtMatmulDesc_t matmul_desc_{nullptr};
|
||||||
cublasLtMatmulPreference_t pref_{nullptr};
|
|
||||||
cublasLtMatrixLayout_t a_desc_{nullptr};
|
cublasLtMatrixLayout_t a_desc_{nullptr};
|
||||||
cublasLtMatrixLayout_t b_desc_{nullptr};
|
cublasLtMatrixLayout_t b_desc_{nullptr};
|
||||||
cublasLtMatrixLayout_t c_desc_{nullptr};
|
cublasLtMatrixLayout_t c_desc_{nullptr};
|
||||||
|
Loading…
Reference in New Issue
Block a user