diff --git a/mindspore/_extends/parse/resources.py b/mindspore/_extends/parse/resources.py index 5dd24ccf80..9fb357597e 100644 --- a/mindspore/_extends/parse/resources.py +++ b/mindspore/_extends/parse/resources.py @@ -83,9 +83,9 @@ convert_object_map = { T.mul: multitype_ops.mul, T.truediv: multitype_ops.div, T.getitem: multitype_ops.getitem, - T.floordiv: NO_IMPLEMENT, - T.mod: F.scalar_mod, - T.pow: F.scalar_pow, + T.floordiv: multitype_ops.floordiv, + T.mod: multitype_ops.mod, + T.pow: multitype_ops.pow_, T.matmul: F.dot, T.lshift: NO_IMPLEMENT, T.rshift: NO_IMPLEMENT, @@ -104,8 +104,8 @@ convert_object_map = { T.ge: multitype_ops.greater_equal, T.is_: F.is_, T.is_not: F.is_not, - T.contains: NO_IMPLEMENT, - T.not_contains: NO_IMPLEMENT, + T.contains: F.in_dict, + T.not_contains: F.not_in_dict, # system function T.len: M.ms_len, diff --git a/mindspore/ccsrc/operator/cc_implementations.cc b/mindspore/ccsrc/operator/cc_implementations.cc index 5ff49758b4..62b23b346f 100644 --- a/mindspore/ccsrc/operator/cc_implementations.cc +++ b/mindspore/ccsrc/operator/cc_implementations.cc @@ -103,7 +103,7 @@ T InnerScalarMul(T x, T y) { } template -T InnerScalarDiv(T x, T y) { +float InnerScalarDiv(T x, T y) { if (y == 0) { MS_LOG(EXCEPTION) << "Divisor could not be zero"; } @@ -111,23 +111,41 @@ T InnerScalarDiv(T x, T y) { MS_LOG(EXCEPTION) << "Overflow of the div of two signed number x: " << std::to_string(x) << ", y: " << std::to_string(y) << "."; } - return x / y; + return static_cast(x) / static_cast(y); } -int32_t InnerScalarMod(int32_t x, int32_t y) { +template +T InnerScalarFloordiv(T x, T y) { + auto ret = std::floor(InnerScalarDiv(x, y)); + if (std::is_integral::value) { + return static_cast(ret); + } + return ret; +} + +template +T InnerScalarMod(T x, T y) { if (y == 0) { MS_LOG(EXCEPTION) << "Could not mod to zero."; } - if (IsSignedIntOverflow(x, y, OpType::MOD)) { + if (std::is_integral::value && std::is_signed::value && IsSignedIntOverflow(x, y, OpType::MOD)) { MS_LOG(EXCEPTION) << "Overflow of the mod of two signed number x: " << std::to_string(x) << ", y: " << std::to_string(y) << "."; } - return x % y; + if (std::is_integral::value) { + return static_cast(x) % static_cast(y); + } + float x_int = std::floor(x); + float y_int = std::ceil(y); + float max = x_int / y_int; + float ret = x - y * max; + return ret; } -float InnerScalarMod(float, float) { MS_LOG(EXCEPTION) << "Float does not support mod operator."; } - -double InnerScalarMod(double, double) { MS_LOG(EXCEPTION) << "Double does not support mod operator."; } +template +T InnerScalarPow(T x, U y) { + return std::pow(x, y); +} template bool InnerScalarEq(T x, U y) { @@ -193,6 +211,8 @@ SCALAR_OP(Sub) SCALAR_OP(Mul) SCALAR_OP(Div) SCALAR_OP(Mod) +SCALAR_OP(Pow) +SCALAR_OP(Floordiv) #define LOGIC_OP(op_t) \ ValuePtr Scalar##op_t(const ValuePtrList& list) { \ @@ -227,6 +247,10 @@ SCALAR_OP(Mod) bool sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ return MakeValue(sum); \ } \ + if (x->isa() && y->isa()) { \ + bool sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ + return MakeValue(sum); \ + } \ if (x->isa() && y->isa()) { \ bool sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ return MakeValue(sum); \ diff --git a/mindspore/ccsrc/operator/cc_implementations.h b/mindspore/ccsrc/operator/cc_implementations.h index 2c2936fc92..69981cea7d 100644 --- a/mindspore/ccsrc/operator/cc_implementations.h +++ b/mindspore/ccsrc/operator/cc_implementations.h @@ -37,9 +37,10 @@ ValuePtr ScalarSub(const ValuePtrList& list); ValuePtr ScalarMul(const ValuePtrList& list); ValuePtr ScalarDiv(const ValuePtrList& list); ValuePtr ScalarMod(const ValuePtrList& list); +ValuePtr ScalarPow(const ValuePtrList& list); +ValuePtr ScalarFloordiv(const ValuePtrList& list); ValuePtr ScalarUAdd(const ValuePtrList& list); ValuePtr ScalarUSub(const ValuePtrList& list); -ValuePtr ScalarUSub(const ValuePtrList& list); ValuePtr ScalarLog(const ValuePtrList& list); ValuePtr ScalarEq(const ValuePtrList& list); ValuePtr ScalarLt(const ValuePtrList& list); diff --git a/mindspore/ccsrc/operator/composite/do_signature.cc b/mindspore/ccsrc/operator/composite/do_signature.cc index 62de1c71f2..a4a26377f5 100644 --- a/mindspore/ccsrc/operator/composite/do_signature.cc +++ b/mindspore/ccsrc/operator/composite/do_signature.cc @@ -88,14 +88,17 @@ std::map GetMaxDtypeIndex(const std::vectorisa()) { - m_index = indexs[i]; + + for (const auto& index : indexs) { + AbstractBasePtr arg_value = args_spec_list[index]; + if (arg_value->isa()) { + arg_value = arg_value->cast()->ref(); + } + + if (arg_value->isa()) { + (void)dst_type.insert(std::make_pair(type, index)); + break; } - } - if (args_spec_list[m_index]->isa()) { - (void)dst_type.insert(std::make_pair(type, m_index)); } } return dst_type; @@ -119,15 +122,19 @@ void DoAutoCast(const std::vector& signature, const abstract::Abstrac (void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes), [](const Signature& sig) { return sig.dtype; }); int empty_dtype_count = std::count(dtypes.begin(), dtypes.end(), SignatureEnumDType::kDTypeEmptyDefaultValue); - if (dtypes.size() == 0 || static_cast(dtypes.size()) == empty_dtype_count) { + if (dtypes.empty() || static_cast(dtypes.size()) == empty_dtype_count) { return; } // Stat the index of the arguments with the largest type in the same SignatureEnumDType. std::map dst_type = GetMaxDtypeIndex(dtypes, args_spec_list); // Identify which arg requires auto cast for (size_t i = 0; i < args_spec_list.size(); ++i) { + AbstractBasePtr arg_value = args_spec_list[i]; + if (arg_value->isa()) { + arg_value = arg_value->cast()->ref(); + } auto it = dst_type.find(dtypes[i]); - if (it == dst_type.end() || it->second == i || !args_spec_list[i]->isa()) { + if (it == dst_type.end() || it->second == i || !arg_value->isa()) { continue; } // get source node for cast diff --git a/mindspore/ccsrc/operator/ops.cc b/mindspore/ccsrc/operator/ops.cc index f3053cac7d..e190d7d0b2 100755 --- a/mindspore/ccsrc/operator/ops.cc +++ b/mindspore/ccsrc/operator/ops.cc @@ -28,6 +28,7 @@ const PrimitivePtr kPrimScalarAdd = std::make_shared("scalar_add"); const PrimitivePtr kPrimScalarSub = std::make_shared("scalar_sub"); const PrimitivePtr kPrimScalarMul = std::make_shared("scalar_mul"); const PrimitivePtr kPrimScalarDiv = std::make_shared("scalar_div"); +const PrimitivePtr kPrimScalarFloordiv = std::make_shared("scalar_floordiv"); const PrimitivePtr kPrimScalarMod = std::make_shared("scalar_mod"); const PrimitivePtr kPrimScalarPow = std::make_shared("scalar_pow"); const PrimitivePtr kPrimScalarTrunc = std::make_shared("scalar_trunc"); @@ -78,6 +79,7 @@ const PrimitivePtr kPrimCreateInstance = std::make_shared("create_ins // Structure const PrimitivePtr kPrimStringEqual = std::make_shared("string_equal"); +const PrimitivePtr kPrimStringConcat = std::make_shared("string_concat"); const PrimitivePtr kPrimMakeTuple = std::make_shared("make_tuple"); const PrimitivePtr kPrimMakeList = std::make_shared("make_list"); const PrimitivePtr kPrimMakeDict = std::make_shared("make_dict"); @@ -221,6 +223,8 @@ const PrimitivePtr kPrimBroadcastGradientArgs = std::make_shared("Bro const PrimitivePtr kPrimControlDepend = std::make_shared("ControlDepend"); const PrimitivePtr kPrimIs_ = std::make_shared("is_"); const PrimitivePtr kPrimIsNot = std::make_shared("is_not"); +const PrimitivePtr kPrimInDict = std::make_shared("in_dict"); +const PrimitivePtr kPrimNotInDict = std::make_shared("not_in_dict"); // Comm ops const PrimitivePtr kPrimMirror = std::make_shared("_MirrorOperator"); diff --git a/mindspore/ccsrc/operator/ops.h b/mindspore/ccsrc/operator/ops.h index 2dc7072972..0148e073e0 100755 --- a/mindspore/ccsrc/operator/ops.h +++ b/mindspore/ccsrc/operator/ops.h @@ -34,6 +34,7 @@ extern const PrimitivePtr kPrimScalarAdd; extern const PrimitivePtr kPrimScalarSub; extern const PrimitivePtr kPrimScalarMul; extern const PrimitivePtr kPrimScalarDiv; +extern const PrimitivePtr kPrimScalarFloordiv; extern const PrimitivePtr kPrimScalarMod; extern const PrimitivePtr kPrimScalarPow; extern const PrimitivePtr kPrimScalarTrunc; @@ -84,6 +85,7 @@ extern const PrimitivePtr kPrimCreateInstance; // Structure extern const PrimitivePtr kPrimStringEqual; +extern const PrimitivePtr kPrimStringConcat; extern const PrimitivePtr kPrimMakeTuple; extern const PrimitivePtr kPrimMakeList; extern const PrimitivePtr kPrimMakeDict; @@ -227,8 +229,8 @@ extern const PrimitivePtr kPrimBroadcastGradientArgs; extern const PrimitivePtr kPrimControlDepend; extern const PrimitivePtr kPrimIs_; extern const PrimitivePtr kPrimIsNot; -extern const PrimitivePtr kPrimMinimumGrad; -extern const PrimitivePtr kPrimMaximumGrad; +extern const PrimitivePtr kPrimInDict; +extern const PrimitivePtr kPrimNotInDict; // Comm ops extern const PrimitivePtr kPrimMirror; diff --git a/mindspore/ccsrc/operator/prim_nn.cc b/mindspore/ccsrc/operator/prim_nn.cc index 892bf2921e..3591168187 100644 --- a/mindspore/ccsrc/operator/prim_nn.cc +++ b/mindspore/ccsrc/operator/prim_nn.cc @@ -114,12 +114,12 @@ void FusedBatchNormCheckDim(const PrimitivePtr &primitive, const AbstractBasePtr AbstractTensorPtr arg = CheckArg(op_name, args_spec_list, i); ShapePtr arg_shape = dyn_cast(arg->GetShapeTrack()); if (arg_shape == nullptr) { - MS_LOG(EXCEPTION) << "" << op_name << " type of args[" << i << "] should be Shape, but " << arg->ToString(); + MS_LOG(EXCEPTION) << op_name << " type of args[" << i << "] should be Shape, but " << arg->ToString(); } if (i == 0) { if (arg_shape->shape().size() < 2) { - MS_LOG(EXCEPTION) << "" << op_name << " shape of args[" << i + MS_LOG(EXCEPTION) << op_name << " shape of args[" << i << "] should be TensorShape with dimension greater than 1, but shape: " << arg_shape->ToString(); } @@ -127,7 +127,7 @@ void FusedBatchNormCheckDim(const PrimitivePtr &primitive, const AbstractBasePtr } if (arg_shape->shape().size() != 1) { - MS_LOG(EXCEPTION) << "" << op_name << " shape of args[" << i + MS_LOG(EXCEPTION) << op_name << " shape of args[" << i << "] should be TensorShape with dimension: 1, but shape: " << arg_shape->ToString(); } } @@ -159,7 +159,7 @@ AbstractBasePtr InferImplFusedBatchNorm(const AnalysisEnginePtr &, const Primiti MS_LOG(EXCEPTION) << "Arg shape size should >= 1."; } if (arg_shape_list[0] != input_shape_list[1]) { - MS_LOG(EXCEPTION) << "" << op_name << " size of tensor param[" << i << "](which is " << arg_shape_list[0] + MS_LOG(EXCEPTION) << op_name << " size of tensor param[" << i << "](which is " << arg_shape_list[0] << ") should match the second dimension of tensor" " param[0](which is " << input_shape_list[1] << ")."; @@ -378,7 +378,7 @@ AbstractBasePtr InferImplDropoutGenMask(const AnalysisEnginePtr &, const Primiti TypePtr prob_type = keep_prob->element()->BuildType(); if ((prob_type->type_id() != kNumberTypeFloat16) && (prob_type->type_id() != kNumberTypeFloat32)) { - MS_LOG(EXCEPTION) << "" << op_name << " keep_prob type should be float16 or float32, but " << prob_type->ToString() + MS_LOG(EXCEPTION) << op_name << " keep_prob type should be float16 or float32, but " << prob_type->ToString() << "."; } diff --git a/mindspore/ccsrc/operator/prim_statement.cc b/mindspore/ccsrc/operator/prim_statement.cc index 7d5038d4e1..239aed5bde 100644 --- a/mindspore/ccsrc/operator/prim_statement.cc +++ b/mindspore/ccsrc/operator/prim_statement.cc @@ -169,5 +169,36 @@ AbstractBasePtr InferImplIsNot(const AnalysisEnginePtr &, const PrimitivePtr &pr return std::make_shared(!(*t == *x)); } + +bool IsInDict(const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + auto key = CheckArg(op_name, args_spec_list, 0); + auto dict = CheckArg(op_name, args_spec_list, 1); + + ValuePtr key_value = key->BuildValue(); + if (!key_value->isa()) { + MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString(); + } + auto key_str = GetValue(key_value); + std::vector dict_elems = dict->elements(); + auto it = std::find_if(dict_elems.begin(), dict_elems.end(), + [key_str](const AbstractAttribute &item) { return item.first == key_str; }); + return it != dict_elems.end(); +} + +AbstractBasePtr InferImplInDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // statement: x in t + // Inputs: x, t + return std::make_shared(IsInDict(primitive, args_spec_list)); +} + +AbstractBasePtr InferImplNotInDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // statement: x not in t + // Inputs: x, t + return std::make_shared(!IsInDict(primitive, args_spec_list)); +} } // namespace abstract } // namespace mindspore diff --git a/mindspore/ccsrc/operator/prim_structures.cc b/mindspore/ccsrc/operator/prim_structures.cc index 88699c4d38..31d2bff43d 100644 --- a/mindspore/ccsrc/operator/prim_structures.cc +++ b/mindspore/ccsrc/operator/prim_structures.cc @@ -36,7 +36,7 @@ AbstractBasePtr InferImplStringEqual(const AnalysisEnginePtr &, const PrimitiveP ValuePtr value_x = scalar_x->BuildValue(); ValuePtr value_y = scalar_y->BuildValue(); if (!value_x->isa() || !value_y->isa()) { - MS_LOG(EXCEPTION) << "" << op_name << " requires 2 parameters are string, but got param0: " << value_x->ToString() + MS_LOG(EXCEPTION) << op_name << " requires 2 parameters are string, but got param0: " << value_x->ToString() << ", param1: " << value_y->ToString(); } @@ -44,6 +44,25 @@ AbstractBasePtr InferImplStringEqual(const AnalysisEnginePtr &, const PrimitiveP return std::make_shared(ret); } +AbstractBasePtr InferImplStringConcat(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: two scalars whose value is a string. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + AbstractScalarPtr scalar_x = CheckArg(op_name, args_spec_list, 0); + AbstractScalarPtr scalar_y = CheckArg(op_name, args_spec_list, 1); + + ValuePtr value_x = scalar_x->BuildValue(); + ValuePtr value_y = scalar_y->BuildValue(); + if (!value_x->isa() || !value_y->isa()) { + MS_LOG(EXCEPTION) << op_name << " requires 2 parameters are string, but got param0: " << value_x->ToString() + << ", param1: " << value_y->ToString(); + } + + std::string ret = (value_x->cast()->value() + value_y->cast()->value()); + return std::make_shared(ret); +} + AbstractBasePtr InferImplMakeTuple(const AnalysisEnginePtr &, const PrimitivePtr &, const AbstractBasePtrList &args_spec_list) { return std::make_shared(args_spec_list); @@ -64,7 +83,7 @@ AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr size_t keys_size = keys->size(); if (values->size() != keys_size) { - MS_LOG(EXCEPTION) << "" << op_name << " evaluator keys' size is not equal with values' size"; + MS_LOG(EXCEPTION) << op_name << " evaluator keys' size is not equal with values' size"; } std::vector key_value; @@ -76,7 +95,7 @@ AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr ValuePtr keyPtr = key->BuildValue(); MS_EXCEPTION_IF_NULL(keyPtr); if (!keyPtr->isa()) { - MS_LOG(EXCEPTION) << "" << op_name << " evaluator keys should be string, but got " << keyPtr->ToString(); + MS_LOG(EXCEPTION) << op_name << " evaluator keys should be string, but got " << keyPtr->ToString(); } std::string key_string = GetValue(keyPtr); key_value.emplace_back(key_string, value_list[index]); @@ -93,7 +112,7 @@ AbstractBasePtr InferImplMakeKwarg(const AnalysisEnginePtr &, const PrimitivePtr ValuePtr keyPtr = key->BuildValue(); if (!keyPtr->isa()) { - MS_LOG(EXCEPTION) << "" << op_name << " evaluator key should be string, but got " << keyPtr->ToString(); + MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << keyPtr->ToString(); } std::string key_string = GetValue(keyPtr); return std::make_shared(key_string, args_spec_list[1]); @@ -109,14 +128,13 @@ AbstractBasePtr InferImplExtractKwarg(const AnalysisEnginePtr &, const Primitive ValuePtr key_value = key->BuildValue(); if (!key_value->isa()) { - MS_LOG(EXCEPTION) << "" << op_name << " evaluator key should be string, but got " << key_value->ToString(); + MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString(); } std::string key_input = GetValue(key_value); std::string key_actual = kwarg->get_key(); if (key_actual != key_input) { - MS_LOG(EXCEPTION) << "" << op_name - << " evaluator input key should be same as AbstractKeywordArg' key, but input is " << key_input - << ", AbstractKeywordArg' key is " << key_actual; + MS_LOG(EXCEPTION) << op_name << " evaluator input key should be same as AbstractKeywordArg' key, but input is " + << key_input << ", AbstractKeywordArg' key is " << key_actual; } return kwarg->get_arg(); } @@ -187,13 +205,12 @@ AbstractBasePtr InferTupleOrListGetItem(const std::string &op_name, const Abstra ValuePtr index_value = index->BuildValue(); if (!index_value->isa()) { - MS_LOG(EXCEPTION) << "" << op_name << " evaluator index should be an int32 number, but got " - << index_value->ToString(); + MS_LOG(EXCEPTION) << op_name << " evaluator index should be an int32 number, but got " << index_value->ToString(); } int idx_v = GetValue(index_value); std::size_t nelems = queue->elements().size(); if (idx_v >= SizeToInt(nelems) || idx_v < -SizeToInt(nelems)) { - MS_LOG(EXCEPTION) << "" << op_name << " evaluator index should be in range[-" << SizeToInt(nelems) << ", " + MS_LOG(EXCEPTION) << op_name << " evaluator index should be in range[-" << SizeToInt(nelems) << ", " << SizeToInt(nelems) << "), but got " << idx_v << "."; } @@ -215,8 +232,7 @@ AbstractBasePtr InferTupleOrListSetItem(const std::string &op_name, const Abstra ValuePtr index_value = index->BuildValue(); if (!index_value->isa()) { - MS_LOG(EXCEPTION) << "" << op_name << " evaluator index should be an int32 number, but got " - << index_value->ToString(); + MS_LOG(EXCEPTION) << op_name << " evaluator index should be an int32 number, but got " << index_value->ToString(); } int idx_v = GetValue(index_value); if (idx_v < 0) { @@ -227,8 +243,7 @@ AbstractBasePtr InferTupleOrListSetItem(const std::string &op_name, const Abstra AbstractBasePtrList elements = queue->elements(); std::size_t nelems = elements.size(); if (uidx_v >= nelems) { - MS_LOG(EXCEPTION) << "" << op_name << " evaluator the index: " << uidx_v << " to set out of range: " << nelems - 1 - << "."; + MS_LOG(EXCEPTION) << op_name << " evaluator the index: " << uidx_v << " to set out of range: " << nelems - 1 << "."; } elements[uidx_v] = args_spec_list[2]; return std::make_shared(elements); @@ -264,12 +279,12 @@ AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitiveP ValuePtr key_value = key->BuildValue(); if (!key_value->isa()) { - MS_LOG(EXCEPTION) << "" << op_name << " evaluator key should be string, but got " << key_value->ToString(); + MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString(); } - std::string key_str = GetValue(key_value); + auto key_str = GetValue(key_value); std::vector dict_elems = dict->elements(); auto it = std::find_if(dict_elems.begin(), dict_elems.end(), - [key_str](AbstractAttribute &item) { return item.first == key_str; }); + [key_str](const AbstractAttribute &item) { return item.first == key_str; }); if (it == dict_elems.end()) { MS_LOG(EXCEPTION) << "The key " << key_str << " does not exist in the dict:" << args_spec_list[0]->ToString(); @@ -287,7 +302,7 @@ AbstractBasePtr InferImplDictSetItem(const AnalysisEnginePtr &, const PrimitiveP ValuePtr key_value = key->BuildValue(); if (!key_value->isa()) { - MS_LOG(EXCEPTION) << "" << op_name << " evaluator key should be string, but got " << key_value->ToString(); + MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString(); } std::string key_str = GetValue(key_value); std::vector dict_elems = dict->elements(); @@ -446,27 +461,27 @@ AbstractBasePtr InferImplReduceShape(const AnalysisEnginePtr &, const PrimitiveP auto x_shp_value = shape_x->BuildValue(); if (x_shp_value->isa()) { - MS_LOG(EXCEPTION) << "" << op_name + MS_LOG(EXCEPTION) << op_name << " evaluator shape's data field can't be anything: " << args_spec_list[1]->ToString(); } // Axis can be scalar, tuple or None AbstractTuplePtr axis = nullptr; if (args_spec_list[1]->isa()) { - MS_LOG(DEBUG) << "" << op_name << " evaluator second parameter is scalar"; + MS_LOG(DEBUG) << op_name << " evaluator second parameter is scalar"; AbstractBasePtrList axis_list = {dyn_cast(args_spec_list[1])}; axis = std::make_shared(axis_list); } else if (args_spec_list[1]->isa()) { - MS_LOG(DEBUG) << "" << op_name << " evaluator second parameter is tuple"; + MS_LOG(DEBUG) << op_name << " evaluator second parameter is tuple"; axis = args_spec_list[1]->cast(); } else { - MS_LOG(EXCEPTION) << "" << op_name << " evaluator second parameter should be a scalar or tuple, but got " + MS_LOG(EXCEPTION) << op_name << " evaluator second parameter should be a scalar or tuple, but got " << args_spec_list[1]->ToString(); } auto axis_value = axis->BuildValue(); if (axis_value->isa()) { - MS_LOG(EXCEPTION) << "" << op_name + MS_LOG(EXCEPTION) << op_name << " evaluator shape's data field can't be anything: " << args_spec_list[1]->ToString(); } auto axis_value_ptr = axis_value->cast(); diff --git a/mindspore/ccsrc/operator/prim_to_function.cc b/mindspore/ccsrc/operator/prim_to_function.cc index 234c829d44..bdfe48157c 100644 --- a/mindspore/ccsrc/operator/prim_to_function.cc +++ b/mindspore/ccsrc/operator/prim_to_function.cc @@ -24,36 +24,35 @@ namespace mindspore { namespace prim { PrimToFunction::PrimToFunction() - : prim_func_type_map_({ - // ONE_ARG prim - {"bool_not", kPrimTypeOneArg}, - {"scalar_cos", kPrimTypeOneArg}, - {"scalar_exp", kPrimTypeOneArg}, - {"scalar_floor", kPrimTypeOneArg}, - {"scalar_log", kPrimTypeOneArg}, - {"scalar_sin", kPrimTypeOneArg}, - {"scalar_tan", kPrimTypeOneArg}, - {"scalar_trunc", kPrimTypeOneArg}, - {"typeof", kPrimTypeOneArg}, - {"scalar_uadd", kPrimTypeOneArg}, - {"scalar_usub", kPrimTypeOneArg}, - // TWO_ARGS prim - {"scalar_add", kPrimTypeTwoArgs}, - {"bool_and", kPrimTypeTwoArgs}, - {"bool_eq", kPrimTypeTwoArgs}, - {"bool_or", kPrimTypeTwoArgs}, - {"scalar_div", kPrimTypeTwoArgs}, - {"scalar_eq", kPrimTypeTwoArgs}, - {"scalar_ge", kPrimTypeTwoArgs}, - {"scalar_gt", kPrimTypeTwoArgs}, - {"scalar_le", kPrimTypeTwoArgs}, - {"scalar_lt", kPrimTypeTwoArgs}, - {"scalar_ne", kPrimTypeTwoArgs}, - {"scalar_mod", kPrimTypeTwoArgs}, - {"scalar_mul", kPrimTypeTwoArgs}, - {"scalar_pow", kPrimTypeTwoArgs}, - {"scalar_sub", kPrimTypeTwoArgs}, - }) {} + : prim_func_type_map_({// ONE_ARG prim + {"bool_not", kPrimTypeOneArg}, + {"scalar_cos", kPrimTypeOneArg}, + {"scalar_exp", kPrimTypeOneArg}, + {"scalar_floor", kPrimTypeOneArg}, + {"scalar_log", kPrimTypeOneArg}, + {"scalar_sin", kPrimTypeOneArg}, + {"scalar_tan", kPrimTypeOneArg}, + {"scalar_trunc", kPrimTypeOneArg}, + {"typeof", kPrimTypeOneArg}, + {"scalar_uadd", kPrimTypeOneArg}, + {"scalar_usub", kPrimTypeOneArg}, + // TWO_ARGS prim + {"scalar_add", kPrimTypeTwoArgs}, + {"bool_and", kPrimTypeTwoArgs}, + {"bool_eq", kPrimTypeTwoArgs}, + {"bool_or", kPrimTypeTwoArgs}, + {"scalar_div", kPrimTypeTwoArgs}, + {"scalar_eq", kPrimTypeTwoArgs}, + {"scalar_ge", kPrimTypeTwoArgs}, + {"scalar_gt", kPrimTypeTwoArgs}, + {"scalar_le", kPrimTypeTwoArgs}, + {"scalar_lt", kPrimTypeTwoArgs}, + {"scalar_ne", kPrimTypeTwoArgs}, + {"scalar_mod", kPrimTypeTwoArgs}, + {"scalar_mul", kPrimTypeTwoArgs}, + {"scalar_pow", kPrimTypeTwoArgs}, + {"scalar_sub", kPrimTypeTwoArgs}, + {"scalar_floordiv", kPrimTypeTwoArgs}}) {} bool PrimToFunction::GetFunction(const PrimitivePtr& prim, FunctionPtr* const func) const { bool result = false; diff --git a/mindspore/ccsrc/pipeline/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/static_analysis/prim.cc index 403bbdf433..1512596cb4 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/static_analysis/prim.cc @@ -52,6 +52,8 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimSwitch, {InferImplSwitch, true}}, {prim::kPrimIs_, {InferImplIs_, true}}, {prim::kPrimIsNot, {InferImplIsNot, true}}, + {prim::kPrimInDict, {InferImplInDict, true}}, + {prim::kPrimNotInDict, {InferImplNotInDict, true}}, // Maths {prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, true}}, {prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, true}}, @@ -91,6 +93,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimMakeRange, {InferImplMakeRange, false}}, {prim::kPrimStopGradient, {InferImplStopGradient, false}}, {prim::kPrimStringEqual, {InferImplStringEqual, false}}, + {prim::kPrimStringConcat, {InferImplStringConcat, false}}, {prim::kPrimDictLen, {InferImplDictLen, false}}, // NN {prim::kPrimPooling, {InferImplPooling, true}}, @@ -988,6 +991,8 @@ PrimitiveToImplMap &GetUniformPrimitiveToImplMap() { {prim::kPrimScalarMul, {prim::ScalarMul, true, nullptr, true}}, {prim::kPrimScalarDiv, {prim::ScalarDiv, true, nullptr, true}}, {prim::kPrimScalarMod, {prim::ScalarMod, true, nullptr, true}}, + {prim::kPrimScalarPow, {prim::ScalarPow, true, nullptr, true}}, + {prim::kPrimScalarFloordiv, {prim::ScalarFloordiv, true, nullptr, true}}, {prim::kPrimScalarUadd, {prim::ScalarUAdd, true, nullptr, true}}, {prim::kPrimScalarUsub, {prim::ScalarUSub, true, nullptr, true}}, {prim::kPrimScalarLog, {prim::ScalarLog, true, nullptr, true}}, diff --git a/mindspore/ccsrc/pipeline/static_analysis/prim.h b/mindspore/ccsrc/pipeline/static_analysis/prim.h index e154473dbb..be71f3200a 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/prim.h +++ b/mindspore/ccsrc/pipeline/static_analysis/prim.h @@ -178,6 +178,10 @@ AbstractBasePtr InferImplIs_(const AnalysisEnginePtr &, const PrimitivePtr &, const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplIsNot(const AnalysisEnginePtr &, const PrimitivePtr &, const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplInDict(const AnalysisEnginePtr &, const PrimitivePtr &, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplNotInDict(const AnalysisEnginePtr &, const PrimitivePtr &, + const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplPooling(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplPoolingGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, @@ -287,6 +291,8 @@ AbstractBasePtr InferImplStopGradient(const AnalysisEnginePtr &, const Primitive const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplStringEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplStringConcat(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplDictLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); diff --git a/mindspore/ops/composite/multitype_ops/__init__.py b/mindspore/ops/composite/multitype_ops/__init__.py index db28b1b5f6..40bf71d49a 100644 --- a/mindspore/ops/composite/multitype_ops/__init__.py +++ b/mindspore/ops/composite/multitype_ops/__init__.py @@ -19,6 +19,9 @@ from .add_impl import add from .sub_impl import sub from .mul_impl import mul from .div_impl import div +from .pow_impl import pow_ +from .floordiv_impl import floordiv +from .mod_impl import mod from .getitem_impl import getitem from .zeros_like_impl import zeros_like from .ones_like_impl import ones_like @@ -38,6 +41,9 @@ __all__ = [ 'sub', 'mul', 'div', + 'pow_', + 'floordiv', + 'mod', 'uadd', 'zeros_like', 'ones_like', diff --git a/mindspore/ops/composite/multitype_ops/add_impl.py b/mindspore/ops/composite/multitype_ops/add_impl.py index 2b1f83679e..2ad81bfc93 100644 --- a/mindspore/ops/composite/multitype_ops/add_impl.py +++ b/mindspore/ops/composite/multitype_ops/add_impl.py @@ -69,6 +69,21 @@ def _scalar_add_scalar(x, y): return F.scalar_add(x, y) +@add.register("String", "String") +def _string_concat_string(x, y): + """ + Concatenate the string y to the string x. + + Args: + x (str): The first input string. + y (str): the second input string. + + Returns: + str, concatenate the y to the x. + """ + return F.string_concat(x, y) + + @add.register("Number", "Tensor") def _scalar_add_tensor(x, y): """ @@ -81,8 +96,7 @@ def _scalar_add_tensor(x, y): Returns: Tensor, has the same dtype as x. """ - z = F.scalar_to_tensor(x, F.dtype(y)) - return F.tensor_add(z, y) + return F.tensor_add(x, y) @add.register("Tensor", "Number") @@ -97,8 +111,7 @@ def _tensor_add_scalar(x, y): Returns: Tensor, has the same dtype as x. """ - z = F.scalar_to_tensor(y, F.dtype(x)) - return F.tensor_add(x, z) + return F.tensor_add(x, y) @add.register("Tensor", "Tensor") diff --git a/mindspore/ops/composite/multitype_ops/div_impl.py b/mindspore/ops/composite/multitype_ops/div_impl.py index 3edf3c8d9f..c37fcb9c36 100644 --- a/mindspore/ops/composite/multitype_ops/div_impl.py +++ b/mindspore/ops/composite/multitype_ops/div_impl.py @@ -68,8 +68,7 @@ def _scalar_div_tensor(x, y): Returns: Tensor, has the same dtype as x. """ - z = F.scalar_to_tensor(x, F.dtype(y)) - return F.tensor_div(z, y) + return F.tensor_div(x, y) @div.register("Tensor", "Number") @@ -84,5 +83,4 @@ def _tensor_div_scalar(x, y): Returns: Tensor, has the same dtype as x. """ - z = F.scalar_to_tensor(y, F.dtype(x)) - return F.tensor_div(x, z) + return F.tensor_div(x, y) diff --git a/mindspore/ops/composite/multitype_ops/floordiv_impl.py b/mindspore/ops/composite/multitype_ops/floordiv_impl.py new file mode 100644 index 0000000000..c1a47f881f --- /dev/null +++ b/mindspore/ops/composite/multitype_ops/floordiv_impl.py @@ -0,0 +1,50 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Implementation for internal polymorphism `floordiv` operations.""" + +from ...composite import base +from ... import functional as F + + +floordiv = base.MultitypeFuncGraph("floordiv") +""" +`floordiv` is a metafuncgraph object which will compute the floordiv of two objects +using ".register" decorator. +""" + + +@floordiv.register("Number", "Number") +def _floordiv_scalar(x, y): + """Returns x // y where x and y are all scalars.""" + return F.scalar_floordiv(x, y) + + +@floordiv.register("Tensor", "Tensor") +def _floordiv_tensor(x, y): + """Returns x // y where x and y are all tensors and have save dtype.""" + return F.tensor_floordiv(x, y) + + +@floordiv.register("Tensor", "Number") +def _tensor_floordiv_scalar(x, y): + """Returns x // y where x is a tensor and y is a scalar. x and y should have same dtype.""" + return F.tensor_floordiv(x, y) + + +@floordiv.register("Number", "Tensor") +def _scalar_floordiv_tensor(x, y): + """Returns x // y where x is a scalar and y is a tensor. x and y should have same dtype.""" + return F.tensor_floordiv(x, y) diff --git a/mindspore/ops/composite/multitype_ops/mod_impl.py b/mindspore/ops/composite/multitype_ops/mod_impl.py new file mode 100644 index 0000000000..e9947677ac --- /dev/null +++ b/mindspore/ops/composite/multitype_ops/mod_impl.py @@ -0,0 +1,50 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Implementation for internal polymorphism `mod` operations.""" + +from ...composite import base +from ... import functional as F + + +mod = base.MultitypeFuncGraph("mod") +""" +`mod` is a metafuncgraph object which will compute the mod of two objects +using ".register" decorator. +""" + + +@mod.register("Number", "Number") +def _mod_scalar(x, y): + """Returns x % y where x and y are all scalars.""" + return F.scalar_mod(x, y) + + +@mod.register("Tensor", "Tensor") +def _mod_tensor(x, y): + """Returns x % y where x and y are all tensors and have save dtype.""" + return F.tensor_mod(x, y) + + +@mod.register("Tensor", "Number") +def _tensor_mod_scalar(x, y): + """Returns x % y where x is a tensor and y is a scalar. x and y should have same dtype.""" + return F.tensor_mod(x, y) + + +@mod.register("Number", "Tensor") +def _scalar_mod_tensor(x, y): + """Returns x % y where x is a scalar and y is a tensor. x and y should have same dtype.""" + return F.tensor_mod(x, y) diff --git a/mindspore/ops/composite/multitype_ops/mul_impl.py b/mindspore/ops/composite/multitype_ops/mul_impl.py index 1d4733a46b..ce9ec391af 100644 --- a/mindspore/ops/composite/multitype_ops/mul_impl.py +++ b/mindspore/ops/composite/multitype_ops/mul_impl.py @@ -56,8 +56,7 @@ def _scalar_mul_tensor(x, y): Outputs: Tensor, has the same dtype as x. """ - z = F.scalar_to_tensor(x, F.dtype(y)) - return F.tensor_mul(z, y) + return F.tensor_mul(x, y) @mul.register("Tensor", "Number") @@ -68,5 +67,4 @@ def _tensor_mul_scalar(x, y): Outputs: Tensor, has the same dtype as x. """ - z = F.scalar_to_tensor(y, F.dtype(x)) - return F.tensor_mul(x, z) + return F.tensor_mul(x, y) diff --git a/mindspore/ops/composite/multitype_ops/pow_impl.py b/mindspore/ops/composite/multitype_ops/pow_impl.py new file mode 100644 index 0000000000..8d73335c98 --- /dev/null +++ b/mindspore/ops/composite/multitype_ops/pow_impl.py @@ -0,0 +1,50 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Implementation for internal polymorphism `pow` operations.""" + +from ...composite import base +from ... import functional as F + + +pow_ = base.MultitypeFuncGraph("pow") +""" +`pow` is a metafuncgraph object which will compute the pow of two objects +using ".register" decorator. +""" + + +@pow_.register("Number", "Number") +def _pow_scalar(x, y): + """Returns x ** y where x and y are all scalars.""" + return F.scalar_pow(x, y) + + +@pow_.register("Tensor", "Tensor") +def _pow_tensor(x, y): + """Returns x ** y where x and y are all tensors and have save dtype.""" + return F.tensor_pow(x, y) + + +@pow_.register("Tensor", "Number") +def _tensor_pow_scalar(x, y): + """Returns x ** y where x is a tensor and y is a scalar. x and y should have same dtype.""" + return F.tensor_pow(x, y) + + +@pow_.register("Number", "Tensor") +def _scalar_pow_tensor(x, y): + """Returns x ** y where x is a scalar and y is a tensor. x and y should have same dtype.""" + return F.tensor_pow(x, y) diff --git a/mindspore/ops/composite/multitype_ops/sub_impl.py b/mindspore/ops/composite/multitype_ops/sub_impl.py index 4a3224a859..431a58b991 100644 --- a/mindspore/ops/composite/multitype_ops/sub_impl.py +++ b/mindspore/ops/composite/multitype_ops/sub_impl.py @@ -41,12 +41,10 @@ def _sub_tensor(x, y): @sub.register("Number", "Tensor") def _scalar_sub_tensor(x, y): """Returns x - y where x is a scalar and y is a tensor. x and y should have same dtype.""" - z = F.scalar_to_tensor(x, F.dtype(y)) - return F.tensor_sub(z, y) + return F.tensor_sub(x, y) @sub.register("Tensor", "Number") def _tensor_sub_scalar(x, y): """Returns x - y where x is a tensor and y is a scalar. x and y should have same dtype.""" - z = F.scalar_to_tensor(y, F.dtype(x)) - return F.tensor_sub(x, z) + return F.tensor_sub(x, y) diff --git a/mindspore/ops/functional.py b/mindspore/ops/functional.py index 4da725145f..611c569553 100644 --- a/mindspore/ops/functional.py +++ b/mindspore/ops/functional.py @@ -48,6 +48,9 @@ tensor_ge = P.GreaterEqual() tensor_sub = P.Sub() tensor_mul = P.Mul() tensor_div = P.RealDiv() +tensor_floordiv = P.FloorDiv() +tensor_pow = P.Pow() +tensor_mod = P.FloorMod() strided_slice = P.StridedSlice() same_type_shape = P.SameTypeShape() equal = P.Equal() @@ -83,6 +86,7 @@ scalar_add = Primitive('scalar_add') scalar_mul = Primitive('scalar_mul') scalar_sub = Primitive('scalar_sub') scalar_div = Primitive('scalar_div') +scalar_floordiv = Primitive('scalar_floordiv') scalar_log = Primitive('scalar_log') scalar_pow = Primitive('scalar_pow') scalar_gt = Primitive('scalar_gt') @@ -95,6 +99,7 @@ scalar_uadd = Primitive('scalar_uadd') scalar_usub = Primitive('scalar_usub') scalar_mod = Primitive('scalar_mod') string_eq = Primitive('string_equal') +string_concat = Primitive('string_concat') bool_not = Primitive("bool_not") bool_or = Primitive("bool_or") bool_and = Primitive("bool_and") @@ -104,7 +109,8 @@ logical_not = P.LogicalNot() array_to_scalar = Primitive('array_to_scalar') is_ = Primitive("is_") is_not = Primitive("is_not") - +in_dict = Primitive("in_dict") +not_in_dict = Primitive("not_in_dict") broadcast_gradient_args = Primitive('BroadcastGradientArgs') dot = Primitive('dot') array_reduce = Primitive('array_reduce') diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index e390b6b589..ce3449a8b7 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -667,8 +667,8 @@ class AddN(PrimitiveWithInfer): >>> return self.addN(z) >>> >>> net = NetAddN() - >>> input_x = Tensor(np.array([1, 2, 3]), mindspore.int32) - >>> input_y = Tensor(np.array([4, 5, 6]), mindspore.int32) + >>> input_x = Tensor(np.array([1, 2, 3]), mindspore.float32) + >>> input_y = Tensor(np.array([4, 5, 6]), mindspore.float32) >>> net(input_x, input_y, input_x, input_y) Tensor([10, 14, 18], shape=(3,), dtype=mindspore.int32) """ diff --git a/tests/ut/python/pipeline/parse/test_operator.py b/tests/ut/python/pipeline/parse/test_operator.py index a3412a6f8f..a3c5f7e422 100644 --- a/tests/ut/python/pipeline/parse/test_operator.py +++ b/tests/ut/python/pipeline/parse/test_operator.py @@ -131,3 +131,72 @@ def test_ME_arithmetic_operator_0070(): def test_ME_logical_operator_0020(): """ test_ME_logical_operator_0020 """ logical_operator_base('or') + + +def test_ops(): + class OpsNet(Cell): + """ OpsNet definition """ + + def __init__(self, x, y): + super(OpsNet, self).__init__() + self.x = x + self.y = y + self.int = 4 + self.float = 3.2 + self.str_a = "hello" + self.str_b = "world" + + def construct(self, x, y): + h = x // y + m = x ** y + n = x % y + r = self.x // self.y + s = self.x ** self.y + t = self.x % self.y + p = h + m + n + q = r + s + t + ret_pow = p ** q + q ** p + ret_mod = p % q + q % p + ret_floor = p // q + q // p + ret = ret_pow + ret_mod + ret_floor + if self.int > self.float: + if self.str_a + self.str_b == "helloworld": + return ret + return x + + net = OpsNet(9, 2) + x = Tensor(np.random.randint(low=1, high=10, size=(2, 3, 4), dtype=np.int32)) + y = Tensor(np.random.randint(low=10, high=20, size=(2, 3, 4), dtype=np.int32)) + context.set_context(mode=context.GRAPH_MODE, save_graphs=True) + net(x, y) + + +def test_in_dict(): + class InDictNet(Cell): + """ InDictNet definition """ + + def __init__(self, key_in, key_not_in): + super(InDictNet, self).__init__() + self.key_in = key_in + self.key_not_in = key_not_in + + def construct(self, x, y, z): + d = {"a": x, "b": y} + ret_in = 1 + ret_not_in = 2 + if self.key_in in d: + ret_in = d[self.key_in] + if self.key_not_in not in d: + ret_not_in = z + ret = ret_in + ret_not_in + return ret + + net = InDictNet("a", "c") + x = Tensor(np.random.randint(low=1, high=10, size=(2, 3, 4), dtype=np.int32)) + y = Tensor(np.random.randint(low=10, high=20, size=(2, 3, 4), dtype=np.int32)) + z = Tensor(np.random.randint(low=20, high=30, size=(2, 3, 4), dtype=np.int32)) + context.set_context(mode=context.GRAPH_MODE) + net(x, y, z) + + +