mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Export / import functions to / from a file (#1642)
* export and import functions * refactor + works for few primitives * nit * allow primitives with state * nit * nit * simplify serialize / deserialize * fix for constants * python bindings * maybe fix serialize failure case * add example * more primitives, training kind of works * same result for python and c++ * some fixes * fix export * template it up * some simplificatoin * rebase * allow kwargs and multiple functions * exporter * more primitives for exporting * deal with endianness * handle invalid stream * add docstring
This commit is contained in:
@@ -697,8 +697,7 @@ array scaled_dot_product_attention(
|
||||
return array(
|
||||
std::move(out_shape),
|
||||
final_type,
|
||||
std::make_shared<ScaledDotProductAttention>(
|
||||
stream, fallback, scale, false),
|
||||
std::make_shared<ScaledDotProductAttention>(stream, fallback, scale),
|
||||
{q, k, v});
|
||||
}
|
||||
|
||||
@@ -712,7 +711,7 @@ array scaled_dot_product_attention(
|
||||
bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const {
|
||||
const ScaledDotProductAttention& a_other =
|
||||
static_cast<const ScaledDotProductAttention&>(other);
|
||||
return needs_mask_ == a_other.needs_mask_ && scale_ == a_other.scale_;
|
||||
return scale_ == a_other.scale_;
|
||||
}
|
||||
|
||||
array pack_and_quantize(
|
||||
|
||||
Reference in New Issue
Block a user