diff --git a/mindspore/ccsrc/frontend/operator/cc_implementations.cc b/mindspore/ccsrc/frontend/operator/cc_implementations.cc index 2df68037e5..3526a001f9 100644 --- a/mindspore/ccsrc/frontend/operator/cc_implementations.cc +++ b/mindspore/ccsrc/frontend/operator/cc_implementations.cc @@ -393,40 +393,5 @@ ValuePtr BoolEq(const ValuePtrList &list) { MS_LOG(EXCEPTION) << "Unsported Value for BoolEq, x: " << x->ToString() << "."; } - -std::vector BroadcastShape_(std::vector shpx, std::vector shpy) { - int dlen = SizeToInt(shpx.size()) - SizeToInt(shpy.size()); - if (dlen < 0) { - for (int i = 0; i < -dlen; ++i) { - (void)shpx.insert(shpx.begin(), 1); - } - } else if (dlen > 0) { - for (int i = 0; i < dlen; i++) { - (void)shpy.insert(shpy.begin(), 1); - } - } - if (shpx.size() != shpy.size()) { - MS_LOG(EXCEPTION) << "Failure: shpx.size() != shpy.size()."; - } - std::vector shp; - for (size_t i = 0; i < shpx.size(); i++) { - auto a = shpx[i]; - auto b = shpy[i]; - if (a == 1) { - shp.push_back(b); - } else if (b == 1) { - shp.push_back(a); - } else if (a == -1) { - shp.push_back(b); - } else if (b == -1) { - shp.push_back(a); - } else if (a == b) { - shp.push_back(a); - } else { - return std::vector(); - } - } - return shp; -} } // namespace prim } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/operator/cc_implementations.h b/mindspore/ccsrc/frontend/operator/cc_implementations.h index ffe75cb0c0..beb129645b 100644 --- a/mindspore/ccsrc/frontend/operator/cc_implementations.h +++ b/mindspore/ccsrc/frontend/operator/cc_implementations.h @@ -52,7 +52,6 @@ ValuePtr BoolNot(const ValuePtrList &list); ValuePtr BoolAnd(const ValuePtrList &list); ValuePtr BoolOr(const ValuePtrList &list); ValuePtr BoolEq(const ValuePtrList &list); -std::vector BroadcastShape_(std::vector s1, std::vector s2); } // namespace prim } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/operator/ops.h b/mindspore/ccsrc/frontend/operator/ops.h index 80093049ca..3b49bfbaf1 100755 --- a/mindspore/ccsrc/frontend/operator/ops.h +++ b/mindspore/ccsrc/frontend/operator/ops.h @@ -42,28 +42,13 @@ inline const PrimitivePtr kPrimRefToEmbed = std::make_shared("RefToEm inline const PrimitivePtr kPrimCreateInstance = std::make_shared("create_instance"); // Other miscellaneous -inline const PrimitivePtr kPrimEnvSetItem = std::make_shared("env_setitem"); -inline const PrimitivePtr kPrimEnvGetItem = std::make_shared("env_getitem"); -inline const PrimitivePtr kPrimEnvAdd = std::make_shared("env_add"); -inline const PrimitivePtr kPrimMakeRefKey = std::make_shared("MakeRefKey"); -inline const PrimitivePtr kPrimGetRefKey = std::make_shared("get_ref_key"); -inline const PrimitivePtr kPrimGetRefValue = std::make_shared("get_ref_value"); inline const PrimitivePtr kPrimGetRefOrigin = std::make_shared("get_ref_origin"); inline const PrimitivePtr kPrimInsertGradientOf = std::make_shared("InsertGradientOf"); inline const PrimitivePtr kPrimCheckBprop = std::make_shared("CheckBprop"); -inline const PrimitivePtr kPrimMakeRef = std::make_shared("make_ref"); inline const PrimitivePtr kPrimMixedPrecisionCast = std::make_shared("mixed_precision_cast"); inline const PrimitivePtr kPrimMakeRecord = std::make_shared("make_record"); // Structures -inline const PrimitivePtr kPrimMakeList = std::make_shared("make_list"); -inline const PrimitivePtr kPrimMakeKeywordArg = std::make_shared("make_keyword_arg"); -inline const PrimitivePtr kPrimListGetItem = std::make_shared("list_getitem"); -inline const PrimitivePtr kPrimListSetItem = std::make_shared("list_setitem"); -inline const PrimitivePtr kPrimDictGetItem = std::make_shared("dict_getitem"); -inline const PrimitivePtr kPrimDictSetItem = std::make_shared("dict_setitem"); -inline const PrimitivePtr kPrimListAppend = std::make_shared("list_append"); -inline const PrimitivePtr kPrimListLen = std::make_shared("list_len"); inline const PrimitivePtr kPrimListMap = std::make_shared("list_map"); inline const PrimitivePtr kPrimListReduce = std::make_shared("list_reduce"); diff --git a/mindspore/ccsrc/frontend/operator/prim_structures.cc b/mindspore/ccsrc/frontend/operator/ops_front_infer_function.cc similarity index 60% rename from mindspore/ccsrc/frontend/operator/prim_structures.cc rename to mindspore/ccsrc/frontend/operator/ops_front_infer_function.cc index cc53f9aa22..90b9a6a5a2 100644 --- a/mindspore/ccsrc/frontend/operator/prim_structures.cc +++ b/mindspore/ccsrc/frontend/operator/ops_front_infer_function.cc @@ -1,6 +1,4 @@ /** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * * Copyright 2019 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,360 +13,266 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include "frontend/operator/ops_front_infer_function.h" + +#include +#include +#include +#include +#include +#include "abstract/abstract_value.h" #include "pipeline/jit/static_analysis/prim.h" -#include "abstract/utils.h" #include "abstract/param_validator.h" -#include "frontend/operator/ops.h" -#include "utils/convert_utils.h" #include "utils/tensor_py.h" - -using mindspore::tensor::TensorPy; - +#include "frontend/operator/ops.h" +#include "abstract/infer_functions.h" namespace mindspore { namespace abstract { +enum State { + SAME, + X_ONE, + Y_ONE, +}; -AbstractBasePtr InferImplStringEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: two scalars whose value is a string. - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 2); - AbstractScalarPtr scalar_x = CheckArg(op_name, args_spec_list, 0); - AbstractScalarPtr scalar_y = CheckArg(op_name, args_spec_list, 1); - - ValuePtr value_x = scalar_x->BuildValue(); - ValuePtr value_y = scalar_y->BuildValue(); - if (!value_x->isa() || !value_y->isa()) { - MS_LOG(EXCEPTION) << op_name << " requires 2 parameters are string, but got param0: " << value_x->ToString() - << ", param1: " << value_y->ToString(); - } - - bool ret = (value_x->cast()->value() == value_y->cast()->value()); - return std::make_shared(ret); -} +struct SlideInfo { + int start; + int step; + int stop; +}; -AbstractBasePtr InferImplStringConcat(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: two scalars whose value is a string. - const std::string op_name = primitive->name(); +template +AbstractBasePtr InferImplTupleOrListEqual(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { + // Inputs: two tuples or two lists. CheckArgsSize(op_name, args_spec_list, 2); - AbstractScalarPtr scalar_x = CheckArg(op_name, args_spec_list, 0); - AbstractScalarPtr scalar_y = CheckArg(op_name, args_spec_list, 1); - - ValuePtr value_x = scalar_x->BuildValue(); - ValuePtr value_y = scalar_y->BuildValue(); - if (!value_x->isa() || !value_y->isa()) { - MS_LOG(EXCEPTION) << op_name << " requires 2 parameters are string, but got param0: " << value_x->ToString() - << ", param1: " << value_y->ToString(); - } - - std::string ret = (value_x->cast()->value() + value_y->cast()->value()); - return std::make_shared(ret); -} - -AbstractBasePtr InferImplMakeTuple(const AnalysisEnginePtr &, const PrimitivePtr &, - const AbstractBasePtrList &args_spec_list) { - return std::make_shared(args_spec_list); -} + auto input_x = CheckArg(op_name, args_spec_list, 0); + auto input_y = CheckArg(op_name, args_spec_list, 1); -AbstractBasePtr InferImplMakeList(const AnalysisEnginePtr &, const PrimitivePtr &, - const AbstractBasePtrList &args_spec_list) { - return std::make_shared(args_spec_list); + ValuePtr x_value = input_x->BuildValue(); + ValuePtr y_value = input_y->BuildValue(); + return std::make_shared(*x_value == *y_value); } -AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: two tuples. - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 2); - AbstractTuplePtr keys = CheckArg(op_name, args_spec_list, 0); - AbstractTuplePtr values = CheckArg(op_name, args_spec_list, 1); - - size_t keys_size = keys->size(); - if (values->size() != keys_size) { - MS_LOG(EXCEPTION) << op_name << " evaluator keys' size is not equal with values' size"; - } - - std::vector key_value; - AbstractScalarPtr key; - AbstractBasePtrList key_list = keys->elements(); - AbstractBasePtrList value_list = values->elements(); - for (size_t index = 0; index < keys_size; index++) { - key = CheckArg(op_name + "key", key_list, index); - ValuePtr keyPtr = key->BuildValue(); - MS_EXCEPTION_IF_NULL(keyPtr); - if (!keyPtr->isa()) { - MS_LOG(EXCEPTION) << op_name << " evaluator keys should be string, but got " << keyPtr->ToString(); +void CalcSlidePara(const AbstractBasePtrList &args_spec_list, SlideInfo *slide) { + int arg1 = 0; + int arg2 = 0; + if (!args_spec_list.empty()) { + MS_EXCEPTION_IF_NULL(args_spec_list[0]); + auto arg_value = args_spec_list[0]->BuildValue(); + if (!arg_value->isa()) { + MS_LOG(EXCEPTION) << "Only supported input an int32 number."; } - std::string key_string = GetValue(keyPtr); - key_value.emplace_back(key_string, value_list[index]); + arg1 = GetValue(arg_value); } - return std::make_shared(key_value); -} -AbstractBasePtr InferImplMakeKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: a string and an object of a subclass of AbstractBase. - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 2); - AbstractScalarPtr key = CheckArg(op_name, args_spec_list, 0); - - ValuePtr keyPtr = key->BuildValue(); - if (!keyPtr->isa()) { - MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << keyPtr->ToString(); + if (args_spec_list.size() >= 2) { + MS_EXCEPTION_IF_NULL(args_spec_list[1]); + auto arg_value = args_spec_list[1]->BuildValue(); + if (!arg_value->isa()) { + MS_LOG(EXCEPTION) << "Only supported input an int32 number."; + } + arg2 = GetValue(arg_value); } - std::string key_string = GetValue(keyPtr); - return std::make_shared(key_string, args_spec_list[1]); -} -AbstractBasePtr InferImplExtractKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: a string and a keyword. - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 2); - AbstractScalarPtr key = CheckArg(op_name, args_spec_list, 0); - AbstractKeywordArgPtr kwarg = CheckArg(op_name, args_spec_list, 1); + if (args_spec_list.size() == 3) { + MS_EXCEPTION_IF_NULL(args_spec_list[2]); + auto arg_value = args_spec_list[2]->BuildValue(); + if (!arg_value->isa()) { + MS_LOG(EXCEPTION) << "Only supported input an int32 number."; + } + slide->step = GetValue(arg_value); + slide->start = arg1; + slide->stop = arg2; + } - ValuePtr key_value = key->BuildValue(); - if (!key_value->isa()) { - MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString(); + if (args_spec_list.size() == 2) { + slide->start = arg1; + slide->stop = arg2; } - std::string key_input = GetValue(key_value); - std::string key_actual = kwarg->get_key(); - if (key_actual != key_input) { - MS_LOG(EXCEPTION) << op_name << " evaluator input key should be same as AbstractKeywordArg' key, but input is " - << key_input << ", AbstractKeywordArg' key is " << key_actual; + + if (args_spec_list.size() == 1) { + slide->stop = arg1; } - return kwarg->get_arg(); } -AbstractBasePtr InferImplMakeSlice(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: three scalars whose value is an int32 number. - CheckArgsSize(primitive->name(), args_spec_list, 3); - size_t args_size = args_spec_list.size(); - for (size_t index = 0; index < args_size; index++) { - MS_EXCEPTION_IF_NULL(args_spec_list[index]); - if (!args_spec_list[index]->isa() && !args_spec_list[index]->isa()) { - MS_EXCEPTION(TypeError) << "MakeSlice eval " << index << " parameter is neither AbstractScalar nor AbstractNone."; +void ComputeReduceIndex(const std::vector &reverse_x, const std::vector &reverse_y, + std::vector *grad_x_reduce_idx, std::vector *grad_y_reduce_idy) { + const size_t n = reverse_x.size(); + for (size_t i = 0; i < n; ++i) { + State curr; + const int32_t x_i = reverse_x[i]; + const int32_t y_i = reverse_y[i]; + const int reduce_idx = SizeToInt(n - 1 - i); + if (x_i == y_i) { + curr = SAME; + } else if (x_i == 1) { + grad_x_reduce_idx->push_back(reduce_idx); + curr = X_ONE; + } else if (y_i == 1) { + grad_y_reduce_idy->push_back(reduce_idx); + curr = Y_ONE; + } else { + MS_LOG(EXCEPTION) << "not compatible shape input for BroadcastGradientArgs"; } - if (args_spec_list[index]->isa() && - !dyn_cast(args_spec_list[index])->BuildValue()->isa()) { - MS_EXCEPTION(TypeError) << "MakeSlice eval " << index - << " parameter is an AbstractScalar, but is not an int32 number."; + if (curr == SAME && x_i == 1) { + grad_x_reduce_idx->push_back(reduce_idx); + grad_y_reduce_idy->push_back(reduce_idx); + continue; } } - // Slice: start, end, step - return std::make_shared(args_spec_list[0], args_spec_list[1], args_spec_list[2]); + + std::reverse(grad_x_reduce_idx->begin(), grad_x_reduce_idx->end()); + std::reverse(grad_y_reduce_idy->begin(), grad_y_reduce_idy->end()); } -// Eval the return type of make_record -AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: at lease two objects of a subclass of AbstractBase. - if (args_spec_list.size() < 2) { - MS_LOG(EXCEPTION) << "Typeof evaluator requires more than 1 parameter, while the input size is " - << args_spec_list.size() << "."; - } +AbstractBasePtr BroadcastGradientArgsDiff(const std::vector &x_shape, const std::vector &y_shape) { + std::vector reverse_x; + std::vector reverse_y; - // args_spec_list[0] maybe AbstractScalarPtr or AbstractTypePtr - MS_EXCEPTION_IF_NULL(args_spec_list[0]); - TypePtr type = args_spec_list[0]->GetTypeTrack(); - MS_EXCEPTION_IF_NULL(type); - if (type->type_id() != kMetaTypeTypeType) { - MS_LOG(EXCEPTION) << "Can not make type(" << type->ToString() << ")not TypeType"; - } + (void)std::transform(x_shape.rbegin(), x_shape.rend(), std::back_inserter(reverse_x), + [](const ValuePtr &v) { return v->cast()->value(); }); + (void)std::transform(y_shape.rbegin(), y_shape.rend(), std::back_inserter(reverse_y), + [](const ValuePtr &v) { return v->cast()->value(); }); - ValuePtr value_track = args_spec_list[0]->GetValueTrack(); - MS_EXCEPTION_IF_NULL(value_track); - TypePtr type_ptr = value_track->cast(); - if (type_ptr == nullptr) { - MS_LOG(EXCEPTION) << "Value type error, not Me type:" << value_track->ToString(); + if (reverse_x.size() > reverse_y.size()) { + reverse_y.resize(reverse_x.size(), 1); + } else { + reverse_x.resize(reverse_y.size(), 1); } - auto cls = dyn_cast(type_ptr); - MS_EXCEPTION_IF_NULL(cls); - ClassAttrVector attributes = cls->GetAttributes(); - CheckArgsSize(primitive->name(), args_spec_list, attributes.size() + 1); + std::vector grad_x_reduce_idx; + std::vector grad_y_reduce_idy; + ComputeReduceIndex(reverse_x, reverse_y, &grad_x_reduce_idx, &grad_y_reduce_idy); - std::vector abs_attributes; - for (size_t i = 0; i < attributes.size(); i++) { - AbstractAttribute elem(attributes[i].first, args_spec_list[i + 1]); - abs_attributes.push_back(elem); - } + AbstractBasePtrList abs_list_x; + AbstractBasePtrList abs_list_y; + (void)std::transform(grad_x_reduce_idx.begin(), grad_x_reduce_idx.end(), std::back_inserter(abs_list_x), + [](int v) { return abstract::FromValue(v); }); + (void)std::transform(grad_y_reduce_idy.begin(), grad_y_reduce_idy.end(), std::back_inserter(abs_list_y), + [](int v) { return abstract::FromValue(v); }); + auto x_reduce_idx = std::make_shared(abs_list_x); + auto y_reduce_idx = std::make_shared(abs_list_y); + AbstractBasePtrList elem_list; + elem_list.push_back(x_reduce_idx); + elem_list.push_back(y_reduce_idx); - return std::make_shared(cls->tag(), abs_attributes, cls->methods()); + return std::make_shared(elem_list); } -template -AbstractBasePtr InferTupleOrListGetItem(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { - // Inputs: a tuple or list and a scalar whose value is an int32 number. - CheckArgsSize(op_name, args_spec_list, 2); - auto queue = CheckArg(op_name, args_spec_list, 0); - AbstractScalarPtr index = CheckArg(op_name, args_spec_list, 1); - - ValuePtr index_value = index->BuildValue(); - if (!index_value->isa()) { - // when index_value is an AnyValue and args_spec_list[0] is a scalar, try to return the type of the first element - // and continue - if (dyn_cast(queue->elements()[0]) != nullptr) { - return std::make_shared(queue->elements()[0]->BuildType()); - } - MS_EXCEPTION(IndexError) << op_name << " evaluator index should be an int32 number, but got " - << index_value->ToString(); - } - int idx_v = GetValue(index_value); - std::size_t nelems = queue->elements().size(); - if (idx_v >= SizeToInt(nelems) || idx_v < -SizeToInt(nelems)) { - MS_EXCEPTION(IndexError) << op_name << " evaluator index should be in range[-" << SizeToInt(nelems) << ", " - << SizeToInt(nelems) << "), but got " << idx_v << "."; - } - - std::size_t uidx_v = 0; - if (idx_v >= 0) { - uidx_v = IntToSize(idx_v); - } else { - uidx_v = IntToSize(idx_v + SizeToInt(nelems)); +AbstractBasePtr InferImplTypeof(const AnalysisEnginePtr &, const PrimitivePtr &, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a pointer to an AbstractBase object + if (args_spec_list.size() != 1) { + MS_LOG(EXCEPTION) << "Typeof evaluator requires 1 parameter, while the input size is " << args_spec_list.size() + << "."; } - return queue->elements()[uidx_v]; + AbstractBasePtr abs_base = args_spec_list[0]; + MS_EXCEPTION_IF_NULL(abs_base); + TypePtr type = abs_base->BuildType(); + return std::make_shared(type); } -template -AbstractBasePtr InferTupleOrListSetItem(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { - // Inputs: a tuple or list, a scalar whose value is an int32 number and an object of a subclass of AbstractBase. - CheckArgsSize(op_name, args_spec_list, 3); - auto queue = CheckArg(op_name, args_spec_list, 0); - AbstractScalarPtr index = CheckArg(op_name, args_spec_list, 1); - - ValuePtr index_value = index->BuildValue(); - if (!index_value->isa()) { - MS_EXCEPTION(IndexError) << op_name << " evaluator index should be an int32 number, but got " - << index_value->ToString(); - } - int idx_v = GetValue(index_value); - if (idx_v < 0) { - MS_EXCEPTION(IndexError) << "The index of " << typeid(T).name() << " should be positive number, but got " << idx_v - << "."; - } +AbstractBasePtr InferImplHasType(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a pointer to an AbstractBase object and a pointer to a Type + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + AbstractTypePtr abs_type = CheckArg(op_name, args_spec_list, 1); - size_t uidx_v = IntToSize(idx_v); - AbstractBasePtrList elements = queue->elements(); - std::size_t nelems = elements.size(); - if (uidx_v >= nelems) { - MS_EXCEPTION(IndexError) << op_name << " evaluator the index: " << uidx_v << " to set out of range: " << nelems - 1 - << "."; + auto mode_v = abs_type->GetValueTrack(); + MS_EXCEPTION_IF_NULL(mode_v); + if (!mode_v->isa()) { + MS_LOG(EXCEPTION) << "Get the type from AbstractType value failed."; } - elements[uidx_v] = args_spec_list[2]; - return std::make_shared(elements); -} -AbstractBasePtr InferImplTupleGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - return InferTupleOrListGetItem(primitive->name(), args_spec_list); + TypePtr mode_t = mode_v->cast(); + MS_EXCEPTION_IF_NULL(args_spec_list[0]); + bool v = IsSubtype(args_spec_list[0], mode_t); + return std::make_shared(std::make_shared(v), kBool); } -AbstractBasePtr InferImplListGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - return InferTupleOrListGetItem(primitive->name(), args_spec_list); -} +bool CompareShape(const std::vector &x_shape, const std::vector &y_shape) { + if (x_shape.size() != y_shape.size()) { + return false; + } -AbstractBasePtr InferImplTupleSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - return InferTupleOrListSetItem(primitive->name(), args_spec_list); -} + for (size_t i = 0; i < x_shape.size(); ++i) { + if (GetValue(x_shape[i]) != GetValue(y_shape[i])) { + return false; + } + } -AbstractBasePtr InferImplListSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - return InferTupleOrListSetItem(primitive->name(), args_spec_list); + return true; } -AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: a dict and a scalar whose value is a string. - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 2); - AbstractDictionaryPtr dict = CheckArg(op_name, args_spec_list, 0); - AbstractScalarPtr key = CheckArg(op_name, args_spec_list, 1); - - ValuePtr key_value = key->BuildValue(); - if (!key_value->isa()) { - MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString(); +AbstractBasePtr DoInferReduceShape(const AbstractTuplePtr &x_shape, const ValuePtr &x_shp_value, + const ValueTuplePtr &axis_value_ptr, const PrimitivePtr &primitive) { + size_t x_rank = x_shape->size(); + std::set axis_set; + auto axis_data = axis_value_ptr->value(); + if (axis_data.empty()) { + int size = 1; + AbstractBasePtrList values(x_rank, std::make_shared(size)); + return std::make_shared(values); } - auto key_str = GetValue(key_value); - std::vector dict_elems = dict->elements(); - auto it = std::find_if(dict_elems.begin(), dict_elems.end(), - [key_str](const AbstractAttribute &item) { return item.first == key_str; }); - if (it == dict_elems.end()) { - MS_LOG(EXCEPTION) << "The key " << key_str << " does not exist in the dict:" << args_spec_list[0]->ToString(); + for (auto &elem : axis_data) { + int e_value = CheckAxis(primitive->name(), elem, -SizeToInt(x_rank), SizeToInt(x_rank) - 1); + (void)axis_set.insert(e_value); } - return it->second; -} -AbstractBasePtr InferImplDictSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: a dict and a scalar whose value is a string and an object of a subclass of AbstractBase. - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 3); - AbstractDictionaryPtr dict = CheckArg(op_name, args_spec_list, 0); - AbstractScalarPtr key = CheckArg(op_name, args_spec_list, 1); - - ValuePtr key_value = key->BuildValue(); - if (!key_value->isa()) { - MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString(); - } - std::string key_str = GetValue(key_value); - std::vector dict_elems = dict->elements(); - auto it = std::find_if(dict_elems.begin(), dict_elems.end(), - [key_str](AbstractAttribute &item) { return item.first == key_str; }); - - MS_EXCEPTION_IF_NULL(args_spec_list[2]); - auto new_ele = std::make_pair(key_str, args_spec_list[2]); - if (it != dict_elems.end()) { - int index = it - dict_elems.begin(); - dict_elems[IntToSize(index)] = new_ele; - } else { - dict_elems.push_back(new_ele); + auto x_shp_data = x_shp_value->cast()->value(); + if (x_shp_data.size() < x_rank) { + MS_LOG(EXCEPTION) << "x_shape_data.size() " << x_shp_data.size() << " less than x_shape.size() " << x_rank; + } + AbstractBasePtrList values; + for (size_t i = 0; i < x_rank; i++) { + if (axis_set.count(SizeToInt(i)) || axis_set.count(SizeToInt(i) - SizeToInt(x_rank))) { + auto axis_v = MakeValue(1); + values.push_back(std::make_shared(axis_v, axis_v->type())); + } else { + int dim_value = x_shp_data[i]->cast()->value(); + auto dim = MakeValue(dim_value); + values.push_back(std::make_shared(dim, dim->type())); + } } - return std::make_shared(dict_elems); + + return std::make_shared(values); } -AbstractBasePtr InferImplListAppend(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: a list and an object of a subclass of AbstractBase. +AbstractBasePtr InferImplBroadcastGradientArgs(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // this primitive get the index that need to reduce + // input: x's shape and y's shape, inputs should be tuple + // output: tuple of x and y 's reduce index, reduce index should be a tuple const std::string op_name = primitive->name(); CheckArgsSize(op_name, args_spec_list, 2); - AbstractListPtr list = CheckArg(op_name, args_spec_list, 0); - (void)AbstractJoin(list->elements()); - return list; -} + auto arg_x = CheckArg(op_name, args_spec_list, 0); + auto arg_y = CheckArg(op_name, args_spec_list, 1); -template -AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { - // Inputs: a tuple or list or dict. - CheckArgsSize(op_name, args_spec_list, 1); - auto arg = CheckArg(op_name, args_spec_list, 0); - return std::make_shared(SizeToInt(arg->size())); -} + ValueTuplePtr arg_x_value = arg_x->BuildValue()->cast(); + MS_EXCEPTION_IF_NULL(arg_x_value); -AbstractBasePtr InferImplTupleLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - return InferTupleOrListOrDictLen(primitive->name(), args_spec_list); -} + ValueTuplePtr arg_y_value = arg_y->BuildValue()->cast(); + MS_EXCEPTION_IF_NULL(arg_y_value); -AbstractBasePtr InferImplListLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - return InferTupleOrListOrDictLen(primitive->name(), args_spec_list); -} + const std::vector x_shape = arg_x_value->value(); + const std::vector y_shape = arg_y_value->value(); + bool is_same_shape = CompareShape(x_shape, y_shape); + // if it is the same shape , do not need reduce , return empty tuple + if (is_same_shape) { + AbstractBasePtrList empty_list; + auto x_reduce_idx = std::make_shared(empty_list); + auto y_reduce_idx = std::make_shared(empty_list); -AbstractBasePtr InferImplDictLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - return InferTupleOrListOrDictLen(primitive->name(), args_spec_list); -} + AbstractBasePtrList elem_list; + elem_list.push_back(x_reduce_idx); + elem_list.push_back(y_reduce_idx); -AbstractBasePtr InferImplArrayLen(const AnalysisEnginePtr &, const PrimitivePtr &, - const AbstractBasePtrList &args_spec_list) { - return std::make_shared(kAnyValue, kInt32); + return std::make_shared(elem_list); + } + + return BroadcastGradientArgsDiff(x_shape, y_shape); } AbstractBasePtr InferImplListMap(const AnalysisEnginePtr &engine, const PrimitivePtr &primitive, @@ -430,41 +334,6 @@ AbstractBasePtr InferImplTupleReversed(const AnalysisEnginePtr &, const Primitiv return std::make_shared(elem_list); } -AbstractBasePtr DoInferReduceShape(const AbstractTuplePtr &x_shape, const ValuePtr &x_shp_value, - const ValueTuplePtr &axis_value_ptr, const PrimitivePtr &primitive) { - size_t x_rank = x_shape->size(); - std::set axis_set; - auto axis_data = axis_value_ptr->value(); - if (axis_data.empty()) { - int size = 1; - AbstractBasePtrList values(x_rank, std::make_shared(size)); - return std::make_shared(values); - } - - for (auto &elem : axis_data) { - int e_value = CheckAxis(primitive->name(), elem, -SizeToInt(x_rank), SizeToInt(x_rank) - 1); - (void)axis_set.insert(e_value); - } - - auto x_shp_data = x_shp_value->cast()->value(); - if (x_shp_data.size() < x_rank) { - MS_LOG(EXCEPTION) << "x_shape_data.size() " << x_shp_data.size() << " less than x_shape.size() " << x_rank; - } - AbstractBasePtrList values; - for (size_t i = 0; i < x_rank; i++) { - if (axis_set.count(SizeToInt(i)) || axis_set.count(SizeToInt(i) - SizeToInt(x_rank))) { - auto axis_v = MakeValue(1); - values.push_back(std::make_shared(axis_v, axis_v->type())); - } else { - int dim_value = x_shp_data[i]->cast()->value(); - auto dim = MakeValue(dim_value); - values.push_back(std::make_shared(dim, dim->type())); - } - } - - return std::make_shared(values); -} - AbstractBasePtr InferImplReduceShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { // Inputs: x_shape, axis @@ -563,7 +432,7 @@ AbstractBasePtr InferImplTuple2Array(const AnalysisEnginePtr &, const PrimitiveP py::tuple data_tuple = ValuePtrToPyData(input->BuildValue()); py::array data = py::array(data_tuple); - auto tensor = TensorPy::MakeTensor(data); + auto tensor = tensor::TensorPy::MakeTensor(data); auto ret = tensor->ToAbstract(); ret->set_value(tensor); MS_LOG(DEBUG) << "Tuple2arry result AbstractTensor: " << ret->ToString(); @@ -596,76 +465,6 @@ AbstractBasePtr InferImplShapeMul(const AnalysisEnginePtr &, const PrimitivePtr return std::make_shared(result_v, result_v->type()); } -template -AbstractBasePtr InferImplTupleOrListEqual(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { - // Inputs: two tuples or two lists. - CheckArgsSize(op_name, args_spec_list, 2); - auto input_x = CheckArg(op_name, args_spec_list, 0); - auto input_y = CheckArg(op_name, args_spec_list, 1); - - ValuePtr x_value = input_x->BuildValue(); - ValuePtr y_value = input_y->BuildValue(); - return std::make_shared(*x_value == *y_value); -} - -AbstractBasePtr InferImplTupleEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - return InferImplTupleOrListEqual(primitive->name(), args_spec_list); -} - -AbstractBasePtr InferImplListEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - return InferImplTupleOrListEqual(primitive->name(), args_spec_list); -} - -struct SlideInfo { - int start; - int step; - int stop; -}; - -void CalcSlidePara(const AbstractBasePtrList &args_spec_list, SlideInfo *slide) { - int arg1 = 0; - int arg2 = 0; - if (!args_spec_list.empty()) { - MS_EXCEPTION_IF_NULL(args_spec_list[0]); - auto arg_value = args_spec_list[0]->BuildValue(); - if (!arg_value->isa()) { - MS_LOG(EXCEPTION) << "Only supported input an int32 number."; - } - arg1 = GetValue(arg_value); - } - - if (args_spec_list.size() >= 2) { - MS_EXCEPTION_IF_NULL(args_spec_list[1]); - auto arg_value = args_spec_list[1]->BuildValue(); - if (!arg_value->isa()) { - MS_LOG(EXCEPTION) << "Only supported input an int32 number."; - } - arg2 = GetValue(arg_value); - } - - if (args_spec_list.size() == 3) { - MS_EXCEPTION_IF_NULL(args_spec_list[2]); - auto arg_value = args_spec_list[2]->BuildValue(); - if (!arg_value->isa()) { - MS_LOG(EXCEPTION) << "Only supported input an int32 number."; - } - slide->step = GetValue(arg_value); - slide->start = arg1; - slide->stop = arg2; - } - - if (args_spec_list.size() == 2) { - slide->start = arg1; - slide->stop = arg2; - } - - if (args_spec_list.size() == 1) { - slide->stop = arg1; - } -} - AbstractBasePtr InferImplMakeRange(const AnalysisEnginePtr &, const PrimitivePtr &, const AbstractBasePtrList &args_spec_list) { if (args_spec_list.empty()) { @@ -709,5 +508,145 @@ AbstractBasePtr InferImplStopGradient(const AnalysisEnginePtr &, const Primitive CheckArgsSize(primitive->name(), args_spec_list, 1); return args_spec_list[0]->Clone(); } + +AbstractBasePtr InferImplTupleEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + return InferImplTupleOrListEqual(primitive->name(), args_spec_list); +} + +AbstractBasePtr InferImplListEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + return InferImplTupleOrListEqual(primitive->name(), args_spec_list); +} + +AbstractBasePtr InferImplStringEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: two scalars whose value is a string. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + AbstractScalarPtr scalar_x = CheckArg(op_name, args_spec_list, 0); + AbstractScalarPtr scalar_y = CheckArg(op_name, args_spec_list, 1); + + ValuePtr value_x = scalar_x->BuildValue(); + ValuePtr value_y = scalar_y->BuildValue(); + if (!value_x->isa() || !value_y->isa()) { + MS_LOG(EXCEPTION) << op_name << " requires 2 parameters are string, but got param0: " << value_x->ToString() + << ", param1: " << value_y->ToString(); + } + + bool ret = (value_x->cast()->value() == value_y->cast()->value()); + return std::make_shared(ret); +} + +AbstractBasePtr InferImplStringConcat(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: two scalars whose value is a string. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + AbstractScalarPtr scalar_x = CheckArg(op_name, args_spec_list, 0); + AbstractScalarPtr scalar_y = CheckArg(op_name, args_spec_list, 1); + + ValuePtr value_x = scalar_x->BuildValue(); + ValuePtr value_y = scalar_y->BuildValue(); + if (!value_x->isa() || !value_y->isa()) { + MS_LOG(EXCEPTION) << op_name << " requires 2 parameters are string, but got param0: " << value_x->ToString() + << ", param1: " << value_y->ToString(); + } + + std::string ret = (value_x->cast()->value() + value_y->cast()->value()); + return std::make_shared(ret); +} + +AbstractBasePtr InferImplDictLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + return InferTupleOrListOrDictLen(primitive->name(), args_spec_list); +} + +AbstractBasePtr InferImplJ(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // args: An object of AbstractFunction. + CheckArgsSize(primitive->name(), args_spec_list, 1); + MS_LOG(DEBUG) << "evaluate J: " << args_spec_list[0]->ToString(); + + AbstractFunctionPtr x = dyn_cast(args_spec_list[0]); + if (x == nullptr) { + return std::make_shared(args_spec_list[0]); + } + + AbstractFuncAtomPtrList jv; + auto build_jv = [&jv](const AbstractFuncAtomPtr &func) { + auto j_closure = std::make_shared(func); + jv.push_back(j_closure); + }; + x->Visit(build_jv); + + return AbstractFunction::MakeAbstractFunction(jv); +} + +AbstractBasePtr InferImplFakeBprop(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a tensor. + CheckArgsSize(primitive->name(), args_spec_list, 1); + return args_spec_list[0]->Broaden(); +} + +// Eval the return type of make_record +AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: at lease two objects of a subclass of AbstractBase. + if (args_spec_list.size() < 2) { + MS_LOG(EXCEPTION) << "Typeof evaluator requires more than 1 parameter, while the input size is " + << args_spec_list.size() << "."; + } + + // args_spec_list[0] maybe AbstractScalarPtr or AbstractTypePtr + MS_EXCEPTION_IF_NULL(args_spec_list[0]); + TypePtr type = args_spec_list[0]->GetTypeTrack(); + MS_EXCEPTION_IF_NULL(type); + if (type->type_id() != kMetaTypeTypeType) { + MS_LOG(EXCEPTION) << "Can not make type(" << type->ToString() << ")not TypeType"; + } + + ValuePtr value_track = args_spec_list[0]->GetValueTrack(); + MS_EXCEPTION_IF_NULL(value_track); + TypePtr type_ptr = value_track->cast(); + if (type_ptr == nullptr) { + MS_LOG(EXCEPTION) << "Value type error, not Me type:" << value_track->ToString(); + } + + auto cls = dyn_cast(type_ptr); + MS_EXCEPTION_IF_NULL(cls); + ClassAttrVector attributes = cls->GetAttributes(); + CheckArgsSize(primitive->name(), args_spec_list, attributes.size() + 1); + + std::vector abs_attributes; + for (size_t i = 0; i < attributes.size(); i++) { + AbstractAttribute elem(attributes[i].first, args_spec_list[i + 1]); + abs_attributes.push_back(elem); + } + + return std::make_shared(cls->tag(), abs_attributes, cls->methods()); +} +REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(TypeOf, prim::kPrimTypeOf, InferImplTypeof); +REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(HasType, prim::kPrimHasType, InferImplHasType); +REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(MakeRecord, prim::kPrimMakeRecord, InferImplMakeRecord); +REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(ListMap, prim::kPrimListMap, InferImplListMap); +REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(ListReduce, prim::kPrimListReduce, InferImplListReduce); +REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(TupleReversed, prim::kPrimTupleReversed, InferImplTupleReversed); +REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(ReducedShape, prim::kPrimReducedShape, InferImplReduceShape); +REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(TupleDiv, prim::kPrimTupleDiv, InferImplTupleDiv); +REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(TupleToArray, prim::kPrimTupleToArray, InferImplTuple2Array); +REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(ShapeMul, prim::kPrimShapeMul, InferImplShapeMul); +REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(TupleEqual, prim::kPrimTupleEqual, InferImplTupleEqual); +REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(ListEqual, prim::kPrimListEqual, InferImplListEqual); +REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(MakeRange, prim::kPrimMakeRange, InferImplMakeRange); +REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(StopGradient, prim::kPrimStopGradient, InferImplStopGradient); +REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(StringEqual, prim::kPrimStringEqual, InferImplStringEqual); +REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(StringConcat, prim::kPrimStringConcat, InferImplStringConcat); +REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(DictLen, prim::kPrimDictLen, InferImplDictLen); +REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(FakeBprop, prim::kPrimFakeBprop, InferImplFakeBprop); +REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(J, prim::kPrimJ, InferImplJ); +REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(BroadcastGradientArgs, prim::kPrimBroadcastGradientArgs, + InferImplBroadcastGradientArgs); } // namespace abstract } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/operator/ops_front_infer_function.h b/mindspore/ccsrc/frontend/operator/ops_front_infer_function.h new file mode 100644 index 0000000000..6ddacd199d --- /dev/null +++ b/mindspore/ccsrc/frontend/operator/ops_front_infer_function.h @@ -0,0 +1,77 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_FRONTEND_OPERATE_OPS_FRONT_INFER_FUNCTION_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPERATE_OPS_FRONT_INFER_FUNCTION_H_ +#include "abstract/abstract_value.h" +#include "abstract/primitive_infer_map.h" +namespace mindspore { +namespace abstract { +AbstractBasePtr InferImplTypeof(const AnalysisEnginePtr &, const PrimitivePtr &, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplHasType(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplBroadcastGradientArgs(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplListMap(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplListReduce(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplTupleReversed(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplReduceShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplTupleDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplTuple2Array(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplShapeMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplTupleEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplListEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplMakeRange(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplStopGradient(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplStringEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplStringConcat(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplDictLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplJ(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplFakeBprop(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); + +class RegisterFrontendPrimitiveEvalHelper { + public: + RegisterFrontendPrimitiveEvalHelper(const PrimitivePtr &primitive, const StandardPrimitiveEvalImpl &impl) { + const StandardPrimitiveImplReg impl_reg{impl, false}; + RegisterStandardPrimitiveImpl(primitive, impl_reg); + } + ~RegisterFrontendPrimitiveEvalHelper() = default; +}; + +#define REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(name, primitive, impl) \ + static auto helper_##name = RegisterFrontendPrimitiveEvalHelper(primitive, impl) +} // namespace abstract +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_FRONTEND_OPERATE_OPS_FRONT_INFER_FUNCTION_H_ diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc index e35d5e7614..b25fcdd38b 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc @@ -36,115 +36,12 @@ #include "utils/convert_utils.h" #include "utils/ms_context.h" #include "pipeline/jit/parse/data_converter.h" +#include "abstract/primitive_infer_map.h" #include "abstract/param_validator.h" #include "utils/ms_utils.h" namespace mindspore { namespace abstract { -PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { - static PrimitiveEvalImplMap prim_eval_implement_map = { - // Statements - {prim::kPrimReturn, {InferImplReturn, true}}, - {prim::kPrimTypeOf, {InferImplTypeof, false}}, - {prim::kPrimHasType, {InferImplHasType, false}}, - {prim::kPrimDot, {InferImplDot, true}}, - {prim::kPrimSwitch, {InferImplSwitch, true}}, - {prim::kPrimSwitchLayer, {InferImplSwitchLayer, true}}, - {prim::kPrimIs_, {InferImplIs_, true}}, - {prim::kPrimIsNot, {InferImplIsNot, true}}, - {prim::kPrimInDict, {InferImplInDict, true}}, - {prim::kPrimNotInDict, {InferImplNotInDict, true}}, - {prim::kPrimIsConsant, {InferImplIsConstant, true}}, - // Maths - {prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, true}}, - {prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, true}}, - // Array - {prim::kPrimScalarToArray, {InferImplScalarToArray, true}}, - {prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}}, - {prim::kPrimBroadcastShape, {InferImplBroadCastShape, true}}, - {prim::kPrimPack, {InferImplPack, true}}, - {prim::kPrimUnique, {InferImplUnique, true}}, - {prim::kPrimUniqueGrad, {InferImplUniqueGrad, true}}, - // Structure - {prim::kPrimMakeTuple, {InferImplMakeTuple, true}}, - {prim::kPrimMakeList, {InferImplMakeList, true}}, - {prim::kPrimMakeDict, {InferImplMakeDict, true}}, - {prim::kPrimMakeSlice, {InferImplMakeSlice, true}}, - {prim::kPrimMakeKeywordArg, {InferImplMakeKwarg, true}}, - {prim::kPrimExtractKeywordArg, {InferImplExtractKwarg, true}}, - {prim::kPrimMakeRecord, {InferImplMakeRecord, false}}, - {prim::kPrimTupleGetItem, {InferImplTupleGetItem, true}}, - {prim::kPrimListGetItem, {InferImplListGetItem, true}}, - {prim::kPrimTupleSetItem, {InferImplTupleSetItem, true}}, - {prim::kPrimListSetItem, {InferImplListSetItem, true}}, - {prim::kPrimDictGetItem, {InferImplDictGetItem, true}}, - {prim::kPrimDictSetItem, {InferImplDictSetItem, true}}, - {prim::kPrimListAppend, {InferImplListAppend, true}}, - {prim::kPrimTupleLen, {InferImplTupleLen, true}}, - {prim::kPrimListLen, {InferImplListLen, true}}, - {prim::kPrimArrayLen, {InferImplArrayLen, true}}, - {prim::kPrimListMap, {InferImplListMap, false}}, - {prim::kPrimListReduce, {InferImplListReduce, false}}, - {prim::kPrimTupleReversed, {InferImplTupleReversed, false}}, - {prim::kPrimReducedShape, {InferImplReduceShape, false}}, - {prim::kPrimTupleDiv, {InferImplTupleDiv, false}}, - {prim::kPrimTupleToArray, {InferImplTuple2Array, false}}, - {prim::kPrimShapeMul, {InferImplShapeMul, false}}, - {prim::kPrimTupleEqual, {InferImplTupleEqual, false}}, - {prim::kPrimListEqual, {InferImplListEqual, false}}, - {prim::kPrimMakeRange, {InferImplMakeRange, false}}, - {prim::kPrimStopGradient, {InferImplStopGradient, false}}, - {prim::kPrimStringEqual, {InferImplStringEqual, false}}, - {prim::kPrimStringConcat, {InferImplStringConcat, false}}, - {prim::kPrimDictLen, {InferImplDictLen, false}}, - // NN - {prim::kPrimPooling, {InferImplPooling, true}}, - {prim::kPrimPoolingGrad, {InferImplPoolingGrad, true}}, - {prim::kPrimFusedBatchNorm, {InferImplFusedBatchNorm, true}}, - {prim::kPrimFusedBatchNormGrad, {InferImplFusedBatchNormGrad, true}}, - {prim::kPrimReluGrad, {InferImplReluGrad, true}}, - {prim::kPrimConv2DBackpropInput, {InferImplConv2DBackpropInput, true}}, - {prim::kPrimConv2DBackpropFilter, {InferImplConv2DBackpropFilter, true}}, - {prim::kPrimBiasAddGrad, {InferImplBiasAddGrad, true}}, - {prim::kPrimRelu, {InferImplRelu, true}}, - {prim::kPrimFakeBprop, {InferImplFakeBprop, false}}, - {prim::kPrimZerosLike, {InferImplZerosLike, true}}, - {prim::kPrimBpropCut, {InferImplBpropCut, true}}, - {prim::kPrimLayerNorm, {InferImplLayerNorm, true}}, - {prim::kPrimLayerNormGrad, {InferImplLayerNormGrad, true}}, - {prim::kPrimDropoutGenMask, {InferImplDropoutGenMask, true}}, - // Others - {prim::kPrimIdentity, {InferImplIdentity, true}}, - // Set impl to null as it will use PartialEvaluator; - {prim::kPrimPartial, {nullptr, true}}, - {prim::kPrimJ, {InferImplJ, false}}, - {prim::kPrimEnvGetItem, {InferImplEnvGetItem, true}}, - {prim::kPrimEnvSetItem, {InferImplEnvSetItem, true}}, - {prim::kPrimEnvAdd, {InferImplEnvAdd, true}}, - {prim::kPrimMakeRefKey, {InferImplMakeRefKey, true}}, - {prim::kPrimMakeRef, {InferImplMakeRef, true}}, - {prim::kPrimGetRefKey, {InferImplGetRefKey, true}}, - {prim::kPrimGetRefValue, {InferImplGetRefValue, true}}, - {prim::kPrimStateSetItem, {InferImplStateSetItem, true}}, - {prim::kPrimDepend, {InferImplDepend, true}}, - {prim::kPrimBroadcastGradientArgs, {InferImplBroadcastGradientArgs, false}}, - {prim::kPrimControlDepend, {InferImplControlDepend, true}}, - // Debug - {prim::kPrimDebug, {InferImplDebug, true}}, - // RowTensor - {prim::kPrimMakeRowTensor, {InferImplMakeRowTensor, true}}, - {prim::kPrimRowTensorGetValues, {InferImplRowTensorGetValues, true}}, - {prim::kPrimRowTensorGetIndices, {InferImplRowTensorGetIndices, true}}, - {prim::kPrimRowTensorGetDenseShape, {InferImplRowTensorGetDenseShape, true}}, - // SparseTensor - {prim::kPrimMakeSparseTensor, {InferImplMakeSparseTensor, true}}, - {prim::kPrimSparseTensorGetValues, {InferImplSparseTensorGetValues, true}}, - {prim::kPrimSparseTensorGetIndices, {InferImplSparseTensorGetIndices, true}}, - {prim::kPrimSparseTensorGetDenseShape, {InferImplSparseTensorGetDenseShape, true}}, - }; - return prim_eval_implement_map; -} - using mindspore::parse::PyObjectWrapper; std::unordered_set prims_to_skip_undetermined_infer{"make_tuple", "make_list", "switch", "env_setitem", diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.h b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.h index 48bb0e990c..8d5aff9305 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.h +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.h @@ -26,19 +26,10 @@ #include #include "pipeline/jit/static_analysis/evaluator.h" +#include "abstract/primitive_infer_map.h" namespace mindspore { namespace abstract { -using StandardPrimitiveEvalImpl = AbstractBasePtr (*)(const AnalysisEnginePtr &, const PrimitivePtr &, - const AbstractBasePtrList &); -struct StandartPrimitiveImplReg { - StandardPrimitiveEvalImpl impl_; // Implement function of Primitive. - bool in_white_list_; // true if this Primitive in white list, else false. -}; - -using PrimitiveEvalImplMap = - std::unordered_map; - class StandardPrimEvaluator : public TrivialPrimEvaluator { public: StandardPrimEvaluator(const PrimitivePtr primitive, StandardPrimitiveEvalImpl eval_impl) @@ -179,191 +170,6 @@ bool IsSubtype(const AbstractBasePtr x, const TypePtr model); void ClearPrimEvaluatorMap(); py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base); - -AbstractBasePtr InferImplReturn(const AnalysisEnginePtr &, const PrimitivePtr &, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplTypeof(const AnalysisEnginePtr &, const PrimitivePtr &, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplHasType(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplDot(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitivePtr &, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplIs_(const AnalysisEnginePtr &, const PrimitivePtr &, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplIsNot(const AnalysisEnginePtr &, const PrimitivePtr &, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplInDict(const AnalysisEnginePtr &, const PrimitivePtr &, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplNotInDict(const AnalysisEnginePtr &, const PrimitivePtr &, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplIsConstant(const AnalysisEnginePtr &, const PrimitivePtr &, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplPooling(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplPoolingGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplFusedBatchNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplFusedBatchNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplReluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplConv2DBackpropInput(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplConv2DBackpropFilter(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplBiasAddGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplRelu(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplFakeBprop(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplZerosLike(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplBpropCut(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplLayerNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplLayerNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplDropoutGenMask(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); - -AbstractBasePtr InferImplMinOrMaxGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); - -AbstractBasePtr InferImplScalarToArray(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplArrayToScalar(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplBroadCastShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplPack(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplUniqueGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); - -AbstractBasePtr InferImplMakeTuple(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplMakeList(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplMakeSlice(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplMakeKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplExtractKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplTupleGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplListGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplTupleSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplListSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplDictSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplListAppend(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplTupleLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplListLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplArrayLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplListMap(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplListReduce(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplTupleReversed(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplReduceShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplTupleDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplTuple2Array(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplShapeMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplGenShapeIndex(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplGenInverseIndex(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplTupleEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplListEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplMakeRange(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplStopGradient(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplStringEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplStringConcat(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplDictLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); - -AbstractBasePtr InferImplIdentity(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplJ(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplEnvSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplEnvAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplMakeRefKey(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplMakeRef(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplGetRefKey(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplGetRefValue(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplGetRefOrigin(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplStateSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplBroadcastGradientArgs(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); - -AbstractBasePtr InferImplDebug(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); - -AbstractBasePtr InferImplMakeRowTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplRowTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplRowTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplRowTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplMakeSparseTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplSparseTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplSparseTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplSparseTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); } // namespace abstract } // namespace mindspore diff --git a/mindspore/core/abstract/infer_functions.h b/mindspore/core/abstract/infer_functions.h new file mode 100644 index 0000000000..5c0625a7e1 --- /dev/null +++ b/mindspore/core/abstract/infer_functions.h @@ -0,0 +1,187 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_ABSTRACT_INFER_FUNCTIONS_H_ +#define MINDSPORE_CORE_ABSTRACT_INFER_FUNCTIONS_H_ +#include +#include +#include "abstract/abstract_value.h" +#include "abstract/param_validator.h" +#include "base/core_ops.h" +namespace mindspore { +namespace abstract { +AbstractBasePtr InferImplReturn(const AnalysisEnginePtr &, const PrimitivePtr &, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplDot(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitivePtr &, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplIs_(const AnalysisEnginePtr &, const PrimitivePtr &, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplIsNot(const AnalysisEnginePtr &, const PrimitivePtr &, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplInDict(const AnalysisEnginePtr &, const PrimitivePtr &, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplNotInDict(const AnalysisEnginePtr &, const PrimitivePtr &, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplIsConstant(const AnalysisEnginePtr &, const PrimitivePtr &, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplPooling(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplPoolingGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplFusedBatchNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplFusedBatchNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplReluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplConv2DBackpropInput(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplConv2DBackpropFilter(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplBiasAddGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplGelu(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplGeluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplRelu(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplZerosLike(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplBpropCut(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplLayerNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplLayerNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplDropoutGenMask(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); + +AbstractBasePtr InferImplMinOrMaxGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); + +AbstractBasePtr InferImplScalarToArray(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplArrayToScalar(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplBroadCastShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplPack(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); + +AbstractBasePtr InferImplMakeTuple(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplMakeList(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplMakeSlice(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplMakeKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplExtractKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplTupleGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplListGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplTupleSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplListSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplDictSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplListAppend(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplTupleLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplListLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplArrayLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplGenShapeIndex(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplGenInverseIndex(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplIdentity(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); + +AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplEnvSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplEnvAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplMakeRefKey(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplMakeRef(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplGetRefKey(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplGetRefValue(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplGetRefOrigin(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplStateSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplDebug(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplMakeSparseTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplSparseTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplSparseTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplSparseTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); + +AbstractBasePtr InferImplMakeRowTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplRowTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplRowTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplRowTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplUniqueGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); + +template +AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { + // Inputs: a tuple or list or dict. + CheckArgsSize(op_name, args_spec_list, 1); + auto arg = CheckArg(op_name, args_spec_list, 0); + return std::make_shared(SizeToInt(arg->size())); +} +} // namespace abstract +} // namespace mindspore +#endif // MINDSPORE_CORE_ABSTRACT_INFER_FUNCTIONS_H_ diff --git a/mindspore/ccsrc/frontend/operator/prim_arrays.cc b/mindspore/core/abstract/prim_arrays.cc similarity index 89% rename from mindspore/ccsrc/frontend/operator/prim_arrays.cc rename to mindspore/core/abstract/prim_arrays.cc index ea0725ae6e..40dfbc02fe 100644 --- a/mindspore/ccsrc/frontend/operator/prim_arrays.cc +++ b/mindspore/core/abstract/prim_arrays.cc @@ -14,13 +14,48 @@ * limitations under the License. */ -#include "pipeline/jit/static_analysis/prim.h" +#include "abstract/infer_functions.h" #include "abstract/utils.h" -#include "frontend/operator/cc_implementations.h" #include "abstract/param_validator.h" namespace mindspore { namespace abstract { +namespace { +std::vector BroadcastShape(std::vector shpx, std::vector shpy) { + int dlen = SizeToInt(shpx.size()) - SizeToInt(shpy.size()); + if (dlen < 0) { + for (int i = 0; i < -dlen; ++i) { + (void)shpx.insert(shpx.begin(), 1); + } + } else if (dlen > 0) { + for (int i = 0; i < dlen; i++) { + (void)shpy.insert(shpy.begin(), 1); + } + } + if (shpx.size() != shpy.size()) { + MS_LOG(EXCEPTION) << "Failure: shpx.size() != shpy.size()."; + } + std::vector shp; + for (size_t i = 0; i < shpx.size(); i++) { + auto a = shpx[i]; + auto b = shpy[i]; + if (a == 1) { + shp.push_back(b); + } else if (b == 1) { + shp.push_back(a); + } else if (a == -1) { + shp.push_back(b); + } else if (b == -1) { + shp.push_back(a); + } else if (a == b) { + shp.push_back(a); + } else { + return std::vector(); + } + } + return shp; +} +} // namespace AbstractBasePtr InferImplScalarToArray(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { // Inputs: a scalar. @@ -65,7 +100,7 @@ AbstractBasePtr InferImplBroadCastShape(const AnalysisEnginePtr &, const Primiti (void)std::transform(std::begin(shp_tuple_y), std::end(shp_tuple_y), std::back_inserter(shp_y), [](const ValuePtr &e) -> int { return GetValue(e); }); - std::vector res = prim::BroadcastShape_(shp_x, shp_y); + std::vector res = BroadcastShape(shp_x, shp_y); if (res.empty()) { MS_LOG(EXCEPTION) << "BroadcastShape fail: " << args_spec_list[0]->ToString() << "," << args_spec_list[1]->ToString(); diff --git a/mindspore/ccsrc/frontend/operator/prim_debug.cc b/mindspore/core/abstract/prim_debug.cc similarity index 94% rename from mindspore/ccsrc/frontend/operator/prim_debug.cc rename to mindspore/core/abstract/prim_debug.cc index 718dadf5c1..4b8cde34d8 100644 --- a/mindspore/ccsrc/frontend/operator/prim_debug.cc +++ b/mindspore/core/abstract/prim_debug.cc @@ -15,8 +15,7 @@ */ #include "abstract/param_validator.h" -#include "pipeline/jit/static_analysis/prim.h" -#include "frontend/operator/ops.h" +#include "abstract/infer_functions.h" #include "abstract/utils.h" #include "utils/symbolic.h" diff --git a/mindspore/ccsrc/frontend/operator/prim_maths.cc b/mindspore/core/abstract/prim_maths.cc similarity index 95% rename from mindspore/ccsrc/frontend/operator/prim_maths.cc rename to mindspore/core/abstract/prim_maths.cc index 5d06fb8603..f0da9535b0 100644 --- a/mindspore/ccsrc/frontend/operator/prim_maths.cc +++ b/mindspore/core/abstract/prim_maths.cc @@ -14,8 +14,7 @@ * limitations under the License. */ -#include "pipeline/jit/static_analysis/prim.h" -#include "frontend/operator/ops.h" +#include "abstract/infer_functions.h" #include "abstract/utils.h" #include "abstract/param_validator.h" #include "utils/ms_utils.h" diff --git a/mindspore/ccsrc/frontend/operator/prim_nn.cc b/mindspore/core/abstract/prim_nn.cc similarity index 79% rename from mindspore/ccsrc/frontend/operator/prim_nn.cc rename to mindspore/core/abstract/prim_nn.cc index 67c23307e5..fbd5b3be76 100644 --- a/mindspore/ccsrc/frontend/operator/prim_nn.cc +++ b/mindspore/core/abstract/prim_nn.cc @@ -14,10 +14,12 @@ * limitations under the License. */ -#include "pipeline/jit/static_analysis/prim.h" -#include "frontend/operator/ops.h" +#include "abstract/infer_functions.h" #include "abstract/utils.h" #include "abstract/param_validator.h" +#include "utils/check_convert_utils.h" +#include "c_ops/conv2d.h" +#include "abstract/primitive_infer_map.h" namespace mindspore { namespace abstract { @@ -278,13 +280,6 @@ AbstractBasePtr InferImplZerosLike(const AnalysisEnginePtr &, const PrimitivePtr return args_spec_list[0]->Broaden(); } -AbstractBasePtr InferImplFakeBprop(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: a tensor. - CheckArgsSize(primitive->name(), args_spec_list, 1); - return args_spec_list[0]->Broaden(); -} - AbstractBasePtr InferImplBpropCut(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { // Inputs: a tensor. @@ -433,5 +428,91 @@ AbstractBasePtr InferImplDropoutGenMask(const AnalysisEnginePtr &, const Primiti return std::make_shared(std::make_shared(kAnyValue, kUInt8), std::make_shared(std::vector{shape_y})); } + +abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + auto conv_prim = primitive->cast(); + MS_EXCEPTION_IF_NULL(conv_prim); + auto prim_name = conv_prim->name(); + CheckAndConvertUtils::CheckInRange("Conv2d Infer", input_args.size(), kIncludeLeft, {2, 3}, prim_name); + auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[0]->GetShapeTrack(), prim_name); + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[1]->GetShapeTrack(), prim_name); + + CheckAndConvertUtils::CheckInteger("weight rank", w_shape.size(), kEqual, 4, prim_name); + CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name); + CheckAndConvertUtils::Check("x_shape[1] / group", x_shape[1] / conv_prim->GetGroup(), kEqual, "w_shape[1]", + w_shape[1], conv_prim->name()); + auto out_channel = conv_prim->GetOutputChannel(); + CheckAndConvertUtils::Check("out_channel", out_channel, kEqual, "w_shape[0]", w_shape[0], conv_prim->name()); + std::vector temp_w; + std::copy(w_shape.begin() + 2, w_shape.end(), std::back_inserter(temp_w)); + CheckAndConvertUtils::Check("kernel_size", conv_prim->GetKernelSize(), kEqual, "w_shape[2:4]", temp_w, + conv_prim->name()); + + auto kernel_size_h = w_shape[2]; + auto kernel_size_w = w_shape[3]; + auto stride = conv_prim->GetStride(); + auto dilation = conv_prim->GetDilation(); + auto stride_h = stride[2]; + auto stride_w = stride[3]; + auto dilation_h = dilation[2]; + auto dilation_w = dilation[3]; + int h_out = -1; + int w_out = -1; + std::vector pad_list(4, 0); + auto pad_mode = conv_prim->GetPadMode(); + if (pad_mode == "valid") { + h_out = ceil((x_shape[2] - dilation_h * (kernel_size_h - 1)) / stride_h); + w_out = ceil((x_shape[3] - dilation_w * (kernel_size_w - 1)) / stride_w); + } else if (pad_mode == "same") { + h_out = ceil(x_shape[2] / stride_h); + w_out = ceil(x_shape[3] / stride_w); + + auto pad_needed_h = std::max(0, (h_out - 1) * stride_h + dilation_h * (kernel_size_h - 1) + 1 - x_shape[2]); + pad_list.emplace_back(floor(pad_needed_h / 2)); + pad_list.emplace_back(pad_needed_h / 2); + auto pad_needed_w = std::max(0, (w_out - 1) * stride_w + dilation_w * (kernel_size_w - 1) + 1 - x_shape[3]); + auto pad_left = floor(pad_needed_w / 2); + pad_list.emplace_back(pad_left); + pad_list.emplace_back(pad_needed_h - pad_left); + } else if (pad_mode == "pad") { + std::copy(conv_prim->GetPad().begin(), conv_prim->GetPad().end(), std::back_inserter(pad_list)); + auto pad_top = conv_prim->GetPad()[0]; + auto pad_bottom = conv_prim->GetPad()[1]; + auto pad_right = conv_prim->GetPad()[2]; + auto pad_left = conv_prim->GetPad()[3]; + + h_out = 1 + (x_shape[2] + pad_top + pad_bottom - kernel_size_h - (kernel_size_h - 1) * (dilation_h - 1)) / stride_h; + w_out = 1 + (x_shape[3] + pad_left + pad_right - kernel_size_w - (kernel_size_w - 1) * (dilation_w - 1)) / stride_w; + h_out = floor(h_out); + w_out = floor(w_out); + } + conv_prim->SetPadList(pad_list); + std::vector out_shape = {x_shape[0], out_channel, h_out, w_out}; + return std::make_shared(out_shape); +} + +TypePtr Conv2dInferType(const PrimitivePtr &prim, const std::vector &input_args) { + CheckAndConvertUtils::CheckInRange("", input_args.size(), kIncludeLeft, {2, 3}, prim->name()); + for (const auto &item : input_args) { + MS_EXCEPTION_IF_NULL(item); + } + auto x_type = CheckAndConvertUtils::ConvertTypePtrToTypeId("x_dtype", input_args[0]->GetTypeTrack(), prim->name()); + const std::set valid_types = {kNumberTypeInt8, kNumberTypeInt32, kNumberTypeFloat16, kNumberTypeFloat32}; + std::map types; + types.emplace("x", input_args[0]->GetTypeTrack()); + types.emplace("w", input_args[1]->GetTypeTrack()); + CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); + if (x_type == kNumberTypeInt8) { + return std::make_shared(TypeIdToType(kNumberTypeInt32)); + } + return std::make_shared(TypeIdToType(x_type)); +} +AbstractBasePtr Conv2dInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { + return std::make_shared(Conv2dInferType(primitive, input_args), + Conv2dInferShape(primitive, input_args)->shape()); +} +REGISTER_PRIMITIVE_EVAL_IMPL(Conv2D, prim::kPrimConv2D, Conv2dInfer); } // namespace abstract } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/operator/prim_others.cc b/mindspore/core/abstract/prim_others.cc similarity index 76% rename from mindspore/ccsrc/frontend/operator/prim_others.cc rename to mindspore/core/abstract/prim_others.cc index 7707dd5a8f..358ed75849 100644 --- a/mindspore/ccsrc/frontend/operator/prim_others.cc +++ b/mindspore/core/abstract/prim_others.cc @@ -19,9 +19,9 @@ #include "ir/dtype.h" #include "utils/ms_utils.h" -#include "frontend/operator/ops.h" +#include "base/core_ops.h" #include "abstract/param_validator.h" -#include "pipeline/jit/static_analysis/prim.h" +#include "abstract/infer_functions.h" #include "abstract/utils.h" #include "utils/ms_context.h" #include "utils/symbolic.h" @@ -35,27 +35,6 @@ AbstractBasePtr InferImplIdentity(const AnalysisEnginePtr &, const PrimitivePtr return args_spec_list[0]; } -AbstractBasePtr InferImplJ(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // args: An object of AbstractFunction. - CheckArgsSize(primitive->name(), args_spec_list, 1); - MS_LOG(DEBUG) << "evaluate J: " << args_spec_list[0]->ToString(); - - AbstractFunctionPtr x = dyn_cast(args_spec_list[0]); - if (x == nullptr) { - return std::make_shared(args_spec_list[0]); - } - - AbstractFuncAtomPtrList jv; - auto build_jv = [&jv](const AbstractFuncAtomPtr &func) { - auto j_closure = std::make_shared(func); - jv.push_back(j_closure); - }; - x->Visit(build_jv); - - return AbstractFunction::MakeAbstractFunction(jv); -} - AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { MS_EXCEPTION_IF_NULL(primitive); @@ -196,125 +175,6 @@ AbstractBasePtr InferImplDepend(const AnalysisEnginePtr &, const PrimitivePtr &p return depends; } -bool CompareShape(const std::vector &x_shape, const std::vector &y_shape) { - if (x_shape.size() != y_shape.size()) { - return false; - } - - for (size_t i = 0; i < x_shape.size(); ++i) { - if (GetValue(x_shape[i]) != GetValue(y_shape[i])) { - return false; - } - } - - return true; -} - -enum State { - SAME, - X_ONE, - Y_ONE, -}; - -void ComputeReduceIndex(const std::vector &reverse_x, const std::vector &reverse_y, - std::vector *grad_x_reduce_idx, std::vector *grad_y_reduce_idy) { - const size_t n = reverse_x.size(); - for (size_t i = 0; i < n; ++i) { - State curr; - const int32_t x_i = reverse_x[i]; - const int32_t y_i = reverse_y[i]; - const int reduce_idx = SizeToInt(n - 1 - i); - if (x_i == y_i) { - curr = SAME; - } else if (x_i == 1) { - grad_x_reduce_idx->push_back(reduce_idx); - curr = X_ONE; - } else if (y_i == 1) { - grad_y_reduce_idy->push_back(reduce_idx); - curr = Y_ONE; - } else { - MS_LOG(EXCEPTION) << "not compatible shape input for BroadcastGradientArgs"; - } - if (curr == SAME && x_i == 1) { - grad_x_reduce_idx->push_back(reduce_idx); - grad_y_reduce_idy->push_back(reduce_idx); - continue; - } - } - - std::reverse(grad_x_reduce_idx->begin(), grad_x_reduce_idx->end()); - std::reverse(grad_y_reduce_idy->begin(), grad_y_reduce_idy->end()); -} - -AbstractBasePtr BroadcastGradientArgsDiff(const std::vector &x_shape, const std::vector &y_shape) { - std::vector reverse_x; - std::vector reverse_y; - - (void)std::transform(x_shape.rbegin(), x_shape.rend(), std::back_inserter(reverse_x), - [](const ValuePtr &v) { return v->cast()->value(); }); - (void)std::transform(y_shape.rbegin(), y_shape.rend(), std::back_inserter(reverse_y), - [](const ValuePtr &v) { return v->cast()->value(); }); - - if (reverse_x.size() > reverse_y.size()) { - reverse_y.resize(reverse_x.size(), 1); - } else { - reverse_x.resize(reverse_y.size(), 1); - } - - std::vector grad_x_reduce_idx; - std::vector grad_y_reduce_idy; - ComputeReduceIndex(reverse_x, reverse_y, &grad_x_reduce_idx, &grad_y_reduce_idy); - - AbstractBasePtrList abs_list_x; - AbstractBasePtrList abs_list_y; - (void)std::transform(grad_x_reduce_idx.begin(), grad_x_reduce_idx.end(), std::back_inserter(abs_list_x), - [](int v) { return abstract::FromValue(v); }); - (void)std::transform(grad_y_reduce_idy.begin(), grad_y_reduce_idy.end(), std::back_inserter(abs_list_y), - [](int v) { return abstract::FromValue(v); }); - auto x_reduce_idx = std::make_shared(abs_list_x); - auto y_reduce_idx = std::make_shared(abs_list_y); - AbstractBasePtrList elem_list; - elem_list.push_back(x_reduce_idx); - elem_list.push_back(y_reduce_idx); - - return std::make_shared(elem_list); -} - -AbstractBasePtr InferImplBroadcastGradientArgs(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // this primitive get the index that need to reduce - // input: x's shape and y's shape, inputs should be tuple - // output: tuple of x and y 's reduce index, reduce index should be a tuple - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 2); - auto arg_x = CheckArg(op_name, args_spec_list, 0); - auto arg_y = CheckArg(op_name, args_spec_list, 1); - - ValueTuplePtr arg_x_value = arg_x->BuildValue()->cast(); - MS_EXCEPTION_IF_NULL(arg_x_value); - - ValueTuplePtr arg_y_value = arg_y->BuildValue()->cast(); - MS_EXCEPTION_IF_NULL(arg_y_value); - - const std::vector x_shape = arg_x_value->value(); - const std::vector y_shape = arg_y_value->value(); - bool is_same_shape = CompareShape(x_shape, y_shape); - // if it is the same shape , do not need reduce , return empty tuple - if (is_same_shape) { - AbstractBasePtrList empty_list; - auto x_reduce_idx = std::make_shared(empty_list); - auto y_reduce_idx = std::make_shared(empty_list); - - AbstractBasePtrList elem_list; - elem_list.push_back(x_reduce_idx); - elem_list.push_back(y_reduce_idx); - - return std::make_shared(elem_list); - } - - return BroadcastGradientArgsDiff(x_shape, y_shape); -} - AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { // args: Two objects of a subclass of AbstractBase diff --git a/mindspore/ccsrc/frontend/operator/prim_statement.cc b/mindspore/core/abstract/prim_statement.cc similarity index 85% rename from mindspore/ccsrc/frontend/operator/prim_statement.cc rename to mindspore/core/abstract/prim_statement.cc index 6a7f54007b..24a95709f9 100644 --- a/mindspore/ccsrc/frontend/operator/prim_statement.cc +++ b/mindspore/core/abstract/prim_statement.cc @@ -15,8 +15,7 @@ */ #include "abstract/param_validator.h" -#include "pipeline/jit/static_analysis/prim.h" -#include "frontend/operator/ops.h" +#include "abstract/infer_functions.h" #include "abstract/utils.h" #include "utils/symbolic.h" @@ -34,38 +33,6 @@ AbstractBasePtr InferImplReturn(const AnalysisEnginePtr &, const PrimitivePtr &, return abs_base; } -AbstractBasePtr InferImplTypeof(const AnalysisEnginePtr &, const PrimitivePtr &, - const AbstractBasePtrList &args_spec_list) { - // Inputs: a pointer to an AbstractBase object - if (args_spec_list.size() != 1) { - MS_LOG(EXCEPTION) << "Typeof evaluator requires 1 parameter, while the input size is " << args_spec_list.size() - << "."; - } - AbstractBasePtr abs_base = args_spec_list[0]; - MS_EXCEPTION_IF_NULL(abs_base); - TypePtr type = abs_base->BuildType(); - return std::make_shared(type); -} - -AbstractBasePtr InferImplHasType(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: a pointer to an AbstractBase object and a pointer to a Type - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 2); - AbstractTypePtr abs_type = CheckArg(op_name, args_spec_list, 1); - - auto mode_v = abs_type->GetValueTrack(); - MS_EXCEPTION_IF_NULL(mode_v); - if (!mode_v->isa()) { - MS_LOG(EXCEPTION) << "Get the type from AbstractType value failed."; - } - - TypePtr mode_t = mode_v->cast(); - MS_EXCEPTION_IF_NULL(args_spec_list[0]); - bool v = IsSubtype(args_spec_list[0], mode_t); - return std::make_shared(std::make_shared(v), kBool); -} - AbstractBasePtr InferImplDot(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { // Inputs: two tensors. diff --git a/mindspore/core/abstract/prim_structures.cc b/mindspore/core/abstract/prim_structures.cc new file mode 100644 index 0000000000..8cdfc5d20f --- /dev/null +++ b/mindspore/core/abstract/prim_structures.cc @@ -0,0 +1,278 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "abstract/infer_functions.h" +#include "abstract/utils.h" +#include "abstract/param_validator.h" +namespace mindspore { +namespace abstract { +AbstractBasePtr InferImplMakeTuple(const AnalysisEnginePtr &, const PrimitivePtr &, + const AbstractBasePtrList &args_spec_list) { + return std::make_shared(args_spec_list); +} + +AbstractBasePtr InferImplMakeList(const AnalysisEnginePtr &, const PrimitivePtr &, + const AbstractBasePtrList &args_spec_list) { + return std::make_shared(args_spec_list); +} + +AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: two tuples. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + AbstractTuplePtr keys = CheckArg(op_name, args_spec_list, 0); + AbstractTuplePtr values = CheckArg(op_name, args_spec_list, 1); + + size_t keys_size = keys->size(); + if (values->size() != keys_size) { + MS_LOG(EXCEPTION) << op_name << " evaluator keys' size is not equal with values' size"; + } + + std::vector key_value; + AbstractScalarPtr key; + AbstractBasePtrList key_list = keys->elements(); + AbstractBasePtrList value_list = values->elements(); + for (size_t index = 0; index < keys_size; index++) { + key = CheckArg(op_name + "key", key_list, index); + ValuePtr keyPtr = key->BuildValue(); + MS_EXCEPTION_IF_NULL(keyPtr); + if (!keyPtr->isa()) { + MS_LOG(EXCEPTION) << op_name << " evaluator keys should be string, but got " << keyPtr->ToString(); + } + std::string key_string = GetValue(keyPtr); + key_value.emplace_back(key_string, value_list[index]); + } + return std::make_shared(key_value); +} + +AbstractBasePtr InferImplMakeKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a string and an object of a subclass of AbstractBase. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + AbstractScalarPtr key = CheckArg(op_name, args_spec_list, 0); + + ValuePtr keyPtr = key->BuildValue(); + if (!keyPtr->isa()) { + MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << keyPtr->ToString(); + } + std::string key_string = GetValue(keyPtr); + return std::make_shared(key_string, args_spec_list[1]); +} + +AbstractBasePtr InferImplExtractKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a string and a keyword. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + AbstractScalarPtr key = CheckArg(op_name, args_spec_list, 0); + AbstractKeywordArgPtr kwarg = CheckArg(op_name, args_spec_list, 1); + + ValuePtr key_value = key->BuildValue(); + if (!key_value->isa()) { + MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString(); + } + std::string key_input = GetValue(key_value); + std::string key_actual = kwarg->get_key(); + if (key_actual != key_input) { + MS_LOG(EXCEPTION) << op_name << " evaluator input key should be same as AbstractKeywordArg' key, but input is " + << key_input << ", AbstractKeywordArg' key is " << key_actual; + } + return kwarg->get_arg(); +} + +AbstractBasePtr InferImplMakeSlice(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: three scalars whose value is an int32 number. + CheckArgsSize(primitive->name(), args_spec_list, 3); + size_t args_size = args_spec_list.size(); + for (size_t index = 0; index < args_size; index++) { + MS_EXCEPTION_IF_NULL(args_spec_list[index]); + if (!args_spec_list[index]->isa() && !args_spec_list[index]->isa()) { + MS_EXCEPTION(TypeError) << "MakeSlice eval " << index << " parameter is neither AbstractScalar nor AbstractNone."; + } + if (args_spec_list[index]->isa() && + !dyn_cast(args_spec_list[index])->BuildValue()->isa()) { + MS_EXCEPTION(TypeError) << "MakeSlice eval " << index + << " parameter is an AbstractScalar, but is not an int32 number."; + } + } + // Slice: start, end, step + return std::make_shared(args_spec_list[0], args_spec_list[1], args_spec_list[2]); +} + +template +AbstractBasePtr InferTupleOrListGetItem(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { + // Inputs: a tuple or list and a scalar whose value is an int32 number. + CheckArgsSize(op_name, args_spec_list, 2); + auto queue = CheckArg(op_name, args_spec_list, 0); + AbstractScalarPtr index = CheckArg(op_name, args_spec_list, 1); + + ValuePtr index_value = index->BuildValue(); + if (!index_value->isa()) { + // when index_value is an AnyValue and args_spec_list[0] is a scalar, try to return the type of the first element + // and continue + if (dyn_cast(queue->elements()[0]) != nullptr) { + return std::make_shared(queue->elements()[0]->BuildType()); + } + MS_EXCEPTION(IndexError) << op_name << " evaluator index should be an int32 number, but got " + << index_value->ToString(); + } + int idx_v = GetValue(index_value); + std::size_t nelems = queue->elements().size(); + if (idx_v >= SizeToInt(nelems) || idx_v < -SizeToInt(nelems)) { + MS_EXCEPTION(IndexError) << op_name << " evaluator index should be in range[-" << SizeToInt(nelems) << ", " + << SizeToInt(nelems) << "), but got " << idx_v << "."; + } + + std::size_t uidx_v = 0; + if (idx_v >= 0) { + uidx_v = IntToSize(idx_v); + } else { + uidx_v = IntToSize(idx_v + SizeToInt(nelems)); + } + return queue->elements()[uidx_v]; +} + +template +AbstractBasePtr InferTupleOrListSetItem(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { + // Inputs: a tuple or list, a scalar whose value is an int32 number and an object of a subclass of AbstractBase. + CheckArgsSize(op_name, args_spec_list, 3); + auto queue = CheckArg(op_name, args_spec_list, 0); + AbstractScalarPtr index = CheckArg(op_name, args_spec_list, 1); + + ValuePtr index_value = index->BuildValue(); + if (!index_value->isa()) { + MS_EXCEPTION(IndexError) << op_name << " evaluator index should be an int32 number, but got " + << index_value->ToString(); + } + int idx_v = GetValue(index_value); + if (idx_v < 0) { + MS_EXCEPTION(IndexError) << "The index of " << typeid(T).name() << " should be positive number, but got " << idx_v + << "."; + } + + size_t uidx_v = IntToSize(idx_v); + AbstractBasePtrList elements = queue->elements(); + std::size_t nelems = elements.size(); + if (uidx_v >= nelems) { + MS_EXCEPTION(IndexError) << op_name << " evaluator the index: " << uidx_v << " to set out of range: " << nelems - 1 + << "."; + } + elements[uidx_v] = args_spec_list[2]; + return std::make_shared(elements); +} + +AbstractBasePtr InferImplTupleGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + return InferTupleOrListGetItem(primitive->name(), args_spec_list); +} + +AbstractBasePtr InferImplListGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + return InferTupleOrListGetItem(primitive->name(), args_spec_list); +} + +AbstractBasePtr InferImplTupleSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + return InferTupleOrListSetItem(primitive->name(), args_spec_list); +} + +AbstractBasePtr InferImplListSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + return InferTupleOrListSetItem(primitive->name(), args_spec_list); +} + +AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a dict and a scalar whose value is a string. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + AbstractDictionaryPtr dict = CheckArg(op_name, args_spec_list, 0); + AbstractScalarPtr key = CheckArg(op_name, args_spec_list, 1); + + ValuePtr key_value = key->BuildValue(); + if (!key_value->isa()) { + MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString(); + } + auto key_str = GetValue(key_value); + std::vector dict_elems = dict->elements(); + auto it = std::find_if(dict_elems.begin(), dict_elems.end(), + [key_str](const AbstractAttribute &item) { return item.first == key_str; }); + + if (it == dict_elems.end()) { + MS_LOG(EXCEPTION) << "The key " << key_str << " does not exist in the dict:" << args_spec_list[0]->ToString(); + } + return it->second; +} + +AbstractBasePtr InferImplDictSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a dict and a scalar whose value is a string and an object of a subclass of AbstractBase. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 3); + AbstractDictionaryPtr dict = CheckArg(op_name, args_spec_list, 0); + AbstractScalarPtr key = CheckArg(op_name, args_spec_list, 1); + + ValuePtr key_value = key->BuildValue(); + if (!key_value->isa()) { + MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString(); + } + std::string key_str = GetValue(key_value); + std::vector dict_elems = dict->elements(); + auto it = std::find_if(dict_elems.begin(), dict_elems.end(), + [key_str](const AbstractAttribute &item) { return item.first == key_str; }); + + MS_EXCEPTION_IF_NULL(args_spec_list[2]); + auto new_ele = std::make_pair(key_str, args_spec_list[2]); + if (it != dict_elems.end()) { + int index = it - dict_elems.begin(); + dict_elems[IntToSize(index)] = new_ele; + } else { + dict_elems.push_back(new_ele); + } + return std::make_shared(dict_elems); +} + +AbstractBasePtr InferImplListAppend(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a list and an object of a subclass of AbstractBase. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + AbstractListPtr list = CheckArg(op_name, args_spec_list, 0); + (void)AbstractJoin(list->elements()); + return list; +} + +AbstractBasePtr InferImplTupleLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + return InferTupleOrListOrDictLen(primitive->name(), args_spec_list); +} + +AbstractBasePtr InferImplListLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + return InferTupleOrListOrDictLen(primitive->name(), args_spec_list); +} + +AbstractBasePtr InferImplArrayLen(const AnalysisEnginePtr &, const PrimitivePtr &, + const AbstractBasePtrList &args_spec_list) { + return std::make_shared(kAnyValue, kInt32); +} +} // namespace abstract +} // namespace mindspore diff --git a/mindspore/core/abstract/primitive_infer_map.cc b/mindspore/core/abstract/primitive_infer_map.cc new file mode 100644 index 0000000000..0b34985d67 --- /dev/null +++ b/mindspore/core/abstract/primitive_infer_map.cc @@ -0,0 +1,114 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "abstract/primitive_infer_map.h" +#include "abstract/abstract_function.h" +#include "abstract/infer_functions.h" + +namespace mindspore { +namespace abstract { +PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { + static PrimitiveEvalImplMap prim_eval_implement_map = { + // Statements + {prim::kPrimReturn, {InferImplReturn, true}}, + {prim::kPrimDot, {InferImplDot, true}}, + {prim::kPrimSwitch, {InferImplSwitch, true}}, + {prim::kPrimSwitchLayer, {InferImplSwitchLayer, true}}, + {prim::kPrimIs_, {InferImplIs_, true}}, + {prim::kPrimIsNot, {InferImplIsNot, true}}, + {prim::kPrimInDict, {InferImplInDict, true}}, + {prim::kPrimNotInDict, {InferImplNotInDict, true}}, + {prim::kPrimIsConsant, {InferImplIsConstant, true}}, + // Maths + {prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, true}}, + {prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, true}}, + // Array + {prim::kPrimScalarToArray, {InferImplScalarToArray, true}}, + {prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}}, + {prim::kPrimBroadcastShape, {InferImplBroadCastShape, true}}, + {prim::kPrimPack, {InferImplPack, true}}, + {prim::kPrimUnique, {InferImplUnique, true}}, + {prim::kPrimUniqueGrad, {InferImplUniqueGrad, true}}, + // Structure + {prim::kPrimMakeTuple, {InferImplMakeTuple, true}}, + {prim::kPrimMakeList, {InferImplMakeList, true}}, + {prim::kPrimMakeDict, {InferImplMakeDict, true}}, + {prim::kPrimMakeSlice, {InferImplMakeSlice, true}}, + {prim::kPrimMakeKeywordArg, {InferImplMakeKwarg, true}}, + {prim::kPrimExtractKeywordArg, {InferImplExtractKwarg, true}}, + {prim::kPrimTupleGetItem, {InferImplTupleGetItem, true}}, + {prim::kPrimListGetItem, {InferImplListGetItem, true}}, + {prim::kPrimTupleSetItem, {InferImplTupleSetItem, true}}, + {prim::kPrimListSetItem, {InferImplListSetItem, true}}, + {prim::kPrimDictGetItem, {InferImplDictGetItem, true}}, + {prim::kPrimDictSetItem, {InferImplDictSetItem, true}}, + {prim::kPrimListAppend, {InferImplListAppend, true}}, + {prim::kPrimTupleLen, {InferImplTupleLen, true}}, + {prim::kPrimListLen, {InferImplListLen, true}}, + {prim::kPrimArrayLen, {InferImplArrayLen, true}}, + // NN + {prim::kPrimPooling, {InferImplPooling, true}}, + {prim::kPrimPoolingGrad, {InferImplPoolingGrad, true}}, + {prim::kPrimFusedBatchNorm, {InferImplFusedBatchNorm, true}}, + {prim::kPrimFusedBatchNormGrad, {InferImplFusedBatchNormGrad, true}}, + {prim::kPrimReluGrad, {InferImplReluGrad, true}}, + {prim::kPrimConv2DBackpropInput, {InferImplConv2DBackpropInput, true}}, + {prim::kPrimConv2DBackpropFilter, {InferImplConv2DBackpropFilter, true}}, + {prim::kPrimBiasAddGrad, {InferImplBiasAddGrad, true}}, + {prim::kPrimRelu, {InferImplRelu, true}}, + {prim::kPrimZerosLike, {InferImplZerosLike, true}}, + {prim::kPrimBpropCut, {InferImplBpropCut, true}}, + {prim::kPrimLayerNorm, {InferImplLayerNorm, true}}, + {prim::kPrimLayerNormGrad, {InferImplLayerNormGrad, true}}, + {prim::kPrimDropoutGenMask, {InferImplDropoutGenMask, true}}, + // Others + {prim::kPrimIdentity, {InferImplIdentity, true}}, + // Set impl to null as it will use PartialEvaluator; + {prim::kPrimPartial, {nullptr, true}}, + {prim::kPrimEnvGetItem, {InferImplEnvGetItem, true}}, + {prim::kPrimEnvSetItem, {InferImplEnvSetItem, true}}, + {prim::kPrimEnvAdd, {InferImplEnvAdd, true}}, + {prim::kPrimMakeRefKey, {InferImplMakeRefKey, true}}, + {prim::kPrimMakeRef, {InferImplMakeRef, true}}, + {prim::kPrimGetRefKey, {InferImplGetRefKey, true}}, + {prim::kPrimGetRefValue, {InferImplGetRefValue, true}}, + {prim::kPrimStateSetItem, {InferImplStateSetItem, true}}, + {prim::kPrimDepend, {InferImplDepend, true}}, + {prim::kPrimControlDepend, {InferImplControlDepend, true}}, + // Debug + {prim::kPrimDebug, {InferImplDebug, true}}, + // SparseTensor + {prim::kPrimMakeSparseTensor, {InferImplMakeSparseTensor, true}}, + {prim::kPrimSparseTensorGetValues, {InferImplSparseTensorGetValues, true}}, + {prim::kPrimSparseTensorGetIndices, {InferImplSparseTensorGetIndices, true}}, + {prim::kPrimSparseTensorGetDenseShape, {InferImplSparseTensorGetDenseShape, true}}, + // RowTensor + {prim::kPrimMakeRowTensor, {InferImplMakeRowTensor, true}}, + {prim::kPrimRowTensorGetValues, {InferImplRowTensorGetValues, true}}, + {prim::kPrimRowTensorGetIndices, {InferImplRowTensorGetIndices, true}}, + {prim::kPrimRowTensorGetDenseShape, {InferImplRowTensorGetDenseShape, true}}, + }; + return prim_eval_implement_map; +} + +void RegisterStandardPrimitiveImpl(const PrimitivePtr &primitive, const StandardPrimitiveImplReg &impl_reg) { + auto &prim_eval_map = GetPrimitiveToEvalImplMap(); + prim_eval_map[primitive] = impl_reg; +} +} // namespace abstract +} // namespace mindspore diff --git a/mindspore/core/abstract/primitive_infer_map.h b/mindspore/core/abstract/primitive_infer_map.h new file mode 100644 index 0000000000..87380c60e4 --- /dev/null +++ b/mindspore/core/abstract/primitive_infer_map.h @@ -0,0 +1,53 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CORE_ABSTRACT_PRIMITIVE_INFER_MAP_H_ +#define MINDSPORE_CORE_ABSTRACT_PRIMITIVE_INFER_MAP_H_ +#include +#include "ir/primitive.h" +#include "base/core_ops.h" +#include "abstract/abstract_value.h" +namespace mindspore { +namespace abstract { +using StandardPrimitiveEvalImpl = AbstractBasePtr (*)(const abstract::AnalysisEnginePtr &, const PrimitivePtr &, + const AbstractBasePtrList &); +struct StandardPrimitiveImplReg { + StandardPrimitiveEvalImpl impl_; // Implement function of Primitive. + bool in_white_list_; // true if this Primitive in white list, else false. +}; + +using PrimitiveEvalImplMap = + std::unordered_map; + +PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap(); + +void RegisterStandardPrimitiveImpl(const PrimitivePtr &primitive, const StandardPrimitiveImplReg &impl_reg); + +class RegisterStandardPrimitiveEvalHelper { + public: + RegisterStandardPrimitiveEvalHelper(const PrimitivePtr &primitive, const StandardPrimitiveEvalImpl &impl) { + const StandardPrimitiveImplReg impl_reg{impl, true}; + RegisterStandardPrimitiveImpl(primitive, impl_reg); + } + ~RegisterStandardPrimitiveEvalHelper() = default; +}; + +#define REGISTER_PRIMITIVE_EVAL_IMPL(name, primitive, impl) \ + static auto helper_##name = RegisterStandardPrimitiveEvalHelper(primitive, impl) +} // namespace abstract +} // namespace mindspore +#endif // MINDSPORE_CORE_ABSTRACT_PRIMITIVE_INFER_MAP_H_ diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index ddd490c9d4..ba64ff28d5 100755 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -246,6 +246,25 @@ inline const PrimitivePtr kPrimNotInDict = std::make_shared("not_in_d inline const PrimitivePtr kPrimIsConsant = std::make_shared("is_constant"); inline const PrimitivePtr kPrimEquivFormat = std::make_shared("EquivFormat"); +// Structures +inline const PrimitivePtr kPrimMakeList = std::make_shared("make_list"); +inline const PrimitivePtr kPrimMakeKeywordArg = std::make_shared("make_keyword_arg"); +inline const PrimitivePtr kPrimListGetItem = std::make_shared("list_getitem"); +inline const PrimitivePtr kPrimListSetItem = std::make_shared("list_setitem"); +inline const PrimitivePtr kPrimDictGetItem = std::make_shared("dict_getitem"); +inline const PrimitivePtr kPrimDictSetItem = std::make_shared("dict_setitem"); +inline const PrimitivePtr kPrimListAppend = std::make_shared("list_append"); +inline const PrimitivePtr kPrimListLen = std::make_shared("list_len"); + +// Other miscellaneous +inline const PrimitivePtr kPrimEnvSetItem = std::make_shared("env_setitem"); +inline const PrimitivePtr kPrimEnvGetItem = std::make_shared("env_getitem"); +inline const PrimitivePtr kPrimEnvAdd = std::make_shared("env_add"); +inline const PrimitivePtr kPrimMakeRefKey = std::make_shared("MakeRefKey"); +inline const PrimitivePtr kPrimGetRefKey = std::make_shared("get_ref_key"); +inline const PrimitivePtr kPrimMakeRef = std::make_shared("make_ref"); +inline const PrimitivePtr kPrimGetRefValue = std::make_shared("get_ref_value"); + // Other primitve not used by backend but used in core; inline const PrimitivePtr kPrimStateSetItem = std::make_shared("state_setitem"); inline const PrimitivePtr kPrimJ = std::make_shared("J"); diff --git a/mindspore/core/c_ops/conv2d.cc b/mindspore/core/c_ops/conv2d.cc index 347ab79c1a..77c6a61438 100644 --- a/mindspore/core/c_ops/conv2d.cc +++ b/mindspore/core/c_ops/conv2d.cc @@ -26,87 +26,19 @@ namespace mindspore { namespace { -using PrimConv2dPtr = std::shared_ptr; -abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { - MS_EXCEPTION_IF_NULL(primitive); - auto conv_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(conv_prim); - auto prim_name = conv_prim->name(); - CheckAndConvertUtils::CheckInRange("Conv2d Infer", input_args.size(), kIncludeLeft, {2, 3}, prim_name); - auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[0]->GetShapeTrack(), prim_name); - auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[1]->GetShapeTrack(), prim_name); - - CheckAndConvertUtils::CheckInteger("weight rank", w_shape.size(), kEqual, 4, prim_name); - CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name); - CheckAndConvertUtils::Check("x_shape[1] / group", x_shape[1] / conv_prim->GetGroup(), kEqual, "w_shape[1]", - w_shape[1], conv_prim->name()); - auto out_channel = conv_prim->GetOutputChannel(); - CheckAndConvertUtils::Check("out_channel", out_channel, kEqual, "w_shape[0]", w_shape[0], conv_prim->name()); - std::vector temp_w; - std::copy(w_shape.begin() + 2, w_shape.end(), std::back_inserter(temp_w)); - CheckAndConvertUtils::Check("kernel_size", conv_prim->GetKernelSize(), kEqual, "w_shape[2:4]", temp_w, - conv_prim->name()); - - auto kernel_size_h = w_shape[2]; - auto kernel_size_w = w_shape[3]; - auto stride = conv_prim->GetStride(); - auto dilation = conv_prim->GetDilation(); - auto stride_h = stride[2]; - auto stride_w = stride[3]; - auto dilation_h = dilation[2]; - auto dilation_w = dilation[3]; - int h_out = -1; - int w_out = -1; - std::vector pad_list(4, 0); - auto pad_mode = conv_prim->GetPadMode(); - if (pad_mode == "valid") { - h_out = ceil((x_shape[2] - dilation_h * (kernel_size_h - 1)) / stride_h); - w_out = ceil((x_shape[3] - dilation_w * (kernel_size_w - 1)) / stride_w); - } else if (pad_mode == "same") { - h_out = ceil(x_shape[2] / stride_h); - w_out = ceil(x_shape[3] / stride_w); - - auto pad_needed_h = std::max(0, (h_out - 1) * stride_h + dilation_h * (kernel_size_h - 1) + 1 - x_shape[2]); - pad_list.emplace_back(floor(pad_needed_h / 2)); - pad_list.emplace_back(pad_needed_h / 2); - auto pad_needed_w = std::max(0, (w_out - 1) * stride_w + dilation_w * (kernel_size_w - 1) + 1 - x_shape[3]); - auto pad_left = floor(pad_needed_w / 2); - pad_list.emplace_back(pad_left); - pad_list.emplace_back(pad_needed_h - pad_left); - } else if (pad_mode == "pad") { - std::copy(conv_prim->GetPad().begin(), conv_prim->GetPad().end(), std::back_inserter(pad_list)); - auto pad_top = conv_prim->GetPad()[0]; - auto pad_bottom = conv_prim->GetPad()[1]; - auto pad_right = conv_prim->GetPad()[2]; - auto pad_left = conv_prim->GetPad()[3]; - - h_out = 1 + (x_shape[2] + pad_top + pad_bottom - kernel_size_h - (kernel_size_h - 1) * (dilation_h - 1)) / stride_h; - w_out = 1 + (x_shape[3] + pad_left + pad_right - kernel_size_w - (kernel_size_w - 1) * (dilation_w - 1)) / stride_w; - h_out = floor(h_out); - w_out = floor(w_out); - } - conv_prim->SetPadList(pad_list); - std::vector out_shape = {x_shape[0], out_channel, h_out, w_out}; - return std::make_shared(out_shape); -} - -TypePtr InferType(const PrimitivePtr &prim, const std::vector &input_args) { - CheckAndConvertUtils::CheckInRange("", input_args.size(), kIncludeLeft, {2, 3}, prim->name()); - for (const auto &item : input_args) { - MS_EXCEPTION_IF_NULL(item); - } - auto x_type = CheckAndConvertUtils::ConvertTypePtrToTypeId("x_dtype", input_args[0]->GetTypeTrack(), prim->name()); - const std::set valid_types = {kNumberTypeInt8, kNumberTypeInt32, kNumberTypeFloat16, kNumberTypeFloat32}; - std::map types; - types.emplace("x", input_args[0]->GetTypeTrack()); - types.emplace("w", input_args[1]->GetTypeTrack()); - CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); - if (x_type == kNumberTypeInt8) { - return std::make_shared(TypeIdToType(kNumberTypeInt32)); - } - return std::make_shared(TypeIdToType(x_type)); -} +constexpr auto kKernelSize = "kernel_size"; +constexpr auto kStride = "stride"; +constexpr auto kDilation = "dilation"; +constexpr auto kPadMode = "pad_mode"; +constexpr auto kPad = "pad"; +constexpr auto kMode = "mode"; +constexpr auto kGroup = "group"; +constexpr auto kOutputChannel = "output channel"; +constexpr auto kPadList = "pad_list"; +constexpr auto kConv2DName = "Conv2D"; } // namespace +Conv2d::Conv2d() : PrimitiveC(kConv2DName) { InitIOName({"x", "w"}, {"output"}); } + void Conv2d::Init(int out_channel, const std::vector &kernel_size, int mode, const std::string &pad_mode, const std::vector &pad, const std::vector &stride, const std::vector &dilation, int group) { @@ -130,10 +62,47 @@ void Conv2d::Init(int out_channel, const std::vector &kernel_size, int mode this->SetOutChannel(CheckAndConvertUtils::CheckInteger("out_channel", out_channel, kGreaterThan, 0, prim_name)); this->SetGroup(CheckAndConvertUtils::CheckInteger("group", group, kGreaterThan, 0, prim_name)); } +std::vector Conv2d::GetKernelSize() const { + auto value_ptr = GetAttr(kKernelSize); + return GetValue>(value_ptr); +} +std::vector Conv2d::GetStride() const { + auto value_ptr = GetAttr(kStride); + return GetValue>(value_ptr); +} +std::vector Conv2d::GetDilation() const { + auto value_ptr = GetAttr(kDilation); + return GetValue>(value_ptr); +} +std::string Conv2d::GetPadMode() const { + auto value_ptr = this->GetAttr(kPadMode); + return GetValue(value_ptr); +} +std::vector Conv2d::GetPad() const { + auto value_ptr = this->GetAttr(kPad); + return GetValue>(value_ptr); +} +int Conv2d::GetMode() const { + auto value_ptr = this->GetAttr(kMode); + return GetValue(value_ptr); +} -AbstractBasePtr Conv2dInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, - const std::vector &input_args) { - return std::make_shared(InferType(primitive, input_args), - InferShape(primitive, input_args)->shape()); +int Conv2d::GetGroup() const { + auto value_ptr = this->GetAttr(kGroup); + return GetValue(value_ptr); } +int Conv2d::GetOutputChannel() const { + auto value_ptr = this->GetAttr(kOutputChannel); + return GetValue(value_ptr); +} + +void Conv2d::SetKernelSize(const std::vector &kernel_size) { this->AddAttr(kKernelSize, MakeValue(kernel_size)); } +void Conv2d::SetStride(const std::vector &stride) { this->AddAttr(kStride, MakeValue(stride)); } +void Conv2d::SetDilation(const std::vector &dilation) { this->AddAttr(kDilation, MakeValue(dilation)); } +void Conv2d::SetPadMode(const std::string &pad_mode) { this->AddAttr(kPadMode, MakeValue(pad_mode)); } +void Conv2d::SetPad(const std::vector &pad) { this->AddAttr(kPad, MakeValue(pad)); } +void Conv2d::SetMode(int mode) { this->AddAttr(kMode, MakeValue(mode)); } +void Conv2d::SetGroup(int group) { this->AddAttr(kGroup, MakeValue(group)); } +void Conv2d::SetOutChannel(int output_channel) { this->AddAttr(kOutputChannel, MakeValue(output_channel)); } +void Conv2d::SetPadList(const std::vector &pad_list) { this->AddAttr(kPadList, MakeValue(pad_list)); } } // namespace mindspore diff --git a/mindspore/core/c_ops/conv2d.h b/mindspore/core/c_ops/conv2d.h index cbba5fa068..910fad18af 100644 --- a/mindspore/core/c_ops/conv2d.h +++ b/mindspore/core/c_ops/conv2d.h @@ -16,79 +16,44 @@ * limitations under the License. */ -#ifndef MINDSPORE_CORE_C_OPS_CONV2D_H -#define MINDSPORE_CORE_C_OPS_CONV2D_H +#ifndef MINDSPORE_CORE_C_OPS_CONV2D_H_ +#define MINDSPORE_CORE_C_OPS_CONV2D_H_ #include #include #include +#include + #include "c_ops/primitive_c.h" #include "abstract/abstract_value.h" #include "utils/check_convert_utils.h" namespace mindspore { class Conv2d : public PrimitiveC { public: - Conv2d() : PrimitiveC(kConv2DName) { InitIOName({"x", "w"}, {"output"}); } + Conv2d(); void Init(int out_channel, const std::vector &kernel_size, int mode = 1, const std::string &pad_mode = "valid", const std::vector &pad = {0, 0, 0, 0}, const std::vector &stride = {1, 1, 1, 1}, const std::vector &dilation = {1, 1, 1, 1}, int group = 1); - std::vector GetKernelSize() const { - auto value_ptr = this->GetAttr(kKernelSize); - return GetValue>(value_ptr); - } - std::vector GetStride() const { - auto value_ptr = GetAttr(kStride); - return GetValue>(value_ptr); - } - std::vector GetDilation() const { - auto value_ptr = GetAttr(kDilation); - return GetValue>(value_ptr); - } - std::string GetPadMode() const { - auto value_ptr = this->GetAttr(kPadMode); - return GetValue(value_ptr); - } - std::vector GetPad() const { - auto value_ptr = this->GetAttr(kPad); - return GetValue>(value_ptr); - } - int GetMode() const { - auto value_ptr = this->GetAttr(kMode); - return GetValue(value_ptr); - } - - int GetGroup() const { - auto value_ptr = this->GetAttr(kGroup); - return GetValue(value_ptr); - } - int GetOutputChannel() const { - auto value_ptr = this->GetAttr(kOutputChannel); - return GetValue(value_ptr); - } - - void SetKernelSize(const std::vector &kernel_size) { this->AddAttr(kKernelSize, MakeValue(kernel_size)); } - void SetStride(const std::vector &stride) { this->AddAttr(kStride, MakeValue(stride)); } - void SetDilation(const std::vector &dilation) { this->AddAttr(kDilation, MakeValue(dilation)); } - void SetPadMode(const std::string &pad_mode) { this->AddAttr(kPadMode, MakeValue(pad_mode)); } - void SetPad(const std::vector &pad) { this->AddAttr(kPad, MakeValue(pad)); } - void SetMode(int mode) { this->AddAttr(kMode, MakeValue(mode)); } - void SetGroup(int group) { this->AddAttr(kGroup, MakeValue(group)); } - void SetOutChannel(int output_channel) { this->AddAttr(kOutputChannel, MakeValue(output_channel)); } - void SetPadList(const std::vector &pad_list) { this->AddAttr(kPadList, MakeValue(pad_list)); } - - private: - inline static const string kKernelSize = "kernel_size"; - inline static const string kStride = "stride"; - inline static const string kDilation = "dilation"; - inline static const string kPadMode = "pad_mode"; - inline static const string kPad = "pad"; - inline static const string kMode = "mode"; - inline static const string kGroup = "group"; - inline static const string kOutputChannel = "output channel"; - inline static const string kPadList = "pad_list"; - inline static const string kConv2DName = "Conv2D"; + std::vector GetKernelSize() const; + std::vector GetStride() const; + std::vector GetDilation() const; + std::string GetPadMode() const; + std::vector GetPad() const; + int GetMode() const; + int GetGroup() const; + int GetOutputChannel() const; + void SetKernelSize(const std::vector &kernel_size); + void SetStride(const std::vector &stride); + void SetDilation(const std::vector &dilation); + void SetPadMode(const std::string &pad_mode); + void SetPad(const std::vector &pad); + void SetMode(int mode); + void SetGroup(int group); + void SetOutChannel(int output_channel); + void SetPadList(const std::vector &pad_list); }; AbstractBasePtr Conv2dInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args); +using PrimConv2dPtr = std::shared_ptr; } // namespace mindspore -#endif // MINDSPORE_CORE_C_OPS_CONV2D_H +#endif // MINDSPORE_CORE_C_OPS_CONV2D_H_ diff --git a/mindspore/core/c_ops/primitive_c.h b/mindspore/core/c_ops/primitive_c.h index 69c85b31ae..501f32f964 100644 --- a/mindspore/core/c_ops/primitive_c.h +++ b/mindspore/core/c_ops/primitive_c.h @@ -16,8 +16,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H -#define MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H +#ifndef MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H_ +#define MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H_ #include #include #include "ir/primitive.h" @@ -25,7 +25,7 @@ namespace mindspore { class PrimitiveC : public Primitive { public: - explicit PrimitiveC(const std::string &name) : Primitive(name) { attrs_ = {}; } + explicit PrimitiveC(const std::string &name) : Primitive(name) {} protected: void InitIOName(const std::vector &inputs_name, const std::vector &outputs_name) { @@ -34,4 +34,4 @@ class PrimitiveC : public Primitive { } }; } // namespace mindspore -#endif // MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H +#endif // MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H_ diff --git a/tests/ut/cpp/CMakeLists.txt b/tests/ut/cpp/CMakeLists.txt index 3d2c29d067..81d2b8a14b 100644 --- a/tests/ut/cpp/CMakeLists.txt +++ b/tests/ut/cpp/CMakeLists.txt @@ -61,6 +61,7 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "../../../mindspore/core/abstract/*.cc" "../../../mindspore/core/ir/*.cc" "../../../mindspore/core/utils/*.cc" + "../../../mindspore/core/c_ops/*.cc" "../../../mindspore/ccsrc/common/*.cc" "../../../mindspore/ccsrc/utils/*.cc" "../../../mindspore/ccsrc/pipeline/jit/parse/*.cc"