Compare commits

..

3 Commits

Author SHA1 Message Date
Cheng
afb9817599 [CUDA] Put version in ptx cache dir path (#2352) 2025-07-10 07:24:21 -07:00
Cheng
8fb3e7a26c [CUDA] Set current device before cudaGraphLaunch (#2351) 2025-07-10 07:24:02 -07:00
jhavukainen
8c7bc30ce4 Align mlx::core::min op nan propagation with NumPy (#2346) 2025-07-10 06:20:43 -07:00
8 changed files with 77 additions and 17 deletions

View File

@@ -203,6 +203,11 @@ void time_reductions() {
TIME(max_along_0); TIME(max_along_0);
auto max_along_1 = [&b]() { return mx::max(b, 1, false); }; auto max_along_1 = [&b]() { return mx::max(b, 1, false); };
TIME(max_along_1); TIME(max_along_1);
auto min_along_0 = [&b]() { return mx::min(b, 0, false); };
TIME(min_along_0);
auto min_along_1 = [&b]() { return mx::min(b, 1, false); };
TIME(min_along_1);
} }
void time_gather_scatter() { void time_gather_scatter() {

View File

@@ -58,6 +58,13 @@ def time_max():
time_fn(mx.max, a, 0) time_fn(mx.max, a, 0)
def time_min():
a = mx.random.uniform(shape=(32, 1024, 1024))
a[1, 1] = mx.nan
mx.eval(a)
time_fn(mx.min, a, 0)
def time_negative(): def time_negative():
a = mx.random.uniform(shape=(10000, 1000)) a = mx.random.uniform(shape=(10000, 1000))
mx.eval(a) mx.eval(a)
@@ -115,6 +122,7 @@ if __name__ == "__main__":
time_add() time_add()
time_matmul() time_matmul()
time_min()
time_max() time_max()
time_maximum() time_maximum()
time_exp() time_exp()

View File

@@ -350,7 +350,15 @@ struct MinReduce {
}; };
template <int N, typename T> template <int N, typename T>
T operator()(simd::Simd<T, N> x) { std::enable_if_t<std::is_integral_v<T>, T> operator()(simd::Simd<T, N> x) {
return simd::min(x);
};
template <int N, typename T>
std::enable_if_t<!std::is_integral_v<T>, T> operator()(simd::Simd<T, N> x) {
if (simd::any(x != x)) {
return static_cast<T>(NAN);
}
return simd::min(x); return simd::min(x);
}; };
}; };

View File

@@ -57,6 +57,14 @@ void Device::make_current() {
} }
} }
CommandEncoder& Device::get_command_encoder(Stream s) {
auto it = encoders_.find(s.index);
if (it == encoders_.end()) {
it = encoders_.try_emplace(s.index, *this).first;
}
return it->second;
}
CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) { CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
CHECK_CUDA_ERROR(cudaGraphCreate(&graph, 0)); CHECK_CUDA_ERROR(cudaGraphCreate(&graph, 0));
CHECK_CUDA_ERROR( CHECK_CUDA_ERROR(
@@ -168,15 +176,7 @@ void CommandEncoder::insert_graph_dependencies(std::vector<GraphNode> nodes) {
} }
} }
CommandEncoder& Device::get_command_encoder(Stream s) { CommandEncoder::CommandEncoder(Device& d) : device_(d), stream_(d) {
auto it = encoders_.find(s.index);
if (it == encoders_.end()) {
it = encoders_.try_emplace(s.index, *this).first;
}
return it->second;
}
CommandEncoder::CommandEncoder(Device& d) : stream_(d) {
CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0)); CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0));
} }
@@ -287,6 +287,7 @@ void CommandEncoder::commit() {
CHECK_CUDA_ERROR( CHECK_CUDA_ERROR(
cudaGraphInstantiate(&graph_exec, graph_, NULL, NULL, 0)); cudaGraphInstantiate(&graph_exec, graph_, NULL, NULL, 0));
} }
device_.make_current();
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_)); CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
// TODO smarter cache policy // TODO smarter cache policy

View File

@@ -93,6 +93,7 @@ class CommandEncoder {
void insert_graph_dependencies(GraphNode node); void insert_graph_dependencies(GraphNode node);
void insert_graph_dependencies(std::vector<GraphNode> nodes); void insert_graph_dependencies(std::vector<GraphNode> nodes);
Device& device_;
CudaStream stream_; CudaStream stream_;
cudaGraph_t graph_; cudaGraph_t graph_;
Worker worker_; Worker worker_;

View File

@@ -2,6 +2,7 @@
#include "mlx/backend/cuda/jit_module.h" #include "mlx/backend/cuda/jit_module.h"
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/version.h"
#include "cuda_jit_sources.h" #include "cuda_jit_sources.h"
@@ -53,10 +54,11 @@ const std::string& cuda_home() {
const std::filesystem::path& ptx_cache_dir() { const std::filesystem::path& ptx_cache_dir() {
static std::filesystem::path cache = []() -> std::filesystem::path { static std::filesystem::path cache = []() -> std::filesystem::path {
std::filesystem::path cache; std::filesystem::path cache;
if (auto c = std::getenv("MLX_PTX_CACHE"); c) { if (auto c = std::getenv("MLX_PTX_CACHE_DIR"); c) {
cache = c; cache = c;
} else { } else {
cache = std::filesystem::temp_directory_path() / "mlx" / "ptx"; cache =
std::filesystem::temp_directory_path() / "mlx" / version() / "ptx";
} }
if (!std::filesystem::exists(cache)) { if (!std::filesystem::exists(cache)) {
std::error_code error; std::error_code error;

View File

@@ -164,7 +164,15 @@ struct Min {
DEFINE_SIMD_REDUCE() DEFINE_SIMD_REDUCE()
template <typename T> template <typename T>
T simd_reduce_impl(T val) { metal::enable_if_t<metal::is_integral_v<T>, T> simd_reduce_impl(T val) {
return simd_min(val);
}
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> simd_reduce_impl(T val) {
if (simd_any(val != val)) {
return static_cast<T>(NAN);
}
return simd_min(val); return simd_min(val);
} }
@@ -176,11 +184,38 @@ struct Min {
} }
// Operator // Operator
U operator()(U a, U b) { template <typename T>
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T a, T b) {
return a < b ? a : b; return a < b ? a : b;
} }
};
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T a, T b) {
if (metal::isnan(a) || metal::isnan(b)) {
return static_cast<T>(NAN);
} else {
return a < b ? a : b;
}
}
template <>
complex64_t operator()(complex64_t a, complex64_t b) {
bool real_is_nan = metal::isnan(a.real) || metal::isnan(b.real);
bool imag_is_nan = metal::isnan(a.imag) || metal::isnan(b.imag);
if (!real_is_nan && !imag_is_nan) {
return a < b ? a : b;
} else if (real_is_nan && !imag_is_nan) {
return complex64_t(
static_cast<float>(NAN), a.imag < b.imag ? a.imag : b.imag);
} else if (!real_is_nan && imag_is_nan) {
return complex64_t(
a.real < b.real ? a.real : b.real, static_cast<float>(NAN));
} else {
return complex64_t(static_cast<float>(NAN), static_cast<float>(NAN));
}
};
};
template <typename U> template <typename U>
struct Max { struct Max {
DEFINE_SIMD_REDUCE() DEFINE_SIMD_REDUCE()

View File

@@ -173,7 +173,7 @@ class TestReduce(mlx_tests.MLXTestCase):
x[idx[0], idx[1]] = mx.nan x[idx[0], idx[1]] = mx.nan
x_np = np.array(x) x_np = np.array(x)
for op in ["max"]: for op in ["max", "min"]:
for axis in [0, 1]: for axis in [0, 1]:
out = getattr(mx, op)(x, axis=axis) out = getattr(mx, op)(x, axis=axis)
ref = getattr(np, op)(x_np, axis=axis) ref = getattr(np, op)(x_np, axis=axis)
@@ -205,7 +205,7 @@ class TestReduce(mlx_tests.MLXTestCase):
np_arrays, np_arrays,
): ):
for axis in [0, 1]: for axis in [0, 1]:
for op in ["max"]: for op in ["max", "min"]:
out = getattr(mx, op)(mx_arr, axis=axis) out = getattr(mx, op)(mx_arr, axis=axis)
ref = getattr(np, op)(np_arr, axis=axis) ref = getattr(np, op)(np_arr, axis=axis)
self.assertTrue(np.array_equal(out, ref, equal_nan=True)) self.assertTrue(np.array_equal(out, ref, equal_nan=True))