mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
fix warnings showing up with -Wall (#2692)
This commit is contained in:
@@ -996,131 +996,6 @@ void explicit_gemm_conv_1D_cpu(
|
||||
encoder.add_temporaries(std::move(temps));
|
||||
}
|
||||
|
||||
void explicit_gemm_conv_2D_cpu(
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const std::vector<int>& padding_lo,
|
||||
const std::vector<int>& padding_hi,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation,
|
||||
Stream stream) {
|
||||
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
|
||||
const int iH = in.shape(1); // Input spatial dim
|
||||
const int iW = in.shape(2); // Input spatial dim
|
||||
const int oH = out.shape(1); // Output spatial dim
|
||||
const int oW = out.shape(2); // Output spatial dim
|
||||
const int O = wt.shape(0); // Out channels
|
||||
const int C = wt.shape(3); // In channels
|
||||
const int wH = wt.shape(1); // Weight spatial dim
|
||||
const int wW = wt.shape(2); // Weight spatial dim
|
||||
|
||||
auto conv_dtype = out.dtype();
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
|
||||
// Pad input
|
||||
Shape padded_shape = {
|
||||
N,
|
||||
iH + padding_lo[0] + padding_hi[0],
|
||||
iW + padding_lo[1] + padding_hi[1],
|
||||
C};
|
||||
array in_padded(padded_shape, conv_dtype, nullptr, {});
|
||||
|
||||
// Fill with zeros
|
||||
std::vector<array> temps;
|
||||
temps.push_back(array(0, conv_dtype));
|
||||
copy_cpu(temps.back(), in_padded, CopyType::Scalar, stream);
|
||||
|
||||
// Pick input slice from padded
|
||||
size_t data_offset = padding_lo[0] * in_padded.strides()[1] +
|
||||
padding_lo[1] * in_padded.strides()[2];
|
||||
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
|
||||
in_padded_slice.copy_shared_buffer(
|
||||
in_padded,
|
||||
in_padded.strides(),
|
||||
in_padded.flags(),
|
||||
in_padded_slice.size(),
|
||||
data_offset);
|
||||
temps.push_back(in_padded_slice);
|
||||
|
||||
// Copy input values into the slice
|
||||
copy_cpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream);
|
||||
|
||||
// Make strided view
|
||||
Shape strided_shape = {N, oH, oW, wH, wW, C};
|
||||
|
||||
Strides strided_strides = {
|
||||
in_padded.strides()[0],
|
||||
in_padded.strides()[1] * wt_strides[0],
|
||||
in_padded.strides()[2] * wt_strides[1],
|
||||
in_padded.strides()[1],
|
||||
in_padded.strides()[2],
|
||||
in_padded.strides()[3]};
|
||||
auto flags = in_padded.flags();
|
||||
|
||||
array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {});
|
||||
in_strided_view.copy_shared_buffer(
|
||||
in_padded, strided_strides, flags, in_strided_view.size(), 0);
|
||||
|
||||
// Materialize strided view
|
||||
Shape strided_reshape = {N * oH * oW, wH * wW * C};
|
||||
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
|
||||
copy_cpu(in_strided_view, in_strided, CopyType::General, stream);
|
||||
temps.push_back(in_strided);
|
||||
|
||||
// Check wt dtype and prepare
|
||||
auto gemm_wt = wt;
|
||||
auto gemm_out = out;
|
||||
|
||||
if (wt.dtype() != float32 || !wt.flags().row_contiguous) {
|
||||
auto ctype =
|
||||
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
||||
gemm_wt = array(wt.shape(), float32, nullptr, {});
|
||||
copy_cpu(wt, gemm_wt, ctype, stream);
|
||||
temps.push_back(gemm_wt);
|
||||
}
|
||||
|
||||
if (out.dtype() != float32) {
|
||||
gemm_out = array(out.shape(), float32, nullptr, {});
|
||||
gemm_out.set_data(allocator::malloc(gemm_out.nbytes()));
|
||||
temps.push_back(gemm_out);
|
||||
}
|
||||
|
||||
encoder.set_input_array(in_strided);
|
||||
encoder.set_input_array(gemm_wt);
|
||||
encoder.set_output_array(gemm_out);
|
||||
|
||||
encoder.dispatch([in_strided_ptr = in_strided.data<float>(),
|
||||
gemm_wt_ptr = gemm_wt.data<float>(),
|
||||
gemm_out_ptr = gemm_out.data<float>(),
|
||||
strided_reshape = std::move(strided_reshape),
|
||||
O]() {
|
||||
// Perform gemm
|
||||
cblas_sgemm(
|
||||
CblasRowMajor,
|
||||
CblasNoTrans, // no trans A
|
||||
CblasTrans, // transB
|
||||
strided_reshape[0], // M
|
||||
O, // N
|
||||
strided_reshape[1], // K
|
||||
1.0f, // alpha
|
||||
in_strided_ptr,
|
||||
strided_reshape[1], // lda
|
||||
gemm_wt_ptr,
|
||||
strided_reshape[1], // ldb
|
||||
0.0f, // beta
|
||||
gemm_out_ptr,
|
||||
O // ldc
|
||||
);
|
||||
});
|
||||
|
||||
// Copy results if needed
|
||||
if (out.dtype() != float32) {
|
||||
copy_cpu_inplace(gemm_out, out, CopyType::Vector, stream);
|
||||
}
|
||||
encoder.add_temporaries(std::move(temps));
|
||||
}
|
||||
|
||||
void explicit_gemm_conv_ND_cpu(
|
||||
const array& in,
|
||||
const array& wt,
|
||||
|
||||
@@ -46,7 +46,6 @@ void eig_impl(
|
||||
int info;
|
||||
{
|
||||
T work;
|
||||
int iwork;
|
||||
geev<T>(
|
||||
&jobl,
|
||||
&jobr,
|
||||
|
||||
@@ -215,18 +215,18 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
const void* a_mask_ptr;
|
||||
const void* b_mask_ptr;
|
||||
const void* out_mask_ptr;
|
||||
const void* a_mask_ptr = nullptr;
|
||||
const void* b_mask_ptr = nullptr;
|
||||
const void* out_mask_ptr = nullptr;
|
||||
Shape a_mask_shape;
|
||||
Shape b_mask_shape;
|
||||
Shape out_mask_shape;
|
||||
Strides a_mask_strides;
|
||||
Strides b_mask_strides;
|
||||
Strides out_mask_strides;
|
||||
bool a_mask_bool;
|
||||
bool b_mask_bool;
|
||||
bool out_mask_bool;
|
||||
bool a_mask_bool = false;
|
||||
bool b_mask_bool = false;
|
||||
bool out_mask_bool = false;
|
||||
if (has_op_mask) {
|
||||
auto& a_mask = inputs[inputs.size() - 2];
|
||||
auto& b_mask = inputs[inputs.size() - 1];
|
||||
@@ -423,7 +423,6 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& rhs_indices = inputs[3];
|
||||
|
||||
auto batch_shape = get_batch_dims(out.shape());
|
||||
int batch_ndim = batch_shape.size();
|
||||
|
||||
auto batch_shape_A = get_batch_dims(a.shape());
|
||||
auto batch_strides_A = get_batch_dims(a.strides());
|
||||
|
||||
@@ -91,7 +91,6 @@ void matmul_general(
|
||||
auto [b_transposed, ldb, b] = check_transpose(b_pre);
|
||||
size_t M = a.shape(-2);
|
||||
size_t N = b.shape(-1);
|
||||
size_t K = a.shape(-1);
|
||||
if (M == 0 || N == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -445,7 +445,6 @@ void mxfp4_qmm(
|
||||
int K) {
|
||||
constexpr int group_size = 32;
|
||||
constexpr int pack_factor = get_pack_factor(4, 8);
|
||||
constexpr int bytes_per_pack = get_bytes_per_pack(4);
|
||||
constexpr int packs_in_group = group_size / pack_factor;
|
||||
|
||||
for (int m = 0; m < M; m++) {
|
||||
@@ -487,7 +486,6 @@ void mxfp4_qmm_t(
|
||||
int K) {
|
||||
constexpr int group_size = 32;
|
||||
constexpr int pack_factor = get_pack_factor(4, 8);
|
||||
constexpr int bytes_per_pack = get_bytes_per_pack(4);
|
||||
constexpr int packs_in_group = group_size / pack_factor;
|
||||
|
||||
for (int m = 0; m < M; m++) {
|
||||
|
||||
@@ -39,7 +39,7 @@ struct StridedIterator {
|
||||
StridedIterator() = default;
|
||||
|
||||
explicit StridedIterator(T* ptr, int64_t stride, difference_type offset = 0)
|
||||
: ptr_(ptr + offset * stride), stride_(stride) {}
|
||||
: stride_(stride), ptr_(ptr + offset * stride) {}
|
||||
|
||||
explicit StridedIterator(array& arr, int axis, difference_type offset = 0)
|
||||
: StridedIterator(arr.data<T>(), arr.strides()[axis], offset) {}
|
||||
|
||||
@@ -83,8 +83,6 @@ void svd_impl(
|
||||
|
||||
auto jobz = (u_ptr) ? "A" : "N";
|
||||
|
||||
// Will contain the number of singular values after the call has returned.
|
||||
int ns = 0;
|
||||
T workspace_dimension = 0;
|
||||
|
||||
// Will contain the indices of eigenvectors that failed to converge (not
|
||||
|
||||
Reference in New Issue
Block a user