Adds nuclear norm support (#1894)

* adjust norm unit test tolerance
This commit is contained in:
Abe Leininger
2025-03-04 15:26:02 -06:00
committed by GitHub
parent 9680f72cca
commit 3835a428c5
11 changed files with 260 additions and 55 deletions

View File

@@ -2287,7 +2287,8 @@ class QRF : public Primitive {
/* SVD primitive. */
class SVD : public Primitive {
public:
explicit SVD(Stream stream) : Primitive(stream) {}
explicit SVD(Stream stream, bool compute_uv)
: Primitive(stream), compute_uv_(compute_uv) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
@@ -2296,6 +2297,12 @@ class SVD : public Primitive {
DEFINE_VMAP()
DEFINE_PRINT(SVD)
auto state() const {
return compute_uv_;
}
private:
bool compute_uv_;
};
/* Matrix inversion primitive. */