Cuda bug fixes 2 (#2298)

* more bug fixes

* more bug fixes

* format
This commit is contained in:
Awni Hannun
2025-06-16 13:14:46 -07:00
committed by GitHub
parent c552ff2451
commit bc53f8293f
11 changed files with 143 additions and 107 deletions

View File

@@ -65,8 +65,8 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
Dtype idx_dtype = nidx > 0 ? inputs[1].dtype() : int32;
int32_t idx_ndim = nidx > 0 ? inputs[1].ndim() : 0;
bool large = (nidx > 0 && inputs[1].size() > UINT32_MAX) ||
(src.size() > UINT32_MAX) || (out.size() > UINT32_MAX);
bool large = (nidx > 0 && inputs[1].size() > INT32_MAX) ||
(src.size() > INT32_MAX) || (out.size() > INT32_MAX);
uint32_t slice_size = std::accumulate(
slice_sizes_.begin(), slice_sizes_.end(), 1, std::multiplies<uint32_t>());
@@ -88,7 +88,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
dtype_to_cuda_type(idx_dtype),
nidx,
ndim,
large ? "int64_t" : "uint32_t"));
large ? "int64_t" : "int32_t"));
}
}
return std::make_pair(jit_source_gather, std::move(kernel_names));
@@ -99,7 +99,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
if (large) {
mod.append_arg<int64_t>(out.size());
} else {
mod.append_arg<uint32_t>(out.size());
mod.append_arg<int32_t>(out.size());
}
mod.append_ndim_arg(src.shape());
mod.append_ndim_arg(src.strides());
@@ -115,7 +115,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
dtype_to_cuda_type(idx_dtype),
nidx,
idx_ndim,
large ? "int64_t" : "uint32_t");
large ? "int64_t" : "int32_t");
auto& encoder = cu::get_command_encoder(s);
for (const auto& in : inputs) {
@@ -152,14 +152,14 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
Dtype idx_dtype = nidx > 0 ? inputs[1].dtype() : int32;
int32_t idx_ndim = nidx > 0 ? inputs[1].ndim() : 0;
bool large = (nidx > 0 && inputs[1].size() > UINT32_MAX) ||
(upd.size() > UINT32_MAX) || (out.size() > UINT32_MAX);
bool large = (nidx > 0 && inputs[1].size() > INT32_MAX) ||
(upd.size() > INT32_MAX) || (out.size() > INT32_MAX);
uint32_t upd_post_idx_size = std::accumulate(
int32_t upd_post_idx_size = std::accumulate(
upd.shape().begin() + idx_ndim,
upd.shape().end(),
1,
std::multiplies<uint32_t>());
std::multiplies<int32_t>());
const char* op = g_scatter_ops[reduce_type_];
std::string module_name = fmt::format(
@@ -181,7 +181,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
op,
nidx,
ndim,
large ? "int64_t" : "uint32_t"));
large ? "int64_t" : "int32_t"));
}
}
return std::make_pair(jit_source_scatter, std::move(kernel_names));
@@ -192,7 +192,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
if (large) {
mod.append_arg<int64_t>(upd.size());
} else {
mod.append_arg<uint32_t>(upd.size());
mod.append_arg<int32_t>(upd.size());
}
mod.append_ndim_arg(upd.shape());
mod.append_ndim_arg(upd.strides());
@@ -200,7 +200,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
if (large) {
mod.append_arg<int64_t>(upd_post_idx_size);
} else {
mod.append_arg<uint32_t>(upd_post_idx_size);
mod.append_arg<int32_t>(upd_post_idx_size);
}
mod.append_ndim_arg(out.shape());
mod.append_ndim_arg(out.strides());
@@ -215,7 +215,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
op,
nidx,
idx_ndim,
large ? "int64_t" : "uint32_t");
large ? "int64_t" : "int32_t");
auto& encoder = cu::get_command_encoder(s);
for (const auto& in : inputs) {
@@ -238,7 +238,7 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
return;
}
bool large = idx.size() > UINT32_MAX || src.size() > UINT32_MAX;
bool large = idx.size() > INT32_MAX || src.size() > INT32_MAX;
std::string module_name = fmt::format(
"gather_axis_{}_{}",
@@ -258,7 +258,7 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
ndim,
contiguous & 1 ? true : false,
contiguous & 2 ? true : false,
large ? "int64_t" : "uint32_t"));
large ? "int64_t" : "int32_t"));
}
}
}
@@ -283,9 +283,9 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
mod.append_arg<int64_t>(idx_size_axis);
mod.append_arg<int64_t>(idx_size_post);
} else {
mod.append_arg<uint32_t>(idx_size_pre);
mod.append_arg<uint32_t>(idx_size_axis);
mod.append_arg<uint32_t>(idx_size_post);
mod.append_arg<int32_t>(idx_size_pre);
mod.append_arg<int32_t>(idx_size_axis);
mod.append_arg<int32_t>(idx_size_post);
}
mod.append_arg(remove_index(idx.shape(), axis_));
mod.append_arg(remove_index(src.strides(), axis_));
@@ -302,7 +302,7 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
src.ndim() - 1,
src.flags().row_contiguous,
idx.flags().row_contiguous,
large ? "int64_t" : "uint32_t");
large ? "int64_t" : "int32_t");
auto& encoder = cu::get_command_encoder(s);
for (const auto& in : inputs) {
@@ -337,7 +337,7 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
return;
}
bool large = idx.size() > UINT32_MAX || src.size() > UINT32_MAX;
bool large = idx.size() > INT32_MAX || src.size() > INT32_MAX;
const char* op = reduce_type_ == ScatterAxis::Sum ? "Sum" : "Assign";
std::string module_name = fmt::format(
@@ -360,7 +360,7 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
ndim,
contiguous & 1 ? true : false,
contiguous & 2 ? true : false,
large ? "int64_t" : "uint32_t"));
large ? "int64_t" : "int32_t"));
}
}
}
@@ -385,9 +385,9 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
mod.append_arg<int64_t>(idx_size_axis);
mod.append_arg<int64_t>(idx_size_post);
} else {
mod.append_arg<uint32_t>(idx_size_pre);
mod.append_arg<uint32_t>(idx_size_axis);
mod.append_arg<uint32_t>(idx_size_post);
mod.append_arg<int32_t>(idx_size_pre);
mod.append_arg<int32_t>(idx_size_axis);
mod.append_arg<int32_t>(idx_size_post);
}
mod.append_arg(remove_index(idx.shape(), axis_));
mod.append_arg(remove_index(upd.strides(), axis_));
@@ -405,7 +405,7 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
idx.ndim() - 1,
upd.flags().row_contiguous,
idx.flags().row_contiguous,
large ? "int64_t" : "uint32_t");
large ? "int64_t" : "int32_t");
auto& encoder = cu::get_command_encoder(s);
for (const auto& in : inputs) {