|
|
|
@@ -60,27 +60,6 @@ AbstractBasePtr InferImplSqrtGrad(const AnalysisEnginePtr &, const PrimitivePtr |
|
|
|
return out->Broaden(); |
|
|
|
} |
|
|
|
|
|
|
|
AbstractBasePtr InferImplAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive, |
|
|
|
const AbstractBasePtrList &args_spec_list) { |
|
|
|
// Inputs: two tensors. |
|
|
|
const std::string op_name = primitive->name(); |
|
|
|
CheckArgsSize(op_name, args_spec_list, 2); |
|
|
|
ShapePtr shape_x = dyn_cast<Shape>(args_spec_list[0]->GetShapeTrack()); |
|
|
|
MS_EXCEPTION_IF_NULL(shape_x); |
|
|
|
std::vector<int64_t> x_dims = shape_x->shape(); |
|
|
|
ShapePtr shape_y = dyn_cast<Shape>(args_spec_list[1]->GetShapeTrack()); |
|
|
|
MS_EXCEPTION_IF_NULL(shape_y); |
|
|
|
std::vector<int64_t> y_dims = shape_y->shape(); |
|
|
|
auto broadcast_shape = BroadcastShape(x_dims, y_dims); |
|
|
|
if (broadcast_shape.empty()) { |
|
|
|
MS_LOG(EXCEPTION) << "BroadcastShape fail: " << args_spec_list[0]->ToString() << "," |
|
|
|
<< args_spec_list[1]->ToString(); |
|
|
|
} |
|
|
|
auto out = args_spec_list[0]->Broaden(); |
|
|
|
out->set_shape(std::make_shared<Shape>(broadcast_shape)); |
|
|
|
return out; |
|
|
|
} |
|
|
|
|
|
|
|
AbstractBasePtr InferImplSquare(const AnalysisEnginePtr &, const PrimitivePtr &primitive, |
|
|
|
const AbstractBasePtrList &args_spec_list) { |
|
|
|
// Inputs: one tensor. |
|
|
|
@@ -272,6 +251,11 @@ AbstractBasePtr InferImplMul(const AnalysisEnginePtr &engine_ptr, const Primitiv |
|
|
|
return InferImplBinaryBase(engine_ptr, primitive, args_spec_list); |
|
|
|
} |
|
|
|
|
|
|
|
AbstractBasePtr InferImplAdd(const AnalysisEnginePtr &engine_ptr, const PrimitivePtr &primitive, |
|
|
|
const AbstractBasePtrList &args_spec_list) { |
|
|
|
return InferImplBinaryBase(engine_ptr, primitive, args_spec_list); |
|
|
|
} |
|
|
|
|
|
|
|
AbstractBasePtr InferImplSub(const AnalysisEnginePtr &engine_ptr, const PrimitivePtr &primitive, |
|
|
|
const AbstractBasePtrList &args_spec_list) { |
|
|
|
return InferImplBinaryBase(engine_ptr, primitive, args_spec_list); |
|
|
|
|