mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 12:49:44 +08:00
Reduce a little overhead (#871)
* some small overhead improvements * use result_type in rms_norm * remove release force * fix + use non-vector version * revert compile change * fix ops * a little more overhead * a little more cleanup and overhead
This commit is contained in:
@@ -131,8 +131,8 @@ class Module(dict):
|
||||
return value
|
||||
|
||||
def __getattr__(self, key: str):
|
||||
if key in self:
|
||||
return self[key]
|
||||
if (value := self.get(key, None)) is not None:
|
||||
return value
|
||||
else:
|
||||
super(Module, self).__getattribute__(key)
|
||||
|
||||
|
@@ -64,9 +64,9 @@ class Linear(Module):
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
if "bias" in self:
|
||||
x = mx.addmm(self.bias, x, self.weight.T)
|
||||
x = mx.addmm(self["bias"], x, self["weight"].T)
|
||||
else:
|
||||
x = x @ self.weight.T
|
||||
x = x @ self["weight"].T
|
||||
return x
|
||||
|
||||
|
||||
|
@@ -140,7 +140,7 @@ class RMSNorm(Module):
|
||||
return f"{self.weight.shape[0]}, eps={self.eps}"
|
||||
|
||||
def __call__(self, x):
|
||||
return mx.fast.rms_norm(x, self.weight, self.eps)
|
||||
return mx.fast.rms_norm(x, self["weight"], self.eps)
|
||||
|
||||
|
||||
class GroupNorm(Module):
|
||||
|
@@ -81,15 +81,15 @@ class QuantizedLinear(Module):
|
||||
def __call__(self, x):
|
||||
x = mx.quantized_matmul(
|
||||
x,
|
||||
self.weight,
|
||||
scales=self.scales,
|
||||
biases=self.biases,
|
||||
self["weight"],
|
||||
scales=self["scales"],
|
||||
biases=self["biases"],
|
||||
transpose=True,
|
||||
group_size=self.group_size,
|
||||
bits=self.bits,
|
||||
)
|
||||
if "bias" in self:
|
||||
x = x + self.bias
|
||||
x = x + self["bias"]
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
|
@@ -17,12 +17,7 @@ void init_fast(nb::module_& parent_module) {
|
||||
|
||||
m.def(
|
||||
"rms_norm",
|
||||
[](const array& x,
|
||||
const array& weight,
|
||||
float eps,
|
||||
const StreamOrDevice& s /* = {} */) {
|
||||
return fast::rms_norm(x, weight, eps, s);
|
||||
},
|
||||
&fast::rms_norm,
|
||||
"x"_a,
|
||||
"weight"_a,
|
||||
"eps"_a,
|
||||
@@ -48,13 +43,7 @@ void init_fast(nb::module_& parent_module) {
|
||||
|
||||
m.def(
|
||||
"layer_norm",
|
||||
[](const array& x,
|
||||
const std::optional<array>& weight,
|
||||
const std::optional<array>& bias,
|
||||
float eps,
|
||||
const StreamOrDevice& s /* = {} */) {
|
||||
return fast::layer_norm(x, weight, bias, eps, s);
|
||||
},
|
||||
&fast::layer_norm,
|
||||
"x"_a,
|
||||
"weight"_a.none(),
|
||||
"bias"_a.none(),
|
||||
@@ -84,15 +73,7 @@ void init_fast(nb::module_& parent_module) {
|
||||
|
||||
m.def(
|
||||
"rope",
|
||||
[](const array& a,
|
||||
int dims,
|
||||
bool traditional,
|
||||
float base,
|
||||
float scale,
|
||||
int offset,
|
||||
const StreamOrDevice& s /* = {} */) {
|
||||
return fast::rope(a, dims, traditional, base, scale, offset, s);
|
||||
},
|
||||
&fast::rope,
|
||||
"a"_a,
|
||||
"dims"_a,
|
||||
nb::kw_only(),
|
||||
@@ -123,14 +104,7 @@ void init_fast(nb::module_& parent_module) {
|
||||
|
||||
m.def(
|
||||
"scaled_dot_product_attention",
|
||||
[](const array& q,
|
||||
const array& k,
|
||||
const array& v,
|
||||
const float scale,
|
||||
const std::optional<array>& mask,
|
||||
const StreamOrDevice& s) {
|
||||
return fast::scaled_dot_product_attention(q, k, v, scale, mask, s);
|
||||
},
|
||||
&fast::scaled_dot_product_attention,
|
||||
"q"_a,
|
||||
"k"_a,
|
||||
"v"_a,
|
||||
|
Reference in New Issue
Block a user