mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-02 16:56:46 +08:00
Fix ternary for large arrays (#1359)
* fix ternary for large arrays * fix
This commit is contained in:
parent
860d3a50d7
commit
2fdf9eb535
@ -12,6 +12,7 @@ namespace {
|
|||||||
// TODO: Add support for more combinations of input types.
|
// TODO: Add support for more combinations of input types.
|
||||||
enum class TernaryOpType {
|
enum class TernaryOpType {
|
||||||
ScalarScalarScalar,
|
ScalarScalarScalar,
|
||||||
|
VectorVectorVector,
|
||||||
General,
|
General,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -20,6 +21,12 @@ get_ternary_op_type(const array& a, const array& b, const array& c) {
|
|||||||
TernaryOpType topt;
|
TernaryOpType topt;
|
||||||
if (a.data_size() == 1 && b.data_size() == 1 && c.data_size() == 1) {
|
if (a.data_size() == 1 && b.data_size() == 1 && c.data_size() == 1) {
|
||||||
topt = TernaryOpType::ScalarScalarScalar;
|
topt = TernaryOpType::ScalarScalarScalar;
|
||||||
|
} else if (
|
||||||
|
(a.flags().row_contiguous && b.flags().row_contiguous &&
|
||||||
|
c.flags().row_contiguous) ||
|
||||||
|
(a.flags().col_contiguous && b.flags().col_contiguous &&
|
||||||
|
c.flags().col_contiguous)) {
|
||||||
|
topt = TernaryOpType::VectorVectorVector;
|
||||||
} else {
|
} else {
|
||||||
topt = TernaryOpType::General;
|
topt = TernaryOpType::General;
|
||||||
}
|
}
|
||||||
@ -33,11 +40,32 @@ void set_ternary_op_output_data(
|
|||||||
array& out,
|
array& out,
|
||||||
TernaryOpType topt,
|
TernaryOpType topt,
|
||||||
bool donate_with_move = false) {
|
bool donate_with_move = false) {
|
||||||
|
auto maybe_donate = [&out, donate_with_move](const array& x) {
|
||||||
|
if (x.is_donatable() && x.itemsize() == out.itemsize()) {
|
||||||
|
if (donate_with_move) {
|
||||||
|
out.move_shared_buffer(x);
|
||||||
|
} else {
|
||||||
|
out.copy_shared_buffer(x);
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
|
||||||
switch (topt) {
|
switch (topt) {
|
||||||
case TernaryOpType::ScalarScalarScalar:
|
case TernaryOpType::ScalarScalarScalar:
|
||||||
out.set_data(
|
out.set_data(
|
||||||
allocator::malloc_or_wait(out.itemsize()), 1, b.strides(), b.flags());
|
allocator::malloc_or_wait(out.itemsize()), 1, b.strides(), b.flags());
|
||||||
break;
|
break;
|
||||||
|
case TernaryOpType::VectorVectorVector:
|
||||||
|
if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) {
|
||||||
|
out.set_data(
|
||||||
|
allocator::malloc_or_wait(out.itemsize() * b.data_size()),
|
||||||
|
b.data_size(),
|
||||||
|
b.strides(),
|
||||||
|
b.flags());
|
||||||
|
}
|
||||||
|
break;
|
||||||
case TernaryOpType::General:
|
case TernaryOpType::General:
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
break;
|
break;
|
||||||
|
@ -56,9 +56,12 @@ void ternary_op_gpu_inplace(
|
|||||||
|
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
compute_encoder->setComputePipelineState(kernel);
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
compute_encoder.set_input_array(a, 0);
|
bool donate_a = a.data_shared_ptr() == nullptr;
|
||||||
compute_encoder.set_input_array(b, 1);
|
bool donate_b = b.data_shared_ptr() == nullptr;
|
||||||
compute_encoder.set_input_array(c, 2);
|
bool donate_c = c.data_shared_ptr() == nullptr;
|
||||||
|
compute_encoder.set_input_array(donate_a ? out : a, 0);
|
||||||
|
compute_encoder.set_input_array(donate_b ? out : b, 1);
|
||||||
|
compute_encoder.set_input_array(donate_c ? out : c, 2);
|
||||||
compute_encoder.set_output_array(out, 3);
|
compute_encoder.set_output_array(out, 3);
|
||||||
|
|
||||||
if (topt == TernaryOpType::General) {
|
if (topt == TernaryOpType::General) {
|
||||||
@ -91,9 +94,10 @@ void ternary_op_gpu_inplace(
|
|||||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||||
} else {
|
} else {
|
||||||
// Launch a 1D grid of threads
|
// Launch a 1D or 2D grid of threads
|
||||||
size_t nthreads = out.data_size();
|
size_t nthreads = out.data_size();
|
||||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
|
||||||
|
: MTL::Size(nthreads, 1, 1);
|
||||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||||
if (thread_group_size > nthreads) {
|
if (thread_group_size > nthreads) {
|
||||||
thread_group_size = nthreads;
|
thread_group_size = nthreads;
|
||||||
|
Loading…
Reference in New Issue
Block a user