Merge pull request !217 from zhangbuxue/support_pow_operatortags/v0.2.0-alpha
| @@ -83,9 +83,9 @@ convert_object_map = { | |||||
| T.mul: multitype_ops.mul, | T.mul: multitype_ops.mul, | ||||
| T.truediv: multitype_ops.div, | T.truediv: multitype_ops.div, | ||||
| T.getitem: multitype_ops.getitem, | T.getitem: multitype_ops.getitem, | ||||
| T.floordiv: NO_IMPLEMENT, | |||||
| T.mod: F.scalar_mod, | |||||
| T.pow: F.scalar_pow, | |||||
| T.floordiv: multitype_ops.floordiv, | |||||
| T.mod: multitype_ops.mod, | |||||
| T.pow: multitype_ops.pow_, | |||||
| T.matmul: F.dot, | T.matmul: F.dot, | ||||
| T.lshift: NO_IMPLEMENT, | T.lshift: NO_IMPLEMENT, | ||||
| T.rshift: NO_IMPLEMENT, | T.rshift: NO_IMPLEMENT, | ||||
| @@ -104,8 +104,8 @@ convert_object_map = { | |||||
| T.ge: multitype_ops.greater_equal, | T.ge: multitype_ops.greater_equal, | ||||
| T.is_: F.is_, | T.is_: F.is_, | ||||
| T.is_not: F.is_not, | T.is_not: F.is_not, | ||||
| T.contains: NO_IMPLEMENT, | |||||
| T.not_contains: NO_IMPLEMENT, | |||||
| T.contains: F.in_dict, | |||||
| T.not_contains: F.not_in_dict, | |||||
| # system function | # system function | ||||
| T.len: M.ms_len, | T.len: M.ms_len, | ||||
| @@ -103,7 +103,7 @@ T InnerScalarMul(T x, T y) { | |||||
| } | } | ||||
| template <typename T> | template <typename T> | ||||
| T InnerScalarDiv(T x, T y) { | |||||
| float InnerScalarDiv(T x, T y) { | |||||
| if (y == 0) { | if (y == 0) { | ||||
| MS_LOG(EXCEPTION) << "Divisor could not be zero"; | MS_LOG(EXCEPTION) << "Divisor could not be zero"; | ||||
| } | } | ||||
| @@ -111,23 +111,41 @@ T InnerScalarDiv(T x, T y) { | |||||
| MS_LOG(EXCEPTION) << "Overflow of the div of two signed number x: " << std::to_string(x) | MS_LOG(EXCEPTION) << "Overflow of the div of two signed number x: " << std::to_string(x) | ||||
| << ", y: " << std::to_string(y) << "."; | << ", y: " << std::to_string(y) << "."; | ||||
| } | } | ||||
| return x / y; | |||||
| return static_cast<float>(x) / static_cast<float>(y); | |||||
| } | } | ||||
| int32_t InnerScalarMod(int32_t x, int32_t y) { | |||||
| template <typename T> | |||||
| T InnerScalarFloordiv(T x, T y) { | |||||
| auto ret = std::floor(InnerScalarDiv(x, y)); | |||||
| if (std::is_integral<T>::value) { | |||||
| return static_cast<int>(ret); | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| template <typename T> | |||||
| T InnerScalarMod(T x, T y) { | |||||
| if (y == 0) { | if (y == 0) { | ||||
| MS_LOG(EXCEPTION) << "Could not mod to zero."; | MS_LOG(EXCEPTION) << "Could not mod to zero."; | ||||
| } | } | ||||
| if (IsSignedIntOverflow(x, y, OpType::MOD)) { | |||||
| if (std::is_integral<T>::value && std::is_signed<T>::value && IsSignedIntOverflow(x, y, OpType::MOD)) { | |||||
| MS_LOG(EXCEPTION) << "Overflow of the mod of two signed number x: " << std::to_string(x) | MS_LOG(EXCEPTION) << "Overflow of the mod of two signed number x: " << std::to_string(x) | ||||
| << ", y: " << std::to_string(y) << "."; | << ", y: " << std::to_string(y) << "."; | ||||
| } | } | ||||
| return x % y; | |||||
| if (std::is_integral<T>::value) { | |||||
| return static_cast<int>(x) % static_cast<int>(y); | |||||
| } | |||||
| float x_int = std::floor(x); | |||||
| float y_int = std::ceil(y); | |||||
| float max = x_int / y_int; | |||||
| float ret = x - y * max; | |||||
| return ret; | |||||
| } | } | ||||
| float InnerScalarMod(float, float) { MS_LOG(EXCEPTION) << "Float does not support mod operator."; } | |||||
| double InnerScalarMod(double, double) { MS_LOG(EXCEPTION) << "Double does not support mod operator."; } | |||||
| template <typename T, typename U> | |||||
| T InnerScalarPow(T x, U y) { | |||||
| return std::pow(x, y); | |||||
| } | |||||
| template <typename T, typename U> | template <typename T, typename U> | ||||
| bool InnerScalarEq(T x, U y) { | bool InnerScalarEq(T x, U y) { | ||||
| @@ -193,6 +211,8 @@ SCALAR_OP(Sub) | |||||
| SCALAR_OP(Mul) | SCALAR_OP(Mul) | ||||
| SCALAR_OP(Div) | SCALAR_OP(Div) | ||||
| SCALAR_OP(Mod) | SCALAR_OP(Mod) | ||||
| SCALAR_OP(Pow) | |||||
| SCALAR_OP(Floordiv) | |||||
| #define LOGIC_OP(op_t) \ | #define LOGIC_OP(op_t) \ | ||||
| ValuePtr Scalar##op_t(const ValuePtrList& list) { \ | ValuePtr Scalar##op_t(const ValuePtrList& list) { \ | ||||
| @@ -227,6 +247,10 @@ SCALAR_OP(Mod) | |||||
| bool sum = InnerScalar##op_t(GetValue<float>(x), GetValue<int>(y)); \ | bool sum = InnerScalar##op_t(GetValue<float>(x), GetValue<int>(y)); \ | ||||
| return MakeValue(sum); \ | return MakeValue(sum); \ | ||||
| } \ | } \ | ||||
| if (x->isa<Int32Imm>() && y->isa<FP32Imm>()) { \ | |||||
| bool sum = InnerScalar##op_t(GetValue<int>(x), GetValue<float>(y)); \ | |||||
| return MakeValue(sum); \ | |||||
| } \ | |||||
| if (x->isa<Int64Imm>() && y->isa<Int32Imm>()) { \ | if (x->isa<Int64Imm>() && y->isa<Int32Imm>()) { \ | ||||
| bool sum = InnerScalar##op_t(GetValue<int64_t>(x), GetValue<int>(y)); \ | bool sum = InnerScalar##op_t(GetValue<int64_t>(x), GetValue<int>(y)); \ | ||||
| return MakeValue(sum); \ | return MakeValue(sum); \ | ||||
| @@ -37,9 +37,10 @@ ValuePtr ScalarSub(const ValuePtrList& list); | |||||
| ValuePtr ScalarMul(const ValuePtrList& list); | ValuePtr ScalarMul(const ValuePtrList& list); | ||||
| ValuePtr ScalarDiv(const ValuePtrList& list); | ValuePtr ScalarDiv(const ValuePtrList& list); | ||||
| ValuePtr ScalarMod(const ValuePtrList& list); | ValuePtr ScalarMod(const ValuePtrList& list); | ||||
| ValuePtr ScalarPow(const ValuePtrList& list); | |||||
| ValuePtr ScalarFloordiv(const ValuePtrList& list); | |||||
| ValuePtr ScalarUAdd(const ValuePtrList& list); | ValuePtr ScalarUAdd(const ValuePtrList& list); | ||||
| ValuePtr ScalarUSub(const ValuePtrList& list); | ValuePtr ScalarUSub(const ValuePtrList& list); | ||||
| ValuePtr ScalarUSub(const ValuePtrList& list); | |||||
| ValuePtr ScalarLog(const ValuePtrList& list); | ValuePtr ScalarLog(const ValuePtrList& list); | ||||
| ValuePtr ScalarEq(const ValuePtrList& list); | ValuePtr ScalarEq(const ValuePtrList& list); | ||||
| ValuePtr ScalarLt(const ValuePtrList& list); | ValuePtr ScalarLt(const ValuePtrList& list); | ||||
| @@ -88,14 +88,17 @@ std::map<SignatureEnumDType, size_t> GetMaxDtypeIndex(const std::vector<Signatur | |||||
| if (indexs.size() < 2) { | if (indexs.size() < 2) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| size_t m_index = indexs[0]; | |||||
| for (size_t i = 1; i < indexs.size(); ++i) { | |||||
| if (args_spec_list[indexs[i]]->isa<abstract::AbstractTensor>()) { | |||||
| m_index = indexs[i]; | |||||
| for (const auto& index : indexs) { | |||||
| AbstractBasePtr arg_value = args_spec_list[index]; | |||||
| if (arg_value->isa<abstract::AbstractRef>()) { | |||||
| arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref(); | |||||
| } | |||||
| if (arg_value->isa<abstract::AbstractTensor>()) { | |||||
| (void)dst_type.insert(std::make_pair(type, index)); | |||||
| break; | |||||
| } | } | ||||
| } | |||||
| if (args_spec_list[m_index]->isa<abstract::AbstractTensor>()) { | |||||
| (void)dst_type.insert(std::make_pair(type, m_index)); | |||||
| } | } | ||||
| } | } | ||||
| return dst_type; | return dst_type; | ||||
| @@ -119,15 +122,19 @@ void DoAutoCast(const std::vector<Signature>& signature, const abstract::Abstrac | |||||
| (void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes), | (void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes), | ||||
| [](const Signature& sig) { return sig.dtype; }); | [](const Signature& sig) { return sig.dtype; }); | ||||
| int empty_dtype_count = std::count(dtypes.begin(), dtypes.end(), SignatureEnumDType::kDTypeEmptyDefaultValue); | int empty_dtype_count = std::count(dtypes.begin(), dtypes.end(), SignatureEnumDType::kDTypeEmptyDefaultValue); | ||||
| if (dtypes.size() == 0 || static_cast<int>(dtypes.size()) == empty_dtype_count) { | |||||
| if (dtypes.empty() || static_cast<int>(dtypes.size()) == empty_dtype_count) { | |||||
| return; | return; | ||||
| } | } | ||||
| // Stat the index of the arguments with the largest type in the same SignatureEnumDType. | // Stat the index of the arguments with the largest type in the same SignatureEnumDType. | ||||
| std::map<SignatureEnumDType, size_t> dst_type = GetMaxDtypeIndex(dtypes, args_spec_list); | std::map<SignatureEnumDType, size_t> dst_type = GetMaxDtypeIndex(dtypes, args_spec_list); | ||||
| // Identify which arg requires auto cast | // Identify which arg requires auto cast | ||||
| for (size_t i = 0; i < args_spec_list.size(); ++i) { | for (size_t i = 0; i < args_spec_list.size(); ++i) { | ||||
| AbstractBasePtr arg_value = args_spec_list[i]; | |||||
| if (arg_value->isa<abstract::AbstractRef>()) { | |||||
| arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref(); | |||||
| } | |||||
| auto it = dst_type.find(dtypes[i]); | auto it = dst_type.find(dtypes[i]); | ||||
| if (it == dst_type.end() || it->second == i || !args_spec_list[i]->isa<abstract::AbstractScalar>()) { | |||||
| if (it == dst_type.end() || it->second == i || !arg_value->isa<abstract::AbstractScalar>()) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| // get source node for cast | // get source node for cast | ||||
| @@ -28,6 +28,7 @@ const PrimitivePtr kPrimScalarAdd = std::make_shared<Primitive>("scalar_add"); | |||||
| const PrimitivePtr kPrimScalarSub = std::make_shared<Primitive>("scalar_sub"); | const PrimitivePtr kPrimScalarSub = std::make_shared<Primitive>("scalar_sub"); | ||||
| const PrimitivePtr kPrimScalarMul = std::make_shared<Primitive>("scalar_mul"); | const PrimitivePtr kPrimScalarMul = std::make_shared<Primitive>("scalar_mul"); | ||||
| const PrimitivePtr kPrimScalarDiv = std::make_shared<Primitive>("scalar_div"); | const PrimitivePtr kPrimScalarDiv = std::make_shared<Primitive>("scalar_div"); | ||||
| const PrimitivePtr kPrimScalarFloordiv = std::make_shared<Primitive>("scalar_floordiv"); | |||||
| const PrimitivePtr kPrimScalarMod = std::make_shared<Primitive>("scalar_mod"); | const PrimitivePtr kPrimScalarMod = std::make_shared<Primitive>("scalar_mod"); | ||||
| const PrimitivePtr kPrimScalarPow = std::make_shared<Primitive>("scalar_pow"); | const PrimitivePtr kPrimScalarPow = std::make_shared<Primitive>("scalar_pow"); | ||||
| const PrimitivePtr kPrimScalarTrunc = std::make_shared<Primitive>("scalar_trunc"); | const PrimitivePtr kPrimScalarTrunc = std::make_shared<Primitive>("scalar_trunc"); | ||||
| @@ -78,6 +79,7 @@ const PrimitivePtr kPrimCreateInstance = std::make_shared<Primitive>("create_ins | |||||
| // Structure | // Structure | ||||
| const PrimitivePtr kPrimStringEqual = std::make_shared<Primitive>("string_equal"); | const PrimitivePtr kPrimStringEqual = std::make_shared<Primitive>("string_equal"); | ||||
| const PrimitivePtr kPrimStringConcat = std::make_shared<Primitive>("string_concat"); | |||||
| const PrimitivePtr kPrimMakeTuple = std::make_shared<Primitive>("make_tuple"); | const PrimitivePtr kPrimMakeTuple = std::make_shared<Primitive>("make_tuple"); | ||||
| const PrimitivePtr kPrimMakeList = std::make_shared<Primitive>("make_list"); | const PrimitivePtr kPrimMakeList = std::make_shared<Primitive>("make_list"); | ||||
| const PrimitivePtr kPrimMakeDict = std::make_shared<Primitive>("make_dict"); | const PrimitivePtr kPrimMakeDict = std::make_shared<Primitive>("make_dict"); | ||||
| @@ -221,6 +223,8 @@ const PrimitivePtr kPrimBroadcastGradientArgs = std::make_shared<Primitive>("Bro | |||||
| const PrimitivePtr kPrimControlDepend = std::make_shared<Primitive>("ControlDepend"); | const PrimitivePtr kPrimControlDepend = std::make_shared<Primitive>("ControlDepend"); | ||||
| const PrimitivePtr kPrimIs_ = std::make_shared<Primitive>("is_"); | const PrimitivePtr kPrimIs_ = std::make_shared<Primitive>("is_"); | ||||
| const PrimitivePtr kPrimIsNot = std::make_shared<Primitive>("is_not"); | const PrimitivePtr kPrimIsNot = std::make_shared<Primitive>("is_not"); | ||||
| const PrimitivePtr kPrimInDict = std::make_shared<Primitive>("in_dict"); | |||||
| const PrimitivePtr kPrimNotInDict = std::make_shared<Primitive>("not_in_dict"); | |||||
| // Comm ops | // Comm ops | ||||
| const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator"); | const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator"); | ||||
| @@ -34,6 +34,7 @@ extern const PrimitivePtr kPrimScalarAdd; | |||||
| extern const PrimitivePtr kPrimScalarSub; | extern const PrimitivePtr kPrimScalarSub; | ||||
| extern const PrimitivePtr kPrimScalarMul; | extern const PrimitivePtr kPrimScalarMul; | ||||
| extern const PrimitivePtr kPrimScalarDiv; | extern const PrimitivePtr kPrimScalarDiv; | ||||
| extern const PrimitivePtr kPrimScalarFloordiv; | |||||
| extern const PrimitivePtr kPrimScalarMod; | extern const PrimitivePtr kPrimScalarMod; | ||||
| extern const PrimitivePtr kPrimScalarPow; | extern const PrimitivePtr kPrimScalarPow; | ||||
| extern const PrimitivePtr kPrimScalarTrunc; | extern const PrimitivePtr kPrimScalarTrunc; | ||||
| @@ -84,6 +85,7 @@ extern const PrimitivePtr kPrimCreateInstance; | |||||
| // Structure | // Structure | ||||
| extern const PrimitivePtr kPrimStringEqual; | extern const PrimitivePtr kPrimStringEqual; | ||||
| extern const PrimitivePtr kPrimStringConcat; | |||||
| extern const PrimitivePtr kPrimMakeTuple; | extern const PrimitivePtr kPrimMakeTuple; | ||||
| extern const PrimitivePtr kPrimMakeList; | extern const PrimitivePtr kPrimMakeList; | ||||
| extern const PrimitivePtr kPrimMakeDict; | extern const PrimitivePtr kPrimMakeDict; | ||||
| @@ -227,8 +229,8 @@ extern const PrimitivePtr kPrimBroadcastGradientArgs; | |||||
| extern const PrimitivePtr kPrimControlDepend; | extern const PrimitivePtr kPrimControlDepend; | ||||
| extern const PrimitivePtr kPrimIs_; | extern const PrimitivePtr kPrimIs_; | ||||
| extern const PrimitivePtr kPrimIsNot; | extern const PrimitivePtr kPrimIsNot; | ||||
| extern const PrimitivePtr kPrimMinimumGrad; | |||||
| extern const PrimitivePtr kPrimMaximumGrad; | |||||
| extern const PrimitivePtr kPrimInDict; | |||||
| extern const PrimitivePtr kPrimNotInDict; | |||||
| // Comm ops | // Comm ops | ||||
| extern const PrimitivePtr kPrimMirror; | extern const PrimitivePtr kPrimMirror; | ||||
| @@ -114,12 +114,12 @@ void FusedBatchNormCheckDim(const PrimitivePtr &primitive, const AbstractBasePtr | |||||
| AbstractTensorPtr arg = CheckArg<AbstractTensor>(op_name, args_spec_list, i); | AbstractTensorPtr arg = CheckArg<AbstractTensor>(op_name, args_spec_list, i); | ||||
| ShapePtr arg_shape = dyn_cast<Shape>(arg->GetShapeTrack()); | ShapePtr arg_shape = dyn_cast<Shape>(arg->GetShapeTrack()); | ||||
| if (arg_shape == nullptr) { | if (arg_shape == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "" << op_name << " type of args[" << i << "] should be Shape, but " << arg->ToString(); | |||||
| MS_LOG(EXCEPTION) << op_name << " type of args[" << i << "] should be Shape, but " << arg->ToString(); | |||||
| } | } | ||||
| if (i == 0) { | if (i == 0) { | ||||
| if (arg_shape->shape().size() < 2) { | if (arg_shape->shape().size() < 2) { | ||||
| MS_LOG(EXCEPTION) << "" << op_name << " shape of args[" << i | |||||
| MS_LOG(EXCEPTION) << op_name << " shape of args[" << i | |||||
| << "] should be TensorShape with dimension greater than 1, but shape: " | << "] should be TensorShape with dimension greater than 1, but shape: " | ||||
| << arg_shape->ToString(); | << arg_shape->ToString(); | ||||
| } | } | ||||
| @@ -127,7 +127,7 @@ void FusedBatchNormCheckDim(const PrimitivePtr &primitive, const AbstractBasePtr | |||||
| } | } | ||||
| if (arg_shape->shape().size() != 1) { | if (arg_shape->shape().size() != 1) { | ||||
| MS_LOG(EXCEPTION) << "" << op_name << " shape of args[" << i | |||||
| MS_LOG(EXCEPTION) << op_name << " shape of args[" << i | |||||
| << "] should be TensorShape with dimension: 1, but shape: " << arg_shape->ToString(); | << "] should be TensorShape with dimension: 1, but shape: " << arg_shape->ToString(); | ||||
| } | } | ||||
| } | } | ||||
| @@ -159,7 +159,7 @@ AbstractBasePtr InferImplFusedBatchNorm(const AnalysisEnginePtr &, const Primiti | |||||
| MS_LOG(EXCEPTION) << "Arg shape size should >= 1."; | MS_LOG(EXCEPTION) << "Arg shape size should >= 1."; | ||||
| } | } | ||||
| if (arg_shape_list[0] != input_shape_list[1]) { | if (arg_shape_list[0] != input_shape_list[1]) { | ||||
| MS_LOG(EXCEPTION) << "" << op_name << " size of tensor param[" << i << "](which is " << arg_shape_list[0] | |||||
| MS_LOG(EXCEPTION) << op_name << " size of tensor param[" << i << "](which is " << arg_shape_list[0] | |||||
| << ") should match the second dimension of tensor" | << ") should match the second dimension of tensor" | ||||
| " param[0](which is " | " param[0](which is " | ||||
| << input_shape_list[1] << ")."; | << input_shape_list[1] << ")."; | ||||
| @@ -378,7 +378,7 @@ AbstractBasePtr InferImplDropoutGenMask(const AnalysisEnginePtr &, const Primiti | |||||
| TypePtr prob_type = keep_prob->element()->BuildType(); | TypePtr prob_type = keep_prob->element()->BuildType(); | ||||
| if ((prob_type->type_id() != kNumberTypeFloat16) && (prob_type->type_id() != kNumberTypeFloat32)) { | if ((prob_type->type_id() != kNumberTypeFloat16) && (prob_type->type_id() != kNumberTypeFloat32)) { | ||||
| MS_LOG(EXCEPTION) << "" << op_name << " keep_prob type should be float16 or float32, but " << prob_type->ToString() | |||||
| MS_LOG(EXCEPTION) << op_name << " keep_prob type should be float16 or float32, but " << prob_type->ToString() | |||||
| << "."; | << "."; | ||||
| } | } | ||||
| @@ -169,5 +169,36 @@ AbstractBasePtr InferImplIsNot(const AnalysisEnginePtr &, const PrimitivePtr &pr | |||||
| return std::make_shared<AbstractScalar>(!(*t == *x)); | return std::make_shared<AbstractScalar>(!(*t == *x)); | ||||
| } | } | ||||
| bool IsInDict(const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { | |||||
| const std::string op_name = primitive->name(); | |||||
| CheckArgsSize(op_name, args_spec_list, 2); | |||||
| auto key = CheckArg<AbstractScalar>(op_name, args_spec_list, 0); | |||||
| auto dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 1); | |||||
| ValuePtr key_value = key->BuildValue(); | |||||
| if (!key_value->isa<StringImm>()) { | |||||
| MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString(); | |||||
| } | |||||
| auto key_str = GetValue<std::string>(key_value); | |||||
| std::vector<AbstractAttribute> dict_elems = dict->elements(); | |||||
| auto it = std::find_if(dict_elems.begin(), dict_elems.end(), | |||||
| [key_str](const AbstractAttribute &item) { return item.first == key_str; }); | |||||
| return it != dict_elems.end(); | |||||
| } | |||||
| AbstractBasePtr InferImplInDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| // statement: x in t | |||||
| // Inputs: x, t | |||||
| return std::make_shared<AbstractScalar>(IsInDict(primitive, args_spec_list)); | |||||
| } | |||||
| AbstractBasePtr InferImplNotInDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| // statement: x not in t | |||||
| // Inputs: x, t | |||||
| return std::make_shared<AbstractScalar>(!IsInDict(primitive, args_spec_list)); | |||||
| } | |||||
| } // namespace abstract | } // namespace abstract | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -36,7 +36,7 @@ AbstractBasePtr InferImplStringEqual(const AnalysisEnginePtr &, const PrimitiveP | |||||
| ValuePtr value_x = scalar_x->BuildValue(); | ValuePtr value_x = scalar_x->BuildValue(); | ||||
| ValuePtr value_y = scalar_y->BuildValue(); | ValuePtr value_y = scalar_y->BuildValue(); | ||||
| if (!value_x->isa<StringImm>() || !value_y->isa<StringImm>()) { | if (!value_x->isa<StringImm>() || !value_y->isa<StringImm>()) { | ||||
| MS_LOG(EXCEPTION) << "" << op_name << " requires 2 parameters are string, but got param0: " << value_x->ToString() | |||||
| MS_LOG(EXCEPTION) << op_name << " requires 2 parameters are string, but got param0: " << value_x->ToString() | |||||
| << ", param1: " << value_y->ToString(); | << ", param1: " << value_y->ToString(); | ||||
| } | } | ||||
| @@ -44,6 +44,25 @@ AbstractBasePtr InferImplStringEqual(const AnalysisEnginePtr &, const PrimitiveP | |||||
| return std::make_shared<AbstractScalar>(ret); | return std::make_shared<AbstractScalar>(ret); | ||||
| } | } | ||||
| AbstractBasePtr InferImplStringConcat(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| // Inputs: two scalars whose value is a string. | |||||
| const std::string op_name = primitive->name(); | |||||
| CheckArgsSize(op_name, args_spec_list, 2); | |||||
| AbstractScalarPtr scalar_x = CheckArg<AbstractScalar>(op_name, args_spec_list, 0); | |||||
| AbstractScalarPtr scalar_y = CheckArg<AbstractScalar>(op_name, args_spec_list, 1); | |||||
| ValuePtr value_x = scalar_x->BuildValue(); | |||||
| ValuePtr value_y = scalar_y->BuildValue(); | |||||
| if (!value_x->isa<StringImm>() || !value_y->isa<StringImm>()) { | |||||
| MS_LOG(EXCEPTION) << op_name << " requires 2 parameters are string, but got param0: " << value_x->ToString() | |||||
| << ", param1: " << value_y->ToString(); | |||||
| } | |||||
| std::string ret = (value_x->cast<StringImmPtr>()->value() + value_y->cast<StringImmPtr>()->value()); | |||||
| return std::make_shared<AbstractScalar>(ret); | |||||
| } | |||||
| AbstractBasePtr InferImplMakeTuple(const AnalysisEnginePtr &, const PrimitivePtr &, | AbstractBasePtr InferImplMakeTuple(const AnalysisEnginePtr &, const PrimitivePtr &, | ||||
| const AbstractBasePtrList &args_spec_list) { | const AbstractBasePtrList &args_spec_list) { | ||||
| return std::make_shared<AbstractTuple>(args_spec_list); | return std::make_shared<AbstractTuple>(args_spec_list); | ||||
| @@ -64,7 +83,7 @@ AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr | |||||
| size_t keys_size = keys->size(); | size_t keys_size = keys->size(); | ||||
| if (values->size() != keys_size) { | if (values->size() != keys_size) { | ||||
| MS_LOG(EXCEPTION) << "" << op_name << " evaluator keys' size is not equal with values' size"; | |||||
| MS_LOG(EXCEPTION) << op_name << " evaluator keys' size is not equal with values' size"; | |||||
| } | } | ||||
| std::vector<AbstractAttribute> key_value; | std::vector<AbstractAttribute> key_value; | ||||
| @@ -76,7 +95,7 @@ AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr | |||||
| ValuePtr keyPtr = key->BuildValue(); | ValuePtr keyPtr = key->BuildValue(); | ||||
| MS_EXCEPTION_IF_NULL(keyPtr); | MS_EXCEPTION_IF_NULL(keyPtr); | ||||
| if (!keyPtr->isa<StringImm>()) { | if (!keyPtr->isa<StringImm>()) { | ||||
| MS_LOG(EXCEPTION) << "" << op_name << " evaluator keys should be string, but got " << keyPtr->ToString(); | |||||
| MS_LOG(EXCEPTION) << op_name << " evaluator keys should be string, but got " << keyPtr->ToString(); | |||||
| } | } | ||||
| std::string key_string = GetValue<std::string>(keyPtr); | std::string key_string = GetValue<std::string>(keyPtr); | ||||
| key_value.emplace_back(key_string, value_list[index]); | key_value.emplace_back(key_string, value_list[index]); | ||||
| @@ -93,7 +112,7 @@ AbstractBasePtr InferImplMakeKwarg(const AnalysisEnginePtr &, const PrimitivePtr | |||||
| ValuePtr keyPtr = key->BuildValue(); | ValuePtr keyPtr = key->BuildValue(); | ||||
| if (!keyPtr->isa<StringImm>()) { | if (!keyPtr->isa<StringImm>()) { | ||||
| MS_LOG(EXCEPTION) << "" << op_name << " evaluator key should be string, but got " << keyPtr->ToString(); | |||||
| MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << keyPtr->ToString(); | |||||
| } | } | ||||
| std::string key_string = GetValue<std::string>(keyPtr); | std::string key_string = GetValue<std::string>(keyPtr); | ||||
| return std::make_shared<AbstractKeywordArg>(key_string, args_spec_list[1]); | return std::make_shared<AbstractKeywordArg>(key_string, args_spec_list[1]); | ||||
| @@ -109,14 +128,13 @@ AbstractBasePtr InferImplExtractKwarg(const AnalysisEnginePtr &, const Primitive | |||||
| ValuePtr key_value = key->BuildValue(); | ValuePtr key_value = key->BuildValue(); | ||||
| if (!key_value->isa<StringImm>()) { | if (!key_value->isa<StringImm>()) { | ||||
| MS_LOG(EXCEPTION) << "" << op_name << " evaluator key should be string, but got " << key_value->ToString(); | |||||
| MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString(); | |||||
| } | } | ||||
| std::string key_input = GetValue<std::string>(key_value); | std::string key_input = GetValue<std::string>(key_value); | ||||
| std::string key_actual = kwarg->get_key(); | std::string key_actual = kwarg->get_key(); | ||||
| if (key_actual != key_input) { | if (key_actual != key_input) { | ||||
| MS_LOG(EXCEPTION) << "" << op_name | |||||
| << " evaluator input key should be same as AbstractKeywordArg' key, but input is " << key_input | |||||
| << ", AbstractKeywordArg' key is " << key_actual; | |||||
| MS_LOG(EXCEPTION) << op_name << " evaluator input key should be same as AbstractKeywordArg' key, but input is " | |||||
| << key_input << ", AbstractKeywordArg' key is " << key_actual; | |||||
| } | } | ||||
| return kwarg->get_arg(); | return kwarg->get_arg(); | ||||
| } | } | ||||
| @@ -187,13 +205,12 @@ AbstractBasePtr InferTupleOrListGetItem(const std::string &op_name, const Abstra | |||||
| ValuePtr index_value = index->BuildValue(); | ValuePtr index_value = index->BuildValue(); | ||||
| if (!index_value->isa<Int32Imm>()) { | if (!index_value->isa<Int32Imm>()) { | ||||
| MS_LOG(EXCEPTION) << "" << op_name << " evaluator index should be an int32 number, but got " | |||||
| << index_value->ToString(); | |||||
| MS_LOG(EXCEPTION) << op_name << " evaluator index should be an int32 number, but got " << index_value->ToString(); | |||||
| } | } | ||||
| int idx_v = GetValue<int>(index_value); | int idx_v = GetValue<int>(index_value); | ||||
| std::size_t nelems = queue->elements().size(); | std::size_t nelems = queue->elements().size(); | ||||
| if (idx_v >= SizeToInt(nelems) || idx_v < -SizeToInt(nelems)) { | if (idx_v >= SizeToInt(nelems) || idx_v < -SizeToInt(nelems)) { | ||||
| MS_LOG(EXCEPTION) << "" << op_name << " evaluator index should be in range[-" << SizeToInt(nelems) << ", " | |||||
| MS_LOG(EXCEPTION) << op_name << " evaluator index should be in range[-" << SizeToInt(nelems) << ", " | |||||
| << SizeToInt(nelems) << "), but got " << idx_v << "."; | << SizeToInt(nelems) << "), but got " << idx_v << "."; | ||||
| } | } | ||||
| @@ -215,8 +232,7 @@ AbstractBasePtr InferTupleOrListSetItem(const std::string &op_name, const Abstra | |||||
| ValuePtr index_value = index->BuildValue(); | ValuePtr index_value = index->BuildValue(); | ||||
| if (!index_value->isa<Int32Imm>()) { | if (!index_value->isa<Int32Imm>()) { | ||||
| MS_LOG(EXCEPTION) << "" << op_name << " evaluator index should be an int32 number, but got " | |||||
| << index_value->ToString(); | |||||
| MS_LOG(EXCEPTION) << op_name << " evaluator index should be an int32 number, but got " << index_value->ToString(); | |||||
| } | } | ||||
| int idx_v = GetValue<int>(index_value); | int idx_v = GetValue<int>(index_value); | ||||
| if (idx_v < 0) { | if (idx_v < 0) { | ||||
| @@ -227,8 +243,7 @@ AbstractBasePtr InferTupleOrListSetItem(const std::string &op_name, const Abstra | |||||
| AbstractBasePtrList elements = queue->elements(); | AbstractBasePtrList elements = queue->elements(); | ||||
| std::size_t nelems = elements.size(); | std::size_t nelems = elements.size(); | ||||
| if (uidx_v >= nelems) { | if (uidx_v >= nelems) { | ||||
| MS_LOG(EXCEPTION) << "" << op_name << " evaluator the index: " << uidx_v << " to set out of range: " << nelems - 1 | |||||
| << "."; | |||||
| MS_LOG(EXCEPTION) << op_name << " evaluator the index: " << uidx_v << " to set out of range: " << nelems - 1 << "."; | |||||
| } | } | ||||
| elements[uidx_v] = args_spec_list[2]; | elements[uidx_v] = args_spec_list[2]; | ||||
| return std::make_shared<T>(elements); | return std::make_shared<T>(elements); | ||||
| @@ -264,12 +279,12 @@ AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitiveP | |||||
| ValuePtr key_value = key->BuildValue(); | ValuePtr key_value = key->BuildValue(); | ||||
| if (!key_value->isa<StringImm>()) { | if (!key_value->isa<StringImm>()) { | ||||
| MS_LOG(EXCEPTION) << "" << op_name << " evaluator key should be string, but got " << key_value->ToString(); | |||||
| MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString(); | |||||
| } | } | ||||
| std::string key_str = GetValue<std::string>(key_value); | |||||
| auto key_str = GetValue<std::string>(key_value); | |||||
| std::vector<AbstractAttribute> dict_elems = dict->elements(); | std::vector<AbstractAttribute> dict_elems = dict->elements(); | ||||
| auto it = std::find_if(dict_elems.begin(), dict_elems.end(), | auto it = std::find_if(dict_elems.begin(), dict_elems.end(), | ||||
| [key_str](AbstractAttribute &item) { return item.first == key_str; }); | |||||
| [key_str](const AbstractAttribute &item) { return item.first == key_str; }); | |||||
| if (it == dict_elems.end()) { | if (it == dict_elems.end()) { | ||||
| MS_LOG(EXCEPTION) << "The key " << key_str << " does not exist in the dict:" << args_spec_list[0]->ToString(); | MS_LOG(EXCEPTION) << "The key " << key_str << " does not exist in the dict:" << args_spec_list[0]->ToString(); | ||||
| @@ -287,7 +302,7 @@ AbstractBasePtr InferImplDictSetItem(const AnalysisEnginePtr &, const PrimitiveP | |||||
| ValuePtr key_value = key->BuildValue(); | ValuePtr key_value = key->BuildValue(); | ||||
| if (!key_value->isa<StringImm>()) { | if (!key_value->isa<StringImm>()) { | ||||
| MS_LOG(EXCEPTION) << "" << op_name << " evaluator key should be string, but got " << key_value->ToString(); | |||||
| MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString(); | |||||
| } | } | ||||
| std::string key_str = GetValue<std::string>(key_value); | std::string key_str = GetValue<std::string>(key_value); | ||||
| std::vector<AbstractAttribute> dict_elems = dict->elements(); | std::vector<AbstractAttribute> dict_elems = dict->elements(); | ||||
| @@ -446,27 +461,27 @@ AbstractBasePtr InferImplReduceShape(const AnalysisEnginePtr &, const PrimitiveP | |||||
| auto x_shp_value = shape_x->BuildValue(); | auto x_shp_value = shape_x->BuildValue(); | ||||
| if (x_shp_value->isa<AnyValue>()) { | if (x_shp_value->isa<AnyValue>()) { | ||||
| MS_LOG(EXCEPTION) << "" << op_name | |||||
| MS_LOG(EXCEPTION) << op_name | |||||
| << " evaluator shape's data field can't be anything: " << args_spec_list[1]->ToString(); | << " evaluator shape's data field can't be anything: " << args_spec_list[1]->ToString(); | ||||
| } | } | ||||
| // Axis can be scalar, tuple or None | // Axis can be scalar, tuple or None | ||||
| AbstractTuplePtr axis = nullptr; | AbstractTuplePtr axis = nullptr; | ||||
| if (args_spec_list[1]->isa<AbstractScalar>()) { | if (args_spec_list[1]->isa<AbstractScalar>()) { | ||||
| MS_LOG(DEBUG) << "" << op_name << " evaluator second parameter is scalar"; | |||||
| MS_LOG(DEBUG) << op_name << " evaluator second parameter is scalar"; | |||||
| AbstractBasePtrList axis_list = {dyn_cast<AbstractScalar>(args_spec_list[1])}; | AbstractBasePtrList axis_list = {dyn_cast<AbstractScalar>(args_spec_list[1])}; | ||||
| axis = std::make_shared<AbstractTuple>(axis_list); | axis = std::make_shared<AbstractTuple>(axis_list); | ||||
| } else if (args_spec_list[1]->isa<AbstractTuple>()) { | } else if (args_spec_list[1]->isa<AbstractTuple>()) { | ||||
| MS_LOG(DEBUG) << "" << op_name << " evaluator second parameter is tuple"; | |||||
| MS_LOG(DEBUG) << op_name << " evaluator second parameter is tuple"; | |||||
| axis = args_spec_list[1]->cast<AbstractTuplePtr>(); | axis = args_spec_list[1]->cast<AbstractTuplePtr>(); | ||||
| } else { | } else { | ||||
| MS_LOG(EXCEPTION) << "" << op_name << " evaluator second parameter should be a scalar or tuple, but got " | |||||
| MS_LOG(EXCEPTION) << op_name << " evaluator second parameter should be a scalar or tuple, but got " | |||||
| << args_spec_list[1]->ToString(); | << args_spec_list[1]->ToString(); | ||||
| } | } | ||||
| auto axis_value = axis->BuildValue(); | auto axis_value = axis->BuildValue(); | ||||
| if (axis_value->isa<AnyValue>()) { | if (axis_value->isa<AnyValue>()) { | ||||
| MS_LOG(EXCEPTION) << "" << op_name | |||||
| MS_LOG(EXCEPTION) << op_name | |||||
| << " evaluator shape's data field can't be anything: " << args_spec_list[1]->ToString(); | << " evaluator shape's data field can't be anything: " << args_spec_list[1]->ToString(); | ||||
| } | } | ||||
| auto axis_value_ptr = axis_value->cast<ValueTuplePtr>(); | auto axis_value_ptr = axis_value->cast<ValueTuplePtr>(); | ||||
| @@ -24,36 +24,35 @@ namespace mindspore { | |||||
| namespace prim { | namespace prim { | ||||
| PrimToFunction::PrimToFunction() | PrimToFunction::PrimToFunction() | ||||
| : prim_func_type_map_({ | |||||
| // ONE_ARG prim | |||||
| {"bool_not", kPrimTypeOneArg}, | |||||
| {"scalar_cos", kPrimTypeOneArg}, | |||||
| {"scalar_exp", kPrimTypeOneArg}, | |||||
| {"scalar_floor", kPrimTypeOneArg}, | |||||
| {"scalar_log", kPrimTypeOneArg}, | |||||
| {"scalar_sin", kPrimTypeOneArg}, | |||||
| {"scalar_tan", kPrimTypeOneArg}, | |||||
| {"scalar_trunc", kPrimTypeOneArg}, | |||||
| {"typeof", kPrimTypeOneArg}, | |||||
| {"scalar_uadd", kPrimTypeOneArg}, | |||||
| {"scalar_usub", kPrimTypeOneArg}, | |||||
| // TWO_ARGS prim | |||||
| {"scalar_add", kPrimTypeTwoArgs}, | |||||
| {"bool_and", kPrimTypeTwoArgs}, | |||||
| {"bool_eq", kPrimTypeTwoArgs}, | |||||
| {"bool_or", kPrimTypeTwoArgs}, | |||||
| {"scalar_div", kPrimTypeTwoArgs}, | |||||
| {"scalar_eq", kPrimTypeTwoArgs}, | |||||
| {"scalar_ge", kPrimTypeTwoArgs}, | |||||
| {"scalar_gt", kPrimTypeTwoArgs}, | |||||
| {"scalar_le", kPrimTypeTwoArgs}, | |||||
| {"scalar_lt", kPrimTypeTwoArgs}, | |||||
| {"scalar_ne", kPrimTypeTwoArgs}, | |||||
| {"scalar_mod", kPrimTypeTwoArgs}, | |||||
| {"scalar_mul", kPrimTypeTwoArgs}, | |||||
| {"scalar_pow", kPrimTypeTwoArgs}, | |||||
| {"scalar_sub", kPrimTypeTwoArgs}, | |||||
| }) {} | |||||
| : prim_func_type_map_({// ONE_ARG prim | |||||
| {"bool_not", kPrimTypeOneArg}, | |||||
| {"scalar_cos", kPrimTypeOneArg}, | |||||
| {"scalar_exp", kPrimTypeOneArg}, | |||||
| {"scalar_floor", kPrimTypeOneArg}, | |||||
| {"scalar_log", kPrimTypeOneArg}, | |||||
| {"scalar_sin", kPrimTypeOneArg}, | |||||
| {"scalar_tan", kPrimTypeOneArg}, | |||||
| {"scalar_trunc", kPrimTypeOneArg}, | |||||
| {"typeof", kPrimTypeOneArg}, | |||||
| {"scalar_uadd", kPrimTypeOneArg}, | |||||
| {"scalar_usub", kPrimTypeOneArg}, | |||||
| // TWO_ARGS prim | |||||
| {"scalar_add", kPrimTypeTwoArgs}, | |||||
| {"bool_and", kPrimTypeTwoArgs}, | |||||
| {"bool_eq", kPrimTypeTwoArgs}, | |||||
| {"bool_or", kPrimTypeTwoArgs}, | |||||
| {"scalar_div", kPrimTypeTwoArgs}, | |||||
| {"scalar_eq", kPrimTypeTwoArgs}, | |||||
| {"scalar_ge", kPrimTypeTwoArgs}, | |||||
| {"scalar_gt", kPrimTypeTwoArgs}, | |||||
| {"scalar_le", kPrimTypeTwoArgs}, | |||||
| {"scalar_lt", kPrimTypeTwoArgs}, | |||||
| {"scalar_ne", kPrimTypeTwoArgs}, | |||||
| {"scalar_mod", kPrimTypeTwoArgs}, | |||||
| {"scalar_mul", kPrimTypeTwoArgs}, | |||||
| {"scalar_pow", kPrimTypeTwoArgs}, | |||||
| {"scalar_sub", kPrimTypeTwoArgs}, | |||||
| {"scalar_floordiv", kPrimTypeTwoArgs}}) {} | |||||
| bool PrimToFunction::GetFunction(const PrimitivePtr& prim, FunctionPtr* const func) const { | bool PrimToFunction::GetFunction(const PrimitivePtr& prim, FunctionPtr* const func) const { | ||||
| bool result = false; | bool result = false; | ||||
| @@ -52,6 +52,8 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||||
| {prim::kPrimSwitch, {InferImplSwitch, true}}, | {prim::kPrimSwitch, {InferImplSwitch, true}}, | ||||
| {prim::kPrimIs_, {InferImplIs_, true}}, | {prim::kPrimIs_, {InferImplIs_, true}}, | ||||
| {prim::kPrimIsNot, {InferImplIsNot, true}}, | {prim::kPrimIsNot, {InferImplIsNot, true}}, | ||||
| {prim::kPrimInDict, {InferImplInDict, true}}, | |||||
| {prim::kPrimNotInDict, {InferImplNotInDict, true}}, | |||||
| // Maths | // Maths | ||||
| {prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, true}}, | {prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, true}}, | ||||
| {prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, true}}, | {prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, true}}, | ||||
| @@ -91,6 +93,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||||
| {prim::kPrimMakeRange, {InferImplMakeRange, false}}, | {prim::kPrimMakeRange, {InferImplMakeRange, false}}, | ||||
| {prim::kPrimStopGradient, {InferImplStopGradient, false}}, | {prim::kPrimStopGradient, {InferImplStopGradient, false}}, | ||||
| {prim::kPrimStringEqual, {InferImplStringEqual, false}}, | {prim::kPrimStringEqual, {InferImplStringEqual, false}}, | ||||
| {prim::kPrimStringConcat, {InferImplStringConcat, false}}, | |||||
| {prim::kPrimDictLen, {InferImplDictLen, false}}, | {prim::kPrimDictLen, {InferImplDictLen, false}}, | ||||
| // NN | // NN | ||||
| {prim::kPrimPooling, {InferImplPooling, true}}, | {prim::kPrimPooling, {InferImplPooling, true}}, | ||||
| @@ -988,6 +991,8 @@ PrimitiveToImplMap &GetUniformPrimitiveToImplMap() { | |||||
| {prim::kPrimScalarMul, {prim::ScalarMul, true, nullptr, true}}, | {prim::kPrimScalarMul, {prim::ScalarMul, true, nullptr, true}}, | ||||
| {prim::kPrimScalarDiv, {prim::ScalarDiv, true, nullptr, true}}, | {prim::kPrimScalarDiv, {prim::ScalarDiv, true, nullptr, true}}, | ||||
| {prim::kPrimScalarMod, {prim::ScalarMod, true, nullptr, true}}, | {prim::kPrimScalarMod, {prim::ScalarMod, true, nullptr, true}}, | ||||
| {prim::kPrimScalarPow, {prim::ScalarPow, true, nullptr, true}}, | |||||
| {prim::kPrimScalarFloordiv, {prim::ScalarFloordiv, true, nullptr, true}}, | |||||
| {prim::kPrimScalarUadd, {prim::ScalarUAdd, true, nullptr, true}}, | {prim::kPrimScalarUadd, {prim::ScalarUAdd, true, nullptr, true}}, | ||||
| {prim::kPrimScalarUsub, {prim::ScalarUSub, true, nullptr, true}}, | {prim::kPrimScalarUsub, {prim::ScalarUSub, true, nullptr, true}}, | ||||
| {prim::kPrimScalarLog, {prim::ScalarLog, true, nullptr, true}}, | {prim::kPrimScalarLog, {prim::ScalarLog, true, nullptr, true}}, | ||||
| @@ -178,6 +178,10 @@ AbstractBasePtr InferImplIs_(const AnalysisEnginePtr &, const PrimitivePtr &, | |||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplIsNot(const AnalysisEnginePtr &, const PrimitivePtr &, | AbstractBasePtr InferImplIsNot(const AnalysisEnginePtr &, const PrimitivePtr &, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplInDict(const AnalysisEnginePtr &, const PrimitivePtr &, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| AbstractBasePtr InferImplNotInDict(const AnalysisEnginePtr &, const PrimitivePtr &, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| AbstractBasePtr InferImplPooling(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplPooling(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplPoolingGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplPoolingGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| @@ -287,6 +291,8 @@ AbstractBasePtr InferImplStopGradient(const AnalysisEnginePtr &, const Primitive | |||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplStringEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplStringEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplStringConcat(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| AbstractBasePtr InferImplDictLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplDictLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| @@ -19,6 +19,9 @@ from .add_impl import add | |||||
| from .sub_impl import sub | from .sub_impl import sub | ||||
| from .mul_impl import mul | from .mul_impl import mul | ||||
| from .div_impl import div | from .div_impl import div | ||||
| from .pow_impl import pow_ | |||||
| from .floordiv_impl import floordiv | |||||
| from .mod_impl import mod | |||||
| from .getitem_impl import getitem | from .getitem_impl import getitem | ||||
| from .zeros_like_impl import zeros_like | from .zeros_like_impl import zeros_like | ||||
| from .ones_like_impl import ones_like | from .ones_like_impl import ones_like | ||||
| @@ -38,6 +41,9 @@ __all__ = [ | |||||
| 'sub', | 'sub', | ||||
| 'mul', | 'mul', | ||||
| 'div', | 'div', | ||||
| 'pow_', | |||||
| 'floordiv', | |||||
| 'mod', | |||||
| 'uadd', | 'uadd', | ||||
| 'zeros_like', | 'zeros_like', | ||||
| 'ones_like', | 'ones_like', | ||||
| @@ -69,6 +69,21 @@ def _scalar_add_scalar(x, y): | |||||
| return F.scalar_add(x, y) | return F.scalar_add(x, y) | ||||
| @add.register("String", "String") | |||||
| def _string_concat_string(x, y): | |||||
| """ | |||||
| Concatenate the string y to the string x. | |||||
| Args: | |||||
| x (str): The first input string. | |||||
| y (str): the second input string. | |||||
| Returns: | |||||
| str, concatenate the y to the x. | |||||
| """ | |||||
| return F.string_concat(x, y) | |||||
| @add.register("Number", "Tensor") | @add.register("Number", "Tensor") | ||||
| def _scalar_add_tensor(x, y): | def _scalar_add_tensor(x, y): | ||||
| """ | """ | ||||
| @@ -81,8 +96,7 @@ def _scalar_add_tensor(x, y): | |||||
| Returns: | Returns: | ||||
| Tensor, has the same dtype as x. | Tensor, has the same dtype as x. | ||||
| """ | """ | ||||
| z = F.scalar_to_tensor(x, F.dtype(y)) | |||||
| return F.tensor_add(z, y) | |||||
| return F.tensor_add(x, y) | |||||
| @add.register("Tensor", "Number") | @add.register("Tensor", "Number") | ||||
| @@ -97,8 +111,7 @@ def _tensor_add_scalar(x, y): | |||||
| Returns: | Returns: | ||||
| Tensor, has the same dtype as x. | Tensor, has the same dtype as x. | ||||
| """ | """ | ||||
| z = F.scalar_to_tensor(y, F.dtype(x)) | |||||
| return F.tensor_add(x, z) | |||||
| return F.tensor_add(x, y) | |||||
| @add.register("Tensor", "Tensor") | @add.register("Tensor", "Tensor") | ||||
| @@ -68,8 +68,7 @@ def _scalar_div_tensor(x, y): | |||||
| Returns: | Returns: | ||||
| Tensor, has the same dtype as x. | Tensor, has the same dtype as x. | ||||
| """ | """ | ||||
| z = F.scalar_to_tensor(x, F.dtype(y)) | |||||
| return F.tensor_div(z, y) | |||||
| return F.tensor_div(x, y) | |||||
| @div.register("Tensor", "Number") | @div.register("Tensor", "Number") | ||||
| @@ -84,5 +83,4 @@ def _tensor_div_scalar(x, y): | |||||
| Returns: | Returns: | ||||
| Tensor, has the same dtype as x. | Tensor, has the same dtype as x. | ||||
| """ | """ | ||||
| z = F.scalar_to_tensor(y, F.dtype(x)) | |||||
| return F.tensor_div(x, z) | |||||
| return F.tensor_div(x, y) | |||||
| @@ -0,0 +1,50 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Implementation for internal polymorphism `floordiv` operations.""" | |||||
| from ...composite import base | |||||
| from ... import functional as F | |||||
| floordiv = base.MultitypeFuncGraph("floordiv") | |||||
| """ | |||||
| `floordiv` is a metafuncgraph object which will compute the floordiv of two objects | |||||
| using ".register" decorator. | |||||
| """ | |||||
| @floordiv.register("Number", "Number") | |||||
| def _floordiv_scalar(x, y): | |||||
| """Returns x // y where x and y are all scalars.""" | |||||
| return F.scalar_floordiv(x, y) | |||||
| @floordiv.register("Tensor", "Tensor") | |||||
| def _floordiv_tensor(x, y): | |||||
| """Returns x // y where x and y are all tensors and have save dtype.""" | |||||
| return F.tensor_floordiv(x, y) | |||||
| @floordiv.register("Tensor", "Number") | |||||
| def _tensor_floordiv_scalar(x, y): | |||||
| """Returns x // y where x is a tensor and y is a scalar. x and y should have same dtype.""" | |||||
| return F.tensor_floordiv(x, y) | |||||
| @floordiv.register("Number", "Tensor") | |||||
| def _scalar_floordiv_tensor(x, y): | |||||
| """Returns x // y where x is a scalar and y is a tensor. x and y should have same dtype.""" | |||||
| return F.tensor_floordiv(x, y) | |||||
| @@ -0,0 +1,50 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Implementation for internal polymorphism `mod` operations.""" | |||||
| from ...composite import base | |||||
| from ... import functional as F | |||||
| mod = base.MultitypeFuncGraph("mod") | |||||
| """ | |||||
| `mod` is a metafuncgraph object which will compute the mod of two objects | |||||
| using ".register" decorator. | |||||
| """ | |||||
| @mod.register("Number", "Number") | |||||
| def _mod_scalar(x, y): | |||||
| """Returns x % y where x and y are all scalars.""" | |||||
| return F.scalar_mod(x, y) | |||||
| @mod.register("Tensor", "Tensor") | |||||
| def _mod_tensor(x, y): | |||||
| """Returns x % y where x and y are all tensors and have save dtype.""" | |||||
| return F.tensor_mod(x, y) | |||||
| @mod.register("Tensor", "Number") | |||||
| def _tensor_mod_scalar(x, y): | |||||
| """Returns x % y where x is a tensor and y is a scalar. x and y should have same dtype.""" | |||||
| return F.tensor_mod(x, y) | |||||
| @mod.register("Number", "Tensor") | |||||
| def _scalar_mod_tensor(x, y): | |||||
| """Returns x % y where x is a scalar and y is a tensor. x and y should have same dtype.""" | |||||
| return F.tensor_mod(x, y) | |||||
| @@ -56,8 +56,7 @@ def _scalar_mul_tensor(x, y): | |||||
| Outputs: | Outputs: | ||||
| Tensor, has the same dtype as x. | Tensor, has the same dtype as x. | ||||
| """ | """ | ||||
| z = F.scalar_to_tensor(x, F.dtype(y)) | |||||
| return F.tensor_mul(z, y) | |||||
| return F.tensor_mul(x, y) | |||||
| @mul.register("Tensor", "Number") | @mul.register("Tensor", "Number") | ||||
| @@ -68,5 +67,4 @@ def _tensor_mul_scalar(x, y): | |||||
| Outputs: | Outputs: | ||||
| Tensor, has the same dtype as x. | Tensor, has the same dtype as x. | ||||
| """ | """ | ||||
| z = F.scalar_to_tensor(y, F.dtype(x)) | |||||
| return F.tensor_mul(x, z) | |||||
| return F.tensor_mul(x, y) | |||||
| @@ -0,0 +1,50 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Implementation for internal polymorphism `pow` operations.""" | |||||
| from ...composite import base | |||||
| from ... import functional as F | |||||
| pow_ = base.MultitypeFuncGraph("pow") | |||||
| """ | |||||
| `pow` is a metafuncgraph object which will compute the pow of two objects | |||||
| using ".register" decorator. | |||||
| """ | |||||
| @pow_.register("Number", "Number") | |||||
| def _pow_scalar(x, y): | |||||
| """Returns x ** y where x and y are all scalars.""" | |||||
| return F.scalar_pow(x, y) | |||||
| @pow_.register("Tensor", "Tensor") | |||||
| def _pow_tensor(x, y): | |||||
| """Returns x ** y where x and y are all tensors and have save dtype.""" | |||||
| return F.tensor_pow(x, y) | |||||
| @pow_.register("Tensor", "Number") | |||||
| def _tensor_pow_scalar(x, y): | |||||
| """Returns x ** y where x is a tensor and y is a scalar. x and y should have same dtype.""" | |||||
| return F.tensor_pow(x, y) | |||||
| @pow_.register("Number", "Tensor") | |||||
| def _scalar_pow_tensor(x, y): | |||||
| """Returns x ** y where x is a scalar and y is a tensor. x and y should have same dtype.""" | |||||
| return F.tensor_pow(x, y) | |||||
| @@ -41,12 +41,10 @@ def _sub_tensor(x, y): | |||||
| @sub.register("Number", "Tensor") | @sub.register("Number", "Tensor") | ||||
| def _scalar_sub_tensor(x, y): | def _scalar_sub_tensor(x, y): | ||||
| """Returns x - y where x is a scalar and y is a tensor. x and y should have same dtype.""" | """Returns x - y where x is a scalar and y is a tensor. x and y should have same dtype.""" | ||||
| z = F.scalar_to_tensor(x, F.dtype(y)) | |||||
| return F.tensor_sub(z, y) | |||||
| return F.tensor_sub(x, y) | |||||
| @sub.register("Tensor", "Number") | @sub.register("Tensor", "Number") | ||||
| def _tensor_sub_scalar(x, y): | def _tensor_sub_scalar(x, y): | ||||
| """Returns x - y where x is a tensor and y is a scalar. x and y should have same dtype.""" | """Returns x - y where x is a tensor and y is a scalar. x and y should have same dtype.""" | ||||
| z = F.scalar_to_tensor(y, F.dtype(x)) | |||||
| return F.tensor_sub(x, z) | |||||
| return F.tensor_sub(x, y) | |||||
| @@ -48,6 +48,9 @@ tensor_ge = P.GreaterEqual() | |||||
| tensor_sub = P.Sub() | tensor_sub = P.Sub() | ||||
| tensor_mul = P.Mul() | tensor_mul = P.Mul() | ||||
| tensor_div = P.RealDiv() | tensor_div = P.RealDiv() | ||||
| tensor_floordiv = P.FloorDiv() | |||||
| tensor_pow = P.Pow() | |||||
| tensor_mod = P.FloorMod() | |||||
| strided_slice = P.StridedSlice() | strided_slice = P.StridedSlice() | ||||
| same_type_shape = P.SameTypeShape() | same_type_shape = P.SameTypeShape() | ||||
| equal = P.Equal() | equal = P.Equal() | ||||
| @@ -83,6 +86,7 @@ scalar_add = Primitive('scalar_add') | |||||
| scalar_mul = Primitive('scalar_mul') | scalar_mul = Primitive('scalar_mul') | ||||
| scalar_sub = Primitive('scalar_sub') | scalar_sub = Primitive('scalar_sub') | ||||
| scalar_div = Primitive('scalar_div') | scalar_div = Primitive('scalar_div') | ||||
| scalar_floordiv = Primitive('scalar_floordiv') | |||||
| scalar_log = Primitive('scalar_log') | scalar_log = Primitive('scalar_log') | ||||
| scalar_pow = Primitive('scalar_pow') | scalar_pow = Primitive('scalar_pow') | ||||
| scalar_gt = Primitive('scalar_gt') | scalar_gt = Primitive('scalar_gt') | ||||
| @@ -95,6 +99,7 @@ scalar_uadd = Primitive('scalar_uadd') | |||||
| scalar_usub = Primitive('scalar_usub') | scalar_usub = Primitive('scalar_usub') | ||||
| scalar_mod = Primitive('scalar_mod') | scalar_mod = Primitive('scalar_mod') | ||||
| string_eq = Primitive('string_equal') | string_eq = Primitive('string_equal') | ||||
| string_concat = Primitive('string_concat') | |||||
| bool_not = Primitive("bool_not") | bool_not = Primitive("bool_not") | ||||
| bool_or = Primitive("bool_or") | bool_or = Primitive("bool_or") | ||||
| bool_and = Primitive("bool_and") | bool_and = Primitive("bool_and") | ||||
| @@ -104,7 +109,8 @@ logical_not = P.LogicalNot() | |||||
| array_to_scalar = Primitive('array_to_scalar') | array_to_scalar = Primitive('array_to_scalar') | ||||
| is_ = Primitive("is_") | is_ = Primitive("is_") | ||||
| is_not = Primitive("is_not") | is_not = Primitive("is_not") | ||||
| in_dict = Primitive("in_dict") | |||||
| not_in_dict = Primitive("not_in_dict") | |||||
| broadcast_gradient_args = Primitive('BroadcastGradientArgs') | broadcast_gradient_args = Primitive('BroadcastGradientArgs') | ||||
| dot = Primitive('dot') | dot = Primitive('dot') | ||||
| array_reduce = Primitive('array_reduce') | array_reduce = Primitive('array_reduce') | ||||
| @@ -667,8 +667,8 @@ class AddN(PrimitiveWithInfer): | |||||
| >>> return self.addN(z) | >>> return self.addN(z) | ||||
| >>> | >>> | ||||
| >>> net = NetAddN() | >>> net = NetAddN() | ||||
| >>> input_x = Tensor(np.array([1, 2, 3]), mindspore.int32) | |||||
| >>> input_y = Tensor(np.array([4, 5, 6]), mindspore.int32) | |||||
| >>> input_x = Tensor(np.array([1, 2, 3]), mindspore.float32) | |||||
| >>> input_y = Tensor(np.array([4, 5, 6]), mindspore.float32) | |||||
| >>> net(input_x, input_y, input_x, input_y) | >>> net(input_x, input_y, input_x, input_y) | ||||
| Tensor([10, 14, 18], shape=(3,), dtype=mindspore.int32) | Tensor([10, 14, 18], shape=(3,), dtype=mindspore.int32) | ||||
| """ | """ | ||||
| @@ -131,3 +131,72 @@ def test_ME_arithmetic_operator_0070(): | |||||
| def test_ME_logical_operator_0020(): | def test_ME_logical_operator_0020(): | ||||
| """ test_ME_logical_operator_0020 """ | """ test_ME_logical_operator_0020 """ | ||||
| logical_operator_base('or') | logical_operator_base('or') | ||||
| def test_ops(): | |||||
| class OpsNet(Cell): | |||||
| """ OpsNet definition """ | |||||
| def __init__(self, x, y): | |||||
| super(OpsNet, self).__init__() | |||||
| self.x = x | |||||
| self.y = y | |||||
| self.int = 4 | |||||
| self.float = 3.2 | |||||
| self.str_a = "hello" | |||||
| self.str_b = "world" | |||||
| def construct(self, x, y): | |||||
| h = x // y | |||||
| m = x ** y | |||||
| n = x % y | |||||
| r = self.x // self.y | |||||
| s = self.x ** self.y | |||||
| t = self.x % self.y | |||||
| p = h + m + n | |||||
| q = r + s + t | |||||
| ret_pow = p ** q + q ** p | |||||
| ret_mod = p % q + q % p | |||||
| ret_floor = p // q + q // p | |||||
| ret = ret_pow + ret_mod + ret_floor | |||||
| if self.int > self.float: | |||||
| if self.str_a + self.str_b == "helloworld": | |||||
| return ret | |||||
| return x | |||||
| net = OpsNet(9, 2) | |||||
| x = Tensor(np.random.randint(low=1, high=10, size=(2, 3, 4), dtype=np.int32)) | |||||
| y = Tensor(np.random.randint(low=10, high=20, size=(2, 3, 4), dtype=np.int32)) | |||||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | |||||
| net(x, y) | |||||
| def test_in_dict(): | |||||
| class InDictNet(Cell): | |||||
| """ InDictNet definition """ | |||||
| def __init__(self, key_in, key_not_in): | |||||
| super(InDictNet, self).__init__() | |||||
| self.key_in = key_in | |||||
| self.key_not_in = key_not_in | |||||
| def construct(self, x, y, z): | |||||
| d = {"a": x, "b": y} | |||||
| ret_in = 1 | |||||
| ret_not_in = 2 | |||||
| if self.key_in in d: | |||||
| ret_in = d[self.key_in] | |||||
| if self.key_not_in not in d: | |||||
| ret_not_in = z | |||||
| ret = ret_in + ret_not_in | |||||
| return ret | |||||
| net = InDictNet("a", "c") | |||||
| x = Tensor(np.random.randint(low=1, high=10, size=(2, 3, 4), dtype=np.int32)) | |||||
| y = Tensor(np.random.randint(low=10, high=20, size=(2, 3, 4), dtype=np.int32)) | |||||
| z = Tensor(np.random.randint(low=20, high=30, size=(2, 3, 4), dtype=np.int32)) | |||||
| context.set_context(mode=context.GRAPH_MODE) | |||||
| net(x, y, z) | |||||