minor fixes (#1194)

* minor fixes

* fix build errors
This commit is contained in:
Fangjun Kuang
2024-06-13 13:06:49 +08:00
committed by GitHub
parent 934683088e
commit f20e97b092
16 changed files with 239 additions and 238 deletions

View File

@@ -116,7 +116,7 @@ std::vector<array> Primitive::jvp(
print(msg);
msg << ".";
throw std::invalid_argument(msg.str());
};
}
std::vector<array> Primitive::vjp(
const std::vector<array>&,
@@ -128,7 +128,7 @@ std::vector<array> Primitive::vjp(
print(msg);
msg << ".";
throw std::invalid_argument(msg.str());
};
}
std::pair<std::vector<array>, std::vector<int>> Primitive::vmap(
const std::vector<array>&,
@@ -138,7 +138,7 @@ std::pair<std::vector<array>, std::vector<int>> Primitive::vmap(
print(msg);
msg << ".";
throw std::invalid_argument(msg.str());
};
}
std::vector<std::vector<int>> Primitive::output_shapes(
const std::vector<array>&) {
@@ -147,7 +147,7 @@ std::vector<std::vector<int>> Primitive::output_shapes(
this->print(msg);
msg << " cannot infer output shapes.";
throw std::invalid_argument(msg.str());
};
}
std::vector<array> Abs::vjp(
const std::vector<array>& primals,
@@ -3430,7 +3430,7 @@ std::pair<std::vector<array>, std::vector<int>> StopGradient::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
return {{stop_gradient(inputs[0], stream())}, axes};
};
}
std::vector<array> Subtract::vjp(
const std::vector<array>& primals,