|
|
|
@@ -80,24 +80,31 @@ AbstractBasePtr InferImplEqual(const AnalysisEnginePtr &, const PrimitivePtr &pr |
|
|
|
const AbstractBasePtrList &args_spec_list) { |
|
|
|
const std::string op_name = primitive->name(); |
|
|
|
CheckArgsSize(op_name, args_spec_list, 2); |
|
|
|
auto input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); |
|
|
|
MS_EXCEPTION_IF_NULL(input_x); |
|
|
|
MS_EXCEPTION_IF_NULL(input_x->shape()); |
|
|
|
auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); |
|
|
|
MS_EXCEPTION_IF_NULL(x); |
|
|
|
MS_EXCEPTION_IF_NULL(x->shape()); |
|
|
|
ShapeVector x_shape = x->shape()->shape(); |
|
|
|
ShapeVector x_shape_min = x->shape()->min_shape().empty() ? x_shape : x->shape()->min_shape(); |
|
|
|
ShapeVector x_shape_max = x->shape()->max_shape().empty() ? x_shape : x->shape()->max_shape(); |
|
|
|
|
|
|
|
auto y = CheckArg<AbstractTensor>(op_name, args_spec_list, 1); |
|
|
|
MS_EXCEPTION_IF_NULL(y); |
|
|
|
MS_EXCEPTION_IF_NULL(y->shape()); |
|
|
|
ShapeVector y_shape = y->shape()->shape(); |
|
|
|
ShapeVector y_shape_min = y->shape()->min_shape().empty() ? y_shape : y->shape()->min_shape(); |
|
|
|
ShapeVector y_shape_max = y->shape()->max_shape().empty() ? y_shape : y->shape()->max_shape(); |
|
|
|
|
|
|
|
auto input_y = CheckArg<AbstractTensor>(op_name, args_spec_list, 1); |
|
|
|
MS_EXCEPTION_IF_NULL(input_y); |
|
|
|
MS_EXCEPTION_IF_NULL(input_y->shape()); |
|
|
|
|
|
|
|
auto x_shape = input_x->shape()->shape(); |
|
|
|
auto y_shape = input_y->shape()->shape(); |
|
|
|
auto out_shape = BroadcastShape(x_shape, y_shape); |
|
|
|
if (out_shape.empty()) { |
|
|
|
MS_LOG(EXCEPTION) << "BroadcastShape fail: " << args_spec_list[0]->ToString() << "," |
|
|
|
<< args_spec_list[1]->ToString(); |
|
|
|
} |
|
|
|
auto out_shape_min = BroadcastShape(x_shape_min, y_shape_min); |
|
|
|
auto out_shape_max = BroadcastShape(x_shape_max, y_shape_max); |
|
|
|
|
|
|
|
auto output_type = std::make_shared<Bool>(); |
|
|
|
auto ret = std::make_shared<AbstractTensor>(output_type, out_shape); |
|
|
|
auto ret = |
|
|
|
std::make_shared<AbstractTensor>(output_type, std::make_shared<Shape>(out_shape, out_shape_min, out_shape_max)); |
|
|
|
return ret; |
|
|
|
} |
|
|
|
|
|
|
|
|