Reduce a little overhead (#871)

* some small overhead improvements

* use result_type in rms_norm

* remove release force

* fix + use non-vector version

* revert compile change

* fix ops

* a little more overhead

* a little more cleanup and overhead
This commit is contained in:
Awni Hannun
2024-03-22 17:29:36 -07:00
committed by GitHub
parent 6ee1112f30
commit be98f4ab6b
13 changed files with 239 additions and 240 deletions

View File

@@ -131,8 +131,8 @@ class Module(dict):
return value
def __getattr__(self, key: str):
if key in self:
return self[key]
if (value := self.get(key, None)) is not None:
return value
else:
super(Module, self).__getattribute__(key)

View File

@@ -64,9 +64,9 @@ class Linear(Module):
def __call__(self, x: mx.array) -> mx.array:
if "bias" in self:
x = mx.addmm(self.bias, x, self.weight.T)
x = mx.addmm(self["bias"], x, self["weight"].T)
else:
x = x @ self.weight.T
x = x @ self["weight"].T
return x

View File

@@ -140,7 +140,7 @@ class RMSNorm(Module):
return f"{self.weight.shape[0]}, eps={self.eps}"
def __call__(self, x):
return mx.fast.rms_norm(x, self.weight, self.eps)
return mx.fast.rms_norm(x, self["weight"], self.eps)
class GroupNorm(Module):

View File

@@ -81,15 +81,15 @@ class QuantizedLinear(Module):
def __call__(self, x):
x = mx.quantized_matmul(
x,
self.weight,
scales=self.scales,
biases=self.biases,
self["weight"],
scales=self["scales"],
biases=self["biases"],
transpose=True,
group_size=self.group_size,
bits=self.bits,
)
if "bias" in self:
x = x + self.bias
x = x + self["bias"]
return x
@classmethod

View File

@@ -17,12 +17,7 @@ void init_fast(nb::module_& parent_module) {
m.def(
"rms_norm",
[](const array& x,
const array& weight,
float eps,
const StreamOrDevice& s /* = {} */) {
return fast::rms_norm(x, weight, eps, s);
},
&fast::rms_norm,
"x"_a,
"weight"_a,
"eps"_a,
@@ -48,13 +43,7 @@ void init_fast(nb::module_& parent_module) {
m.def(
"layer_norm",
[](const array& x,
const std::optional<array>& weight,
const std::optional<array>& bias,
float eps,
const StreamOrDevice& s /* = {} */) {
return fast::layer_norm(x, weight, bias, eps, s);
},
&fast::layer_norm,
"x"_a,
"weight"_a.none(),
"bias"_a.none(),
@@ -84,15 +73,7 @@ void init_fast(nb::module_& parent_module) {
m.def(
"rope",
[](const array& a,
int dims,
bool traditional,
float base,
float scale,
int offset,
const StreamOrDevice& s /* = {} */) {
return fast::rope(a, dims, traditional, base, scale, offset, s);
},
&fast::rope,
"a"_a,
"dims"_a,
nb::kw_only(),
@@ -123,14 +104,7 @@ void init_fast(nb::module_& parent_module) {
m.def(
"scaled_dot_product_attention",
[](const array& q,
const array& k,
const array& v,
const float scale,
const std::optional<array>& mask,
const StreamOrDevice& s) {
return fast::scaled_dot_product_attention(q, k, v, scale, mask, s);
},
&fast::scaled_dot_product_attention,
"q"_a,
"k"_a,
"v"_a,