mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
3 Commits
85873cb162
...
afb9817599
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
afb9817599 | ||
|
|
8fb3e7a26c | ||
|
|
8c7bc30ce4 |
@@ -203,6 +203,11 @@ void time_reductions() {
|
|||||||
TIME(max_along_0);
|
TIME(max_along_0);
|
||||||
auto max_along_1 = [&b]() { return mx::max(b, 1, false); };
|
auto max_along_1 = [&b]() { return mx::max(b, 1, false); };
|
||||||
TIME(max_along_1);
|
TIME(max_along_1);
|
||||||
|
|
||||||
|
auto min_along_0 = [&b]() { return mx::min(b, 0, false); };
|
||||||
|
TIME(min_along_0);
|
||||||
|
auto min_along_1 = [&b]() { return mx::min(b, 1, false); };
|
||||||
|
TIME(min_along_1);
|
||||||
}
|
}
|
||||||
|
|
||||||
void time_gather_scatter() {
|
void time_gather_scatter() {
|
||||||
|
|||||||
@@ -58,6 +58,13 @@ def time_max():
|
|||||||
time_fn(mx.max, a, 0)
|
time_fn(mx.max, a, 0)
|
||||||
|
|
||||||
|
|
||||||
|
def time_min():
|
||||||
|
a = mx.random.uniform(shape=(32, 1024, 1024))
|
||||||
|
a[1, 1] = mx.nan
|
||||||
|
mx.eval(a)
|
||||||
|
time_fn(mx.min, a, 0)
|
||||||
|
|
||||||
|
|
||||||
def time_negative():
|
def time_negative():
|
||||||
a = mx.random.uniform(shape=(10000, 1000))
|
a = mx.random.uniform(shape=(10000, 1000))
|
||||||
mx.eval(a)
|
mx.eval(a)
|
||||||
@@ -115,6 +122,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
time_add()
|
time_add()
|
||||||
time_matmul()
|
time_matmul()
|
||||||
|
time_min()
|
||||||
time_max()
|
time_max()
|
||||||
time_maximum()
|
time_maximum()
|
||||||
time_exp()
|
time_exp()
|
||||||
|
|||||||
@@ -350,7 +350,15 @@ struct MinReduce {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <int N, typename T>
|
template <int N, typename T>
|
||||||
T operator()(simd::Simd<T, N> x) {
|
std::enable_if_t<std::is_integral_v<T>, T> operator()(simd::Simd<T, N> x) {
|
||||||
|
return simd::min(x);
|
||||||
|
};
|
||||||
|
|
||||||
|
template <int N, typename T>
|
||||||
|
std::enable_if_t<!std::is_integral_v<T>, T> operator()(simd::Simd<T, N> x) {
|
||||||
|
if (simd::any(x != x)) {
|
||||||
|
return static_cast<T>(NAN);
|
||||||
|
}
|
||||||
return simd::min(x);
|
return simd::min(x);
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -57,6 +57,14 @@ void Device::make_current() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
CommandEncoder& Device::get_command_encoder(Stream s) {
|
||||||
|
auto it = encoders_.find(s.index);
|
||||||
|
if (it == encoders_.end()) {
|
||||||
|
it = encoders_.try_emplace(s.index, *this).first;
|
||||||
|
}
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
|
||||||
CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
|
CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
|
||||||
CHECK_CUDA_ERROR(cudaGraphCreate(&graph, 0));
|
CHECK_CUDA_ERROR(cudaGraphCreate(&graph, 0));
|
||||||
CHECK_CUDA_ERROR(
|
CHECK_CUDA_ERROR(
|
||||||
@@ -168,15 +176,7 @@ void CommandEncoder::insert_graph_dependencies(std::vector<GraphNode> nodes) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
CommandEncoder& Device::get_command_encoder(Stream s) {
|
CommandEncoder::CommandEncoder(Device& d) : device_(d), stream_(d) {
|
||||||
auto it = encoders_.find(s.index);
|
|
||||||
if (it == encoders_.end()) {
|
|
||||||
it = encoders_.try_emplace(s.index, *this).first;
|
|
||||||
}
|
|
||||||
return it->second;
|
|
||||||
}
|
|
||||||
|
|
||||||
CommandEncoder::CommandEncoder(Device& d) : stream_(d) {
|
|
||||||
CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0));
|
CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -287,6 +287,7 @@ void CommandEncoder::commit() {
|
|||||||
CHECK_CUDA_ERROR(
|
CHECK_CUDA_ERROR(
|
||||||
cudaGraphInstantiate(&graph_exec, graph_, NULL, NULL, 0));
|
cudaGraphInstantiate(&graph_exec, graph_, NULL, NULL, 0));
|
||||||
}
|
}
|
||||||
|
device_.make_current();
|
||||||
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
|
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
|
||||||
|
|
||||||
// TODO smarter cache policy
|
// TODO smarter cache policy
|
||||||
|
|||||||
@@ -93,6 +93,7 @@ class CommandEncoder {
|
|||||||
void insert_graph_dependencies(GraphNode node);
|
void insert_graph_dependencies(GraphNode node);
|
||||||
void insert_graph_dependencies(std::vector<GraphNode> nodes);
|
void insert_graph_dependencies(std::vector<GraphNode> nodes);
|
||||||
|
|
||||||
|
Device& device_;
|
||||||
CudaStream stream_;
|
CudaStream stream_;
|
||||||
cudaGraph_t graph_;
|
cudaGraph_t graph_;
|
||||||
Worker worker_;
|
Worker worker_;
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
#include "mlx/backend/cuda/jit_module.h"
|
#include "mlx/backend/cuda/jit_module.h"
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/version.h"
|
||||||
|
|
||||||
#include "cuda_jit_sources.h"
|
#include "cuda_jit_sources.h"
|
||||||
|
|
||||||
@@ -53,10 +54,11 @@ const std::string& cuda_home() {
|
|||||||
const std::filesystem::path& ptx_cache_dir() {
|
const std::filesystem::path& ptx_cache_dir() {
|
||||||
static std::filesystem::path cache = []() -> std::filesystem::path {
|
static std::filesystem::path cache = []() -> std::filesystem::path {
|
||||||
std::filesystem::path cache;
|
std::filesystem::path cache;
|
||||||
if (auto c = std::getenv("MLX_PTX_CACHE"); c) {
|
if (auto c = std::getenv("MLX_PTX_CACHE_DIR"); c) {
|
||||||
cache = c;
|
cache = c;
|
||||||
} else {
|
} else {
|
||||||
cache = std::filesystem::temp_directory_path() / "mlx" / "ptx";
|
cache =
|
||||||
|
std::filesystem::temp_directory_path() / "mlx" / version() / "ptx";
|
||||||
}
|
}
|
||||||
if (!std::filesystem::exists(cache)) {
|
if (!std::filesystem::exists(cache)) {
|
||||||
std::error_code error;
|
std::error_code error;
|
||||||
|
|||||||
@@ -164,7 +164,15 @@ struct Min {
|
|||||||
DEFINE_SIMD_REDUCE()
|
DEFINE_SIMD_REDUCE()
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T simd_reduce_impl(T val) {
|
metal::enable_if_t<metal::is_integral_v<T>, T> simd_reduce_impl(T val) {
|
||||||
|
return simd_min(val);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
metal::enable_if_t<!metal::is_integral_v<T>, T> simd_reduce_impl(T val) {
|
||||||
|
if (simd_any(val != val)) {
|
||||||
|
return static_cast<T>(NAN);
|
||||||
|
}
|
||||||
return simd_min(val);
|
return simd_min(val);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -176,11 +184,38 @@ struct Min {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Operator
|
// Operator
|
||||||
U operator()(U a, U b) {
|
template <typename T>
|
||||||
|
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T a, T b) {
|
||||||
return a < b ? a : b;
|
return a < b ? a : b;
|
||||||
}
|
}
|
||||||
};
|
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T a, T b) {
|
||||||
|
if (metal::isnan(a) || metal::isnan(b)) {
|
||||||
|
return static_cast<T>(NAN);
|
||||||
|
} else {
|
||||||
|
return a < b ? a : b;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
complex64_t operator()(complex64_t a, complex64_t b) {
|
||||||
|
bool real_is_nan = metal::isnan(a.real) || metal::isnan(b.real);
|
||||||
|
bool imag_is_nan = metal::isnan(a.imag) || metal::isnan(b.imag);
|
||||||
|
|
||||||
|
if (!real_is_nan && !imag_is_nan) {
|
||||||
|
return a < b ? a : b;
|
||||||
|
} else if (real_is_nan && !imag_is_nan) {
|
||||||
|
return complex64_t(
|
||||||
|
static_cast<float>(NAN), a.imag < b.imag ? a.imag : b.imag);
|
||||||
|
} else if (!real_is_nan && imag_is_nan) {
|
||||||
|
return complex64_t(
|
||||||
|
a.real < b.real ? a.real : b.real, static_cast<float>(NAN));
|
||||||
|
} else {
|
||||||
|
return complex64_t(static_cast<float>(NAN), static_cast<float>(NAN));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
};
|
||||||
template <typename U>
|
template <typename U>
|
||||||
struct Max {
|
struct Max {
|
||||||
DEFINE_SIMD_REDUCE()
|
DEFINE_SIMD_REDUCE()
|
||||||
|
|||||||
@@ -173,7 +173,7 @@ class TestReduce(mlx_tests.MLXTestCase):
|
|||||||
x[idx[0], idx[1]] = mx.nan
|
x[idx[0], idx[1]] = mx.nan
|
||||||
x_np = np.array(x)
|
x_np = np.array(x)
|
||||||
|
|
||||||
for op in ["max"]:
|
for op in ["max", "min"]:
|
||||||
for axis in [0, 1]:
|
for axis in [0, 1]:
|
||||||
out = getattr(mx, op)(x, axis=axis)
|
out = getattr(mx, op)(x, axis=axis)
|
||||||
ref = getattr(np, op)(x_np, axis=axis)
|
ref = getattr(np, op)(x_np, axis=axis)
|
||||||
@@ -205,7 +205,7 @@ class TestReduce(mlx_tests.MLXTestCase):
|
|||||||
np_arrays,
|
np_arrays,
|
||||||
):
|
):
|
||||||
for axis in [0, 1]:
|
for axis in [0, 1]:
|
||||||
for op in ["max"]:
|
for op in ["max", "min"]:
|
||||||
out = getattr(mx, op)(mx_arr, axis=axis)
|
out = getattr(mx, op)(mx_arr, axis=axis)
|
||||||
ref = getattr(np, op)(np_arr, axis=axis)
|
ref = getattr(np, op)(np_arr, axis=axis)
|
||||||
self.assertTrue(np.array_equal(out, ref, equal_nan=True))
|
self.assertTrue(np.array_equal(out, ref, equal_nan=True))
|
||||||
|
|||||||
Reference in New Issue
Block a user