Ring distributed backend (#1784)

This commit is contained in:
Angelos Katharopoulos
2025-01-27 22:15:01 -08:00
committed by GitHub
parent 2235dee906
commit ccb61d7aae
17 changed files with 1078 additions and 44 deletions

View File

@@ -652,7 +652,7 @@ void normalize_dynamic_slice_inputs(
const array& a,
const array& start,
std::vector<int>& axes,
const std::string prefix) {
std::string_view prefix) {
if (start.size() > a.ndim()) {
std::ostringstream msg;
msg << prefix << " Invalid number of starting positions for "
@@ -690,7 +690,9 @@ void normalize_dynamic_slice_inputs(
}
std::set dims(axes.begin(), axes.end());
if (dims.size() != axes.size()) {
throw std::invalid_argument(prefix + " Repeat axes not allowed.");
std::ostringstream msg;
msg << prefix << " Repeat axes not allowed.";
throw std::invalid_argument(msg.str());
}
}
@@ -927,7 +929,7 @@ split(const array& a, int num_splits, StreamOrDevice s /* = {} */) {
std::vector<array> meshgrid(
const std::vector<array>& arrays,
bool sparse /* = false */,
std::string indexing /* = "xy" */,
const std::string& indexing /* = "xy" */,
StreamOrDevice s /* = {} */) {
if (indexing != "xy" && indexing != "ij") {
throw std::invalid_argument(
@@ -1186,7 +1188,7 @@ array pad(
const Shape& low_pad_size,
const Shape& high_pad_size,
const array& pad_value /*= array(0)*/,
const std::string mode /*= "constant"*/,
const std::string& mode /*= "constant"*/,
StreamOrDevice s /* = {}*/) {
if (axes.size() != low_pad_size.size() ||
axes.size() != high_pad_size.size()) {
@@ -1238,7 +1240,7 @@ array pad(
const array& a,
const std::vector<std::pair<int, int>>& pad_width,
const array& pad_value /*= array(0)*/,
const std::string mode /*= "constant"*/,
const std::string& mode /*= "constant"*/,
StreamOrDevice s /*= {}*/) {
std::vector<int> axes(a.ndim(), 0);
std::iota(axes.begin(), axes.end(), 0);
@@ -1258,7 +1260,7 @@ array pad(
const array& a,
const std::pair<int, int>& pad_width,
const array& pad_value /*= array(0)*/,
const std::string mode /*= "constant"*/,
const std::string& mode /*= "constant"*/,
StreamOrDevice s /*= {}*/) {
return pad(
a,
@@ -1272,7 +1274,7 @@ array pad(
const array& a,
int pad_width,
const array& pad_value /*= array(0)*/,
const std::string mode /*= "constant"*/,
const std::string& mode /*= "constant"*/,
StreamOrDevice s /*= {}*/) {
return pad(
a,