mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-15 09:29:26 +08:00
Reduce a little overhead (#871)
* some small overhead improvements * use result_type in rms_norm * remove release force * fix + use non-vector version * revert compile change * fix ops * a little more overhead * a little more cleanup and overhead
This commit is contained in:
@@ -12,16 +12,6 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
std::pair<size_t, std::vector<size_t>> cum_prod(const std::vector<int>& shape) {
|
||||
std::vector<size_t> strides(shape.size());
|
||||
size_t cum_prod = 1;
|
||||
for (int i = shape.size() - 1; i >= 0; --i) {
|
||||
strides[i] = cum_prod;
|
||||
cum_prod *= shape[i];
|
||||
}
|
||||
return {cum_prod, strides};
|
||||
}
|
||||
|
||||
/** Return true if we are currently performing a function transformation in
|
||||
* order to keep the graph when evaluating tracer arrays. */
|
||||
bool in_tracing() {
|
||||
@@ -171,9 +161,21 @@ void array::move_shared_buffer(array other) {
|
||||
move_shared_buffer(other, other.strides(), other.flags(), other.data_size());
|
||||
}
|
||||
|
||||
void array::ArrayDesc::init() {
|
||||
strides.resize(shape.size());
|
||||
size = 1;
|
||||
for (int i = shape.size() - 1; i >= 0; --i) {
|
||||
strides[i] = size;
|
||||
size *= shape[i];
|
||||
}
|
||||
for (auto& in : inputs) {
|
||||
is_tracer |= in.is_tracer();
|
||||
}
|
||||
}
|
||||
|
||||
array::ArrayDesc::ArrayDesc(std::vector<int> shape, Dtype dtype)
|
||||
: shape(std::move(shape)), dtype(dtype) {
|
||||
std::tie(size, strides) = cum_prod(this->shape);
|
||||
init();
|
||||
}
|
||||
|
||||
array::ArrayDesc::ArrayDesc(
|
||||
@@ -185,10 +187,7 @@ array::ArrayDesc::ArrayDesc(
|
||||
dtype(dtype),
|
||||
primitive(std::move(primitive)),
|
||||
inputs(std::move(inputs)) {
|
||||
std::tie(size, strides) = cum_prod(this->shape);
|
||||
for (auto& in : this->inputs) {
|
||||
is_tracer |= in.is_tracer();
|
||||
}
|
||||
init();
|
||||
}
|
||||
|
||||
array::ArrayIterator::ArrayIterator(const array& arr, int idx)
|
||||
|
||||
Reference in New Issue
Block a user