fix resource leaks in matmul and graph (#2383)

This commit is contained in:
Awni Hannun 2025-07-17 06:50:15 -07:00 committed by GitHub
parent 6b1b8ea91b
commit fbb3f65a1a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 36 additions and 22 deletions

View File

@ -66,7 +66,6 @@ CommandEncoder& Device::get_command_encoder(Stream s) {
} }
CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) { CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
CHECK_CUDA_ERROR(cudaGraphCreate(&graph, 0));
CHECK_CUDA_ERROR( CHECK_CUDA_ERROR(
cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal)); cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal));
} }

View File

@ -27,6 +27,35 @@ void check_cublas_error(const char* name, cublasStatus_t err) {
} }
} }
struct CublasPreference {
CublasPreference(Device& device) {
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB
// for Hopper+:
// https://docs.nvidia.com/cuda/cublas/#cublassetworkspace
uint64_t MiB = 1024 * 1024;
uint64_t workspace_size =
device.compute_capability_major() >= 9 ? 32 * MiB : 4 * MiB;
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceCreate(&pref_));
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceSetAttribute(
pref_,
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&workspace_size,
sizeof(uint64_t)));
}
~CublasPreference() {
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceDestroy(pref_));
}
cublasLtMatmulPreference_t pref_{nullptr};
};
cublasLtMatmulPreference_t cublas_preference(Device& device) {
static CublasPreference pref(device);
return pref.pref_;
}
class MatMul { class MatMul {
public: public:
MatMul( MatMul(
@ -43,7 +72,7 @@ class MatMul {
int32_t batch_count, int32_t batch_count,
int64_t a_batch_stride, int64_t a_batch_stride,
int64_t b_batch_stride) int64_t b_batch_stride)
: handle_(device.lt_handle()) { : handle_(device.lt_handle()), pref_(cublas_preference(device)) {
heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED; heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED;
auto scale_type = dtype_to_cuda_type(dtype); auto scale_type = dtype_to_cuda_type(dtype);
@ -77,20 +106,6 @@ class MatMul {
type, b_rows, b_cols, b_transposed, ldb, batch_count, b_batch_stride); type, b_rows, b_cols, b_transposed, ldb, batch_count, b_batch_stride);
out_desc_ = create_matrix_layout( out_desc_ = create_matrix_layout(
type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols); type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols);
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB
// for Hopper+:
// https://docs.nvidia.com/cuda/cublas/#cublassetworkspace
uint64_t MiB = 1024 * 1024;
uint64_t workspace_size =
device.compute_capability_major() >= 9 ? 32 * MiB : 4 * MiB;
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceCreate(&pref_));
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceSetAttribute(
pref_,
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&workspace_size,
sizeof(uint64_t)));
} }
MatMul( MatMul(
@ -130,11 +145,11 @@ class MatMul {
} }
~MatMul() { ~MatMul() {
cublasLtMatrixLayoutDestroy(a_desc_); CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(a_desc_));
cublasLtMatrixLayoutDestroy(b_desc_); CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(b_desc_));
cublasLtMatrixLayoutDestroy(c_desc_); CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(c_desc_));
cublasLtMatrixLayoutDestroy(out_desc_); CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_));
cublasLtMatmulDescDestroy(matmul_desc_); CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_));
} }
void run( void run(
@ -259,9 +274,9 @@ class MatMul {
return desc; return desc;
} }
cublasLtMatmulPreference_t pref_{nullptr};
cublasLtHandle_t handle_{nullptr}; cublasLtHandle_t handle_{nullptr};
cublasLtMatmulDesc_t matmul_desc_{nullptr}; cublasLtMatmulDesc_t matmul_desc_{nullptr};
cublasLtMatmulPreference_t pref_{nullptr};
cublasLtMatrixLayout_t a_desc_{nullptr}; cublasLtMatrixLayout_t a_desc_{nullptr};
cublasLtMatrixLayout_t b_desc_{nullptr}; cublasLtMatrixLayout_t b_desc_{nullptr};
cublasLtMatrixLayout_t c_desc_{nullptr}; cublasLtMatrixLayout_t c_desc_{nullptr};