From 981b7339cf5439538530798674f3682497fd9259 Mon Sep 17 00:00:00 2001 From: chenfei Date: Mon, 18 Jan 2021 11:32:19 +0800 Subject: [PATCH] debug const tensor simplify --- mindspore/core/ir/pattern_matcher.h | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/mindspore/core/ir/pattern_matcher.h b/mindspore/core/ir/pattern_matcher.h index f14b4f98ec..778ab5185e 100644 --- a/mindspore/core/ir/pattern_matcher.h +++ b/mindspore/core/ir/pattern_matcher.h @@ -21,6 +21,7 @@ #include #include #include +#include #include "base/core_ops.h" #include "ir/visitor.h" @@ -840,15 +841,25 @@ class PConstant : public PBase > { } } else { if (in_data_2_size < out_data_size) { - MS_EXCEPTION(ValueError) << "in_data_2_size is smaller than out_data_size."; + MS_LOG(INFO) << "in_data_2_size:" << in_data_2_size << " is smaller than out_data_size:" << out_data_size + << ".in_data2 will be broadcast."; } - for (int i = 0; i < out_data_size; i++) { + auto min_size = std::min(in_data_2_size, out_data_size); + for (int i = 0; i < min_size; i++) { if (bin_operator == ADD) { data_out[i] += data_2[i]; } else { data_out[i] *= data_2[i]; } } + // In case of in_data2_size < out_data_size + for (int i = min_size; i < out_data_size; i++) { + if (bin_operator != ADD) { + // if operator is MUL, data_out[i] *= 0, => data_out[i] = 0. + data_out[i] = 0; + } + // if operator is ADD, data_out[i] += 0, => data_out[i] = data_out[i], => NoChange. + } } *out_data = reinterpret_cast(data_out); return;