Browse Source

!11261 [ME]Check whether tensor shape is same to value node abstract shape when arithmatic simplify

From: @chenfei52
Reviewed-by: @ginfung
Signed-off-by:
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
93c216ce7f
1 changed files with 13 additions and 2 deletions
  1. +13
    -2
      mindspore/core/ir/pattern_matcher.h

+ 13
- 2
mindspore/core/ir/pattern_matcher.h View File

@@ -21,6 +21,7 @@
#include <memory>
#include <tuple>
#include <vector>
#include <algorithm>

#include "base/core_ops.h"
#include "ir/visitor.h"
@@ -840,15 +841,25 @@ class PConstant : public PBase<PConstant<T> > {
}
} else {
if (in_data_2_size < out_data_size) {
MS_EXCEPTION(ValueError) << "in_data_2_size is smaller than out_data_size.";
MS_LOG(INFO) << "in_data_2_size:" << in_data_2_size << " is smaller than out_data_size:" << out_data_size
<< ".in_data2 will be broadcast.";
}
for (int i = 0; i < out_data_size; i++) {
auto min_size = std::min<int>(in_data_2_size, out_data_size);
for (int i = 0; i < min_size; i++) {
if (bin_operator == ADD) {
data_out[i] += data_2[i];
} else {
data_out[i] *= data_2[i];
}
}
// In case of in_data2_size < out_data_size
for (int i = min_size; i < out_data_size; i++) {
if (bin_operator != ADD) {
// if operator is MUL, data_out[i] *= 0, => data_out[i] = 0.
data_out[i] = 0;
}
// if operator is ADD, data_out[i] += 0, => data_out[i] = data_out[i], => NoChange.
}
}
*out_data = reinterpret_cast<void *>(data_out);
return;


Loading…
Cancel
Save