mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-31 07:12:20 +08:00
Compare commits
5 Commits
c5be966863
...
c5ecc0c5ab
Author | SHA1 | Date | |
---|---|---|---|
![]() |
c5ecc0c5ab | ||
![]() |
a14aaa7c9d | ||
![]() |
a6d780154f | ||
![]() |
688e421184 | ||
![]() |
9ffe88841c |
@ -1,5 +1,4 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/iterators/strided_iterator.cuh"
|
||||
@ -113,7 +112,7 @@ __global__ void arg_reduce_general(
|
||||
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
T vals[N_READS];
|
||||
auto tid = r * BLOCK_DIM + block.thread_index().z;
|
||||
auto tid = r * BLOCK_DIM + block.thread_index().x;
|
||||
cub::LoadDirectBlocked(
|
||||
tid, strided_iterator(in + in_idx, axis_stride), vals, axis_size, init);
|
||||
best = op.reduce_many(best, vals, tid * N_READS);
|
||||
@ -158,7 +157,7 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
constexpr uint32_t N_READS = 4;
|
||||
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
||||
dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides());
|
||||
dim3 block_dims{1, 1, BLOCK_DIM};
|
||||
dim3 block_dims{BLOCK_DIM, 1, 1};
|
||||
auto kernel = &cu::arg_reduce_general<
|
||||
InType,
|
||||
cu::ArgMax<InType>,
|
||||
|
@ -5,6 +5,7 @@
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
#include <cublasLt.h>
|
||||
#include <fmt/format.h>
|
||||
@ -44,9 +45,12 @@ class MatMul {
|
||||
int64_t b_batch_stride) {
|
||||
heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED;
|
||||
|
||||
auto type = dtype_to_cuda_type(dtype);
|
||||
auto scale_type = dtype_to_cuda_type(dtype);
|
||||
if (dtype == bfloat16 || dtype == float16) {
|
||||
scale_type = CUDA_R_32F;
|
||||
}
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(
|
||||
&matmul_desc_, dtype_to_compute_type(dtype), type));
|
||||
&matmul_desc_, dtype_to_compute_type(dtype), scale_type));
|
||||
int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST;
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
||||
matmul_desc_,
|
||||
@ -65,6 +69,7 @@ class MatMul {
|
||||
&op,
|
||||
sizeof(cublasOperation_t)));
|
||||
|
||||
auto type = dtype_to_cuda_type(dtype);
|
||||
a_desc_ = create_matrix_layout(
|
||||
type, a_rows, a_cols, a_transposed, lda, batch_count, a_batch_stride);
|
||||
b_desc_ = create_matrix_layout(
|
||||
@ -187,17 +192,13 @@ class MatMul {
|
||||
private:
|
||||
cublasComputeType_t dtype_to_compute_type(Dtype dtype) {
|
||||
switch (dtype) {
|
||||
case uint8:
|
||||
case uint16:
|
||||
case int8:
|
||||
case int16:
|
||||
case int32:
|
||||
return CUBLAS_COMPUTE_32I;
|
||||
case float16:
|
||||
case bfloat16:
|
||||
return CUBLAS_COMPUTE_16F;
|
||||
case float32:
|
||||
return CUBLAS_COMPUTE_32F;
|
||||
case bfloat16:
|
||||
return CUBLAS_COMPUTE_32F;
|
||||
case float32:
|
||||
return mlx::core::env::enable_tf32() ? CUBLAS_COMPUTE_32F_FAST_TF32
|
||||
: CUBLAS_COMPUTE_32F;
|
||||
case float64:
|
||||
case complex64:
|
||||
return CUBLAS_COMPUTE_64F;
|
||||
@ -209,16 +210,6 @@ class MatMul {
|
||||
|
||||
cudaDataType_t dtype_to_cuda_type(Dtype dtype) {
|
||||
switch (dtype) {
|
||||
case uint8:
|
||||
return CUDA_R_8U;
|
||||
case uint16:
|
||||
return CUDA_R_16U;
|
||||
case int8:
|
||||
return CUDA_R_8I;
|
||||
case int16:
|
||||
return CUDA_R_16I;
|
||||
case int32:
|
||||
return CUDA_R_32I;
|
||||
case float16:
|
||||
return CUDA_R_16F;
|
||||
case bfloat16:
|
||||
|
@ -2,6 +2,7 @@
|
||||
#include <algorithm>
|
||||
#include <deque>
|
||||
#include <future>
|
||||
#include <mutex>
|
||||
#include <numeric>
|
||||
#include <set>
|
||||
#include <sstream>
|
||||
@ -36,6 +37,42 @@ class Synchronizer : public Primitive {
|
||||
DEFINE_PRINT(Synchronize);
|
||||
};
|
||||
|
||||
class Interrupt {
|
||||
private:
|
||||
static std::mutex mutex_;
|
||||
static bool eval_running_;
|
||||
static bool interrupt_;
|
||||
|
||||
public:
|
||||
Interrupt() {
|
||||
std::unique_lock lk(mutex_);
|
||||
eval_running_ = true;
|
||||
}
|
||||
|
||||
static bool interrupt() {
|
||||
std::unique_lock lk(mutex_);
|
||||
if (eval_running_) {
|
||||
interrupt_ = true;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool interrupted() {
|
||||
std::unique_lock lk(mutex_);
|
||||
return interrupt_;
|
||||
}
|
||||
|
||||
~Interrupt() {
|
||||
std::unique_lock lk(mutex_);
|
||||
eval_running_ = false;
|
||||
interrupt_ = false;
|
||||
}
|
||||
};
|
||||
std::mutex Interrupt::mutex_{};
|
||||
bool Interrupt::eval_running_ = false;
|
||||
bool Interrupt::interrupt_ = false;
|
||||
|
||||
// Initialize the static tracing members from transforms_impl.h
|
||||
//
|
||||
// These are used to implement the in_tracing() function the returns true if we
|
||||
@ -50,6 +87,8 @@ int detail::InTracing::grad_counter{0};
|
||||
int detail::RetainGraph::tracing_counter{0};
|
||||
|
||||
array eval_impl(std::vector<array> outputs, bool async) {
|
||||
Interrupt interrupt;
|
||||
|
||||
std::deque<array> tape;
|
||||
|
||||
// Make an effort to choose a good output stream
|
||||
@ -260,6 +299,11 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
||||
if (!arr.is_tracer()) {
|
||||
arr.detach();
|
||||
}
|
||||
|
||||
if (Interrupt::interrupted()) {
|
||||
synchronizer.attach_event(Event{stream});
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Signal the event in its stream
|
||||
@ -274,6 +318,10 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
||||
return synchronizer;
|
||||
}
|
||||
|
||||
bool interrupt_eval() {
|
||||
return Interrupt::interrupt();
|
||||
}
|
||||
|
||||
void async_eval(std::vector<array> outputs) {
|
||||
if (outputs.empty()) {
|
||||
return;
|
||||
|
@ -22,6 +22,14 @@ void eval(Arrays&&... outputs) {
|
||||
eval(std::vector<array>{std::forward<Arrays>(outputs)...});
|
||||
}
|
||||
|
||||
/**
|
||||
* Interrupt an ongoing eval.
|
||||
*
|
||||
* Leaves the graph in a valid state. Returns true if an ongoing eval was
|
||||
* interrupted and false otherwise.
|
||||
*/
|
||||
bool interrupt_eval();
|
||||
|
||||
/**
|
||||
* Computes the output and vector-Jacobian product (VJP) of a function.
|
||||
*
|
||||
|
@ -149,6 +149,11 @@ inline bool metal_fast_synch() {
|
||||
return metal_fast_synch;
|
||||
}
|
||||
|
||||
inline bool enable_tf32() {
|
||||
static bool enable_tf32_ = get_var("MLX_ENABLE_TF32", 1);
|
||||
return enable_tf32_;
|
||||
}
|
||||
|
||||
} // namespace env
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -1,5 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include "doctest/doctest.h"
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
@ -91,3 +90,17 @@ TEST_CASE("test eval graph retention when not tracing") {
|
||||
CHECK(!a.has_primitive());
|
||||
CHECK(a.is_available());
|
||||
}
|
||||
|
||||
TEST_CASE("test interrupt eval") {
|
||||
auto x = zeros({1024}, int32);
|
||||
for (int i = 0; i < 1000; ++i) {
|
||||
x = x + 1;
|
||||
}
|
||||
std::thread t([x]() { eval(x); });
|
||||
while (!interrupt_eval()) {
|
||||
}
|
||||
t.join();
|
||||
// Check that x is not evaluated
|
||||
CHECK(!x.is_available());
|
||||
CHECK(array_equal(x, full({1024}, 1000, int32)).item<bool>());
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user