mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
CPU binary reduction + Nits (#1242)
* very minor nits * reduce binary * fix test
This commit is contained in:
@@ -157,8 +157,8 @@ class Module(dict):
|
||||
|
||||
Args:
|
||||
file_or_weights (str or list(tuple(str, mx.array))): The path to
|
||||
the weights ``.npz`` file (``.npz`` or ``.safetensors``) or a list of pairs of parameter names
|
||||
and arrays.
|
||||
the weights ``.npz`` file (``.npz`` or ``.safetensors``) or a list
|
||||
of pairs of parameter names and arrays.
|
||||
strict (bool, optional): If ``True`` then checks that the provided
|
||||
weights exactly match the parameters of the model. Otherwise,
|
||||
only the weights actually contained in the model are loaded and
|
||||
@@ -222,7 +222,7 @@ class Module(dict):
|
||||
if v_new.shape != v.shape:
|
||||
raise ValueError(
|
||||
f"Expected shape {v.shape} but received "
|
||||
f" shape {v_new.shape} for parameter {k}"
|
||||
f"shape {v_new.shape} for parameter {k}"
|
||||
)
|
||||
|
||||
self.update(tree_unflatten(weights))
|
||||
|
@@ -83,7 +83,7 @@ void init_fast(nb::module_& parent_module) {
|
||||
"offset"_a,
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def rope(a: array, dims: int, *, traditinoal: bool, base: float, scale: float, offset: int, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def rope(a: array, dims: int, *, traditional: bool, base: float, scale: float, offset: int, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Apply rotary positional encoding to the input.
|
||||
|
||||
|
Reference in New Issue
Block a user