Fixes for large arrays with a few ops (#1299)

* fixes for large arrays with a few ops

* fix bug

* fix all of copy
This commit is contained in:
Awni Hannun
2024-07-30 17:18:39 -07:00
committed by GitHub
parent c52d1600f0
commit 40b6d67333
21 changed files with 273 additions and 202 deletions

View File

@@ -32,6 +32,7 @@ void ternary_op_gpu_inplace(
auto& strides_c = strides[2];
auto& strides_out = strides[3];
bool use_2d = out.data_size();
std::string kernel_name;
{
std::ostringstream kname;
@@ -40,6 +41,8 @@ void ternary_op_gpu_inplace(
if (shape.size() <= MAX_TERNARY_SPECIALIZED_DIMS) {
kname << shape.size();
}
} else if (use_2d) {
kname << "v2";
} else {
kname << "v";
}