mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-01 08:38:12 +08:00
Add more nvtx range for debug
This commit is contained in:
@@ -67,6 +67,8 @@ std::shared_ptr<fe::graph::Graph> get_sdpa_forward_graph(
|
|||||||
return it->second;
|
return it->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
nvtx3::scoped_range r("get_sdpa_forward_graph");
|
||||||
|
|
||||||
// Set up new graph
|
// Set up new graph
|
||||||
auto graph = std::make_shared<fe::graph::Graph>();
|
auto graph = std::make_shared<fe::graph::Graph>();
|
||||||
|
|
||||||
@@ -141,6 +143,8 @@ std::shared_ptr<fe::graph::Graph> get_sdpa_forward_graph(
|
|||||||
|
|
||||||
// cuDNN only supports native CUDA graphs for sdpa in 9.6 or above.
|
// cuDNN only supports native CUDA graphs for sdpa in 9.6 or above.
|
||||||
if (cudnnGetVersion() < 90600) {
|
if (cudnnGetVersion() < 90600) {
|
||||||
|
nvtx3::scoped_range r("get_sdpa_forward_graph::graph_building");
|
||||||
|
|
||||||
auto build_status = graph->build(handle, {fe::HeurMode_t::A});
|
auto build_status = graph->build(handle, {fe::HeurMode_t::A});
|
||||||
if (!build_status.is_good()) {
|
if (!build_status.is_good()) {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
@@ -437,6 +441,10 @@ void ScaledDotProductAttention::eval_gpu(
|
|||||||
const auto& k = copy_unless(is_matrix_contiguous, k_pre);
|
const auto& k = copy_unless(is_matrix_contiguous, k_pre);
|
||||||
const auto& v = copy_unless(is_matrix_contiguous, v_pre);
|
const auto& v = copy_unless(is_matrix_contiguous, v_pre);
|
||||||
|
|
||||||
|
for (const auto& cp : copies) {
|
||||||
|
encoder.add_temporary(cp);
|
||||||
|
}
|
||||||
|
|
||||||
int64_t str_oD = 1;
|
int64_t str_oD = 1;
|
||||||
int64_t str_oH = o.shape(3);
|
int64_t str_oH = o.shape(3);
|
||||||
int64_t str_oL = o.shape(1) * str_oH;
|
int64_t str_oL = o.shape(1) * str_oH;
|
||||||
|
|||||||
Reference in New Issue
Block a user