mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fixes for large arrays with a few ops (#1299)
* fixes for large arrays with a few ops * fix bug * fix all of copy
This commit is contained in:
@@ -32,6 +32,7 @@ void ternary_op_gpu_inplace(
|
||||
auto& strides_c = strides[2];
|
||||
auto& strides_out = strides[3];
|
||||
|
||||
bool use_2d = out.data_size();
|
||||
std::string kernel_name;
|
||||
{
|
||||
std::ostringstream kname;
|
||||
@@ -40,6 +41,8 @@ void ternary_op_gpu_inplace(
|
||||
if (shape.size() <= MAX_TERNARY_SPECIALIZED_DIMS) {
|
||||
kname << shape.size();
|
||||
}
|
||||
} else if (use_2d) {
|
||||
kname << "v2";
|
||||
} else {
|
||||
kname << "v";
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user