mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
@@ -2287,7 +2287,8 @@ class QRF : public Primitive {
|
||||
/* SVD primitive. */
|
||||
class SVD : public Primitive {
|
||||
public:
|
||||
explicit SVD(Stream stream) : Primitive(stream) {}
|
||||
explicit SVD(Stream stream, bool compute_uv)
|
||||
: Primitive(stream), compute_uv_(compute_uv) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
@@ -2296,6 +2297,12 @@ class SVD : public Primitive {
|
||||
|
||||
DEFINE_VMAP()
|
||||
DEFINE_PRINT(SVD)
|
||||
auto state() const {
|
||||
return compute_uv_;
|
||||
}
|
||||
|
||||
private:
|
||||
bool compute_uv_;
|
||||
};
|
||||
|
||||
/* Matrix inversion primitive. */
|
||||
|
||||
Reference in New Issue
Block a user