/** * Copyright 2019 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "pipeline/static_analysis/param_validator.h" #include "pipeline/static_analysis/prim.h" #include "operator/ops.h" #include "pipeline/static_analysis/utils.h" #include "utils/symbolic.h" namespace mindspore { namespace abstract { AbstractBasePtr InferImplReturn(const AnalysisEnginePtr &, const PrimitivePtr &, const AbstractBasePtrList &args_spec_list) { // Inputs: a pointer to an AbstractBase object if (args_spec_list.size() != 1) { MS_LOG(INFO) << "Return evaluator requires 1 parameter, is this the default value attached? " "while the input size is " << args_spec_list.size() << "."; } AbstractBasePtr abs_base = args_spec_list[0]; return abs_base; } AbstractBasePtr InferImplTypeof(const AnalysisEnginePtr &, const PrimitivePtr &, const AbstractBasePtrList &args_spec_list) { // Inputs: a pointer to an AbstractBase object if (args_spec_list.size() != 1) { MS_LOG(EXCEPTION) << "Typeof evaluator requires 1 parameter, while the input size is " << args_spec_list.size() << "."; } AbstractBasePtr abs_base = args_spec_list[0]; MS_EXCEPTION_IF_NULL(abs_base); TypePtr type = abs_base->BuildType(); return std::make_shared(type); } AbstractBasePtr InferImplHasType(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { // Inputs: a pointer to an AbstractBase object and a pointer to a Type const std::string op_name = primitive->name(); CheckArgsSize(op_name, args_spec_list, 2); AbstractTypePtr abs_type = CheckArg(op_name, args_spec_list, 1); auto mode_v = abs_type->GetValueTrack(); MS_EXCEPTION_IF_NULL(mode_v); if (!mode_v->isa()) { MS_LOG(EXCEPTION) << "Get the type from AbstractType value failed."; } TypePtr mode_t = mode_v->cast(); MS_EXCEPTION_IF_NULL(args_spec_list[0]); bool v = IsSubtype(args_spec_list[0], mode_t); return std::make_shared(std::make_shared(v), kBool); } AbstractBasePtr InferImplDot(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { // Inputs: two tensors. const std::string op_name = primitive->name(); CheckArgsSize(op_name, args_spec_list, 2); AbstractTensorPtr input_x = CheckArg(op_name, args_spec_list, 0); AbstractTensorPtr input_y = CheckArg(op_name, args_spec_list, 1); ShapePtr x_shp = input_x->shape(); auto x_shp_value = x_shp->shape(); ShapePtr y_shp = input_y->shape(); auto y_shp_value = y_shp->shape(); // Should be matrix which shape size is 2. if (x_shp_value.size() != 2 || y_shp_value.size() != 2) { MS_LOG(EXCEPTION) << op_name << " evaluator requires input two 2D tensors, while the dimensions of two tensors are " << x_shp_value.size() << ", " << y_shp_value.size() << " "; } if (x_shp_value[1] != y_shp_value[0] && x_shp_value[1] != Shape::SHP_ANY && y_shp_value[0] != Shape::SHP_ANY) { MS_LOG(EXCEPTION) << "Incompatible shapes in dot: {" << x_shp->ToString() << "} and {" << y_shp->ToString() << "}"; } auto x_element = input_x->element(); MS_EXCEPTION_IF_NULL(x_element); (void)x_element->Join(input_y->element()); auto param = {x_shp_value[0], y_shp_value[1]}; return std::make_shared(input_x->element(), std::make_shared(param)); } AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &, const AbstractBasePtrList &args_spec_list) { // Inputs: condition, true branch, false branch if (args_spec_list.size() != 3) { MS_LOG(EXCEPTION) << "Switch evaluator requires 3 parameters, while the input size is " << args_spec_list.size() << "."; } auto cond = args_spec_list[0]; auto tb = args_spec_list[1]; auto fb = args_spec_list[2]; MS_EXCEPTION_IF_NULL(cond); ValuePtr v = cond->GetValueTrack(); MS_EXCEPTION_IF_NULL(v); // for tensor as condition, keeps both true and false branch. if (v->isa() || cond->isa()) { MS_EXCEPTION_IF_NULL(tb); return tb->Join(fb); } if (v->isa()) { if (v->cast()->IsOne()) { return tb; } else { return fb; } } MS_LOG(EXCEPTION) << "Invalid condition value for switch " << cond->ToString(); } AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { // Inputs: index, branch const std::string op_name = primitive->name(); abstract::CheckArgsSize(op_name, args_spec_list, 2); (void)CheckArg(op_name, args_spec_list, 0); AbstractTuplePtr branches_abs = CheckArg(op_name, args_spec_list, 1); AbstractBasePtrList branches = branches_abs->elements(); const size_t maximum_layer_num = 1000; if (branches.size() < 0 || branches.size() > maximum_layer_num) { MS_EXCEPTION(ValueError) << op_name << " support at least 1 and at most " << maximum_layer_num << " but got " << branches.size() << " branches."; } for (size_t i = 0; i < branches.size(); i++) { MS_EXCEPTION_IF_NULL(branches[i]); if (!branches[i]->isa()) { MS_LOG(EXCEPTION) << op_name << " requires that the 2th arg be tuple of functions, but got " << branches[i]->ToString() << " as the " << i << "th element."; } } auto b = branches[0]; for (size_t i = 1; i < branches.size(); i++) { b = b->Join(branches[i]); } return b; } std::vector GetSupportedTargetValue() { std::vector list = {kNone, MakeValue(false), MakeValue(true)}; return list; } bool SupportedIsTargetValue(const ValuePtr t) { auto list = GetSupportedTargetValue(); auto match = std::any_of(list.begin(), list.end(), [&t](const ValuePtr &v) { return *v == *t; }); return match; } AbstractBasePtr InferImplIs_(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { // statement: x is t // Inputs: x, t const std::string op_name = primitive->name(); CheckArgsSize(op_name, args_spec_list, 2); ValuePtr t = args_spec_list[1]->BuildValue(); if (!SupportedIsTargetValue(t)) { MS_LOG(EXCEPTION) << "Not supported type:" << t->ToString() << " for statement is, supported list is:None, False, True "; } ValuePtr x = args_spec_list[0]->BuildValue(); return std::make_shared(*t == *x); } AbstractBasePtr InferImplIsNot(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { // statement: x is not t // Inputs: x, t const std::string op_name = primitive->name(); CheckArgsSize(op_name, args_spec_list, 2); ValuePtr t = args_spec_list[1]->BuildValue(); if (!SupportedIsTargetValue(t)) { MS_LOG(EXCEPTION) << "Not supported type:" << t->ToString() << " for statement is not, supported list is:None, False, True "; } ValuePtr x = args_spec_list[0]->BuildValue(); return std::make_shared(!(*t == *x)); } bool IsInDict(const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { const std::string op_name = primitive->name(); CheckArgsSize(op_name, args_spec_list, 2); auto key = CheckArg(op_name, args_spec_list, 0); auto dict = CheckArg(op_name, args_spec_list, 1); ValuePtr key_value = key->BuildValue(); if (!key_value->isa()) { MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString(); } auto key_str = GetValue(key_value); std::vector dict_elems = dict->elements(); auto it = std::find_if(dict_elems.begin(), dict_elems.end(), [key_str](const AbstractAttribute &item) { return item.first == key_str; }); return it != dict_elems.end(); } AbstractBasePtr InferImplInDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { // statement: x in t // Inputs: x, t return std::make_shared(IsInDict(primitive, args_spec_list)); } AbstractBasePtr InferImplNotInDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { // statement: x not in t // Inputs: x, t return std::make_shared(!IsInDict(primitive, args_spec_list)); } AbstractBasePtr InferImplIsConstant(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { // statement: isconstant(x) // Inputs: x if (args_spec_list.size() != 1) { MS_LOG(EXCEPTION) << "IsConstant requires args input size = 1"; } ValuePtr v = args_spec_list[0]->BuildValue(); return std::make_shared(!v->isa()); } } // namespace abstract } // namespace mindspore