mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-12 20:26:40 +08:00
[CUDA] Save primitive inputs faster (#2449)
* Add more nvtx loggings * [CUDA] Saving primitive inputs faster * Remove unneeded check
This commit is contained in:
parent
86c6a15571
commit
b26d88591c
@ -269,6 +269,7 @@ void CommandEncoder::add_graph_node(cudaGraph_t child) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void CommandEncoder::commit() {
|
void CommandEncoder::commit() {
|
||||||
|
nvtx3::scoped_range r("CommandEncoder::commit");
|
||||||
if (!temporaries_.empty()) {
|
if (!temporaries_.empty()) {
|
||||||
add_completed_handler([temporaries = std::move(temporaries_)]() {});
|
add_completed_handler([temporaries = std::move(temporaries_)]() {});
|
||||||
}
|
}
|
||||||
|
@ -36,18 +36,15 @@ void eval(array& arr) {
|
|||||||
|
|
||||||
auto& encoder = cu::get_command_encoder(arr.primitive().stream());
|
auto& encoder = cu::get_command_encoder(arr.primitive().stream());
|
||||||
// Keep used buffers alive until kernel finishes running.
|
// Keep used buffers alive until kernel finishes running.
|
||||||
std::unordered_set<std::shared_ptr<array::Data>> buffers;
|
|
||||||
for (auto& in : arr.inputs()) {
|
for (auto& in : arr.inputs()) {
|
||||||
buffers.insert(in.data_shared_ptr());
|
// Except for the donated one.
|
||||||
|
if (in.data_shared_ptr() != arr.data_shared_ptr()) {
|
||||||
|
encoder.add_temporary(in);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
for (auto& s : arr.siblings()) {
|
for (auto& s : arr.siblings()) {
|
||||||
buffers.insert(s.data_shared_ptr());
|
encoder.add_temporary(s);
|
||||||
}
|
}
|
||||||
// Remove the output if it was donated to by an input.
|
|
||||||
if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) {
|
|
||||||
buffers.erase(it);
|
|
||||||
}
|
|
||||||
encoder.add_completed_handler([buffers = std::move(buffers)]() {});
|
|
||||||
encoder.maybe_commit();
|
encoder.maybe_commit();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5,6 +5,8 @@
|
|||||||
#include "mlx/backend/gpu/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
#include "mlx/fast_primitives.h"
|
#include "mlx/fast_primitives.h"
|
||||||
|
|
||||||
|
#include <nvtx3/nvtx3.hpp>
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
@ -42,6 +44,7 @@ inline array ensure_row_contiguous_matrix(
|
|||||||
void fast::AffineQuantize::eval_gpu(
|
void fast::AffineQuantize::eval_gpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs) {
|
std::vector<array>& outputs) {
|
||||||
|
nvtx3::scoped_range r("AffineQuantize::eval_gpu");
|
||||||
auto& s = stream();
|
auto& s = stream();
|
||||||
auto& d = cu::device(s.device);
|
auto& d = cu::device(s.device);
|
||||||
auto& enc = d.get_command_encoder(s);
|
auto& enc = d.get_command_encoder(s);
|
||||||
|
@ -192,7 +192,7 @@ void ternary_op_gpu(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Select::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void Select::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
nvtx3::scoped_range r("select::eval_gpu");
|
nvtx3::scoped_range r("Select::eval_gpu");
|
||||||
auto& s = out.primitive().stream();
|
auto& s = out.primitive().stream();
|
||||||
ternary_op_gpu<cu::Select>(inputs, out, s);
|
ternary_op_gpu<cu::Select>(inputs, out, s);
|
||||||
}
|
}
|
||||||
|
@ -133,6 +133,7 @@ void NumberOfElements::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Pad::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void Pad::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
MLX_PROFILER_RANGE("Pad::eval_gpu");
|
||||||
// Inputs must be base input array and scalar val array
|
// Inputs must be base input array and scalar val array
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
|
Loading…
Reference in New Issue
Block a user