Browse Source

!15548 fix ci alarm on master

From: @TFbunny
Reviewed-by: @tom__chen,@liangchenghui
Signed-off-by: @liangchenghui
pull/15548/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
15cd9c8997
1 changed files with 3 additions and 3 deletions
  1. +3
    -3
      mindspore/core/abstract/prim_arrays.cc

+ 3
- 3
mindspore/core/abstract/prim_arrays.cc View File

@@ -776,7 +776,7 @@ AbstractBasePtr InferImplTranspose(const AnalysisEnginePtr &, const PrimitivePtr
ShapeVector x_min_shp = input->shape()->min_shape(); ShapeVector x_min_shp = input->shape()->min_shape();
(void)CheckMinMaxShape(input_shp, &x_min_shp, &x_max_shp); (void)CheckMinMaxShape(input_shp, &x_min_shp, &x_max_shp);
for (size_t i = 0; i < perm_vec.size(); i++) { for (size_t i = 0; i < perm_vec.size(); i++) {
size_t idx = static_cast<size_t>(perm_vec[i]);
auto idx = static_cast<size_t>(perm_vec[i]);
result_shp.push_back(input_shp[idx]); result_shp.push_back(input_shp[idx]);
max_shp.push_back(x_max_shp[idx]); max_shp.push_back(x_max_shp[idx]);
min_shp.push_back(x_min_shp[idx]); min_shp.push_back(x_min_shp[idx]);
@@ -1096,7 +1096,7 @@ AbstractBasePtr InferImplArgMaxWithValue(const AnalysisEnginePtr &, const Primit
} }
// check axis convert negative to positive value // check axis convert negative to positive value
auto check_axis = [](int64_t &axis, const size_t dim) -> void { auto check_axis = [](int64_t &axis, const size_t dim) -> void {
int64_t dim_ = static_cast<int64_t>(dim);
auto dim_ = static_cast<int64_t>(dim);
if (axis < -dim_ || axis >= dim_) { if (axis < -dim_ || axis >= dim_) {
MS_LOG(EXCEPTION) << "axis should be in [" << -dim_ << ", " << dim_ << "). But got axis = " << axis << "."; MS_LOG(EXCEPTION) << "axis should be in [" << -dim_ << ", " << dim_ << "). But got axis = " << axis << ".";
} }
@@ -1108,7 +1108,7 @@ AbstractBasePtr InferImplArgMaxWithValue(const AnalysisEnginePtr &, const Primit
// main calculate shape func // main calculate shape func
auto cal_shape = [axis, keep_dims_value, check_axis](ShapeVector &shape, const ShapeVector &x_shape) -> void { auto cal_shape = [axis, keep_dims_value, check_axis](ShapeVector &shape, const ShapeVector &x_shape) -> void {
shape.insert(shape.end(), x_shape.begin(), x_shape.end()); shape.insert(shape.end(), x_shape.begin(), x_shape.end());
int64_t axis_value = GetValue<int64_t>(axis);
auto axis_value = GetValue<int64_t>(axis);
check_axis(axis_value, x_shape.size()); check_axis(axis_value, x_shape.size());
if (keep_dims_value) { if (keep_dims_value) {
shape[axis_value] = 1; shape[axis_value] = 1;


Loading…
Cancel
Save