Fast Hadamard Transform (#1249)

* Working hadamard for powers of 2

* working for m*2^k

* add scale and check contiguity

* add size check

* clean up

* fix test

* add grads + vmap

* gpu only

* skip on linux

* test typo

* add cpu impl

* remove gpu only tests

* fix linux build + add is_equivalent
This commit is contained in:
Alex Barron
2024-07-09 20:39:01 -07:00
committed by GitHub
parent 03cf033f82
commit a3c287354f
22 changed files with 878 additions and 11 deletions

View File

@@ -1064,6 +1064,27 @@ class GreaterEqual : public UnaryPrimitive {
void eval(const std::vector<array>& inputs, array& out);
};
class Hadamard : public UnaryPrimitive {
public:
explicit Hadamard(Stream stream, float scale)
: UnaryPrimitive(stream), scale_(scale) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
DEFINE_GRADS()
DEFINE_PRINT(Hadamard)
DEFINE_INPUT_OUTPUT_SHAPE()
bool is_equivalent(const Primitive& other) const override;
private:
float scale_;
void eval(const std::vector<array>& inputs, array& out);
};
class Less : public UnaryPrimitive {
public:
explicit Less(Stream stream) : UnaryPrimitive(stream) {}