mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
3 Commits
85873cb162
...
afb9817599
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
afb9817599 | ||
|
|
8fb3e7a26c | ||
|
|
8c7bc30ce4 |
@@ -203,6 +203,11 @@ void time_reductions() {
|
||||
TIME(max_along_0);
|
||||
auto max_along_1 = [&b]() { return mx::max(b, 1, false); };
|
||||
TIME(max_along_1);
|
||||
|
||||
auto min_along_0 = [&b]() { return mx::min(b, 0, false); };
|
||||
TIME(min_along_0);
|
||||
auto min_along_1 = [&b]() { return mx::min(b, 1, false); };
|
||||
TIME(min_along_1);
|
||||
}
|
||||
|
||||
void time_gather_scatter() {
|
||||
|
||||
@@ -58,6 +58,13 @@ def time_max():
|
||||
time_fn(mx.max, a, 0)
|
||||
|
||||
|
||||
def time_min():
|
||||
a = mx.random.uniform(shape=(32, 1024, 1024))
|
||||
a[1, 1] = mx.nan
|
||||
mx.eval(a)
|
||||
time_fn(mx.min, a, 0)
|
||||
|
||||
|
||||
def time_negative():
|
||||
a = mx.random.uniform(shape=(10000, 1000))
|
||||
mx.eval(a)
|
||||
@@ -115,6 +122,7 @@ if __name__ == "__main__":
|
||||
|
||||
time_add()
|
||||
time_matmul()
|
||||
time_min()
|
||||
time_max()
|
||||
time_maximum()
|
||||
time_exp()
|
||||
|
||||
@@ -350,7 +350,15 @@ struct MinReduce {
|
||||
};
|
||||
|
||||
template <int N, typename T>
|
||||
T operator()(simd::Simd<T, N> x) {
|
||||
std::enable_if_t<std::is_integral_v<T>, T> operator()(simd::Simd<T, N> x) {
|
||||
return simd::min(x);
|
||||
};
|
||||
|
||||
template <int N, typename T>
|
||||
std::enable_if_t<!std::is_integral_v<T>, T> operator()(simd::Simd<T, N> x) {
|
||||
if (simd::any(x != x)) {
|
||||
return static_cast<T>(NAN);
|
||||
}
|
||||
return simd::min(x);
|
||||
};
|
||||
};
|
||||
|
||||
@@ -57,6 +57,14 @@ void Device::make_current() {
|
||||
}
|
||||
}
|
||||
|
||||
CommandEncoder& Device::get_command_encoder(Stream s) {
|
||||
auto it = encoders_.find(s.index);
|
||||
if (it == encoders_.end()) {
|
||||
it = encoders_.try_emplace(s.index, *this).first;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
|
||||
CHECK_CUDA_ERROR(cudaGraphCreate(&graph, 0));
|
||||
CHECK_CUDA_ERROR(
|
||||
@@ -168,15 +176,7 @@ void CommandEncoder::insert_graph_dependencies(std::vector<GraphNode> nodes) {
|
||||
}
|
||||
}
|
||||
|
||||
CommandEncoder& Device::get_command_encoder(Stream s) {
|
||||
auto it = encoders_.find(s.index);
|
||||
if (it == encoders_.end()) {
|
||||
it = encoders_.try_emplace(s.index, *this).first;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
CommandEncoder::CommandEncoder(Device& d) : stream_(d) {
|
||||
CommandEncoder::CommandEncoder(Device& d) : device_(d), stream_(d) {
|
||||
CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0));
|
||||
}
|
||||
|
||||
@@ -287,6 +287,7 @@ void CommandEncoder::commit() {
|
||||
CHECK_CUDA_ERROR(
|
||||
cudaGraphInstantiate(&graph_exec, graph_, NULL, NULL, 0));
|
||||
}
|
||||
device_.make_current();
|
||||
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
|
||||
|
||||
// TODO smarter cache policy
|
||||
|
||||
@@ -93,6 +93,7 @@ class CommandEncoder {
|
||||
void insert_graph_dependencies(GraphNode node);
|
||||
void insert_graph_dependencies(std::vector<GraphNode> nodes);
|
||||
|
||||
Device& device_;
|
||||
CudaStream stream_;
|
||||
cudaGraph_t graph_;
|
||||
Worker worker_;
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
#include "mlx/backend/cuda/jit_module.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/version.h"
|
||||
|
||||
#include "cuda_jit_sources.h"
|
||||
|
||||
@@ -53,10 +54,11 @@ const std::string& cuda_home() {
|
||||
const std::filesystem::path& ptx_cache_dir() {
|
||||
static std::filesystem::path cache = []() -> std::filesystem::path {
|
||||
std::filesystem::path cache;
|
||||
if (auto c = std::getenv("MLX_PTX_CACHE"); c) {
|
||||
if (auto c = std::getenv("MLX_PTX_CACHE_DIR"); c) {
|
||||
cache = c;
|
||||
} else {
|
||||
cache = std::filesystem::temp_directory_path() / "mlx" / "ptx";
|
||||
cache =
|
||||
std::filesystem::temp_directory_path() / "mlx" / version() / "ptx";
|
||||
}
|
||||
if (!std::filesystem::exists(cache)) {
|
||||
std::error_code error;
|
||||
|
||||
@@ -164,7 +164,15 @@ struct Min {
|
||||
DEFINE_SIMD_REDUCE()
|
||||
|
||||
template <typename T>
|
||||
T simd_reduce_impl(T val) {
|
||||
metal::enable_if_t<metal::is_integral_v<T>, T> simd_reduce_impl(T val) {
|
||||
return simd_min(val);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
metal::enable_if_t<!metal::is_integral_v<T>, T> simd_reduce_impl(T val) {
|
||||
if (simd_any(val != val)) {
|
||||
return static_cast<T>(NAN);
|
||||
}
|
||||
return simd_min(val);
|
||||
}
|
||||
|
||||
@@ -176,11 +184,38 @@ struct Min {
|
||||
}
|
||||
|
||||
// Operator
|
||||
U operator()(U a, U b) {
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T a, T b) {
|
||||
return a < b ? a : b;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T a, T b) {
|
||||
if (metal::isnan(a) || metal::isnan(b)) {
|
||||
return static_cast<T>(NAN);
|
||||
} else {
|
||||
return a < b ? a : b;
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t a, complex64_t b) {
|
||||
bool real_is_nan = metal::isnan(a.real) || metal::isnan(b.real);
|
||||
bool imag_is_nan = metal::isnan(a.imag) || metal::isnan(b.imag);
|
||||
|
||||
if (!real_is_nan && !imag_is_nan) {
|
||||
return a < b ? a : b;
|
||||
} else if (real_is_nan && !imag_is_nan) {
|
||||
return complex64_t(
|
||||
static_cast<float>(NAN), a.imag < b.imag ? a.imag : b.imag);
|
||||
} else if (!real_is_nan && imag_is_nan) {
|
||||
return complex64_t(
|
||||
a.real < b.real ? a.real : b.real, static_cast<float>(NAN));
|
||||
} else {
|
||||
return complex64_t(static_cast<float>(NAN), static_cast<float>(NAN));
|
||||
}
|
||||
};
|
||||
};
|
||||
template <typename U>
|
||||
struct Max {
|
||||
DEFINE_SIMD_REDUCE()
|
||||
|
||||
@@ -173,7 +173,7 @@ class TestReduce(mlx_tests.MLXTestCase):
|
||||
x[idx[0], idx[1]] = mx.nan
|
||||
x_np = np.array(x)
|
||||
|
||||
for op in ["max"]:
|
||||
for op in ["max", "min"]:
|
||||
for axis in [0, 1]:
|
||||
out = getattr(mx, op)(x, axis=axis)
|
||||
ref = getattr(np, op)(x_np, axis=axis)
|
||||
@@ -205,7 +205,7 @@ class TestReduce(mlx_tests.MLXTestCase):
|
||||
np_arrays,
|
||||
):
|
||||
for axis in [0, 1]:
|
||||
for op in ["max"]:
|
||||
for op in ["max", "min"]:
|
||||
out = getattr(mx, op)(mx_arr, axis=axis)
|
||||
ref = getattr(np, op)(np_arr, axis=axis)
|
||||
self.assertTrue(np.array_equal(out, ref, equal_nan=True))
|
||||
|
||||
Reference in New Issue
Block a user