|
|
@@ -776,7 +776,7 @@ AbstractBasePtr InferImplTranspose(const AnalysisEnginePtr &, const PrimitivePtr |
|
|
ShapeVector x_min_shp = input->shape()->min_shape(); |
|
|
ShapeVector x_min_shp = input->shape()->min_shape(); |
|
|
(void)CheckMinMaxShape(input_shp, &x_min_shp, &x_max_shp); |
|
|
(void)CheckMinMaxShape(input_shp, &x_min_shp, &x_max_shp); |
|
|
for (size_t i = 0; i < perm_vec.size(); i++) { |
|
|
for (size_t i = 0; i < perm_vec.size(); i++) { |
|
|
size_t idx = static_cast<size_t>(perm_vec[i]); |
|
|
|
|
|
|
|
|
auto idx = static_cast<size_t>(perm_vec[i]); |
|
|
result_shp.push_back(input_shp[idx]); |
|
|
result_shp.push_back(input_shp[idx]); |
|
|
max_shp.push_back(x_max_shp[idx]); |
|
|
max_shp.push_back(x_max_shp[idx]); |
|
|
min_shp.push_back(x_min_shp[idx]); |
|
|
min_shp.push_back(x_min_shp[idx]); |
|
|
@@ -1096,7 +1096,7 @@ AbstractBasePtr InferImplArgMaxWithValue(const AnalysisEnginePtr &, const Primit |
|
|
} |
|
|
} |
|
|
// check axis convert negative to positive value |
|
|
// check axis convert negative to positive value |
|
|
auto check_axis = [](int64_t &axis, const size_t dim) -> void { |
|
|
auto check_axis = [](int64_t &axis, const size_t dim) -> void { |
|
|
int64_t dim_ = static_cast<int64_t>(dim); |
|
|
|
|
|
|
|
|
auto dim_ = static_cast<int64_t>(dim); |
|
|
if (axis < -dim_ || axis >= dim_) { |
|
|
if (axis < -dim_ || axis >= dim_) { |
|
|
MS_LOG(EXCEPTION) << "axis should be in [" << -dim_ << ", " << dim_ << "). But got axis = " << axis << "."; |
|
|
MS_LOG(EXCEPTION) << "axis should be in [" << -dim_ << ", " << dim_ << "). But got axis = " << axis << "."; |
|
|
} |
|
|
} |
|
|
@@ -1108,7 +1108,7 @@ AbstractBasePtr InferImplArgMaxWithValue(const AnalysisEnginePtr &, const Primit |
|
|
// main calculate shape func |
|
|
// main calculate shape func |
|
|
auto cal_shape = [axis, keep_dims_value, check_axis](ShapeVector &shape, const ShapeVector &x_shape) -> void { |
|
|
auto cal_shape = [axis, keep_dims_value, check_axis](ShapeVector &shape, const ShapeVector &x_shape) -> void { |
|
|
shape.insert(shape.end(), x_shape.begin(), x_shape.end()); |
|
|
shape.insert(shape.end(), x_shape.begin(), x_shape.end()); |
|
|
int64_t axis_value = GetValue<int64_t>(axis); |
|
|
|
|
|
|
|
|
auto axis_value = GetValue<int64_t>(axis); |
|
|
check_axis(axis_value, x_shape.size()); |
|
|
check_axis(axis_value, x_shape.size()); |
|
|
if (keep_dims_value) { |
|
|
if (keep_dims_value) { |
|
|
shape[axis_value] = 1; |
|
|
shape[axis_value] = 1; |
|
|
|