|
|
|
@@ -1,5 +1,5 @@ |
|
|
|
/** |
|
|
|
* Copyright 2019 Huawei Technologies Co., Ltd |
|
|
|
* Copyright 2019-2021 Huawei Technologies Co., Ltd |
|
|
|
* |
|
|
|
* Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
|
* you may not use this file except in compliance with the License. |
|
|
|
@@ -34,35 +34,6 @@ AbstractBasePtr InferImplReturn(const AnalysisEnginePtr &, const PrimitivePtr &, |
|
|
|
return abs_base; |
|
|
|
} |
|
|
|
|
|
|
|
AbstractBasePtr InferImplDot(const AnalysisEnginePtr &, const PrimitivePtr &primitive, |
|
|
|
const AbstractBasePtrList &args_spec_list) { |
|
|
|
// Inputs: two tensors. |
|
|
|
const std::string op_name = primitive->name(); |
|
|
|
CheckArgsSize(op_name, args_spec_list, 2); |
|
|
|
AbstractTensorPtr input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); |
|
|
|
AbstractTensorPtr input_y = CheckArg<AbstractTensor>(op_name, args_spec_list, 1); |
|
|
|
|
|
|
|
ShapePtr x_shp = input_x->shape(); |
|
|
|
auto x_shp_value = x_shp->shape(); |
|
|
|
ShapePtr y_shp = input_y->shape(); |
|
|
|
auto y_shp_value = y_shp->shape(); |
|
|
|
// Should be matrix which shape size is 2. |
|
|
|
if (x_shp_value.size() != 2 || y_shp_value.size() != 2) { |
|
|
|
MS_LOG(EXCEPTION) << op_name << " evaluator requires input two 2D tensors, while the dimensions of two tensors are " |
|
|
|
<< x_shp_value.size() << ", " << y_shp_value.size() << " "; |
|
|
|
} |
|
|
|
if (x_shp_value[1] != y_shp_value[0] && x_shp_value[1] != Shape::SHP_ANY && y_shp_value[0] != Shape::SHP_ANY) { |
|
|
|
MS_LOG(EXCEPTION) << "Incompatible shapes in dot: {" << x_shp->ToString() << "} and {" << y_shp->ToString() << "}"; |
|
|
|
} |
|
|
|
|
|
|
|
auto x_element = input_x->element(); |
|
|
|
MS_EXCEPTION_IF_NULL(x_element); |
|
|
|
(void)x_element->Join(input_y->element()); |
|
|
|
auto param = {x_shp_value[0], y_shp_value[1]}; |
|
|
|
|
|
|
|
return std::make_shared<AbstractTensor>(input_x->element(), std::make_shared<Shape>(param)); |
|
|
|
} |
|
|
|
|
|
|
|
AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &prim, |
|
|
|
const AbstractBasePtrList &args_spec_list) { |
|
|
|
// Inputs: condition, true branch, false branch |
|
|
|
|