No quant reshape (#957)

* precise option on cpu

* remove print

* remove reshape in quant matmul

* no quant reshape
This commit is contained in:
Awni Hannun
2024-04-04 11:52:12 -07:00
committed by GitHub
parent d88d2124b5
commit 039da779d1
2 changed files with 34 additions and 18 deletions

View File

@@ -3057,7 +3057,6 @@ array quantized_matmul(
int bits /* = 4 */,
StreamOrDevice s /* = {} */) {
array x = in_x;
if (w.dtype() != uint32) {
std::ostringstream msg;
msg << "[quantized_matmul] The weight matrix should be uint32 "
@@ -3074,12 +3073,6 @@ array quantized_matmul(
// Keep x's batch dimensions to reshape it back after the matmul
auto original_shape = x.shape();
int x_inner_dims = original_shape.back();
original_shape.pop_back();
// Reshape x into a matrix if it isn't already one
if (x.ndim() != 2) {
x = reshape(x, {-1, x_inner_dims}, s);
}
if (scales.ndim() != 2 || scales.shape() != biases.shape()) {
std::ostringstream msg;
@@ -3122,9 +3115,10 @@ array quantized_matmul(
<< " and biases.dtype() == " << biases.dtype();
throw std::invalid_argument(msg.str());
}
auto out = array(
{x.shape(0), w_outer_dims},
std::vector<array> inputs;
original_shape.back() = w_outer_dims;
return array(
std::move(original_shape),
dtype,
std::make_shared<QuantizedMatmul>(
to_stream(s), group_size, bits, transpose),
@@ -3132,14 +3126,6 @@ array quantized_matmul(
w,
astype(scales, dtype, s),
astype(biases, dtype, s)});
// If needed reshape x to the original batch shape
if (original_shape.size() != 1) {
original_shape.push_back(w_outer_dims);
out = reshape(out, std::move(original_shape), s);
}
return out;
}
std::tuple<array, array, array> quantize(