Export / import functions to / from a file (#1642)

* export and import functions

* refactor + works for few primitives

* nit

* allow primitives with state

* nit

* nit

* simplify serialize / deserialize

* fix for constants

* python bindings

* maybe fix serialize failure case

* add example

* more primitives, training kind of works

* same result for python and c++

* some fixes

* fix export

* template it up

* some simplificatoin

* rebase

* allow kwargs and multiple functions

* exporter

* more primitives for exporting

* deal with endianness

* handle invalid stream

* add docstring
This commit is contained in:
Awni Hannun
2024-12-24 11:19:13 -08:00
committed by GitHub
parent 935c8c4bb1
commit 4ba0c24a8f
35 changed files with 2239 additions and 90 deletions

View File

@@ -697,8 +697,7 @@ array scaled_dot_product_attention(
return array(
std::move(out_shape),
final_type,
std::make_shared<ScaledDotProductAttention>(
stream, fallback, scale, false),
std::make_shared<ScaledDotProductAttention>(stream, fallback, scale),
{q, k, v});
}
@@ -712,7 +711,7 @@ array scaled_dot_product_attention(
bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const {
const ScaledDotProductAttention& a_other =
static_cast<const ScaledDotProductAttention&>(other);
return needs_mask_ == a_other.needs_mask_ && scale_ == a_other.scale_;
return scale_ == a_other.scale_;
}
array pack_and_quantize(