mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Ring distributed backend (#1784)
This commit is contained in:
committed by
GitHub
parent
2235dee906
commit
ccb61d7aae
16
mlx/ops.cpp
16
mlx/ops.cpp
@@ -652,7 +652,7 @@ void normalize_dynamic_slice_inputs(
|
||||
const array& a,
|
||||
const array& start,
|
||||
std::vector<int>& axes,
|
||||
const std::string prefix) {
|
||||
std::string_view prefix) {
|
||||
if (start.size() > a.ndim()) {
|
||||
std::ostringstream msg;
|
||||
msg << prefix << " Invalid number of starting positions for "
|
||||
@@ -690,7 +690,9 @@ void normalize_dynamic_slice_inputs(
|
||||
}
|
||||
std::set dims(axes.begin(), axes.end());
|
||||
if (dims.size() != axes.size()) {
|
||||
throw std::invalid_argument(prefix + " Repeat axes not allowed.");
|
||||
std::ostringstream msg;
|
||||
msg << prefix << " Repeat axes not allowed.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -927,7 +929,7 @@ split(const array& a, int num_splits, StreamOrDevice s /* = {} */) {
|
||||
std::vector<array> meshgrid(
|
||||
const std::vector<array>& arrays,
|
||||
bool sparse /* = false */,
|
||||
std::string indexing /* = "xy" */,
|
||||
const std::string& indexing /* = "xy" */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (indexing != "xy" && indexing != "ij") {
|
||||
throw std::invalid_argument(
|
||||
@@ -1186,7 +1188,7 @@ array pad(
|
||||
const Shape& low_pad_size,
|
||||
const Shape& high_pad_size,
|
||||
const array& pad_value /*= array(0)*/,
|
||||
const std::string mode /*= "constant"*/,
|
||||
const std::string& mode /*= "constant"*/,
|
||||
StreamOrDevice s /* = {}*/) {
|
||||
if (axes.size() != low_pad_size.size() ||
|
||||
axes.size() != high_pad_size.size()) {
|
||||
@@ -1238,7 +1240,7 @@ array pad(
|
||||
const array& a,
|
||||
const std::vector<std::pair<int, int>>& pad_width,
|
||||
const array& pad_value /*= array(0)*/,
|
||||
const std::string mode /*= "constant"*/,
|
||||
const std::string& mode /*= "constant"*/,
|
||||
StreamOrDevice s /*= {}*/) {
|
||||
std::vector<int> axes(a.ndim(), 0);
|
||||
std::iota(axes.begin(), axes.end(), 0);
|
||||
@@ -1258,7 +1260,7 @@ array pad(
|
||||
const array& a,
|
||||
const std::pair<int, int>& pad_width,
|
||||
const array& pad_value /*= array(0)*/,
|
||||
const std::string mode /*= "constant"*/,
|
||||
const std::string& mode /*= "constant"*/,
|
||||
StreamOrDevice s /*= {}*/) {
|
||||
return pad(
|
||||
a,
|
||||
@@ -1272,7 +1274,7 @@ array pad(
|
||||
const array& a,
|
||||
int pad_width,
|
||||
const array& pad_value /*= array(0)*/,
|
||||
const std::string mode /*= "constant"*/,
|
||||
const std::string& mode /*= "constant"*/,
|
||||
StreamOrDevice s /*= {}*/) {
|
||||
return pad(
|
||||
a,
|
||||
|
||||
Reference in New Issue
Block a user