Merge pull request !4255 from lianliguang/unify-primitivetags/v0.7.0-beta
| @@ -393,40 +393,5 @@ ValuePtr BoolEq(const ValuePtrList &list) { | |||
| MS_LOG(EXCEPTION) << "Unsported Value for BoolEq, x: " << x->ToString() << "."; | |||
| } | |||
| std::vector<int> BroadcastShape_(std::vector<int> shpx, std::vector<int> shpy) { | |||
| int dlen = SizeToInt(shpx.size()) - SizeToInt(shpy.size()); | |||
| if (dlen < 0) { | |||
| for (int i = 0; i < -dlen; ++i) { | |||
| (void)shpx.insert(shpx.begin(), 1); | |||
| } | |||
| } else if (dlen > 0) { | |||
| for (int i = 0; i < dlen; i++) { | |||
| (void)shpy.insert(shpy.begin(), 1); | |||
| } | |||
| } | |||
| if (shpx.size() != shpy.size()) { | |||
| MS_LOG(EXCEPTION) << "Failure: shpx.size() != shpy.size()."; | |||
| } | |||
| std::vector<int> shp; | |||
| for (size_t i = 0; i < shpx.size(); i++) { | |||
| auto a = shpx[i]; | |||
| auto b = shpy[i]; | |||
| if (a == 1) { | |||
| shp.push_back(b); | |||
| } else if (b == 1) { | |||
| shp.push_back(a); | |||
| } else if (a == -1) { | |||
| shp.push_back(b); | |||
| } else if (b == -1) { | |||
| shp.push_back(a); | |||
| } else if (a == b) { | |||
| shp.push_back(a); | |||
| } else { | |||
| return std::vector<int>(); | |||
| } | |||
| } | |||
| return shp; | |||
| } | |||
| } // namespace prim | |||
| } // namespace mindspore | |||
| @@ -52,7 +52,6 @@ ValuePtr BoolNot(const ValuePtrList &list); | |||
| ValuePtr BoolAnd(const ValuePtrList &list); | |||
| ValuePtr BoolOr(const ValuePtrList &list); | |||
| ValuePtr BoolEq(const ValuePtrList &list); | |||
| std::vector<int> BroadcastShape_(std::vector<int> s1, std::vector<int> s2); | |||
| } // namespace prim | |||
| } // namespace mindspore | |||
| @@ -42,28 +42,13 @@ inline const PrimitivePtr kPrimRefToEmbed = std::make_shared<Primitive>("RefToEm | |||
| inline const PrimitivePtr kPrimCreateInstance = std::make_shared<Primitive>("create_instance"); | |||
| // Other miscellaneous | |||
| inline const PrimitivePtr kPrimEnvSetItem = std::make_shared<Primitive>("env_setitem"); | |||
| inline const PrimitivePtr kPrimEnvGetItem = std::make_shared<Primitive>("env_getitem"); | |||
| inline const PrimitivePtr kPrimEnvAdd = std::make_shared<Primitive>("env_add"); | |||
| inline const PrimitivePtr kPrimMakeRefKey = std::make_shared<Primitive>("MakeRefKey"); | |||
| inline const PrimitivePtr kPrimGetRefKey = std::make_shared<Primitive>("get_ref_key"); | |||
| inline const PrimitivePtr kPrimGetRefValue = std::make_shared<Primitive>("get_ref_value"); | |||
| inline const PrimitivePtr kPrimGetRefOrigin = std::make_shared<Primitive>("get_ref_origin"); | |||
| inline const PrimitivePtr kPrimInsertGradientOf = std::make_shared<Primitive>("InsertGradientOf"); | |||
| inline const PrimitivePtr kPrimCheckBprop = std::make_shared<Primitive>("CheckBprop"); | |||
| inline const PrimitivePtr kPrimMakeRef = std::make_shared<Primitive>("make_ref"); | |||
| inline const PrimitivePtr kPrimMixedPrecisionCast = std::make_shared<Primitive>("mixed_precision_cast"); | |||
| inline const PrimitivePtr kPrimMakeRecord = std::make_shared<Primitive>("make_record"); | |||
| // Structures | |||
| inline const PrimitivePtr kPrimMakeList = std::make_shared<Primitive>("make_list"); | |||
| inline const PrimitivePtr kPrimMakeKeywordArg = std::make_shared<Primitive>("make_keyword_arg"); | |||
| inline const PrimitivePtr kPrimListGetItem = std::make_shared<Primitive>("list_getitem"); | |||
| inline const PrimitivePtr kPrimListSetItem = std::make_shared<Primitive>("list_setitem"); | |||
| inline const PrimitivePtr kPrimDictGetItem = std::make_shared<Primitive>("dict_getitem"); | |||
| inline const PrimitivePtr kPrimDictSetItem = std::make_shared<Primitive>("dict_setitem"); | |||
| inline const PrimitivePtr kPrimListAppend = std::make_shared<Primitive>("list_append"); | |||
| inline const PrimitivePtr kPrimListLen = std::make_shared<Primitive>("list_len"); | |||
| inline const PrimitivePtr kPrimListMap = std::make_shared<Primitive>("list_map"); | |||
| inline const PrimitivePtr kPrimListReduce = std::make_shared<Primitive>("list_reduce"); | |||
| @@ -1,6 +1,4 @@ | |||
| /** | |||
| * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). | |||
| * | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| @@ -15,360 +13,266 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "frontend/operator/ops_front_infer_function.h" | |||
| #include <set> | |||
| #include <string> | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <algorithm> | |||
| #include "abstract/abstract_value.h" | |||
| #include "pipeline/jit/static_analysis/prim.h" | |||
| #include "abstract/utils.h" | |||
| #include "abstract/param_validator.h" | |||
| #include "frontend/operator/ops.h" | |||
| #include "utils/convert_utils.h" | |||
| #include "utils/tensor_py.h" | |||
| using mindspore::tensor::TensorPy; | |||
| #include "frontend/operator/ops.h" | |||
| #include "abstract/infer_functions.h" | |||
| namespace mindspore { | |||
| namespace abstract { | |||
| enum State { | |||
| SAME, | |||
| X_ONE, | |||
| Y_ONE, | |||
| }; | |||
| AbstractBasePtr InferImplStringEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: two scalars whose value is a string. | |||
| const std::string op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, 2); | |||
| AbstractScalarPtr scalar_x = CheckArg<AbstractScalar>(op_name, args_spec_list, 0); | |||
| AbstractScalarPtr scalar_y = CheckArg<AbstractScalar>(op_name, args_spec_list, 1); | |||
| ValuePtr value_x = scalar_x->BuildValue(); | |||
| ValuePtr value_y = scalar_y->BuildValue(); | |||
| if (!value_x->isa<StringImm>() || !value_y->isa<StringImm>()) { | |||
| MS_LOG(EXCEPTION) << op_name << " requires 2 parameters are string, but got param0: " << value_x->ToString() | |||
| << ", param1: " << value_y->ToString(); | |||
| } | |||
| bool ret = (value_x->cast<StringImmPtr>()->value() == value_y->cast<StringImmPtr>()->value()); | |||
| return std::make_shared<AbstractScalar>(ret); | |||
| } | |||
| struct SlideInfo { | |||
| int start; | |||
| int step; | |||
| int stop; | |||
| }; | |||
| AbstractBasePtr InferImplStringConcat(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: two scalars whose value is a string. | |||
| const std::string op_name = primitive->name(); | |||
| template <typename T> | |||
| AbstractBasePtr InferImplTupleOrListEqual(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: two tuples or two lists. | |||
| CheckArgsSize(op_name, args_spec_list, 2); | |||
| AbstractScalarPtr scalar_x = CheckArg<AbstractScalar>(op_name, args_spec_list, 0); | |||
| AbstractScalarPtr scalar_y = CheckArg<AbstractScalar>(op_name, args_spec_list, 1); | |||
| ValuePtr value_x = scalar_x->BuildValue(); | |||
| ValuePtr value_y = scalar_y->BuildValue(); | |||
| if (!value_x->isa<StringImm>() || !value_y->isa<StringImm>()) { | |||
| MS_LOG(EXCEPTION) << op_name << " requires 2 parameters are string, but got param0: " << value_x->ToString() | |||
| << ", param1: " << value_y->ToString(); | |||
| } | |||
| std::string ret = (value_x->cast<StringImmPtr>()->value() + value_y->cast<StringImmPtr>()->value()); | |||
| return std::make_shared<AbstractScalar>(ret); | |||
| } | |||
| AbstractBasePtr InferImplMakeTuple(const AnalysisEnginePtr &, const PrimitivePtr &, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| return std::make_shared<AbstractTuple>(args_spec_list); | |||
| } | |||
| auto input_x = CheckArg<T>(op_name, args_spec_list, 0); | |||
| auto input_y = CheckArg<T>(op_name, args_spec_list, 1); | |||
| AbstractBasePtr InferImplMakeList(const AnalysisEnginePtr &, const PrimitivePtr &, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| return std::make_shared<AbstractList>(args_spec_list); | |||
| ValuePtr x_value = input_x->BuildValue(); | |||
| ValuePtr y_value = input_y->BuildValue(); | |||
| return std::make_shared<AbstractScalar>(*x_value == *y_value); | |||
| } | |||
| AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: two tuples. | |||
| const std::string op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, 2); | |||
| AbstractTuplePtr keys = CheckArg<AbstractTuple>(op_name, args_spec_list, 0); | |||
| AbstractTuplePtr values = CheckArg<AbstractTuple>(op_name, args_spec_list, 1); | |||
| size_t keys_size = keys->size(); | |||
| if (values->size() != keys_size) { | |||
| MS_LOG(EXCEPTION) << op_name << " evaluator keys' size is not equal with values' size"; | |||
| } | |||
| std::vector<AbstractAttribute> key_value; | |||
| AbstractScalarPtr key; | |||
| AbstractBasePtrList key_list = keys->elements(); | |||
| AbstractBasePtrList value_list = values->elements(); | |||
| for (size_t index = 0; index < keys_size; index++) { | |||
| key = CheckArg<AbstractScalar>(op_name + "key", key_list, index); | |||
| ValuePtr keyPtr = key->BuildValue(); | |||
| MS_EXCEPTION_IF_NULL(keyPtr); | |||
| if (!keyPtr->isa<StringImm>()) { | |||
| MS_LOG(EXCEPTION) << op_name << " evaluator keys should be string, but got " << keyPtr->ToString(); | |||
| void CalcSlidePara(const AbstractBasePtrList &args_spec_list, SlideInfo *slide) { | |||
| int arg1 = 0; | |||
| int arg2 = 0; | |||
| if (!args_spec_list.empty()) { | |||
| MS_EXCEPTION_IF_NULL(args_spec_list[0]); | |||
| auto arg_value = args_spec_list[0]->BuildValue(); | |||
| if (!arg_value->isa<Int32Imm>()) { | |||
| MS_LOG(EXCEPTION) << "Only supported input an int32 number."; | |||
| } | |||
| std::string key_string = GetValue<std::string>(keyPtr); | |||
| key_value.emplace_back(key_string, value_list[index]); | |||
| arg1 = GetValue<int>(arg_value); | |||
| } | |||
| return std::make_shared<AbstractDictionary>(key_value); | |||
| } | |||
| AbstractBasePtr InferImplMakeKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: a string and an object of a subclass of AbstractBase. | |||
| const std::string op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, 2); | |||
| AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 0); | |||
| ValuePtr keyPtr = key->BuildValue(); | |||
| if (!keyPtr->isa<StringImm>()) { | |||
| MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << keyPtr->ToString(); | |||
| if (args_spec_list.size() >= 2) { | |||
| MS_EXCEPTION_IF_NULL(args_spec_list[1]); | |||
| auto arg_value = args_spec_list[1]->BuildValue(); | |||
| if (!arg_value->isa<Int32Imm>()) { | |||
| MS_LOG(EXCEPTION) << "Only supported input an int32 number."; | |||
| } | |||
| arg2 = GetValue<int>(arg_value); | |||
| } | |||
| std::string key_string = GetValue<std::string>(keyPtr); | |||
| return std::make_shared<AbstractKeywordArg>(key_string, args_spec_list[1]); | |||
| } | |||
| AbstractBasePtr InferImplExtractKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: a string and a keyword. | |||
| const std::string op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, 2); | |||
| AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 0); | |||
| AbstractKeywordArgPtr kwarg = CheckArg<AbstractKeywordArg>(op_name, args_spec_list, 1); | |||
| if (args_spec_list.size() == 3) { | |||
| MS_EXCEPTION_IF_NULL(args_spec_list[2]); | |||
| auto arg_value = args_spec_list[2]->BuildValue(); | |||
| if (!arg_value->isa<Int32Imm>()) { | |||
| MS_LOG(EXCEPTION) << "Only supported input an int32 number."; | |||
| } | |||
| slide->step = GetValue<int>(arg_value); | |||
| slide->start = arg1; | |||
| slide->stop = arg2; | |||
| } | |||
| ValuePtr key_value = key->BuildValue(); | |||
| if (!key_value->isa<StringImm>()) { | |||
| MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString(); | |||
| if (args_spec_list.size() == 2) { | |||
| slide->start = arg1; | |||
| slide->stop = arg2; | |||
| } | |||
| std::string key_input = GetValue<std::string>(key_value); | |||
| std::string key_actual = kwarg->get_key(); | |||
| if (key_actual != key_input) { | |||
| MS_LOG(EXCEPTION) << op_name << " evaluator input key should be same as AbstractKeywordArg' key, but input is " | |||
| << key_input << ", AbstractKeywordArg' key is " << key_actual; | |||
| if (args_spec_list.size() == 1) { | |||
| slide->stop = arg1; | |||
| } | |||
| return kwarg->get_arg(); | |||
| } | |||
| AbstractBasePtr InferImplMakeSlice(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: three scalars whose value is an int32 number. | |||
| CheckArgsSize(primitive->name(), args_spec_list, 3); | |||
| size_t args_size = args_spec_list.size(); | |||
| for (size_t index = 0; index < args_size; index++) { | |||
| MS_EXCEPTION_IF_NULL(args_spec_list[index]); | |||
| if (!args_spec_list[index]->isa<AbstractScalar>() && !args_spec_list[index]->isa<AbstractNone>()) { | |||
| MS_EXCEPTION(TypeError) << "MakeSlice eval " << index << " parameter is neither AbstractScalar nor AbstractNone."; | |||
| void ComputeReduceIndex(const std::vector<int> &reverse_x, const std::vector<int> &reverse_y, | |||
| std::vector<int> *grad_x_reduce_idx, std::vector<int> *grad_y_reduce_idy) { | |||
| const size_t n = reverse_x.size(); | |||
| for (size_t i = 0; i < n; ++i) { | |||
| State curr; | |||
| const int32_t x_i = reverse_x[i]; | |||
| const int32_t y_i = reverse_y[i]; | |||
| const int reduce_idx = SizeToInt(n - 1 - i); | |||
| if (x_i == y_i) { | |||
| curr = SAME; | |||
| } else if (x_i == 1) { | |||
| grad_x_reduce_idx->push_back(reduce_idx); | |||
| curr = X_ONE; | |||
| } else if (y_i == 1) { | |||
| grad_y_reduce_idy->push_back(reduce_idx); | |||
| curr = Y_ONE; | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "not compatible shape input for BroadcastGradientArgs"; | |||
| } | |||
| if (args_spec_list[index]->isa<AbstractScalar>() && | |||
| !dyn_cast<AbstractScalar>(args_spec_list[index])->BuildValue()->isa<Int32Imm>()) { | |||
| MS_EXCEPTION(TypeError) << "MakeSlice eval " << index | |||
| << " parameter is an AbstractScalar, but is not an int32 number."; | |||
| if (curr == SAME && x_i == 1) { | |||
| grad_x_reduce_idx->push_back(reduce_idx); | |||
| grad_y_reduce_idy->push_back(reduce_idx); | |||
| continue; | |||
| } | |||
| } | |||
| // Slice: start, end, step | |||
| return std::make_shared<AbstractSlice>(args_spec_list[0], args_spec_list[1], args_spec_list[2]); | |||
| std::reverse(grad_x_reduce_idx->begin(), grad_x_reduce_idx->end()); | |||
| std::reverse(grad_y_reduce_idy->begin(), grad_y_reduce_idy->end()); | |||
| } | |||
| // Eval the return type of make_record | |||
| AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: at lease two objects of a subclass of AbstractBase. | |||
| if (args_spec_list.size() < 2) { | |||
| MS_LOG(EXCEPTION) << "Typeof evaluator requires more than 1 parameter, while the input size is " | |||
| << args_spec_list.size() << "."; | |||
| } | |||
| AbstractBasePtr BroadcastGradientArgsDiff(const std::vector<ValuePtr> &x_shape, const std::vector<ValuePtr> &y_shape) { | |||
| std::vector<int> reverse_x; | |||
| std::vector<int> reverse_y; | |||
| // args_spec_list[0] maybe AbstractScalarPtr or AbstractTypePtr | |||
| MS_EXCEPTION_IF_NULL(args_spec_list[0]); | |||
| TypePtr type = args_spec_list[0]->GetTypeTrack(); | |||
| MS_EXCEPTION_IF_NULL(type); | |||
| if (type->type_id() != kMetaTypeTypeType) { | |||
| MS_LOG(EXCEPTION) << "Can not make type(" << type->ToString() << ")not TypeType"; | |||
| } | |||
| (void)std::transform(x_shape.rbegin(), x_shape.rend(), std::back_inserter(reverse_x), | |||
| [](const ValuePtr &v) { return v->cast<Int32ImmPtr>()->value(); }); | |||
| (void)std::transform(y_shape.rbegin(), y_shape.rend(), std::back_inserter(reverse_y), | |||
| [](const ValuePtr &v) { return v->cast<Int32ImmPtr>()->value(); }); | |||
| ValuePtr value_track = args_spec_list[0]->GetValueTrack(); | |||
| MS_EXCEPTION_IF_NULL(value_track); | |||
| TypePtr type_ptr = value_track->cast<TypePtr>(); | |||
| if (type_ptr == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Value type error, not Me type:" << value_track->ToString(); | |||
| if (reverse_x.size() > reverse_y.size()) { | |||
| reverse_y.resize(reverse_x.size(), 1); | |||
| } else { | |||
| reverse_x.resize(reverse_y.size(), 1); | |||
| } | |||
| auto cls = dyn_cast<Class>(type_ptr); | |||
| MS_EXCEPTION_IF_NULL(cls); | |||
| ClassAttrVector attributes = cls->GetAttributes(); | |||
| CheckArgsSize(primitive->name(), args_spec_list, attributes.size() + 1); | |||
| std::vector<int> grad_x_reduce_idx; | |||
| std::vector<int> grad_y_reduce_idy; | |||
| ComputeReduceIndex(reverse_x, reverse_y, &grad_x_reduce_idx, &grad_y_reduce_idy); | |||
| std::vector<AbstractAttribute> abs_attributes; | |||
| for (size_t i = 0; i < attributes.size(); i++) { | |||
| AbstractAttribute elem(attributes[i].first, args_spec_list[i + 1]); | |||
| abs_attributes.push_back(elem); | |||
| } | |||
| AbstractBasePtrList abs_list_x; | |||
| AbstractBasePtrList abs_list_y; | |||
| (void)std::transform(grad_x_reduce_idx.begin(), grad_x_reduce_idx.end(), std::back_inserter(abs_list_x), | |||
| [](int v) { return abstract::FromValue(v); }); | |||
| (void)std::transform(grad_y_reduce_idy.begin(), grad_y_reduce_idy.end(), std::back_inserter(abs_list_y), | |||
| [](int v) { return abstract::FromValue(v); }); | |||
| auto x_reduce_idx = std::make_shared<AbstractTuple>(abs_list_x); | |||
| auto y_reduce_idx = std::make_shared<AbstractTuple>(abs_list_y); | |||
| AbstractBasePtrList elem_list; | |||
| elem_list.push_back(x_reduce_idx); | |||
| elem_list.push_back(y_reduce_idx); | |||
| return std::make_shared<AbstractClass>(cls->tag(), abs_attributes, cls->methods()); | |||
| return std::make_shared<AbstractTuple>(elem_list); | |||
| } | |||
| template <typename T> | |||
| AbstractBasePtr InferTupleOrListGetItem(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: a tuple or list and a scalar whose value is an int32 number. | |||
| CheckArgsSize(op_name, args_spec_list, 2); | |||
| auto queue = CheckArg<T>(op_name, args_spec_list, 0); | |||
| AbstractScalarPtr index = CheckArg<AbstractScalar>(op_name, args_spec_list, 1); | |||
| ValuePtr index_value = index->BuildValue(); | |||
| if (!index_value->isa<Int32Imm>()) { | |||
| // when index_value is an AnyValue and args_spec_list[0] is a scalar, try to return the type of the first element | |||
| // and continue | |||
| if (dyn_cast<AbstractScalar>(queue->elements()[0]) != nullptr) { | |||
| return std::make_shared<AbstractScalar>(queue->elements()[0]->BuildType()); | |||
| } | |||
| MS_EXCEPTION(IndexError) << op_name << " evaluator index should be an int32 number, but got " | |||
| << index_value->ToString(); | |||
| } | |||
| int idx_v = GetValue<int>(index_value); | |||
| std::size_t nelems = queue->elements().size(); | |||
| if (idx_v >= SizeToInt(nelems) || idx_v < -SizeToInt(nelems)) { | |||
| MS_EXCEPTION(IndexError) << op_name << " evaluator index should be in range[-" << SizeToInt(nelems) << ", " | |||
| << SizeToInt(nelems) << "), but got " << idx_v << "."; | |||
| } | |||
| std::size_t uidx_v = 0; | |||
| if (idx_v >= 0) { | |||
| uidx_v = IntToSize(idx_v); | |||
| } else { | |||
| uidx_v = IntToSize(idx_v + SizeToInt(nelems)); | |||
| AbstractBasePtr InferImplTypeof(const AnalysisEnginePtr &, const PrimitivePtr &, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: a pointer to an AbstractBase object | |||
| if (args_spec_list.size() != 1) { | |||
| MS_LOG(EXCEPTION) << "Typeof evaluator requires 1 parameter, while the input size is " << args_spec_list.size() | |||
| << "."; | |||
| } | |||
| return queue->elements()[uidx_v]; | |||
| AbstractBasePtr abs_base = args_spec_list[0]; | |||
| MS_EXCEPTION_IF_NULL(abs_base); | |||
| TypePtr type = abs_base->BuildType(); | |||
| return std::make_shared<AbstractType>(type); | |||
| } | |||
| template <typename T> | |||
| AbstractBasePtr InferTupleOrListSetItem(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: a tuple or list, a scalar whose value is an int32 number and an object of a subclass of AbstractBase. | |||
| CheckArgsSize(op_name, args_spec_list, 3); | |||
| auto queue = CheckArg<T>(op_name, args_spec_list, 0); | |||
| AbstractScalarPtr index = CheckArg<AbstractScalar>(op_name, args_spec_list, 1); | |||
| ValuePtr index_value = index->BuildValue(); | |||
| if (!index_value->isa<Int32Imm>()) { | |||
| MS_EXCEPTION(IndexError) << op_name << " evaluator index should be an int32 number, but got " | |||
| << index_value->ToString(); | |||
| } | |||
| int idx_v = GetValue<int>(index_value); | |||
| if (idx_v < 0) { | |||
| MS_EXCEPTION(IndexError) << "The index of " << typeid(T).name() << " should be positive number, but got " << idx_v | |||
| << "."; | |||
| } | |||
| AbstractBasePtr InferImplHasType(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: a pointer to an AbstractBase object and a pointer to a Type | |||
| const std::string op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, 2); | |||
| AbstractTypePtr abs_type = CheckArg<AbstractType>(op_name, args_spec_list, 1); | |||
| size_t uidx_v = IntToSize(idx_v); | |||
| AbstractBasePtrList elements = queue->elements(); | |||
| std::size_t nelems = elements.size(); | |||
| if (uidx_v >= nelems) { | |||
| MS_EXCEPTION(IndexError) << op_name << " evaluator the index: " << uidx_v << " to set out of range: " << nelems - 1 | |||
| << "."; | |||
| auto mode_v = abs_type->GetValueTrack(); | |||
| MS_EXCEPTION_IF_NULL(mode_v); | |||
| if (!mode_v->isa<Type>()) { | |||
| MS_LOG(EXCEPTION) << "Get the type from AbstractType value failed."; | |||
| } | |||
| elements[uidx_v] = args_spec_list[2]; | |||
| return std::make_shared<T>(elements); | |||
| } | |||
| AbstractBasePtr InferImplTupleGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| return InferTupleOrListGetItem<AbstractTuple>(primitive->name(), args_spec_list); | |||
| TypePtr mode_t = mode_v->cast<TypePtr>(); | |||
| MS_EXCEPTION_IF_NULL(args_spec_list[0]); | |||
| bool v = IsSubtype(args_spec_list[0], mode_t); | |||
| return std::make_shared<AbstractScalar>(std::make_shared<BoolImm>(v), kBool); | |||
| } | |||
| AbstractBasePtr InferImplListGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| return InferTupleOrListGetItem<AbstractList>(primitive->name(), args_spec_list); | |||
| } | |||
| bool CompareShape(const std::vector<ValuePtr> &x_shape, const std::vector<ValuePtr> &y_shape) { | |||
| if (x_shape.size() != y_shape.size()) { | |||
| return false; | |||
| } | |||
| AbstractBasePtr InferImplTupleSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| return InferTupleOrListSetItem<AbstractTuple>(primitive->name(), args_spec_list); | |||
| } | |||
| for (size_t i = 0; i < x_shape.size(); ++i) { | |||
| if (GetValue<int>(x_shape[i]) != GetValue<int>(y_shape[i])) { | |||
| return false; | |||
| } | |||
| } | |||
| AbstractBasePtr InferImplListSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| return InferTupleOrListSetItem<AbstractList>(primitive->name(), args_spec_list); | |||
| return true; | |||
| } | |||
| AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: a dict and a scalar whose value is a string. | |||
| const std::string op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, 2); | |||
| AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0); | |||
| AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 1); | |||
| ValuePtr key_value = key->BuildValue(); | |||
| if (!key_value->isa<StringImm>()) { | |||
| MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString(); | |||
| AbstractBasePtr DoInferReduceShape(const AbstractTuplePtr &x_shape, const ValuePtr &x_shp_value, | |||
| const ValueTuplePtr &axis_value_ptr, const PrimitivePtr &primitive) { | |||
| size_t x_rank = x_shape->size(); | |||
| std::set<int> axis_set; | |||
| auto axis_data = axis_value_ptr->value(); | |||
| if (axis_data.empty()) { | |||
| int size = 1; | |||
| AbstractBasePtrList values(x_rank, std::make_shared<AbstractScalar>(size)); | |||
| return std::make_shared<AbstractTuple>(values); | |||
| } | |||
| auto key_str = GetValue<std::string>(key_value); | |||
| std::vector<AbstractAttribute> dict_elems = dict->elements(); | |||
| auto it = std::find_if(dict_elems.begin(), dict_elems.end(), | |||
| [key_str](const AbstractAttribute &item) { return item.first == key_str; }); | |||
| if (it == dict_elems.end()) { | |||
| MS_LOG(EXCEPTION) << "The key " << key_str << " does not exist in the dict:" << args_spec_list[0]->ToString(); | |||
| for (auto &elem : axis_data) { | |||
| int e_value = CheckAxis(primitive->name(), elem, -SizeToInt(x_rank), SizeToInt(x_rank) - 1); | |||
| (void)axis_set.insert(e_value); | |||
| } | |||
| return it->second; | |||
| } | |||
| AbstractBasePtr InferImplDictSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: a dict and a scalar whose value is a string and an object of a subclass of AbstractBase. | |||
| const std::string op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, 3); | |||
| AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0); | |||
| AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 1); | |||
| ValuePtr key_value = key->BuildValue(); | |||
| if (!key_value->isa<StringImm>()) { | |||
| MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString(); | |||
| } | |||
| std::string key_str = GetValue<std::string>(key_value); | |||
| std::vector<AbstractAttribute> dict_elems = dict->elements(); | |||
| auto it = std::find_if(dict_elems.begin(), dict_elems.end(), | |||
| [key_str](AbstractAttribute &item) { return item.first == key_str; }); | |||
| MS_EXCEPTION_IF_NULL(args_spec_list[2]); | |||
| auto new_ele = std::make_pair(key_str, args_spec_list[2]); | |||
| if (it != dict_elems.end()) { | |||
| int index = it - dict_elems.begin(); | |||
| dict_elems[IntToSize(index)] = new_ele; | |||
| } else { | |||
| dict_elems.push_back(new_ele); | |||
| auto x_shp_data = x_shp_value->cast<ValueTuplePtr>()->value(); | |||
| if (x_shp_data.size() < x_rank) { | |||
| MS_LOG(EXCEPTION) << "x_shape_data.size() " << x_shp_data.size() << " less than x_shape.size() " << x_rank; | |||
| } | |||
| AbstractBasePtrList values; | |||
| for (size_t i = 0; i < x_rank; i++) { | |||
| if (axis_set.count(SizeToInt(i)) || axis_set.count(SizeToInt(i) - SizeToInt(x_rank))) { | |||
| auto axis_v = MakeValue(1); | |||
| values.push_back(std::make_shared<AbstractScalar>(axis_v, axis_v->type())); | |||
| } else { | |||
| int dim_value = x_shp_data[i]->cast<Int32ImmPtr>()->value(); | |||
| auto dim = MakeValue(dim_value); | |||
| values.push_back(std::make_shared<AbstractScalar>(dim, dim->type())); | |||
| } | |||
| } | |||
| return std::make_shared<AbstractDictionary>(dict_elems); | |||
| return std::make_shared<AbstractTuple>(values); | |||
| } | |||
| AbstractBasePtr InferImplListAppend(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: a list and an object of a subclass of AbstractBase. | |||
| AbstractBasePtr InferImplBroadcastGradientArgs(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // this primitive get the index that need to reduce | |||
| // input: x's shape and y's shape, inputs should be tuple | |||
| // output: tuple of x and y 's reduce index, reduce index should be a tuple | |||
| const std::string op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, 2); | |||
| AbstractListPtr list = CheckArg<AbstractList>(op_name, args_spec_list, 0); | |||
| (void)AbstractJoin(list->elements()); | |||
| return list; | |||
| } | |||
| auto arg_x = CheckArg<AbstractTuple>(op_name, args_spec_list, 0); | |||
| auto arg_y = CheckArg<AbstractTuple>(op_name, args_spec_list, 1); | |||
| template <typename T> | |||
| AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: a tuple or list or dict. | |||
| CheckArgsSize(op_name, args_spec_list, 1); | |||
| auto arg = CheckArg<T>(op_name, args_spec_list, 0); | |||
| return std::make_shared<AbstractScalar>(SizeToInt(arg->size())); | |||
| } | |||
| ValueTuplePtr arg_x_value = arg_x->BuildValue()->cast<ValueTuplePtr>(); | |||
| MS_EXCEPTION_IF_NULL(arg_x_value); | |||
| AbstractBasePtr InferImplTupleLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| return InferTupleOrListOrDictLen<AbstractTuple>(primitive->name(), args_spec_list); | |||
| } | |||
| ValueTuplePtr arg_y_value = arg_y->BuildValue()->cast<ValueTuplePtr>(); | |||
| MS_EXCEPTION_IF_NULL(arg_y_value); | |||
| AbstractBasePtr InferImplListLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| return InferTupleOrListOrDictLen<AbstractList>(primitive->name(), args_spec_list); | |||
| } | |||
| const std::vector<ValuePtr> x_shape = arg_x_value->value(); | |||
| const std::vector<ValuePtr> y_shape = arg_y_value->value(); | |||
| bool is_same_shape = CompareShape(x_shape, y_shape); | |||
| // if it is the same shape , do not need reduce , return empty tuple | |||
| if (is_same_shape) { | |||
| AbstractBasePtrList empty_list; | |||
| auto x_reduce_idx = std::make_shared<AbstractTuple>(empty_list); | |||
| auto y_reduce_idx = std::make_shared<AbstractTuple>(empty_list); | |||
| AbstractBasePtr InferImplDictLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| return InferTupleOrListOrDictLen<AbstractDictionary>(primitive->name(), args_spec_list); | |||
| } | |||
| AbstractBasePtrList elem_list; | |||
| elem_list.push_back(x_reduce_idx); | |||
| elem_list.push_back(y_reduce_idx); | |||
| AbstractBasePtr InferImplArrayLen(const AnalysisEnginePtr &, const PrimitivePtr &, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| return std::make_shared<AbstractScalar>(kAnyValue, kInt32); | |||
| return std::make_shared<AbstractTuple>(elem_list); | |||
| } | |||
| return BroadcastGradientArgsDiff(x_shape, y_shape); | |||
| } | |||
| AbstractBasePtr InferImplListMap(const AnalysisEnginePtr &engine, const PrimitivePtr &primitive, | |||
| @@ -430,41 +334,6 @@ AbstractBasePtr InferImplTupleReversed(const AnalysisEnginePtr &, const Primitiv | |||
| return std::make_shared<AbstractTuple>(elem_list); | |||
| } | |||
| AbstractBasePtr DoInferReduceShape(const AbstractTuplePtr &x_shape, const ValuePtr &x_shp_value, | |||
| const ValueTuplePtr &axis_value_ptr, const PrimitivePtr &primitive) { | |||
| size_t x_rank = x_shape->size(); | |||
| std::set<int> axis_set; | |||
| auto axis_data = axis_value_ptr->value(); | |||
| if (axis_data.empty()) { | |||
| int size = 1; | |||
| AbstractBasePtrList values(x_rank, std::make_shared<AbstractScalar>(size)); | |||
| return std::make_shared<AbstractTuple>(values); | |||
| } | |||
| for (auto &elem : axis_data) { | |||
| int e_value = CheckAxis(primitive->name(), elem, -SizeToInt(x_rank), SizeToInt(x_rank) - 1); | |||
| (void)axis_set.insert(e_value); | |||
| } | |||
| auto x_shp_data = x_shp_value->cast<ValueTuplePtr>()->value(); | |||
| if (x_shp_data.size() < x_rank) { | |||
| MS_LOG(EXCEPTION) << "x_shape_data.size() " << x_shp_data.size() << " less than x_shape.size() " << x_rank; | |||
| } | |||
| AbstractBasePtrList values; | |||
| for (size_t i = 0; i < x_rank; i++) { | |||
| if (axis_set.count(SizeToInt(i)) || axis_set.count(SizeToInt(i) - SizeToInt(x_rank))) { | |||
| auto axis_v = MakeValue(1); | |||
| values.push_back(std::make_shared<AbstractScalar>(axis_v, axis_v->type())); | |||
| } else { | |||
| int dim_value = x_shp_data[i]->cast<Int32ImmPtr>()->value(); | |||
| auto dim = MakeValue(dim_value); | |||
| values.push_back(std::make_shared<AbstractScalar>(dim, dim->type())); | |||
| } | |||
| } | |||
| return std::make_shared<AbstractTuple>(values); | |||
| } | |||
| AbstractBasePtr InferImplReduceShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: x_shape, axis | |||
| @@ -563,7 +432,7 @@ AbstractBasePtr InferImplTuple2Array(const AnalysisEnginePtr &, const PrimitiveP | |||
| py::tuple data_tuple = ValuePtrToPyData(input->BuildValue()); | |||
| py::array data = py::array(data_tuple); | |||
| auto tensor = TensorPy::MakeTensor(data); | |||
| auto tensor = tensor::TensorPy::MakeTensor(data); | |||
| auto ret = tensor->ToAbstract(); | |||
| ret->set_value(tensor); | |||
| MS_LOG(DEBUG) << "Tuple2arry result AbstractTensor: " << ret->ToString(); | |||
| @@ -596,76 +465,6 @@ AbstractBasePtr InferImplShapeMul(const AnalysisEnginePtr &, const PrimitivePtr | |||
| return std::make_shared<AbstractScalar>(result_v, result_v->type()); | |||
| } | |||
| template <typename T> | |||
| AbstractBasePtr InferImplTupleOrListEqual(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: two tuples or two lists. | |||
| CheckArgsSize(op_name, args_spec_list, 2); | |||
| auto input_x = CheckArg<T>(op_name, args_spec_list, 0); | |||
| auto input_y = CheckArg<T>(op_name, args_spec_list, 1); | |||
| ValuePtr x_value = input_x->BuildValue(); | |||
| ValuePtr y_value = input_y->BuildValue(); | |||
| return std::make_shared<AbstractScalar>(*x_value == *y_value); | |||
| } | |||
| AbstractBasePtr InferImplTupleEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| return InferImplTupleOrListEqual<AbstractTuple>(primitive->name(), args_spec_list); | |||
| } | |||
| AbstractBasePtr InferImplListEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| return InferImplTupleOrListEqual<AbstractList>(primitive->name(), args_spec_list); | |||
| } | |||
| struct SlideInfo { | |||
| int start; | |||
| int step; | |||
| int stop; | |||
| }; | |||
| void CalcSlidePara(const AbstractBasePtrList &args_spec_list, SlideInfo *slide) { | |||
| int arg1 = 0; | |||
| int arg2 = 0; | |||
| if (!args_spec_list.empty()) { | |||
| MS_EXCEPTION_IF_NULL(args_spec_list[0]); | |||
| auto arg_value = args_spec_list[0]->BuildValue(); | |||
| if (!arg_value->isa<Int32Imm>()) { | |||
| MS_LOG(EXCEPTION) << "Only supported input an int32 number."; | |||
| } | |||
| arg1 = GetValue<int>(arg_value); | |||
| } | |||
| if (args_spec_list.size() >= 2) { | |||
| MS_EXCEPTION_IF_NULL(args_spec_list[1]); | |||
| auto arg_value = args_spec_list[1]->BuildValue(); | |||
| if (!arg_value->isa<Int32Imm>()) { | |||
| MS_LOG(EXCEPTION) << "Only supported input an int32 number."; | |||
| } | |||
| arg2 = GetValue<int>(arg_value); | |||
| } | |||
| if (args_spec_list.size() == 3) { | |||
| MS_EXCEPTION_IF_NULL(args_spec_list[2]); | |||
| auto arg_value = args_spec_list[2]->BuildValue(); | |||
| if (!arg_value->isa<Int32Imm>()) { | |||
| MS_LOG(EXCEPTION) << "Only supported input an int32 number."; | |||
| } | |||
| slide->step = GetValue<int>(arg_value); | |||
| slide->start = arg1; | |||
| slide->stop = arg2; | |||
| } | |||
| if (args_spec_list.size() == 2) { | |||
| slide->start = arg1; | |||
| slide->stop = arg2; | |||
| } | |||
| if (args_spec_list.size() == 1) { | |||
| slide->stop = arg1; | |||
| } | |||
| } | |||
| AbstractBasePtr InferImplMakeRange(const AnalysisEnginePtr &, const PrimitivePtr &, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| if (args_spec_list.empty()) { | |||
| @@ -709,5 +508,145 @@ AbstractBasePtr InferImplStopGradient(const AnalysisEnginePtr &, const Primitive | |||
| CheckArgsSize(primitive->name(), args_spec_list, 1); | |||
| return args_spec_list[0]->Clone(); | |||
| } | |||
| AbstractBasePtr InferImplTupleEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| return InferImplTupleOrListEqual<AbstractTuple>(primitive->name(), args_spec_list); | |||
| } | |||
| AbstractBasePtr InferImplListEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| return InferImplTupleOrListEqual<AbstractList>(primitive->name(), args_spec_list); | |||
| } | |||
| AbstractBasePtr InferImplStringEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: two scalars whose value is a string. | |||
| const std::string op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, 2); | |||
| AbstractScalarPtr scalar_x = CheckArg<AbstractScalar>(op_name, args_spec_list, 0); | |||
| AbstractScalarPtr scalar_y = CheckArg<AbstractScalar>(op_name, args_spec_list, 1); | |||
| ValuePtr value_x = scalar_x->BuildValue(); | |||
| ValuePtr value_y = scalar_y->BuildValue(); | |||
| if (!value_x->isa<StringImm>() || !value_y->isa<StringImm>()) { | |||
| MS_LOG(EXCEPTION) << op_name << " requires 2 parameters are string, but got param0: " << value_x->ToString() | |||
| << ", param1: " << value_y->ToString(); | |||
| } | |||
| bool ret = (value_x->cast<StringImmPtr>()->value() == value_y->cast<StringImmPtr>()->value()); | |||
| return std::make_shared<AbstractScalar>(ret); | |||
| } | |||
| AbstractBasePtr InferImplStringConcat(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: two scalars whose value is a string. | |||
| const std::string op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, 2); | |||
| AbstractScalarPtr scalar_x = CheckArg<AbstractScalar>(op_name, args_spec_list, 0); | |||
| AbstractScalarPtr scalar_y = CheckArg<AbstractScalar>(op_name, args_spec_list, 1); | |||
| ValuePtr value_x = scalar_x->BuildValue(); | |||
| ValuePtr value_y = scalar_y->BuildValue(); | |||
| if (!value_x->isa<StringImm>() || !value_y->isa<StringImm>()) { | |||
| MS_LOG(EXCEPTION) << op_name << " requires 2 parameters are string, but got param0: " << value_x->ToString() | |||
| << ", param1: " << value_y->ToString(); | |||
| } | |||
| std::string ret = (value_x->cast<StringImmPtr>()->value() + value_y->cast<StringImmPtr>()->value()); | |||
| return std::make_shared<AbstractScalar>(ret); | |||
| } | |||
| AbstractBasePtr InferImplDictLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| return InferTupleOrListOrDictLen<AbstractDictionary>(primitive->name(), args_spec_list); | |||
| } | |||
| AbstractBasePtr InferImplJ(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // args: An object of AbstractFunction. | |||
| CheckArgsSize(primitive->name(), args_spec_list, 1); | |||
| MS_LOG(DEBUG) << "evaluate J: " << args_spec_list[0]->ToString(); | |||
| AbstractFunctionPtr x = dyn_cast<AbstractFunction>(args_spec_list[0]); | |||
| if (x == nullptr) { | |||
| return std::make_shared<AbstractJTagged>(args_spec_list[0]); | |||
| } | |||
| AbstractFuncAtomPtrList jv; | |||
| auto build_jv = [&jv](const AbstractFuncAtomPtr &func) { | |||
| auto j_closure = std::make_shared<JTransformedAbstractClosure>(func); | |||
| jv.push_back(j_closure); | |||
| }; | |||
| x->Visit(build_jv); | |||
| return AbstractFunction::MakeAbstractFunction(jv); | |||
| } | |||
| AbstractBasePtr InferImplFakeBprop(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: a tensor. | |||
| CheckArgsSize(primitive->name(), args_spec_list, 1); | |||
| return args_spec_list[0]->Broaden(); | |||
| } | |||
| // Eval the return type of make_record | |||
| AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: at lease two objects of a subclass of AbstractBase. | |||
| if (args_spec_list.size() < 2) { | |||
| MS_LOG(EXCEPTION) << "Typeof evaluator requires more than 1 parameter, while the input size is " | |||
| << args_spec_list.size() << "."; | |||
| } | |||
| // args_spec_list[0] maybe AbstractScalarPtr or AbstractTypePtr | |||
| MS_EXCEPTION_IF_NULL(args_spec_list[0]); | |||
| TypePtr type = args_spec_list[0]->GetTypeTrack(); | |||
| MS_EXCEPTION_IF_NULL(type); | |||
| if (type->type_id() != kMetaTypeTypeType) { | |||
| MS_LOG(EXCEPTION) << "Can not make type(" << type->ToString() << ")not TypeType"; | |||
| } | |||
| ValuePtr value_track = args_spec_list[0]->GetValueTrack(); | |||
| MS_EXCEPTION_IF_NULL(value_track); | |||
| TypePtr type_ptr = value_track->cast<TypePtr>(); | |||
| if (type_ptr == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Value type error, not Me type:" << value_track->ToString(); | |||
| } | |||
| auto cls = dyn_cast<Class>(type_ptr); | |||
| MS_EXCEPTION_IF_NULL(cls); | |||
| ClassAttrVector attributes = cls->GetAttributes(); | |||
| CheckArgsSize(primitive->name(), args_spec_list, attributes.size() + 1); | |||
| std::vector<AbstractAttribute> abs_attributes; | |||
| for (size_t i = 0; i < attributes.size(); i++) { | |||
| AbstractAttribute elem(attributes[i].first, args_spec_list[i + 1]); | |||
| abs_attributes.push_back(elem); | |||
| } | |||
| return std::make_shared<AbstractClass>(cls->tag(), abs_attributes, cls->methods()); | |||
| } | |||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(TypeOf, prim::kPrimTypeOf, InferImplTypeof); | |||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(HasType, prim::kPrimHasType, InferImplHasType); | |||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(MakeRecord, prim::kPrimMakeRecord, InferImplMakeRecord); | |||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(ListMap, prim::kPrimListMap, InferImplListMap); | |||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(ListReduce, prim::kPrimListReduce, InferImplListReduce); | |||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(TupleReversed, prim::kPrimTupleReversed, InferImplTupleReversed); | |||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(ReducedShape, prim::kPrimReducedShape, InferImplReduceShape); | |||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(TupleDiv, prim::kPrimTupleDiv, InferImplTupleDiv); | |||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(TupleToArray, prim::kPrimTupleToArray, InferImplTuple2Array); | |||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(ShapeMul, prim::kPrimShapeMul, InferImplShapeMul); | |||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(TupleEqual, prim::kPrimTupleEqual, InferImplTupleEqual); | |||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(ListEqual, prim::kPrimListEqual, InferImplListEqual); | |||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(MakeRange, prim::kPrimMakeRange, InferImplMakeRange); | |||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(StopGradient, prim::kPrimStopGradient, InferImplStopGradient); | |||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(StringEqual, prim::kPrimStringEqual, InferImplStringEqual); | |||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(StringConcat, prim::kPrimStringConcat, InferImplStringConcat); | |||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(DictLen, prim::kPrimDictLen, InferImplDictLen); | |||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(FakeBprop, prim::kPrimFakeBprop, InferImplFakeBprop); | |||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(J, prim::kPrimJ, InferImplJ); | |||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(BroadcastGradientArgs, prim::kPrimBroadcastGradientArgs, | |||
| InferImplBroadcastGradientArgs); | |||
| } // namespace abstract | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,77 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_FRONTEND_OPERATE_OPS_FRONT_INFER_FUNCTION_H_ | |||
| #define MINDSPORE_CCSRC_FRONTEND_OPERATE_OPS_FRONT_INFER_FUNCTION_H_ | |||
| #include "abstract/abstract_value.h" | |||
| #include "abstract/primitive_infer_map.h" | |||
| namespace mindspore { | |||
| namespace abstract { | |||
| AbstractBasePtr InferImplTypeof(const AnalysisEnginePtr &, const PrimitivePtr &, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplHasType(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplBroadcastGradientArgs(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplListMap(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplListReduce(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplTupleReversed(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplReduceShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplTupleDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplTuple2Array(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplShapeMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplTupleEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplListEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplMakeRange(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplStopGradient(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplStringEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplStringConcat(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplDictLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplJ(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplFakeBprop(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| class RegisterFrontendPrimitiveEvalHelper { | |||
| public: | |||
| RegisterFrontendPrimitiveEvalHelper(const PrimitivePtr &primitive, const StandardPrimitiveEvalImpl &impl) { | |||
| const StandardPrimitiveImplReg impl_reg{impl, false}; | |||
| RegisterStandardPrimitiveImpl(primitive, impl_reg); | |||
| } | |||
| ~RegisterFrontendPrimitiveEvalHelper() = default; | |||
| }; | |||
| #define REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(name, primitive, impl) \ | |||
| static auto helper_##name = RegisterFrontendPrimitiveEvalHelper(primitive, impl) | |||
| } // namespace abstract | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_FRONTEND_OPERATE_OPS_FRONT_INFER_FUNCTION_H_ | |||
| @@ -36,115 +36,12 @@ | |||
| #include "utils/convert_utils.h" | |||
| #include "utils/ms_context.h" | |||
| #include "pipeline/jit/parse/data_converter.h" | |||
| #include "abstract/primitive_infer_map.h" | |||
| #include "abstract/param_validator.h" | |||
| #include "utils/ms_utils.h" | |||
| namespace mindspore { | |||
| namespace abstract { | |||
| PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||
| static PrimitiveEvalImplMap prim_eval_implement_map = { | |||
| // Statements | |||
| {prim::kPrimReturn, {InferImplReturn, true}}, | |||
| {prim::kPrimTypeOf, {InferImplTypeof, false}}, | |||
| {prim::kPrimHasType, {InferImplHasType, false}}, | |||
| {prim::kPrimDot, {InferImplDot, true}}, | |||
| {prim::kPrimSwitch, {InferImplSwitch, true}}, | |||
| {prim::kPrimSwitchLayer, {InferImplSwitchLayer, true}}, | |||
| {prim::kPrimIs_, {InferImplIs_, true}}, | |||
| {prim::kPrimIsNot, {InferImplIsNot, true}}, | |||
| {prim::kPrimInDict, {InferImplInDict, true}}, | |||
| {prim::kPrimNotInDict, {InferImplNotInDict, true}}, | |||
| {prim::kPrimIsConsant, {InferImplIsConstant, true}}, | |||
| // Maths | |||
| {prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, true}}, | |||
| {prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, true}}, | |||
| // Array | |||
| {prim::kPrimScalarToArray, {InferImplScalarToArray, true}}, | |||
| {prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}}, | |||
| {prim::kPrimBroadcastShape, {InferImplBroadCastShape, true}}, | |||
| {prim::kPrimPack, {InferImplPack, true}}, | |||
| {prim::kPrimUnique, {InferImplUnique, true}}, | |||
| {prim::kPrimUniqueGrad, {InferImplUniqueGrad, true}}, | |||
| // Structure | |||
| {prim::kPrimMakeTuple, {InferImplMakeTuple, true}}, | |||
| {prim::kPrimMakeList, {InferImplMakeList, true}}, | |||
| {prim::kPrimMakeDict, {InferImplMakeDict, true}}, | |||
| {prim::kPrimMakeSlice, {InferImplMakeSlice, true}}, | |||
| {prim::kPrimMakeKeywordArg, {InferImplMakeKwarg, true}}, | |||
| {prim::kPrimExtractKeywordArg, {InferImplExtractKwarg, true}}, | |||
| {prim::kPrimMakeRecord, {InferImplMakeRecord, false}}, | |||
| {prim::kPrimTupleGetItem, {InferImplTupleGetItem, true}}, | |||
| {prim::kPrimListGetItem, {InferImplListGetItem, true}}, | |||
| {prim::kPrimTupleSetItem, {InferImplTupleSetItem, true}}, | |||
| {prim::kPrimListSetItem, {InferImplListSetItem, true}}, | |||
| {prim::kPrimDictGetItem, {InferImplDictGetItem, true}}, | |||
| {prim::kPrimDictSetItem, {InferImplDictSetItem, true}}, | |||
| {prim::kPrimListAppend, {InferImplListAppend, true}}, | |||
| {prim::kPrimTupleLen, {InferImplTupleLen, true}}, | |||
| {prim::kPrimListLen, {InferImplListLen, true}}, | |||
| {prim::kPrimArrayLen, {InferImplArrayLen, true}}, | |||
| {prim::kPrimListMap, {InferImplListMap, false}}, | |||
| {prim::kPrimListReduce, {InferImplListReduce, false}}, | |||
| {prim::kPrimTupleReversed, {InferImplTupleReversed, false}}, | |||
| {prim::kPrimReducedShape, {InferImplReduceShape, false}}, | |||
| {prim::kPrimTupleDiv, {InferImplTupleDiv, false}}, | |||
| {prim::kPrimTupleToArray, {InferImplTuple2Array, false}}, | |||
| {prim::kPrimShapeMul, {InferImplShapeMul, false}}, | |||
| {prim::kPrimTupleEqual, {InferImplTupleEqual, false}}, | |||
| {prim::kPrimListEqual, {InferImplListEqual, false}}, | |||
| {prim::kPrimMakeRange, {InferImplMakeRange, false}}, | |||
| {prim::kPrimStopGradient, {InferImplStopGradient, false}}, | |||
| {prim::kPrimStringEqual, {InferImplStringEqual, false}}, | |||
| {prim::kPrimStringConcat, {InferImplStringConcat, false}}, | |||
| {prim::kPrimDictLen, {InferImplDictLen, false}}, | |||
| // NN | |||
| {prim::kPrimPooling, {InferImplPooling, true}}, | |||
| {prim::kPrimPoolingGrad, {InferImplPoolingGrad, true}}, | |||
| {prim::kPrimFusedBatchNorm, {InferImplFusedBatchNorm, true}}, | |||
| {prim::kPrimFusedBatchNormGrad, {InferImplFusedBatchNormGrad, true}}, | |||
| {prim::kPrimReluGrad, {InferImplReluGrad, true}}, | |||
| {prim::kPrimConv2DBackpropInput, {InferImplConv2DBackpropInput, true}}, | |||
| {prim::kPrimConv2DBackpropFilter, {InferImplConv2DBackpropFilter, true}}, | |||
| {prim::kPrimBiasAddGrad, {InferImplBiasAddGrad, true}}, | |||
| {prim::kPrimRelu, {InferImplRelu, true}}, | |||
| {prim::kPrimFakeBprop, {InferImplFakeBprop, false}}, | |||
| {prim::kPrimZerosLike, {InferImplZerosLike, true}}, | |||
| {prim::kPrimBpropCut, {InferImplBpropCut, true}}, | |||
| {prim::kPrimLayerNorm, {InferImplLayerNorm, true}}, | |||
| {prim::kPrimLayerNormGrad, {InferImplLayerNormGrad, true}}, | |||
| {prim::kPrimDropoutGenMask, {InferImplDropoutGenMask, true}}, | |||
| // Others | |||
| {prim::kPrimIdentity, {InferImplIdentity, true}}, | |||
| // Set impl to null as it will use PartialEvaluator; | |||
| {prim::kPrimPartial, {nullptr, true}}, | |||
| {prim::kPrimJ, {InferImplJ, false}}, | |||
| {prim::kPrimEnvGetItem, {InferImplEnvGetItem, true}}, | |||
| {prim::kPrimEnvSetItem, {InferImplEnvSetItem, true}}, | |||
| {prim::kPrimEnvAdd, {InferImplEnvAdd, true}}, | |||
| {prim::kPrimMakeRefKey, {InferImplMakeRefKey, true}}, | |||
| {prim::kPrimMakeRef, {InferImplMakeRef, true}}, | |||
| {prim::kPrimGetRefKey, {InferImplGetRefKey, true}}, | |||
| {prim::kPrimGetRefValue, {InferImplGetRefValue, true}}, | |||
| {prim::kPrimStateSetItem, {InferImplStateSetItem, true}}, | |||
| {prim::kPrimDepend, {InferImplDepend, true}}, | |||
| {prim::kPrimBroadcastGradientArgs, {InferImplBroadcastGradientArgs, false}}, | |||
| {prim::kPrimControlDepend, {InferImplControlDepend, true}}, | |||
| // Debug | |||
| {prim::kPrimDebug, {InferImplDebug, true}}, | |||
| // RowTensor | |||
| {prim::kPrimMakeRowTensor, {InferImplMakeRowTensor, true}}, | |||
| {prim::kPrimRowTensorGetValues, {InferImplRowTensorGetValues, true}}, | |||
| {prim::kPrimRowTensorGetIndices, {InferImplRowTensorGetIndices, true}}, | |||
| {prim::kPrimRowTensorGetDenseShape, {InferImplRowTensorGetDenseShape, true}}, | |||
| // SparseTensor | |||
| {prim::kPrimMakeSparseTensor, {InferImplMakeSparseTensor, true}}, | |||
| {prim::kPrimSparseTensorGetValues, {InferImplSparseTensorGetValues, true}}, | |||
| {prim::kPrimSparseTensorGetIndices, {InferImplSparseTensorGetIndices, true}}, | |||
| {prim::kPrimSparseTensorGetDenseShape, {InferImplSparseTensorGetDenseShape, true}}, | |||
| }; | |||
| return prim_eval_implement_map; | |||
| } | |||
| using mindspore::parse::PyObjectWrapper; | |||
| std::unordered_set<std::string> prims_to_skip_undetermined_infer{"make_tuple", "make_list", "switch", "env_setitem", | |||
| @@ -26,19 +26,10 @@ | |||
| #include <vector> | |||
| #include "pipeline/jit/static_analysis/evaluator.h" | |||
| #include "abstract/primitive_infer_map.h" | |||
| namespace mindspore { | |||
| namespace abstract { | |||
| using StandardPrimitiveEvalImpl = AbstractBasePtr (*)(const AnalysisEnginePtr &, const PrimitivePtr &, | |||
| const AbstractBasePtrList &); | |||
| struct StandartPrimitiveImplReg { | |||
| StandardPrimitiveEvalImpl impl_; // Implement function of Primitive. | |||
| bool in_white_list_; // true if this Primitive in white list, else false. | |||
| }; | |||
| using PrimitiveEvalImplMap = | |||
| std::unordered_map<PrimitivePtr, StandartPrimitiveImplReg, PrimitiveHasher, PrimitiveEqual>; | |||
| class StandardPrimEvaluator : public TrivialPrimEvaluator { | |||
| public: | |||
| StandardPrimEvaluator(const PrimitivePtr primitive, StandardPrimitiveEvalImpl eval_impl) | |||
| @@ -179,191 +170,6 @@ bool IsSubtype(const AbstractBasePtr x, const TypePtr model); | |||
| void ClearPrimEvaluatorMap(); | |||
| py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base); | |||
| AbstractBasePtr InferImplReturn(const AnalysisEnginePtr &, const PrimitivePtr &, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplTypeof(const AnalysisEnginePtr &, const PrimitivePtr &, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplHasType(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplDot(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitivePtr &, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplIs_(const AnalysisEnginePtr &, const PrimitivePtr &, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplIsNot(const AnalysisEnginePtr &, const PrimitivePtr &, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplInDict(const AnalysisEnginePtr &, const PrimitivePtr &, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplNotInDict(const AnalysisEnginePtr &, const PrimitivePtr &, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplIsConstant(const AnalysisEnginePtr &, const PrimitivePtr &, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplPooling(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplPoolingGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplFusedBatchNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplFusedBatchNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplReluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplConv2DBackpropInput(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplConv2DBackpropFilter(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplBiasAddGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplRelu(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplFakeBprop(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplZerosLike(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplBpropCut(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplLayerNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplLayerNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplDropoutGenMask(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplMinOrMaxGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplScalarToArray(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplArrayToScalar(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplBroadCastShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplPack(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplUniqueGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplMakeTuple(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplMakeList(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplMakeSlice(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplMakeKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplExtractKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplTupleGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplListGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplTupleSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplListSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplDictSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplListAppend(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplTupleLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplListLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplArrayLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplListMap(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplListReduce(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplTupleReversed(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplReduceShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplTupleDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplTuple2Array(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplShapeMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplGenShapeIndex(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplGenInverseIndex(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplTupleEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplListEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplMakeRange(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplStopGradient(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplStringEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplStringConcat(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplDictLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplIdentity(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplJ(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplEnvSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplEnvAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplMakeRefKey(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplMakeRef(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplGetRefKey(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplGetRefValue(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplGetRefOrigin(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplStateSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplBroadcastGradientArgs(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplDebug(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplMakeRowTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplRowTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplRowTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplRowTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplMakeSparseTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplSparseTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplSparseTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplSparseTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| } // namespace abstract | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,187 @@ | |||
| /** | |||
| * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). | |||
| * | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_ABSTRACT_INFER_FUNCTIONS_H_ | |||
| #define MINDSPORE_CORE_ABSTRACT_INFER_FUNCTIONS_H_ | |||
| #include <string> | |||
| #include <memory> | |||
| #include "abstract/abstract_value.h" | |||
| #include "abstract/param_validator.h" | |||
| #include "base/core_ops.h" | |||
| namespace mindspore { | |||
| namespace abstract { | |||
| AbstractBasePtr InferImplReturn(const AnalysisEnginePtr &, const PrimitivePtr &, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplDot(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitivePtr &, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplIs_(const AnalysisEnginePtr &, const PrimitivePtr &, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplIsNot(const AnalysisEnginePtr &, const PrimitivePtr &, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplInDict(const AnalysisEnginePtr &, const PrimitivePtr &, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplNotInDict(const AnalysisEnginePtr &, const PrimitivePtr &, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplIsConstant(const AnalysisEnginePtr &, const PrimitivePtr &, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplPooling(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplPoolingGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplFusedBatchNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplFusedBatchNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplReluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplConv2DBackpropInput(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplConv2DBackpropFilter(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplBiasAddGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplGelu(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplGeluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplRelu(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplZerosLike(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplBpropCut(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplLayerNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplLayerNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplDropoutGenMask(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplMinOrMaxGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplScalarToArray(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplArrayToScalar(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplBroadCastShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplPack(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplMakeTuple(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplMakeList(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplMakeSlice(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplMakeKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplExtractKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplTupleGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplListGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplTupleSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplListSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplDictSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplListAppend(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplTupleLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplListLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplArrayLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplGenShapeIndex(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplGenInverseIndex(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplIdentity(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplEnvSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplEnvAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplMakeRefKey(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplMakeRef(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplGetRefKey(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplGetRefValue(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplGetRefOrigin(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplStateSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplDebug(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplMakeSparseTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplSparseTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplSparseTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplSparseTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplMakeRowTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplRowTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplRowTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplRowTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplUniqueGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| template <typename T> | |||
| AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: a tuple or list or dict. | |||
| CheckArgsSize(op_name, args_spec_list, 1); | |||
| auto arg = CheckArg<T>(op_name, args_spec_list, 0); | |||
| return std::make_shared<AbstractScalar>(SizeToInt(arg->size())); | |||
| } | |||
| } // namespace abstract | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_ABSTRACT_INFER_FUNCTIONS_H_ | |||
| @@ -14,13 +14,48 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "pipeline/jit/static_analysis/prim.h" | |||
| #include "abstract/infer_functions.h" | |||
| #include "abstract/utils.h" | |||
| #include "frontend/operator/cc_implementations.h" | |||
| #include "abstract/param_validator.h" | |||
| namespace mindspore { | |||
| namespace abstract { | |||
| namespace { | |||
| std::vector<int> BroadcastShape(std::vector<int> shpx, std::vector<int> shpy) { | |||
| int dlen = SizeToInt(shpx.size()) - SizeToInt(shpy.size()); | |||
| if (dlen < 0) { | |||
| for (int i = 0; i < -dlen; ++i) { | |||
| (void)shpx.insert(shpx.begin(), 1); | |||
| } | |||
| } else if (dlen > 0) { | |||
| for (int i = 0; i < dlen; i++) { | |||
| (void)shpy.insert(shpy.begin(), 1); | |||
| } | |||
| } | |||
| if (shpx.size() != shpy.size()) { | |||
| MS_LOG(EXCEPTION) << "Failure: shpx.size() != shpy.size()."; | |||
| } | |||
| std::vector<int> shp; | |||
| for (size_t i = 0; i < shpx.size(); i++) { | |||
| auto a = shpx[i]; | |||
| auto b = shpy[i]; | |||
| if (a == 1) { | |||
| shp.push_back(b); | |||
| } else if (b == 1) { | |||
| shp.push_back(a); | |||
| } else if (a == -1) { | |||
| shp.push_back(b); | |||
| } else if (b == -1) { | |||
| shp.push_back(a); | |||
| } else if (a == b) { | |||
| shp.push_back(a); | |||
| } else { | |||
| return std::vector<int>(); | |||
| } | |||
| } | |||
| return shp; | |||
| } | |||
| } // namespace | |||
| AbstractBasePtr InferImplScalarToArray(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: a scalar. | |||
| @@ -65,7 +100,7 @@ AbstractBasePtr InferImplBroadCastShape(const AnalysisEnginePtr &, const Primiti | |||
| (void)std::transform(std::begin(shp_tuple_y), std::end(shp_tuple_y), std::back_inserter(shp_y), | |||
| [](const ValuePtr &e) -> int { return GetValue<int>(e); }); | |||
| std::vector<int> res = prim::BroadcastShape_(shp_x, shp_y); | |||
| std::vector<int> res = BroadcastShape(shp_x, shp_y); | |||
| if (res.empty()) { | |||
| MS_LOG(EXCEPTION) << "BroadcastShape fail: " << args_spec_list[0]->ToString() << "," | |||
| << args_spec_list[1]->ToString(); | |||
| @@ -15,8 +15,7 @@ | |||
| */ | |||
| #include "abstract/param_validator.h" | |||
| #include "pipeline/jit/static_analysis/prim.h" | |||
| #include "frontend/operator/ops.h" | |||
| #include "abstract/infer_functions.h" | |||
| #include "abstract/utils.h" | |||
| #include "utils/symbolic.h" | |||
| @@ -14,8 +14,7 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "pipeline/jit/static_analysis/prim.h" | |||
| #include "frontend/operator/ops.h" | |||
| #include "abstract/infer_functions.h" | |||
| #include "abstract/utils.h" | |||
| #include "abstract/param_validator.h" | |||
| #include "utils/ms_utils.h" | |||
| @@ -14,10 +14,12 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "pipeline/jit/static_analysis/prim.h" | |||
| #include "frontend/operator/ops.h" | |||
| #include "abstract/infer_functions.h" | |||
| #include "abstract/utils.h" | |||
| #include "abstract/param_validator.h" | |||
| #include "utils/check_convert_utils.h" | |||
| #include "c_ops/conv2d.h" | |||
| #include "abstract/primitive_infer_map.h" | |||
| namespace mindspore { | |||
| namespace abstract { | |||
| @@ -278,13 +280,6 @@ AbstractBasePtr InferImplZerosLike(const AnalysisEnginePtr &, const PrimitivePtr | |||
| return args_spec_list[0]->Broaden(); | |||
| } | |||
| AbstractBasePtr InferImplFakeBprop(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: a tensor. | |||
| CheckArgsSize(primitive->name(), args_spec_list, 1); | |||
| return args_spec_list[0]->Broaden(); | |||
| } | |||
| AbstractBasePtr InferImplBpropCut(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: a tensor. | |||
| @@ -433,5 +428,91 @@ AbstractBasePtr InferImplDropoutGenMask(const AnalysisEnginePtr &, const Primiti | |||
| return std::make_shared<AbstractTensor>(std::make_shared<AbstractScalar>(kAnyValue, kUInt8), | |||
| std::make_shared<Shape>(std::vector<int64_t>{shape_y})); | |||
| } | |||
| abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto conv_prim = primitive->cast<PrimConv2dPtr>(); | |||
| MS_EXCEPTION_IF_NULL(conv_prim); | |||
| auto prim_name = conv_prim->name(); | |||
| CheckAndConvertUtils::CheckInRange("Conv2d Infer", input_args.size(), kIncludeLeft, {2, 3}, prim_name); | |||
| auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[0]->GetShapeTrack(), prim_name); | |||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[1]->GetShapeTrack(), prim_name); | |||
| CheckAndConvertUtils::CheckInteger("weight rank", w_shape.size(), kEqual, 4, prim_name); | |||
| CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name); | |||
| CheckAndConvertUtils::Check("x_shape[1] / group", x_shape[1] / conv_prim->GetGroup(), kEqual, "w_shape[1]", | |||
| w_shape[1], conv_prim->name()); | |||
| auto out_channel = conv_prim->GetOutputChannel(); | |||
| CheckAndConvertUtils::Check("out_channel", out_channel, kEqual, "w_shape[0]", w_shape[0], conv_prim->name()); | |||
| std::vector<int> temp_w; | |||
| std::copy(w_shape.begin() + 2, w_shape.end(), std::back_inserter(temp_w)); | |||
| CheckAndConvertUtils::Check("kernel_size", conv_prim->GetKernelSize(), kEqual, "w_shape[2:4]", temp_w, | |||
| conv_prim->name()); | |||
| auto kernel_size_h = w_shape[2]; | |||
| auto kernel_size_w = w_shape[3]; | |||
| auto stride = conv_prim->GetStride(); | |||
| auto dilation = conv_prim->GetDilation(); | |||
| auto stride_h = stride[2]; | |||
| auto stride_w = stride[3]; | |||
| auto dilation_h = dilation[2]; | |||
| auto dilation_w = dilation[3]; | |||
| int h_out = -1; | |||
| int w_out = -1; | |||
| std::vector<int> pad_list(4, 0); | |||
| auto pad_mode = conv_prim->GetPadMode(); | |||
| if (pad_mode == "valid") { | |||
| h_out = ceil((x_shape[2] - dilation_h * (kernel_size_h - 1)) / stride_h); | |||
| w_out = ceil((x_shape[3] - dilation_w * (kernel_size_w - 1)) / stride_w); | |||
| } else if (pad_mode == "same") { | |||
| h_out = ceil(x_shape[2] / stride_h); | |||
| w_out = ceil(x_shape[3] / stride_w); | |||
| auto pad_needed_h = std::max(0, (h_out - 1) * stride_h + dilation_h * (kernel_size_h - 1) + 1 - x_shape[2]); | |||
| pad_list.emplace_back(floor(pad_needed_h / 2)); | |||
| pad_list.emplace_back(pad_needed_h / 2); | |||
| auto pad_needed_w = std::max(0, (w_out - 1) * stride_w + dilation_w * (kernel_size_w - 1) + 1 - x_shape[3]); | |||
| auto pad_left = floor(pad_needed_w / 2); | |||
| pad_list.emplace_back(pad_left); | |||
| pad_list.emplace_back(pad_needed_h - pad_left); | |||
| } else if (pad_mode == "pad") { | |||
| std::copy(conv_prim->GetPad().begin(), conv_prim->GetPad().end(), std::back_inserter(pad_list)); | |||
| auto pad_top = conv_prim->GetPad()[0]; | |||
| auto pad_bottom = conv_prim->GetPad()[1]; | |||
| auto pad_right = conv_prim->GetPad()[2]; | |||
| auto pad_left = conv_prim->GetPad()[3]; | |||
| h_out = 1 + (x_shape[2] + pad_top + pad_bottom - kernel_size_h - (kernel_size_h - 1) * (dilation_h - 1)) / stride_h; | |||
| w_out = 1 + (x_shape[3] + pad_left + pad_right - kernel_size_w - (kernel_size_w - 1) * (dilation_w - 1)) / stride_w; | |||
| h_out = floor(h_out); | |||
| w_out = floor(w_out); | |||
| } | |||
| conv_prim->SetPadList(pad_list); | |||
| std::vector<int> out_shape = {x_shape[0], out_channel, h_out, w_out}; | |||
| return std::make_shared<abstract::Shape>(out_shape); | |||
| } | |||
| TypePtr Conv2dInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | |||
| CheckAndConvertUtils::CheckInRange("", input_args.size(), kIncludeLeft, {2, 3}, prim->name()); | |||
| for (const auto &item : input_args) { | |||
| MS_EXCEPTION_IF_NULL(item); | |||
| } | |||
| auto x_type = CheckAndConvertUtils::ConvertTypePtrToTypeId("x_dtype", input_args[0]->GetTypeTrack(), prim->name()); | |||
| const std::set<TypeId> valid_types = {kNumberTypeInt8, kNumberTypeInt32, kNumberTypeFloat16, kNumberTypeFloat32}; | |||
| std::map<std::string, TypePtr> types; | |||
| types.emplace("x", input_args[0]->GetTypeTrack()); | |||
| types.emplace("w", input_args[1]->GetTypeTrack()); | |||
| CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); | |||
| if (x_type == kNumberTypeInt8) { | |||
| return std::make_shared<TensorType>(TypeIdToType(kNumberTypeInt32)); | |||
| } | |||
| return std::make_shared<TensorType>(TypeIdToType(x_type)); | |||
| } | |||
| AbstractBasePtr Conv2dInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| return std::make_shared<abstract::AbstractTensor>(Conv2dInferType(primitive, input_args), | |||
| Conv2dInferShape(primitive, input_args)->shape()); | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(Conv2D, prim::kPrimConv2D, Conv2dInfer); | |||
| } // namespace abstract | |||
| } // namespace mindspore | |||
| @@ -19,9 +19,9 @@ | |||
| #include "ir/dtype.h" | |||
| #include "utils/ms_utils.h" | |||
| #include "frontend/operator/ops.h" | |||
| #include "base/core_ops.h" | |||
| #include "abstract/param_validator.h" | |||
| #include "pipeline/jit/static_analysis/prim.h" | |||
| #include "abstract/infer_functions.h" | |||
| #include "abstract/utils.h" | |||
| #include "utils/ms_context.h" | |||
| #include "utils/symbolic.h" | |||
| @@ -35,27 +35,6 @@ AbstractBasePtr InferImplIdentity(const AnalysisEnginePtr &, const PrimitivePtr | |||
| return args_spec_list[0]; | |||
| } | |||
| AbstractBasePtr InferImplJ(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // args: An object of AbstractFunction. | |||
| CheckArgsSize(primitive->name(), args_spec_list, 1); | |||
| MS_LOG(DEBUG) << "evaluate J: " << args_spec_list[0]->ToString(); | |||
| AbstractFunctionPtr x = dyn_cast<AbstractFunction>(args_spec_list[0]); | |||
| if (x == nullptr) { | |||
| return std::make_shared<AbstractJTagged>(args_spec_list[0]); | |||
| } | |||
| AbstractFuncAtomPtrList jv; | |||
| auto build_jv = [&jv](const AbstractFuncAtomPtr &func) { | |||
| auto j_closure = std::make_shared<JTransformedAbstractClosure>(func); | |||
| jv.push_back(j_closure); | |||
| }; | |||
| x->Visit(build_jv); | |||
| return AbstractFunction::MakeAbstractFunction(jv); | |||
| } | |||
| AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| @@ -196,125 +175,6 @@ AbstractBasePtr InferImplDepend(const AnalysisEnginePtr &, const PrimitivePtr &p | |||
| return depends; | |||
| } | |||
| bool CompareShape(const std::vector<ValuePtr> &x_shape, const std::vector<ValuePtr> &y_shape) { | |||
| if (x_shape.size() != y_shape.size()) { | |||
| return false; | |||
| } | |||
| for (size_t i = 0; i < x_shape.size(); ++i) { | |||
| if (GetValue<int>(x_shape[i]) != GetValue<int>(y_shape[i])) { | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| enum State { | |||
| SAME, | |||
| X_ONE, | |||
| Y_ONE, | |||
| }; | |||
| void ComputeReduceIndex(const std::vector<int> &reverse_x, const std::vector<int> &reverse_y, | |||
| std::vector<int> *grad_x_reduce_idx, std::vector<int> *grad_y_reduce_idy) { | |||
| const size_t n = reverse_x.size(); | |||
| for (size_t i = 0; i < n; ++i) { | |||
| State curr; | |||
| const int32_t x_i = reverse_x[i]; | |||
| const int32_t y_i = reverse_y[i]; | |||
| const int reduce_idx = SizeToInt(n - 1 - i); | |||
| if (x_i == y_i) { | |||
| curr = SAME; | |||
| } else if (x_i == 1) { | |||
| grad_x_reduce_idx->push_back(reduce_idx); | |||
| curr = X_ONE; | |||
| } else if (y_i == 1) { | |||
| grad_y_reduce_idy->push_back(reduce_idx); | |||
| curr = Y_ONE; | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "not compatible shape input for BroadcastGradientArgs"; | |||
| } | |||
| if (curr == SAME && x_i == 1) { | |||
| grad_x_reduce_idx->push_back(reduce_idx); | |||
| grad_y_reduce_idy->push_back(reduce_idx); | |||
| continue; | |||
| } | |||
| } | |||
| std::reverse(grad_x_reduce_idx->begin(), grad_x_reduce_idx->end()); | |||
| std::reverse(grad_y_reduce_idy->begin(), grad_y_reduce_idy->end()); | |||
| } | |||
| AbstractBasePtr BroadcastGradientArgsDiff(const std::vector<ValuePtr> &x_shape, const std::vector<ValuePtr> &y_shape) { | |||
| std::vector<int> reverse_x; | |||
| std::vector<int> reverse_y; | |||
| (void)std::transform(x_shape.rbegin(), x_shape.rend(), std::back_inserter(reverse_x), | |||
| [](const ValuePtr &v) { return v->cast<Int32ImmPtr>()->value(); }); | |||
| (void)std::transform(y_shape.rbegin(), y_shape.rend(), std::back_inserter(reverse_y), | |||
| [](const ValuePtr &v) { return v->cast<Int32ImmPtr>()->value(); }); | |||
| if (reverse_x.size() > reverse_y.size()) { | |||
| reverse_y.resize(reverse_x.size(), 1); | |||
| } else { | |||
| reverse_x.resize(reverse_y.size(), 1); | |||
| } | |||
| std::vector<int> grad_x_reduce_idx; | |||
| std::vector<int> grad_y_reduce_idy; | |||
| ComputeReduceIndex(reverse_x, reverse_y, &grad_x_reduce_idx, &grad_y_reduce_idy); | |||
| AbstractBasePtrList abs_list_x; | |||
| AbstractBasePtrList abs_list_y; | |||
| (void)std::transform(grad_x_reduce_idx.begin(), grad_x_reduce_idx.end(), std::back_inserter(abs_list_x), | |||
| [](int v) { return abstract::FromValue(v); }); | |||
| (void)std::transform(grad_y_reduce_idy.begin(), grad_y_reduce_idy.end(), std::back_inserter(abs_list_y), | |||
| [](int v) { return abstract::FromValue(v); }); | |||
| auto x_reduce_idx = std::make_shared<AbstractTuple>(abs_list_x); | |||
| auto y_reduce_idx = std::make_shared<AbstractTuple>(abs_list_y); | |||
| AbstractBasePtrList elem_list; | |||
| elem_list.push_back(x_reduce_idx); | |||
| elem_list.push_back(y_reduce_idx); | |||
| return std::make_shared<AbstractTuple>(elem_list); | |||
| } | |||
| AbstractBasePtr InferImplBroadcastGradientArgs(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // this primitive get the index that need to reduce | |||
| // input: x's shape and y's shape, inputs should be tuple | |||
| // output: tuple of x and y 's reduce index, reduce index should be a tuple | |||
| const std::string op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, 2); | |||
| auto arg_x = CheckArg<AbstractTuple>(op_name, args_spec_list, 0); | |||
| auto arg_y = CheckArg<AbstractTuple>(op_name, args_spec_list, 1); | |||
| ValueTuplePtr arg_x_value = arg_x->BuildValue()->cast<ValueTuplePtr>(); | |||
| MS_EXCEPTION_IF_NULL(arg_x_value); | |||
| ValueTuplePtr arg_y_value = arg_y->BuildValue()->cast<ValueTuplePtr>(); | |||
| MS_EXCEPTION_IF_NULL(arg_y_value); | |||
| const std::vector<ValuePtr> x_shape = arg_x_value->value(); | |||
| const std::vector<ValuePtr> y_shape = arg_y_value->value(); | |||
| bool is_same_shape = CompareShape(x_shape, y_shape); | |||
| // if it is the same shape , do not need reduce , return empty tuple | |||
| if (is_same_shape) { | |||
| AbstractBasePtrList empty_list; | |||
| auto x_reduce_idx = std::make_shared<AbstractTuple>(empty_list); | |||
| auto y_reduce_idx = std::make_shared<AbstractTuple>(empty_list); | |||
| AbstractBasePtrList elem_list; | |||
| elem_list.push_back(x_reduce_idx); | |||
| elem_list.push_back(y_reduce_idx); | |||
| return std::make_shared<AbstractTuple>(elem_list); | |||
| } | |||
| return BroadcastGradientArgsDiff(x_shape, y_shape); | |||
| } | |||
| AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // args: Two objects of a subclass of AbstractBase | |||
| @@ -15,8 +15,7 @@ | |||
| */ | |||
| #include "abstract/param_validator.h" | |||
| #include "pipeline/jit/static_analysis/prim.h" | |||
| #include "frontend/operator/ops.h" | |||
| #include "abstract/infer_functions.h" | |||
| #include "abstract/utils.h" | |||
| #include "utils/symbolic.h" | |||
| @@ -34,38 +33,6 @@ AbstractBasePtr InferImplReturn(const AnalysisEnginePtr &, const PrimitivePtr &, | |||
| return abs_base; | |||
| } | |||
| AbstractBasePtr InferImplTypeof(const AnalysisEnginePtr &, const PrimitivePtr &, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: a pointer to an AbstractBase object | |||
| if (args_spec_list.size() != 1) { | |||
| MS_LOG(EXCEPTION) << "Typeof evaluator requires 1 parameter, while the input size is " << args_spec_list.size() | |||
| << "."; | |||
| } | |||
| AbstractBasePtr abs_base = args_spec_list[0]; | |||
| MS_EXCEPTION_IF_NULL(abs_base); | |||
| TypePtr type = abs_base->BuildType(); | |||
| return std::make_shared<AbstractType>(type); | |||
| } | |||
| AbstractBasePtr InferImplHasType(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: a pointer to an AbstractBase object and a pointer to a Type | |||
| const std::string op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, 2); | |||
| AbstractTypePtr abs_type = CheckArg<AbstractType>(op_name, args_spec_list, 1); | |||
| auto mode_v = abs_type->GetValueTrack(); | |||
| MS_EXCEPTION_IF_NULL(mode_v); | |||
| if (!mode_v->isa<Type>()) { | |||
| MS_LOG(EXCEPTION) << "Get the type from AbstractType value failed."; | |||
| } | |||
| TypePtr mode_t = mode_v->cast<TypePtr>(); | |||
| MS_EXCEPTION_IF_NULL(args_spec_list[0]); | |||
| bool v = IsSubtype(args_spec_list[0], mode_t); | |||
| return std::make_shared<AbstractScalar>(std::make_shared<BoolImm>(v), kBool); | |||
| } | |||
| AbstractBasePtr InferImplDot(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: two tensors. | |||
| @@ -0,0 +1,278 @@ | |||
| /** | |||
| * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). | |||
| * | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "abstract/infer_functions.h" | |||
| #include "abstract/utils.h" | |||
| #include "abstract/param_validator.h" | |||
| namespace mindspore { | |||
| namespace abstract { | |||
| AbstractBasePtr InferImplMakeTuple(const AnalysisEnginePtr &, const PrimitivePtr &, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| return std::make_shared<AbstractTuple>(args_spec_list); | |||
| } | |||
| AbstractBasePtr InferImplMakeList(const AnalysisEnginePtr &, const PrimitivePtr &, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| return std::make_shared<AbstractList>(args_spec_list); | |||
| } | |||
| AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: two tuples. | |||
| const std::string op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, 2); | |||
| AbstractTuplePtr keys = CheckArg<AbstractTuple>(op_name, args_spec_list, 0); | |||
| AbstractTuplePtr values = CheckArg<AbstractTuple>(op_name, args_spec_list, 1); | |||
| size_t keys_size = keys->size(); | |||
| if (values->size() != keys_size) { | |||
| MS_LOG(EXCEPTION) << op_name << " evaluator keys' size is not equal with values' size"; | |||
| } | |||
| std::vector<AbstractAttribute> key_value; | |||
| AbstractScalarPtr key; | |||
| AbstractBasePtrList key_list = keys->elements(); | |||
| AbstractBasePtrList value_list = values->elements(); | |||
| for (size_t index = 0; index < keys_size; index++) { | |||
| key = CheckArg<AbstractScalar>(op_name + "key", key_list, index); | |||
| ValuePtr keyPtr = key->BuildValue(); | |||
| MS_EXCEPTION_IF_NULL(keyPtr); | |||
| if (!keyPtr->isa<StringImm>()) { | |||
| MS_LOG(EXCEPTION) << op_name << " evaluator keys should be string, but got " << keyPtr->ToString(); | |||
| } | |||
| std::string key_string = GetValue<std::string>(keyPtr); | |||
| key_value.emplace_back(key_string, value_list[index]); | |||
| } | |||
| return std::make_shared<AbstractDictionary>(key_value); | |||
| } | |||
| AbstractBasePtr InferImplMakeKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: a string and an object of a subclass of AbstractBase. | |||
| const std::string op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, 2); | |||
| AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 0); | |||
| ValuePtr keyPtr = key->BuildValue(); | |||
| if (!keyPtr->isa<StringImm>()) { | |||
| MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << keyPtr->ToString(); | |||
| } | |||
| std::string key_string = GetValue<std::string>(keyPtr); | |||
| return std::make_shared<AbstractKeywordArg>(key_string, args_spec_list[1]); | |||
| } | |||
| AbstractBasePtr InferImplExtractKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: a string and a keyword. | |||
| const std::string op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, 2); | |||
| AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 0); | |||
| AbstractKeywordArgPtr kwarg = CheckArg<AbstractKeywordArg>(op_name, args_spec_list, 1); | |||
| ValuePtr key_value = key->BuildValue(); | |||
| if (!key_value->isa<StringImm>()) { | |||
| MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString(); | |||
| } | |||
| std::string key_input = GetValue<std::string>(key_value); | |||
| std::string key_actual = kwarg->get_key(); | |||
| if (key_actual != key_input) { | |||
| MS_LOG(EXCEPTION) << op_name << " evaluator input key should be same as AbstractKeywordArg' key, but input is " | |||
| << key_input << ", AbstractKeywordArg' key is " << key_actual; | |||
| } | |||
| return kwarg->get_arg(); | |||
| } | |||
| AbstractBasePtr InferImplMakeSlice(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: three scalars whose value is an int32 number. | |||
| CheckArgsSize(primitive->name(), args_spec_list, 3); | |||
| size_t args_size = args_spec_list.size(); | |||
| for (size_t index = 0; index < args_size; index++) { | |||
| MS_EXCEPTION_IF_NULL(args_spec_list[index]); | |||
| if (!args_spec_list[index]->isa<AbstractScalar>() && !args_spec_list[index]->isa<AbstractNone>()) { | |||
| MS_EXCEPTION(TypeError) << "MakeSlice eval " << index << " parameter is neither AbstractScalar nor AbstractNone."; | |||
| } | |||
| if (args_spec_list[index]->isa<AbstractScalar>() && | |||
| !dyn_cast<AbstractScalar>(args_spec_list[index])->BuildValue()->isa<Int32Imm>()) { | |||
| MS_EXCEPTION(TypeError) << "MakeSlice eval " << index | |||
| << " parameter is an AbstractScalar, but is not an int32 number."; | |||
| } | |||
| } | |||
| // Slice: start, end, step | |||
| return std::make_shared<AbstractSlice>(args_spec_list[0], args_spec_list[1], args_spec_list[2]); | |||
| } | |||
| template <typename T> | |||
| AbstractBasePtr InferTupleOrListGetItem(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: a tuple or list and a scalar whose value is an int32 number. | |||
| CheckArgsSize(op_name, args_spec_list, 2); | |||
| auto queue = CheckArg<T>(op_name, args_spec_list, 0); | |||
| AbstractScalarPtr index = CheckArg<AbstractScalar>(op_name, args_spec_list, 1); | |||
| ValuePtr index_value = index->BuildValue(); | |||
| if (!index_value->isa<Int32Imm>()) { | |||
| // when index_value is an AnyValue and args_spec_list[0] is a scalar, try to return the type of the first element | |||
| // and continue | |||
| if (dyn_cast<AbstractScalar>(queue->elements()[0]) != nullptr) { | |||
| return std::make_shared<AbstractScalar>(queue->elements()[0]->BuildType()); | |||
| } | |||
| MS_EXCEPTION(IndexError) << op_name << " evaluator index should be an int32 number, but got " | |||
| << index_value->ToString(); | |||
| } | |||
| int idx_v = GetValue<int>(index_value); | |||
| std::size_t nelems = queue->elements().size(); | |||
| if (idx_v >= SizeToInt(nelems) || idx_v < -SizeToInt(nelems)) { | |||
| MS_EXCEPTION(IndexError) << op_name << " evaluator index should be in range[-" << SizeToInt(nelems) << ", " | |||
| << SizeToInt(nelems) << "), but got " << idx_v << "."; | |||
| } | |||
| std::size_t uidx_v = 0; | |||
| if (idx_v >= 0) { | |||
| uidx_v = IntToSize(idx_v); | |||
| } else { | |||
| uidx_v = IntToSize(idx_v + SizeToInt(nelems)); | |||
| } | |||
| return queue->elements()[uidx_v]; | |||
| } | |||
| template <typename T> | |||
| AbstractBasePtr InferTupleOrListSetItem(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: a tuple or list, a scalar whose value is an int32 number and an object of a subclass of AbstractBase. | |||
| CheckArgsSize(op_name, args_spec_list, 3); | |||
| auto queue = CheckArg<T>(op_name, args_spec_list, 0); | |||
| AbstractScalarPtr index = CheckArg<AbstractScalar>(op_name, args_spec_list, 1); | |||
| ValuePtr index_value = index->BuildValue(); | |||
| if (!index_value->isa<Int32Imm>()) { | |||
| MS_EXCEPTION(IndexError) << op_name << " evaluator index should be an int32 number, but got " | |||
| << index_value->ToString(); | |||
| } | |||
| int idx_v = GetValue<int>(index_value); | |||
| if (idx_v < 0) { | |||
| MS_EXCEPTION(IndexError) << "The index of " << typeid(T).name() << " should be positive number, but got " << idx_v | |||
| << "."; | |||
| } | |||
| size_t uidx_v = IntToSize(idx_v); | |||
| AbstractBasePtrList elements = queue->elements(); | |||
| std::size_t nelems = elements.size(); | |||
| if (uidx_v >= nelems) { | |||
| MS_EXCEPTION(IndexError) << op_name << " evaluator the index: " << uidx_v << " to set out of range: " << nelems - 1 | |||
| << "."; | |||
| } | |||
| elements[uidx_v] = args_spec_list[2]; | |||
| return std::make_shared<T>(elements); | |||
| } | |||
| AbstractBasePtr InferImplTupleGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| return InferTupleOrListGetItem<AbstractTuple>(primitive->name(), args_spec_list); | |||
| } | |||
| AbstractBasePtr InferImplListGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| return InferTupleOrListGetItem<AbstractList>(primitive->name(), args_spec_list); | |||
| } | |||
| AbstractBasePtr InferImplTupleSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| return InferTupleOrListSetItem<AbstractTuple>(primitive->name(), args_spec_list); | |||
| } | |||
| AbstractBasePtr InferImplListSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| return InferTupleOrListSetItem<AbstractList>(primitive->name(), args_spec_list); | |||
| } | |||
| AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: a dict and a scalar whose value is a string. | |||
| const std::string op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, 2); | |||
| AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0); | |||
| AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 1); | |||
| ValuePtr key_value = key->BuildValue(); | |||
| if (!key_value->isa<StringImm>()) { | |||
| MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString(); | |||
| } | |||
| auto key_str = GetValue<std::string>(key_value); | |||
| std::vector<AbstractAttribute> dict_elems = dict->elements(); | |||
| auto it = std::find_if(dict_elems.begin(), dict_elems.end(), | |||
| [key_str](const AbstractAttribute &item) { return item.first == key_str; }); | |||
| if (it == dict_elems.end()) { | |||
| MS_LOG(EXCEPTION) << "The key " << key_str << " does not exist in the dict:" << args_spec_list[0]->ToString(); | |||
| } | |||
| return it->second; | |||
| } | |||
| AbstractBasePtr InferImplDictSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: a dict and a scalar whose value is a string and an object of a subclass of AbstractBase. | |||
| const std::string op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, 3); | |||
| AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0); | |||
| AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 1); | |||
| ValuePtr key_value = key->BuildValue(); | |||
| if (!key_value->isa<StringImm>()) { | |||
| MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString(); | |||
| } | |||
| std::string key_str = GetValue<std::string>(key_value); | |||
| std::vector<AbstractAttribute> dict_elems = dict->elements(); | |||
| auto it = std::find_if(dict_elems.begin(), dict_elems.end(), | |||
| [key_str](const AbstractAttribute &item) { return item.first == key_str; }); | |||
| MS_EXCEPTION_IF_NULL(args_spec_list[2]); | |||
| auto new_ele = std::make_pair(key_str, args_spec_list[2]); | |||
| if (it != dict_elems.end()) { | |||
| int index = it - dict_elems.begin(); | |||
| dict_elems[IntToSize(index)] = new_ele; | |||
| } else { | |||
| dict_elems.push_back(new_ele); | |||
| } | |||
| return std::make_shared<AbstractDictionary>(dict_elems); | |||
| } | |||
| AbstractBasePtr InferImplListAppend(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: a list and an object of a subclass of AbstractBase. | |||
| const std::string op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, 2); | |||
| AbstractListPtr list = CheckArg<AbstractList>(op_name, args_spec_list, 0); | |||
| (void)AbstractJoin(list->elements()); | |||
| return list; | |||
| } | |||
| AbstractBasePtr InferImplTupleLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| return InferTupleOrListOrDictLen<AbstractTuple>(primitive->name(), args_spec_list); | |||
| } | |||
| AbstractBasePtr InferImplListLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| return InferTupleOrListOrDictLen<AbstractList>(primitive->name(), args_spec_list); | |||
| } | |||
| AbstractBasePtr InferImplArrayLen(const AnalysisEnginePtr &, const PrimitivePtr &, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| return std::make_shared<AbstractScalar>(kAnyValue, kInt32); | |||
| } | |||
| } // namespace abstract | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,114 @@ | |||
| /** | |||
| * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). | |||
| * | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "abstract/primitive_infer_map.h" | |||
| #include "abstract/abstract_function.h" | |||
| #include "abstract/infer_functions.h" | |||
| namespace mindspore { | |||
| namespace abstract { | |||
| PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||
| static PrimitiveEvalImplMap prim_eval_implement_map = { | |||
| // Statements | |||
| {prim::kPrimReturn, {InferImplReturn, true}}, | |||
| {prim::kPrimDot, {InferImplDot, true}}, | |||
| {prim::kPrimSwitch, {InferImplSwitch, true}}, | |||
| {prim::kPrimSwitchLayer, {InferImplSwitchLayer, true}}, | |||
| {prim::kPrimIs_, {InferImplIs_, true}}, | |||
| {prim::kPrimIsNot, {InferImplIsNot, true}}, | |||
| {prim::kPrimInDict, {InferImplInDict, true}}, | |||
| {prim::kPrimNotInDict, {InferImplNotInDict, true}}, | |||
| {prim::kPrimIsConsant, {InferImplIsConstant, true}}, | |||
| // Maths | |||
| {prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, true}}, | |||
| {prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, true}}, | |||
| // Array | |||
| {prim::kPrimScalarToArray, {InferImplScalarToArray, true}}, | |||
| {prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}}, | |||
| {prim::kPrimBroadcastShape, {InferImplBroadCastShape, true}}, | |||
| {prim::kPrimPack, {InferImplPack, true}}, | |||
| {prim::kPrimUnique, {InferImplUnique, true}}, | |||
| {prim::kPrimUniqueGrad, {InferImplUniqueGrad, true}}, | |||
| // Structure | |||
| {prim::kPrimMakeTuple, {InferImplMakeTuple, true}}, | |||
| {prim::kPrimMakeList, {InferImplMakeList, true}}, | |||
| {prim::kPrimMakeDict, {InferImplMakeDict, true}}, | |||
| {prim::kPrimMakeSlice, {InferImplMakeSlice, true}}, | |||
| {prim::kPrimMakeKeywordArg, {InferImplMakeKwarg, true}}, | |||
| {prim::kPrimExtractKeywordArg, {InferImplExtractKwarg, true}}, | |||
| {prim::kPrimTupleGetItem, {InferImplTupleGetItem, true}}, | |||
| {prim::kPrimListGetItem, {InferImplListGetItem, true}}, | |||
| {prim::kPrimTupleSetItem, {InferImplTupleSetItem, true}}, | |||
| {prim::kPrimListSetItem, {InferImplListSetItem, true}}, | |||
| {prim::kPrimDictGetItem, {InferImplDictGetItem, true}}, | |||
| {prim::kPrimDictSetItem, {InferImplDictSetItem, true}}, | |||
| {prim::kPrimListAppend, {InferImplListAppend, true}}, | |||
| {prim::kPrimTupleLen, {InferImplTupleLen, true}}, | |||
| {prim::kPrimListLen, {InferImplListLen, true}}, | |||
| {prim::kPrimArrayLen, {InferImplArrayLen, true}}, | |||
| // NN | |||
| {prim::kPrimPooling, {InferImplPooling, true}}, | |||
| {prim::kPrimPoolingGrad, {InferImplPoolingGrad, true}}, | |||
| {prim::kPrimFusedBatchNorm, {InferImplFusedBatchNorm, true}}, | |||
| {prim::kPrimFusedBatchNormGrad, {InferImplFusedBatchNormGrad, true}}, | |||
| {prim::kPrimReluGrad, {InferImplReluGrad, true}}, | |||
| {prim::kPrimConv2DBackpropInput, {InferImplConv2DBackpropInput, true}}, | |||
| {prim::kPrimConv2DBackpropFilter, {InferImplConv2DBackpropFilter, true}}, | |||
| {prim::kPrimBiasAddGrad, {InferImplBiasAddGrad, true}}, | |||
| {prim::kPrimRelu, {InferImplRelu, true}}, | |||
| {prim::kPrimZerosLike, {InferImplZerosLike, true}}, | |||
| {prim::kPrimBpropCut, {InferImplBpropCut, true}}, | |||
| {prim::kPrimLayerNorm, {InferImplLayerNorm, true}}, | |||
| {prim::kPrimLayerNormGrad, {InferImplLayerNormGrad, true}}, | |||
| {prim::kPrimDropoutGenMask, {InferImplDropoutGenMask, true}}, | |||
| // Others | |||
| {prim::kPrimIdentity, {InferImplIdentity, true}}, | |||
| // Set impl to null as it will use PartialEvaluator; | |||
| {prim::kPrimPartial, {nullptr, true}}, | |||
| {prim::kPrimEnvGetItem, {InferImplEnvGetItem, true}}, | |||
| {prim::kPrimEnvSetItem, {InferImplEnvSetItem, true}}, | |||
| {prim::kPrimEnvAdd, {InferImplEnvAdd, true}}, | |||
| {prim::kPrimMakeRefKey, {InferImplMakeRefKey, true}}, | |||
| {prim::kPrimMakeRef, {InferImplMakeRef, true}}, | |||
| {prim::kPrimGetRefKey, {InferImplGetRefKey, true}}, | |||
| {prim::kPrimGetRefValue, {InferImplGetRefValue, true}}, | |||
| {prim::kPrimStateSetItem, {InferImplStateSetItem, true}}, | |||
| {prim::kPrimDepend, {InferImplDepend, true}}, | |||
| {prim::kPrimControlDepend, {InferImplControlDepend, true}}, | |||
| // Debug | |||
| {prim::kPrimDebug, {InferImplDebug, true}}, | |||
| // SparseTensor | |||
| {prim::kPrimMakeSparseTensor, {InferImplMakeSparseTensor, true}}, | |||
| {prim::kPrimSparseTensorGetValues, {InferImplSparseTensorGetValues, true}}, | |||
| {prim::kPrimSparseTensorGetIndices, {InferImplSparseTensorGetIndices, true}}, | |||
| {prim::kPrimSparseTensorGetDenseShape, {InferImplSparseTensorGetDenseShape, true}}, | |||
| // RowTensor | |||
| {prim::kPrimMakeRowTensor, {InferImplMakeRowTensor, true}}, | |||
| {prim::kPrimRowTensorGetValues, {InferImplRowTensorGetValues, true}}, | |||
| {prim::kPrimRowTensorGetIndices, {InferImplRowTensorGetIndices, true}}, | |||
| {prim::kPrimRowTensorGetDenseShape, {InferImplRowTensorGetDenseShape, true}}, | |||
| }; | |||
| return prim_eval_implement_map; | |||
| } | |||
| void RegisterStandardPrimitiveImpl(const PrimitivePtr &primitive, const StandardPrimitiveImplReg &impl_reg) { | |||
| auto &prim_eval_map = GetPrimitiveToEvalImplMap(); | |||
| prim_eval_map[primitive] = impl_reg; | |||
| } | |||
| } // namespace abstract | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,53 @@ | |||
| /** | |||
| * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). | |||
| * | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_ABSTRACT_PRIMITIVE_INFER_MAP_H_ | |||
| #define MINDSPORE_CORE_ABSTRACT_PRIMITIVE_INFER_MAP_H_ | |||
| #include <unordered_map> | |||
| #include "ir/primitive.h" | |||
| #include "base/core_ops.h" | |||
| #include "abstract/abstract_value.h" | |||
| namespace mindspore { | |||
| namespace abstract { | |||
| using StandardPrimitiveEvalImpl = AbstractBasePtr (*)(const abstract::AnalysisEnginePtr &, const PrimitivePtr &, | |||
| const AbstractBasePtrList &); | |||
| struct StandardPrimitiveImplReg { | |||
| StandardPrimitiveEvalImpl impl_; // Implement function of Primitive. | |||
| bool in_white_list_; // true if this Primitive in white list, else false. | |||
| }; | |||
| using PrimitiveEvalImplMap = | |||
| std::unordered_map<PrimitivePtr, StandardPrimitiveImplReg, PrimitiveHasher, PrimitiveEqual>; | |||
| PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap(); | |||
| void RegisterStandardPrimitiveImpl(const PrimitivePtr &primitive, const StandardPrimitiveImplReg &impl_reg); | |||
| class RegisterStandardPrimitiveEvalHelper { | |||
| public: | |||
| RegisterStandardPrimitiveEvalHelper(const PrimitivePtr &primitive, const StandardPrimitiveEvalImpl &impl) { | |||
| const StandardPrimitiveImplReg impl_reg{impl, true}; | |||
| RegisterStandardPrimitiveImpl(primitive, impl_reg); | |||
| } | |||
| ~RegisterStandardPrimitiveEvalHelper() = default; | |||
| }; | |||
| #define REGISTER_PRIMITIVE_EVAL_IMPL(name, primitive, impl) \ | |||
| static auto helper_##name = RegisterStandardPrimitiveEvalHelper(primitive, impl) | |||
| } // namespace abstract | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_ABSTRACT_PRIMITIVE_INFER_MAP_H_ | |||
| @@ -246,6 +246,25 @@ inline const PrimitivePtr kPrimNotInDict = std::make_shared<Primitive>("not_in_d | |||
| inline const PrimitivePtr kPrimIsConsant = std::make_shared<Primitive>("is_constant"); | |||
| inline const PrimitivePtr kPrimEquivFormat = std::make_shared<Primitive>("EquivFormat"); | |||
| // Structures | |||
| inline const PrimitivePtr kPrimMakeList = std::make_shared<Primitive>("make_list"); | |||
| inline const PrimitivePtr kPrimMakeKeywordArg = std::make_shared<Primitive>("make_keyword_arg"); | |||
| inline const PrimitivePtr kPrimListGetItem = std::make_shared<Primitive>("list_getitem"); | |||
| inline const PrimitivePtr kPrimListSetItem = std::make_shared<Primitive>("list_setitem"); | |||
| inline const PrimitivePtr kPrimDictGetItem = std::make_shared<Primitive>("dict_getitem"); | |||
| inline const PrimitivePtr kPrimDictSetItem = std::make_shared<Primitive>("dict_setitem"); | |||
| inline const PrimitivePtr kPrimListAppend = std::make_shared<Primitive>("list_append"); | |||
| inline const PrimitivePtr kPrimListLen = std::make_shared<Primitive>("list_len"); | |||
| // Other miscellaneous | |||
| inline const PrimitivePtr kPrimEnvSetItem = std::make_shared<Primitive>("env_setitem"); | |||
| inline const PrimitivePtr kPrimEnvGetItem = std::make_shared<Primitive>("env_getitem"); | |||
| inline const PrimitivePtr kPrimEnvAdd = std::make_shared<Primitive>("env_add"); | |||
| inline const PrimitivePtr kPrimMakeRefKey = std::make_shared<Primitive>("MakeRefKey"); | |||
| inline const PrimitivePtr kPrimGetRefKey = std::make_shared<Primitive>("get_ref_key"); | |||
| inline const PrimitivePtr kPrimMakeRef = std::make_shared<Primitive>("make_ref"); | |||
| inline const PrimitivePtr kPrimGetRefValue = std::make_shared<Primitive>("get_ref_value"); | |||
| // Other primitve not used by backend but used in core; | |||
| inline const PrimitivePtr kPrimStateSetItem = std::make_shared<Primitive>("state_setitem"); | |||
| inline const PrimitivePtr kPrimJ = std::make_shared<Primitive>("J"); | |||
| @@ -26,87 +26,19 @@ | |||
| namespace mindspore { | |||
| namespace { | |||
| using PrimConv2dPtr = std::shared_ptr<Conv2d>; | |||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto conv_prim = primitive->cast<PrimConv2dPtr>(); | |||
| MS_EXCEPTION_IF_NULL(conv_prim); | |||
| auto prim_name = conv_prim->name(); | |||
| CheckAndConvertUtils::CheckInRange("Conv2d Infer", input_args.size(), kIncludeLeft, {2, 3}, prim_name); | |||
| auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[0]->GetShapeTrack(), prim_name); | |||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[1]->GetShapeTrack(), prim_name); | |||
| CheckAndConvertUtils::CheckInteger("weight rank", w_shape.size(), kEqual, 4, prim_name); | |||
| CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name); | |||
| CheckAndConvertUtils::Check("x_shape[1] / group", x_shape[1] / conv_prim->GetGroup(), kEqual, "w_shape[1]", | |||
| w_shape[1], conv_prim->name()); | |||
| auto out_channel = conv_prim->GetOutputChannel(); | |||
| CheckAndConvertUtils::Check("out_channel", out_channel, kEqual, "w_shape[0]", w_shape[0], conv_prim->name()); | |||
| std::vector<int> temp_w; | |||
| std::copy(w_shape.begin() + 2, w_shape.end(), std::back_inserter(temp_w)); | |||
| CheckAndConvertUtils::Check("kernel_size", conv_prim->GetKernelSize(), kEqual, "w_shape[2:4]", temp_w, | |||
| conv_prim->name()); | |||
| auto kernel_size_h = w_shape[2]; | |||
| auto kernel_size_w = w_shape[3]; | |||
| auto stride = conv_prim->GetStride(); | |||
| auto dilation = conv_prim->GetDilation(); | |||
| auto stride_h = stride[2]; | |||
| auto stride_w = stride[3]; | |||
| auto dilation_h = dilation[2]; | |||
| auto dilation_w = dilation[3]; | |||
| int h_out = -1; | |||
| int w_out = -1; | |||
| std::vector<int> pad_list(4, 0); | |||
| auto pad_mode = conv_prim->GetPadMode(); | |||
| if (pad_mode == "valid") { | |||
| h_out = ceil((x_shape[2] - dilation_h * (kernel_size_h - 1)) / stride_h); | |||
| w_out = ceil((x_shape[3] - dilation_w * (kernel_size_w - 1)) / stride_w); | |||
| } else if (pad_mode == "same") { | |||
| h_out = ceil(x_shape[2] / stride_h); | |||
| w_out = ceil(x_shape[3] / stride_w); | |||
| auto pad_needed_h = std::max(0, (h_out - 1) * stride_h + dilation_h * (kernel_size_h - 1) + 1 - x_shape[2]); | |||
| pad_list.emplace_back(floor(pad_needed_h / 2)); | |||
| pad_list.emplace_back(pad_needed_h / 2); | |||
| auto pad_needed_w = std::max(0, (w_out - 1) * stride_w + dilation_w * (kernel_size_w - 1) + 1 - x_shape[3]); | |||
| auto pad_left = floor(pad_needed_w / 2); | |||
| pad_list.emplace_back(pad_left); | |||
| pad_list.emplace_back(pad_needed_h - pad_left); | |||
| } else if (pad_mode == "pad") { | |||
| std::copy(conv_prim->GetPad().begin(), conv_prim->GetPad().end(), std::back_inserter(pad_list)); | |||
| auto pad_top = conv_prim->GetPad()[0]; | |||
| auto pad_bottom = conv_prim->GetPad()[1]; | |||
| auto pad_right = conv_prim->GetPad()[2]; | |||
| auto pad_left = conv_prim->GetPad()[3]; | |||
| h_out = 1 + (x_shape[2] + pad_top + pad_bottom - kernel_size_h - (kernel_size_h - 1) * (dilation_h - 1)) / stride_h; | |||
| w_out = 1 + (x_shape[3] + pad_left + pad_right - kernel_size_w - (kernel_size_w - 1) * (dilation_w - 1)) / stride_w; | |||
| h_out = floor(h_out); | |||
| w_out = floor(w_out); | |||
| } | |||
| conv_prim->SetPadList(pad_list); | |||
| std::vector<int> out_shape = {x_shape[0], out_channel, h_out, w_out}; | |||
| return std::make_shared<abstract::Shape>(out_shape); | |||
| } | |||
| TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | |||
| CheckAndConvertUtils::CheckInRange("", input_args.size(), kIncludeLeft, {2, 3}, prim->name()); | |||
| for (const auto &item : input_args) { | |||
| MS_EXCEPTION_IF_NULL(item); | |||
| } | |||
| auto x_type = CheckAndConvertUtils::ConvertTypePtrToTypeId("x_dtype", input_args[0]->GetTypeTrack(), prim->name()); | |||
| const std::set<TypeId> valid_types = {kNumberTypeInt8, kNumberTypeInt32, kNumberTypeFloat16, kNumberTypeFloat32}; | |||
| std::map<std::string, TypePtr> types; | |||
| types.emplace("x", input_args[0]->GetTypeTrack()); | |||
| types.emplace("w", input_args[1]->GetTypeTrack()); | |||
| CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); | |||
| if (x_type == kNumberTypeInt8) { | |||
| return std::make_shared<TensorType>(TypeIdToType(kNumberTypeInt32)); | |||
| } | |||
| return std::make_shared<TensorType>(TypeIdToType(x_type)); | |||
| } | |||
| constexpr auto kKernelSize = "kernel_size"; | |||
| constexpr auto kStride = "stride"; | |||
| constexpr auto kDilation = "dilation"; | |||
| constexpr auto kPadMode = "pad_mode"; | |||
| constexpr auto kPad = "pad"; | |||
| constexpr auto kMode = "mode"; | |||
| constexpr auto kGroup = "group"; | |||
| constexpr auto kOutputChannel = "output channel"; | |||
| constexpr auto kPadList = "pad_list"; | |||
| constexpr auto kConv2DName = "Conv2D"; | |||
| } // namespace | |||
| Conv2d::Conv2d() : PrimitiveC(kConv2DName) { InitIOName({"x", "w"}, {"output"}); } | |||
| void Conv2d::Init(int out_channel, const std::vector<int> &kernel_size, int mode, const std::string &pad_mode, | |||
| const std::vector<int> &pad, const std::vector<int> &stride, const std::vector<int> &dilation, | |||
| int group) { | |||
| @@ -130,10 +62,47 @@ void Conv2d::Init(int out_channel, const std::vector<int> &kernel_size, int mode | |||
| this->SetOutChannel(CheckAndConvertUtils::CheckInteger("out_channel", out_channel, kGreaterThan, 0, prim_name)); | |||
| this->SetGroup(CheckAndConvertUtils::CheckInteger("group", group, kGreaterThan, 0, prim_name)); | |||
| } | |||
| std::vector<int> Conv2d::GetKernelSize() const { | |||
| auto value_ptr = GetAttr(kKernelSize); | |||
| return GetValue<std::vector<int>>(value_ptr); | |||
| } | |||
| std::vector<int> Conv2d::GetStride() const { | |||
| auto value_ptr = GetAttr(kStride); | |||
| return GetValue<std::vector<int>>(value_ptr); | |||
| } | |||
| std::vector<int> Conv2d::GetDilation() const { | |||
| auto value_ptr = GetAttr(kDilation); | |||
| return GetValue<std::vector<int>>(value_ptr); | |||
| } | |||
| std::string Conv2d::GetPadMode() const { | |||
| auto value_ptr = this->GetAttr(kPadMode); | |||
| return GetValue<string>(value_ptr); | |||
| } | |||
| std::vector<int> Conv2d::GetPad() const { | |||
| auto value_ptr = this->GetAttr(kPad); | |||
| return GetValue<std::vector<int>>(value_ptr); | |||
| } | |||
| int Conv2d::GetMode() const { | |||
| auto value_ptr = this->GetAttr(kMode); | |||
| return GetValue<int>(value_ptr); | |||
| } | |||
| AbstractBasePtr Conv2dInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), | |||
| InferShape(primitive, input_args)->shape()); | |||
| int Conv2d::GetGroup() const { | |||
| auto value_ptr = this->GetAttr(kGroup); | |||
| return GetValue<int>(value_ptr); | |||
| } | |||
| int Conv2d::GetOutputChannel() const { | |||
| auto value_ptr = this->GetAttr(kOutputChannel); | |||
| return GetValue<int>(value_ptr); | |||
| } | |||
| void Conv2d::SetKernelSize(const std::vector<int> &kernel_size) { this->AddAttr(kKernelSize, MakeValue(kernel_size)); } | |||
| void Conv2d::SetStride(const std::vector<int> &stride) { this->AddAttr(kStride, MakeValue(stride)); } | |||
| void Conv2d::SetDilation(const std::vector<int> &dilation) { this->AddAttr(kDilation, MakeValue(dilation)); } | |||
| void Conv2d::SetPadMode(const std::string &pad_mode) { this->AddAttr(kPadMode, MakeValue(pad_mode)); } | |||
| void Conv2d::SetPad(const std::vector<int> &pad) { this->AddAttr(kPad, MakeValue(pad)); } | |||
| void Conv2d::SetMode(int mode) { this->AddAttr(kMode, MakeValue(mode)); } | |||
| void Conv2d::SetGroup(int group) { this->AddAttr(kGroup, MakeValue(group)); } | |||
| void Conv2d::SetOutChannel(int output_channel) { this->AddAttr(kOutputChannel, MakeValue(output_channel)); } | |||
| void Conv2d::SetPadList(const std::vector<int> &pad_list) { this->AddAttr(kPadList, MakeValue(pad_list)); } | |||
| } // namespace mindspore | |||
| @@ -16,79 +16,44 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_C_OPS_CONV2D_H | |||
| #define MINDSPORE_CORE_C_OPS_CONV2D_H | |||
| #ifndef MINDSPORE_CORE_C_OPS_CONV2D_H_ | |||
| #define MINDSPORE_CORE_C_OPS_CONV2D_H_ | |||
| #include <map> | |||
| #include <vector> | |||
| #include <string> | |||
| #include <memory> | |||
| #include "c_ops/primitive_c.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| class Conv2d : public PrimitiveC { | |||
| public: | |||
| Conv2d() : PrimitiveC(kConv2DName) { InitIOName({"x", "w"}, {"output"}); } | |||
| Conv2d(); | |||
| void Init(int out_channel, const std::vector<int> &kernel_size, int mode = 1, const std::string &pad_mode = "valid", | |||
| const std::vector<int> &pad = {0, 0, 0, 0}, const std::vector<int> &stride = {1, 1, 1, 1}, | |||
| const std::vector<int> &dilation = {1, 1, 1, 1}, int group = 1); | |||
| std::vector<int> GetKernelSize() const { | |||
| auto value_ptr = this->GetAttr(kKernelSize); | |||
| return GetValue<std::vector<int>>(value_ptr); | |||
| } | |||
| std::vector<int> GetStride() const { | |||
| auto value_ptr = GetAttr(kStride); | |||
| return GetValue<std::vector<int>>(value_ptr); | |||
| } | |||
| std::vector<int> GetDilation() const { | |||
| auto value_ptr = GetAttr(kDilation); | |||
| return GetValue<std::vector<int>>(value_ptr); | |||
| } | |||
| std::string GetPadMode() const { | |||
| auto value_ptr = this->GetAttr(kPadMode); | |||
| return GetValue<string>(value_ptr); | |||
| } | |||
| std::vector<int> GetPad() const { | |||
| auto value_ptr = this->GetAttr(kPad); | |||
| return GetValue<std::vector<int>>(value_ptr); | |||
| } | |||
| int GetMode() const { | |||
| auto value_ptr = this->GetAttr(kMode); | |||
| return GetValue<int>(value_ptr); | |||
| } | |||
| int GetGroup() const { | |||
| auto value_ptr = this->GetAttr(kGroup); | |||
| return GetValue<int>(value_ptr); | |||
| } | |||
| int GetOutputChannel() const { | |||
| auto value_ptr = this->GetAttr(kOutputChannel); | |||
| return GetValue<int>(value_ptr); | |||
| } | |||
| void SetKernelSize(const std::vector<int> &kernel_size) { this->AddAttr(kKernelSize, MakeValue(kernel_size)); } | |||
| void SetStride(const std::vector<int> &stride) { this->AddAttr(kStride, MakeValue(stride)); } | |||
| void SetDilation(const std::vector<int> &dilation) { this->AddAttr(kDilation, MakeValue(dilation)); } | |||
| void SetPadMode(const std::string &pad_mode) { this->AddAttr(kPadMode, MakeValue(pad_mode)); } | |||
| void SetPad(const std::vector<int> &pad) { this->AddAttr(kPad, MakeValue(pad)); } | |||
| void SetMode(int mode) { this->AddAttr(kMode, MakeValue(mode)); } | |||
| void SetGroup(int group) { this->AddAttr(kGroup, MakeValue(group)); } | |||
| void SetOutChannel(int output_channel) { this->AddAttr(kOutputChannel, MakeValue(output_channel)); } | |||
| void SetPadList(const std::vector<int> &pad_list) { this->AddAttr(kPadList, MakeValue(pad_list)); } | |||
| private: | |||
| inline static const string kKernelSize = "kernel_size"; | |||
| inline static const string kStride = "stride"; | |||
| inline static const string kDilation = "dilation"; | |||
| inline static const string kPadMode = "pad_mode"; | |||
| inline static const string kPad = "pad"; | |||
| inline static const string kMode = "mode"; | |||
| inline static const string kGroup = "group"; | |||
| inline static const string kOutputChannel = "output channel"; | |||
| inline static const string kPadList = "pad_list"; | |||
| inline static const string kConv2DName = "Conv2D"; | |||
| std::vector<int> GetKernelSize() const; | |||
| std::vector<int> GetStride() const; | |||
| std::vector<int> GetDilation() const; | |||
| std::string GetPadMode() const; | |||
| std::vector<int> GetPad() const; | |||
| int GetMode() const; | |||
| int GetGroup() const; | |||
| int GetOutputChannel() const; | |||
| void SetKernelSize(const std::vector<int> &kernel_size); | |||
| void SetStride(const std::vector<int> &stride); | |||
| void SetDilation(const std::vector<int> &dilation); | |||
| void SetPadMode(const std::string &pad_mode); | |||
| void SetPad(const std::vector<int> &pad); | |||
| void SetMode(int mode); | |||
| void SetGroup(int group); | |||
| void SetOutChannel(int output_channel); | |||
| void SetPadList(const std::vector<int> &pad_list); | |||
| }; | |||
| AbstractBasePtr Conv2dInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args); | |||
| using PrimConv2dPtr = std::shared_ptr<Conv2d>; | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_C_OPS_CONV2D_H | |||
| #endif // MINDSPORE_CORE_C_OPS_CONV2D_H_ | |||
| @@ -16,8 +16,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H | |||
| #define MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H | |||
| #ifndef MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H_ | |||
| #define MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H_ | |||
| #include <string> | |||
| #include <vector> | |||
| #include "ir/primitive.h" | |||
| @@ -25,7 +25,7 @@ | |||
| namespace mindspore { | |||
| class PrimitiveC : public Primitive { | |||
| public: | |||
| explicit PrimitiveC(const std::string &name) : Primitive(name) { attrs_ = {}; } | |||
| explicit PrimitiveC(const std::string &name) : Primitive(name) {} | |||
| protected: | |||
| void InitIOName(const std::vector<std::string> &inputs_name, const std::vector<std::string> &outputs_name) { | |||
| @@ -34,4 +34,4 @@ class PrimitiveC : public Primitive { | |||
| } | |||
| }; | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H | |||
| #endif // MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H_ | |||
| @@ -61,6 +61,7 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||
| "../../../mindspore/core/abstract/*.cc" | |||
| "../../../mindspore/core/ir/*.cc" | |||
| "../../../mindspore/core/utils/*.cc" | |||
| "../../../mindspore/core/c_ops/*.cc" | |||
| "../../../mindspore/ccsrc/common/*.cc" | |||
| "../../../mindspore/ccsrc/utils/*.cc" | |||
| "../../../mindspore/ccsrc/pipeline/jit/parse/*.cc" | |||