Merge pull request !4204 from fary86/adapt_primitive_dynamic_shapetags/v1.0.0
| @@ -49,22 +49,6 @@ using mindspore::parse::PyObjectWrapper; | |||
| std::unordered_set<std::string> prims_to_skip_undetermined_infer{"make_tuple", "make_list", "switch", "env_setitem", | |||
| "env_getitem"}; | |||
| EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) { | |||
| if (prims_to_skip_undetermined_infer.find(prim_->name()) == prims_to_skip_undetermined_infer.end()) { | |||
| auto ret_abstract = AbstractEval(args); | |||
| if (ret_abstract != nullptr) { | |||
| MS_LOG(DEBUG) << "StandardPrimEvaluator eval Undetermined"; | |||
| return ret_abstract; | |||
| } | |||
| } | |||
| prim_->BeginRecordAddAttr(); | |||
| AbstractBasePtr abs_base = eval_impl_(engine, prim_, args); | |||
| prim_->EndRecordAddAttr(); | |||
| auto added_attrs = prim_->evaluate_added_attrs(); | |||
| auto infer_result = std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs)); | |||
| return infer_result; | |||
| } | |||
| EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, | |||
| AnfNodeConfigPtr out_conf) { | |||
| AbstractBasePtrList args_spec_list; | |||
| @@ -289,45 +273,45 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) { | |||
| py::dict dic; | |||
| if (abs_base->isa<AbstractTensor>()) { | |||
| auto arg_tensor = dyn_cast<AbstractTensor>(abs_base); | |||
| dic["shape"] = arg_tensor->shape()->shape(); | |||
| dic[ATTR_SHAPE] = arg_tensor->shape()->shape(); | |||
| if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) { | |||
| const auto &min_shape = arg_tensor->shape()->min_shape(); | |||
| const auto &max_shape = arg_tensor->shape()->max_shape(); | |||
| if (!min_shape.empty() && !max_shape.empty()) { | |||
| dic["min_shape"] = min_shape; | |||
| dic["max_shape"] = max_shape; | |||
| dic[ATTR_MIN_SHAPE] = min_shape; | |||
| dic[ATTR_MAX_SHAPE] = max_shape; | |||
| } | |||
| } | |||
| dic["dtype"] = arg_tensor->BuildType(); | |||
| dic["value"] = BuildValue(arg_tensor->BuildValue()); | |||
| dic[ATTR_DTYPE] = arg_tensor->BuildType(); | |||
| dic[ATTR_VALUE] = BuildValue(arg_tensor->BuildValue()); | |||
| } else if (abs_base->isa<AbstractRowTensor>()) { | |||
| auto arg = dyn_cast<AbstractRowTensor>(abs_base); | |||
| dic["shape"] = arg->shape()->shape(); | |||
| dic["dtype"] = arg->BuildType(); | |||
| dic["value"] = BuildValue(arg->BuildValue()); | |||
| dic[ATTR_SHAPE] = arg->shape()->shape(); | |||
| dic[ATTR_DTYPE] = arg->BuildType(); | |||
| dic[ATTR_VALUE] = BuildValue(arg->BuildValue()); | |||
| } else if (abs_base->isa<AbstractSparseTensor>()) { | |||
| auto arg = dyn_cast<AbstractSparseTensor>(abs_base); | |||
| dic["shape"] = arg->shape()->shape(); | |||
| dic["dtype"] = arg->BuildType(); | |||
| dic["value"] = BuildValue(arg->BuildValue()); | |||
| dic[ATTR_SHAPE] = arg->shape()->shape(); | |||
| dic[ATTR_DTYPE] = arg->BuildType(); | |||
| dic[ATTR_VALUE] = BuildValue(arg->BuildValue()); | |||
| } else if (abs_base->isa<AbstractScalar>() || abs_base->isa<AbstractType>() || abs_base->isa<AbstractRefKey>()) { | |||
| ShapeVector shape; | |||
| dic["shape"] = shape; | |||
| dic["dtype"] = abs_base->BuildType(); | |||
| dic["value"] = BuildValue(abs_base->BuildValue()); | |||
| dic[ATTR_SHAPE] = shape; | |||
| dic[ATTR_DTYPE] = abs_base->BuildType(); | |||
| dic[ATTR_VALUE] = BuildValue(abs_base->BuildValue()); | |||
| } else if (abs_base->isa<AbstractSlice>()) { | |||
| auto arg_slice = dyn_cast<AbstractSlice>(abs_base); | |||
| ShapeVector shape; | |||
| dic["shape"] = shape; | |||
| dic["dtype"] = arg_slice->BuildType(); | |||
| dic["value"] = BuildValue(arg_slice->BuildValue()); | |||
| dic[ATTR_SHAPE] = shape; | |||
| dic[ATTR_DTYPE] = arg_slice->BuildType(); | |||
| dic[ATTR_VALUE] = BuildValue(arg_slice->BuildValue()); | |||
| } else if (abs_base->isa<AbstractRef>()) { | |||
| auto value = abs_base->cast<AbstractRefPtr>()->ref(); | |||
| dic = ConvertAbstractToPython(value); | |||
| } else if (abs_base->isa<AbstractEllipsis>()) { | |||
| dic["shape"] = py::none(); | |||
| dic["dtype"] = py::ellipsis(); | |||
| dic["value"] = py::ellipsis(); | |||
| dic[ATTR_SHAPE] = py::none(); | |||
| dic[ATTR_DTYPE] = py::ellipsis(); | |||
| dic[ATTR_VALUE] = py::ellipsis(); | |||
| } else if (abs_base->isa<AbstractTuple>()) { | |||
| auto arg_tuple = dyn_cast<AbstractTuple>(abs_base); | |||
| size_t len = arg_tuple->size(); | |||
| @@ -336,12 +320,12 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) { | |||
| for (size_t i = 0; i < len; i++) { | |||
| py::dict out = ConvertAbstractToPython(arg_tuple->elements()[i]); | |||
| shape_tuple[i] = out["shape"]; | |||
| dtype_tuple[i] = out["dtype"]; | |||
| shape_tuple[i] = out[ATTR_SHAPE]; | |||
| dtype_tuple[i] = out[ATTR_DTYPE]; | |||
| } | |||
| dic["shape"] = shape_tuple; | |||
| dic["dtype"] = dtype_tuple; | |||
| dic["value"] = BuildValue(arg_tuple->BuildValue()); | |||
| dic[ATTR_SHAPE] = shape_tuple; | |||
| dic[ATTR_DTYPE] = dtype_tuple; | |||
| dic[ATTR_VALUE] = BuildValue(arg_tuple->BuildValue()); | |||
| } else if (abs_base->isa<AbstractList>()) { | |||
| auto arg_list = dyn_cast<AbstractList>(abs_base); | |||
| size_t len = arg_list->size(); | |||
| @@ -350,25 +334,25 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) { | |||
| for (size_t i = 0; i < len; i++) { | |||
| py::dict out = ConvertAbstractToPython(arg_list->elements()[i]); | |||
| shape_list[i] = out["shape"]; | |||
| dtype_list[i] = out["dtype"]; | |||
| shape_list[i] = out[ATTR_SHAPE]; | |||
| dtype_list[i] = out[ATTR_DTYPE]; | |||
| } | |||
| dic["shape"] = shape_list; | |||
| dic["dtype"] = dtype_list; | |||
| dic["value"] = BuildValue(arg_list->BuildValue()); | |||
| dic[ATTR_SHAPE] = shape_list; | |||
| dic[ATTR_DTYPE] = dtype_list; | |||
| dic[ATTR_VALUE] = BuildValue(arg_list->BuildValue()); | |||
| } else if (abs_base->isa<AbstractNone>()) { | |||
| dic["shape"] = py::none(); | |||
| dic["dtype"] = py::none(); | |||
| dic["value"] = py::none(); | |||
| dic[ATTR_SHAPE] = py::none(); | |||
| dic[ATTR_DTYPE] = py::none(); | |||
| dic[ATTR_VALUE] = py::none(); | |||
| } else if (abs_base->isa<AbstractFunction>()) { | |||
| dic["shape"] = py::none(); | |||
| dic["dtype"] = abs_base->BuildType(); | |||
| dic["value"] = py::none(); | |||
| dic[ATTR_SHAPE] = py::none(); | |||
| dic[ATTR_DTYPE] = abs_base->BuildType(); | |||
| dic[ATTR_VALUE] = py::none(); | |||
| } else if (abs_base->isa<AbstractUndetermined>()) { | |||
| auto arg = dyn_cast<AbstractUndetermined>(abs_base); | |||
| dic["shape"] = py::none(); | |||
| dic["dtype"] = arg->BuildType(); | |||
| dic["value"] = py::none(); | |||
| dic[ATTR_SHAPE] = py::none(); | |||
| dic[ATTR_DTYPE] = arg->BuildType(); | |||
| dic[ATTR_VALUE] = py::none(); | |||
| } else { | |||
| auto value = abs_base->BuildValue(); | |||
| if ((*value == *kAnyValue)) { | |||
| @@ -409,18 +393,20 @@ py::tuple PreparePyInputs(const PrimitivePyPtr &prim_py, const AbstractBasePtrLi | |||
| AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dict &output) { | |||
| // Convert to AbstractValue based on type and shape | |||
| auto out_dtype = output["dtype"]; | |||
| if (output["value"].is_none()) { | |||
| auto out_shape = output["shape"]; | |||
| py::object min_shape = output.contains("min_shape") ? (py::object)output["min_shape"] : (py::object)py::none(); | |||
| py::object max_shape = output.contains("max_shape") ? (py::object)output["max_shape"] : (py::object)py::none(); | |||
| auto out_dtype = output[ATTR_DTYPE]; | |||
| if (output[ATTR_VALUE].is_none()) { | |||
| auto out_shape = output[ATTR_SHAPE]; | |||
| py::object min_shape = | |||
| output.contains(py::str(ATTR_MIN_SHAPE)) ? (py::object)output[ATTR_MIN_SHAPE] : (py::object)py::none(); | |||
| py::object max_shape = | |||
| output.contains(py::str(ATTR_MAX_SHAPE)) ? (py::object)output[ATTR_MAX_SHAPE] : (py::object)py::none(); | |||
| return PyListDtype2AbstractTensor(out_shape, out_dtype, min_shape, max_shape); | |||
| } | |||
| // Convert pyobject to Value, then to AbstractValue | |||
| ValuePtr converted_ret = nullptr; | |||
| TypePtr dtype = py::isinstance<Type>(out_dtype) ? out_dtype.cast<TypePtr>() : nullptr; | |||
| bool converted = parse::ConvertData(output["value"], &converted_ret, false, dtype); | |||
| bool converted = parse::ConvertData(output[ATTR_VALUE], &converted_ret, false, dtype); | |||
| if (!converted) { | |||
| MS_LOG(EXCEPTION) << "Convert data failed"; | |||
| } | |||
| @@ -447,6 +433,73 @@ AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dic | |||
| } | |||
| } // end anonymous namespace | |||
| EvalResultPtr StandardPrimEvaluator::EvalPyCheckPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) { | |||
| auto prim_py = dyn_cast<PrimitivePy>(prim_); | |||
| if (prim_py == nullptr) { | |||
| MS_LOG(EXCEPTION) << "The primitive with type 'kPrimTypePyInferCheck' should be a python primitive."; | |||
| } | |||
| // Call checking method '__check__' for subclass of 'PrimitiveWithCheck' | |||
| MS_LOG(DEBUG) << "Begin input args checking for: " << prim_py->ToString(); | |||
| auto py_args = PreparePyInputs(prim_py, args); | |||
| prim_py->RunCheck(py_args); | |||
| prim_->BeginRecordAddAttr(); | |||
| AbstractBasePtr abs_base = eval_impl_(engine, prim_, args); | |||
| prim_->EndRecordAddAttr(); | |||
| auto added_attrs = prim_->evaluate_added_attrs(); | |||
| if (!py::hasattr(prim_py->GetPyObj(), PY_PRIM_METHOD_INFER_VALUE)) { | |||
| return std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs)); | |||
| } | |||
| // Call method 'infer_value' for primitive with this method for constant propagation | |||
| py::tuple py_vals(py_args.size()); | |||
| for (size_t i = 0; i < py_args.size(); ++i) { | |||
| py_vals[i] = py_args[i][ATTR_VALUE]; | |||
| } | |||
| py::object py_ret = prim_py->RunInferValue(py_vals); | |||
| if (py::isinstance<py::none>(py_ret)) { | |||
| return std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs)); | |||
| } | |||
| // Convert pyobject to Value, then to AbstractValue | |||
| ValuePtr converted_ret = nullptr; | |||
| TypePtr dtype = abs_base->BuildType(); | |||
| bool converted = parse::ConvertData(py_ret, &converted_ret, false, dtype); | |||
| if (!converted) { | |||
| MS_LOG(EXCEPTION) << "Convert data failed"; | |||
| } | |||
| auto res_spec = FromValue(converted_ret); | |||
| MS_EXCEPTION_IF_NULL(res_spec); | |||
| if (res_spec->isa<AbstractTensor>()) { | |||
| // Replace to tensor constant node in specialize | |||
| auto res_tensor = res_spec->cast<AbstractTensorPtr>(); | |||
| res_tensor->set_value(converted_ret); | |||
| } | |||
| return std::make_shared<EvalResult>(res_spec, std::make_shared<AttrValueMap>(added_attrs)); | |||
| } | |||
| EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) { | |||
| if (prims_to_skip_undetermined_infer.find(prim_->name()) == prims_to_skip_undetermined_infer.end()) { | |||
| auto ret_abstract = AbstractEval(args); | |||
| if (ret_abstract != nullptr) { | |||
| MS_LOG(DEBUG) << "StandardPrimEvaluator eval Undetermined"; | |||
| return ret_abstract; | |||
| } | |||
| } | |||
| if (prim_->prim_type() == PrimType::kPrimTypePyInferCheck) { | |||
| return EvalPyCheckPrim(engine, args); | |||
| } | |||
| prim_->BeginRecordAddAttr(); | |||
| AbstractBasePtr abs_base = eval_impl_(engine, prim_, args); | |||
| prim_->EndRecordAddAttr(); | |||
| auto added_attrs = prim_->evaluate_added_attrs(); | |||
| return std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs)); | |||
| } | |||
| EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) { | |||
| auto ret_abstract = AbstractEval(args); | |||
| if (ret_abstract != nullptr) { | |||
| @@ -42,6 +42,8 @@ class StandardPrimEvaluator : public TrivialPrimEvaluator { | |||
| std::string ToString() const override { return identifier_ + prim_->name(); } | |||
| private: | |||
| EvalResultPtr EvalPyCheckPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args); | |||
| PrimitivePtr prim_; | |||
| const StandardPrimitiveEvalImpl eval_impl_; | |||
| }; | |||
| @@ -308,20 +308,18 @@ void AnalysisEngine::Clear() { | |||
| namespace { | |||
| EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr &engine) { | |||
| // Custom Primitive with python infer_shape, infer_type | |||
| EvaluatorPtr evaluator = nullptr; | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| if (prim->isa<prim::DoSignaturePrimitive>()) { | |||
| evaluator = std::make_shared<DoSignatureEvaluator>(prim); | |||
| return evaluator; | |||
| return std::make_shared<DoSignatureEvaluator>(prim); | |||
| } | |||
| if (prim->isa<prim::UnpackGraphPrimitive>()) { | |||
| evaluator = std::make_shared<UnpackGraphEvaluator>(prim); | |||
| return evaluator; | |||
| return std::make_shared<UnpackGraphEvaluator>(prim); | |||
| } | |||
| if (prim->Hash() == prim::kPrimMixedPrecisionCast->Hash() && prim->name() == prim::kPrimMixedPrecisionCast->name()) { | |||
| evaluator = std::make_shared<MixedPrecisionCastEvaluator>(prim); | |||
| return evaluator; | |||
| return std::make_shared<MixedPrecisionCastEvaluator>(prim); | |||
| } | |||
| EvaluatorPtr evaluator = nullptr; | |||
| if (prim->HasPyEvaluator()) { | |||
| auto prim_py = dyn_cast<PrimitivePy>(prim); | |||
| if (prim_py != nullptr) { | |||
| @@ -55,6 +55,10 @@ void ValidateOperation(const AnfNodePtr &node) { | |||
| MS_LOG(DEBUG) << "Primitive " << prim->name() << " has python evaluator."; | |||
| return; | |||
| } | |||
| if (prim->prim_type() == PrimType::kPrimTypePyInferCheck) { | |||
| MS_LOG(DEBUG) << "Primitive " << prim->name() << " has python inference checking method."; | |||
| return; | |||
| } | |||
| if (prim->name() == "fake_bprop") { | |||
| MS_LOG(EXCEPTION) << "Illegal primitive: " << GetValue<std::string>(prim->GetAttr("info")); | |||
| } | |||
| @@ -254,16 +254,33 @@ py::dict PrimitivePy::RunInfer(const py::tuple &args) { | |||
| if (!HasPyObj()) { | |||
| MS_LOG(EXCEPTION) << "[" << this->ToString() << "]: pyobj is empty"; | |||
| } | |||
| auto infer_fuc = python_obj_.attr("__infer__"); | |||
| auto infer_fuc = python_obj_.attr(PY_PRIM_METHOD_INFER); | |||
| return infer_fuc(*args); | |||
| } | |||
| void PrimitivePy::RunCheck(const py::tuple &args) { | |||
| if (!HasPyObj()) { | |||
| MS_LOG(EXCEPTION) << "[" << this->ToString() << "]: pyobj is empty"; | |||
| } | |||
| auto check_func = python_obj_.attr(PY_PRIM_METHOD_CHECK); | |||
| (void)check_func(*args); | |||
| } | |||
| py::object PrimitivePy::RunInferValue(const py::tuple &args) { | |||
| if (!HasPyObj()) { | |||
| MS_LOG(EXCEPTION) << "[" << this->ToString() << "]: pyobj is empty"; | |||
| } | |||
| auto infer_value = python_obj_.attr(PY_PRIM_METHOD_INFER_VALUE); | |||
| return infer_value(*args); | |||
| } | |||
| REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) { | |||
| (void)py::enum_<PrimType>(*m, "prim_type", py::arithmetic()) | |||
| .value("unknown", PrimType::kPrimTypeUnknown) | |||
| .value("builtin", PrimType::kPrimTypeBuiltIn) | |||
| .value("py_infer_shape", PrimType::kPrimTypePyInferShape) | |||
| .value("user_custom", PrimType::kPrimTypeUserCustom); | |||
| .value("user_custom", PrimType::kPrimTypeUserCustom) | |||
| .value("py_infer_check", PrimType::kPrimTypePyInferCheck); | |||
| (void)py::class_<PrimitivePy, std::shared_ptr<PrimitivePy>>(*m, "Primitive_") | |||
| .def_readonly(PYTHON_PRIMITIVE_FLAG, &PrimitivePy::parse_info_) | |||
| .def(py::init<py::str &, py::object>()) | |||
| @@ -62,6 +62,8 @@ class PrimitivePy : public Primitive { | |||
| const bool parse_info_ = true; | |||
| const py::object &GetPyObj() const { return python_obj_; } | |||
| py::dict RunInfer(const py::tuple &args); | |||
| void RunCheck(const py::tuple &args); | |||
| py::object RunInferValue(const py::tuple &args); | |||
| bool ObjHasAttr(const char *attr_name) { return py::hasattr(python_obj_, attr_name); } | |||
| bool HasPyObj() { return python_obj_.operator bool(); } | |||
| PrimitivePtr Clone() override; | |||
| @@ -81,6 +81,9 @@ AbstractBasePtr InferImplDropoutGenMask(const AnalysisEnginePtr &, const Primiti | |||
| AbstractBasePtr InferImplMinOrMaxGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplSqrt(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplScalarToArray(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplArrayToScalar(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| @@ -176,6 +179,14 @@ AbstractBasePtr InferImplUniqueGrad(const AnalysisEnginePtr &, const PrimitivePt | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplGatherV2(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplDynamicShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplSparseApplyFtrl(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplSparseApplyProximalAdagrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| template <typename T> | |||
| AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { | |||
| @@ -14,6 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include <algorithm> | |||
| #include <iterator> | |||
| #include "abstract/infer_functions.h" | |||
| #include "abstract/utils.h" | |||
| #include "abstract/param_validator.h" | |||
| @@ -226,5 +228,60 @@ AbstractBasePtr InferImplUniqueGrad(const AnalysisEnginePtr &, const PrimitivePt | |||
| // outputs: dx | |||
| return std::make_shared<AbstractTensor>(ids->element(), ids_idx->shape()); | |||
| } | |||
| AbstractBasePtr InferImplGatherV2(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| const std::string &op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, 3); | |||
| AbstractTensorPtr params = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||
| AbstractTensorPtr indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 1); | |||
| AbstractScalarPtr axis = CheckArg<AbstractScalar>(op_name, args_spec_list, 2); | |||
| auto params_shp = params->shape()->shape(); | |||
| auto indices_shp = indices->shape()->shape(); | |||
| auto axis_val = GetValue<int>(axis->BuildValue()); | |||
| auto params_rank = static_cast<int>(params_shp.size()); | |||
| if (axis_val < 0) { | |||
| axis_val += params_rank; | |||
| } | |||
| auto calc_shape = [axis_val, ¶ms_shp](const ShapeVector &inp_vec) -> ShapeVector { | |||
| ShapeVector out_vec; | |||
| std::copy(params_shp.begin(), params_shp.begin() + axis_val, std::back_inserter(out_vec)); | |||
| copy(inp_vec.begin(), inp_vec.end(), std::back_inserter(out_vec)); | |||
| copy(params_shp.begin() + axis_val + 1, params_shp.end(), std::back_inserter(out_vec)); | |||
| return out_vec; | |||
| }; | |||
| ShapeVector out_shape = calc_shape(indices_shp); | |||
| if (!indices->shape()->min_shape().empty() && !indices->shape()->max_shape().empty()) { | |||
| ShapeVector min_shape = calc_shape(indices->shape()->min_shape()); | |||
| ShapeVector max_shape = calc_shape(indices->shape()->max_shape()); | |||
| return std::make_shared<AbstractTensor>(params->element(), | |||
| std::make_shared<Shape>(out_shape, min_shape, max_shape)); | |||
| } | |||
| return std::make_shared<AbstractTensor>(params->element(), std::make_shared<Shape>(out_shape)); | |||
| } | |||
| AbstractBasePtr InferImplDynamicShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| const std::string &op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, 1); | |||
| AbstractTensorPtr input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||
| auto shape = input->shape()->shape(); | |||
| bool has_dyn_shape = std::any_of(shape.begin(), shape.end(), [](int dim) { return dim == Shape::SHP_ANY; }); | |||
| std::vector<int> tensor_shp({static_cast<int>(shape.size())}); | |||
| if (has_dyn_shape) { | |||
| auto elem = std::make_shared<AbstractScalar>(std::make_shared<AnyValue>(), std::make_shared<Int>(32)); | |||
| return std::make_shared<AbstractTensor>(elem, std::make_shared<Shape>(tensor_shp)); | |||
| } | |||
| auto shp_buf_size = sizeof(int) * shape.size(); | |||
| auto tensor = std::make_shared<tensor::Tensor>(kNumberTypeInt32, tensor_shp, shape.data(), shp_buf_size); | |||
| return tensor->ToAbstract(); | |||
| } | |||
| } // namespace abstract | |||
| } // namespace mindspore | |||
| @@ -37,5 +37,14 @@ AbstractBasePtr InferImplMinOrMaxGrad(const AnalysisEnginePtr &, const Primitive | |||
| return std::make_shared<AbstractTuple>(AbstractBasePtrList({dx, dy})); | |||
| } | |||
| AbstractBasePtr InferImplSqrt(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: three tensors. | |||
| const std::string op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, 1); | |||
| auto inp = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||
| return inp->Clone()->Broaden(); | |||
| } | |||
| } // namespace abstract | |||
| } // namespace mindspore | |||
| @@ -445,5 +445,25 @@ AbstractBasePtr InferImplDropoutGenMask(const AnalysisEnginePtr &, const Primiti | |||
| return std::make_shared<AbstractTensor>(std::make_shared<AbstractScalar>(kAnyValue, kUInt8), | |||
| std::make_shared<Shape>(std::vector<int64_t>{shape_y})); | |||
| } | |||
| AbstractBasePtr InferImplSparseApplyFtrl(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| CheckArgsSize(primitive->name(), args_spec_list, 5); | |||
| AbstractBasePtrList elements; | |||
| for (size_t i = 0; i < 3; ++i) { | |||
| elements.push_back(args_spec_list[i]->Clone()->Broaden()); | |||
| } | |||
| return std::make_shared<AbstractTuple>(elements); | |||
| } | |||
| AbstractBasePtr InferImplSparseApplyProximalAdagrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| CheckArgsSize(primitive->name(), args_spec_list, 7); | |||
| AbstractBasePtrList elements; | |||
| for (size_t i = 0; i < 2; ++i) { | |||
| elements.push_back(args_spec_list[i]->Clone()->Broaden()); | |||
| } | |||
| return std::make_shared<AbstractTuple>(elements); | |||
| } | |||
| } // namespace abstract | |||
| } // namespace mindspore | |||
| @@ -37,6 +37,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||
| // Maths | |||
| {prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, true}}, | |||
| {prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, true}}, | |||
| {prim::kPrimSqrt, {InferImplSqrt, true}}, | |||
| // Array | |||
| {prim::kPrimScalarToArray, {InferImplScalarToArray, true}}, | |||
| {prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}}, | |||
| @@ -44,6 +45,9 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||
| {prim::kPrimPack, {InferImplPack, true}}, | |||
| {prim::kPrimUnique, {InferImplUnique, true}}, | |||
| {prim::kPrimUniqueGrad, {InferImplUniqueGrad, true}}, | |||
| {prim::kPrimGatherV2, {InferImplGatherV2, true}}, | |||
| {prim::kPrimSparseGatherV2, {InferImplGatherV2, true}}, | |||
| {prim::kPrimDynamicShape, {InferImplDynamicShape, true}}, | |||
| // Structure | |||
| {prim::kPrimMakeTuple, {InferImplMakeTuple, true}}, | |||
| {prim::kPrimMakeList, {InferImplMakeList, true}}, | |||
| @@ -77,6 +81,8 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||
| {prim::kPrimLayerNorm, {InferImplLayerNorm, true}}, | |||
| {prim::kPrimLayerNormGrad, {InferImplLayerNormGrad, true}}, | |||
| {prim::kPrimDropoutGenMask, {InferImplDropoutGenMask, true}}, | |||
| {prim::kPrimSparseApplyFtrl, {InferImplSparseApplyFtrl, true}}, | |||
| {prim::kPrimSparseApplyProximalAdagrad, {InferImplSparseApplyProximalAdagrad, true}}, | |||
| // Others | |||
| {prim::kPrimIdentity, {InferImplIdentity, true}}, | |||
| // Set impl to null as it will use PartialEvaluator; | |||
| @@ -84,6 +84,9 @@ inline const PrimitivePtr kPrimConcat = std::make_shared<Primitive>("Concat"); | |||
| inline const PrimitivePtr kPrimSqueeze = std::make_shared<Primitive>("Squeeze"); | |||
| inline const PrimitivePtr kPrimTranspose = std::make_shared<Primitive>("Transpose"); | |||
| inline const PrimitivePtr kPrimGatherV2 = std::make_shared<Primitive>("GatherV2"); | |||
| inline const PrimitivePtr kPrimSparseGatherV2 = std::make_shared<Primitive>("SparseGatherV2"); | |||
| inline const PrimitivePtr kPrimShape = std::make_shared<Primitive>("Shape"); | |||
| inline const PrimitivePtr kPrimDynamicShape = std::make_shared<Primitive>("DynamicShape"); | |||
| inline const PrimitivePtr kPrimEmbeddingLookup = std::make_shared<Primitive>("EmbeddingLookup"); | |||
| inline const PrimitivePtr kPrimEmbeddingLookupCommGrad = std::make_shared<Primitive>("EmbeddingLookupCommGrad"); | |||
| inline const PrimitivePtr kPrimSize = std::make_shared<Primitive>("Size"); | |||
| @@ -154,6 +157,8 @@ inline const PrimitivePtr kPrimBpropCut = std::make_shared<Primitive>("bprop_cut | |||
| inline const PrimitivePtr kPrimFakeQuantPerLayer = std::make_shared<Primitive>("FakeQuantPerLayer"); | |||
| inline const PrimitivePtr kPrimFakeQuantPerChannel = std::make_shared<Primitive>("FakeQuantPerChannel"); | |||
| inline const PrimitivePtr kPrimApplyRMSProp = std::make_shared<Primitive>("ApplyRMSProp"); | |||
| inline const PrimitivePtr kPrimSparseApplyFtrl = std::make_shared<Primitive>("SparseApplyFtrl"); | |||
| inline const PrimitivePtr kPrimSparseApplyProximalAdagrad = std::make_shared<Primitive>("SparseApplyProximalAdagrad"); | |||
| // Comm ops | |||
| inline const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator"); | |||
| @@ -35,7 +35,8 @@ enum PrimType { | |||
| kPrimTypeBuiltIn, // Built-in primitive operator | |||
| kPrimTypePyInferShape, // Primitive operator defined by custom | |||
| kPrimTypePyInferTensor, // Primitive operator defined by custom | |||
| kPrimTypeUserCustom | |||
| kPrimTypeUserCustom, | |||
| kPrimTypePyInferCheck // Primitive operator with input args checking method | |||
| }; | |||
| class Primitive : public Named { | |||
| @@ -23,4 +23,19 @@ const char GRAPH_FLAG_HAS_EFFECT[] = "has_effect"; | |||
| const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[] = "_effect_patial_order"; | |||
| const char GRAPH_FLAG_RANDOM_EFFECT[] = "_random_effect"; | |||
| const char GRAPH_FLAG_SIDE_EFFECT[] = "_side_effect"; | |||
| // method names of python primitive called from c++ source code | |||
| // 1. infer method name of class 'PrimitiveWithInfer' | |||
| const char PY_PRIM_METHOD_INFER[] = "__infer__"; | |||
| // 2. check method name of class 'PrimitiveWithCheck' | |||
| const char PY_PRIM_METHOD_CHECK[] = "__check__"; | |||
| // 3. method name of class 'PrimitivePy' for constant propagation | |||
| const char PY_PRIM_METHOD_INFER_VALUE[] = "infer_value"; | |||
| // type inference related attributes | |||
| const char ATTR_VALUE[] = "value"; | |||
| const char ATTR_DTYPE[] = "dtype"; | |||
| const char ATTR_SHAPE[] = "shape"; | |||
| const char ATTR_MIN_SHAPE[] = "min_shape"; | |||
| const char ATTR_MAX_SHAPE[] = "max_shape"; | |||
| } // namespace mindspore | |||
| @@ -23,6 +23,16 @@ extern const char GRAPH_FLAG_HAS_EFFECT[]; | |||
| extern const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[]; | |||
| extern const char GRAPH_FLAG_RANDOM_EFFECT[]; | |||
| extern const char GRAPH_FLAG_SIDE_EFFECT[]; | |||
| extern const char PY_PRIM_METHOD_INFER[]; | |||
| extern const char PY_PRIM_METHOD_CHECK[]; | |||
| extern const char PY_PRIM_METHOD_INFER_VALUE[]; | |||
| extern const char ATTR_VALUE[]; | |||
| extern const char ATTR_DTYPE[]; | |||
| extern const char ATTR_SHAPE[]; | |||
| extern const char ATTR_MIN_SHAPE[]; | |||
| extern const char ATTR_MAX_SHAPE[]; | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_UTILS_FLAGS_H | |||
| @@ -27,7 +27,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, | |||
| Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue, | |||
| SameTypeShape, ScatterAdd, ScatterSub, ScatterMul, ScatterDiv, ScatterMax, ScatterMin, | |||
| ScatterUpdate, ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select, | |||
| Shape, Size, Slice, Split, TransShape, ParallelConcat, Padding, | |||
| Shape, DynamicShape, Size, Slice, Split, TransShape, ParallelConcat, Padding, | |||
| ScatterNdAdd, ScatterNdSub, ScatterNonAliasingAdd, ReverseV2, Rint, | |||
| Squeeze, StridedSlice, Tile, TensorScatterUpdate, EditDistance, | |||
| Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentProd, | |||
| @@ -206,6 +206,7 @@ __all__ = [ | |||
| 'HookBackward', | |||
| 'InvertPermutation', | |||
| 'Shape', | |||
| 'DynamicShape', | |||
| 'DropoutDoMask', | |||
| 'DropoutGenMask', | |||
| 'DropoutGrad', | |||
| @@ -27,7 +27,7 @@ import numpy as np | |||
| from .._utils import get_concat_offset | |||
| from ..operations.math_ops import _infer_shape_reduce | |||
| from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register, _run_op | |||
| from ..primitive import Primitive, PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op | |||
| from ..._c_expression import signature_dtype as sig_dtype | |||
| from ..._c_expression import signature_kind as sig_kind | |||
| from ..._c_expression import signature_rw as sig_rw | |||
| @@ -142,6 +142,11 @@ class ExpandDims(PrimitiveWithInfer): | |||
| out = {'shape': x_shape, | |||
| 'dtype': x['dtype'], | |||
| 'value': value} | |||
| if 'min_shape' in x and 'max_shape' in x: | |||
| out['min_shape'] = x['min_shape'] | |||
| out['min_shape'].insert(axis_v, 1) | |||
| out['max_shape'] = x['max_shape'] | |||
| out['max_shape'].insert(axis_v, 1) | |||
| return out | |||
| @@ -277,6 +282,9 @@ class Cast(PrimitiveWithInfer): | |||
| out = {'shape': x['shape'], | |||
| 'dtype': mstype.tensor_type(t['value']), | |||
| 'value': value} | |||
| if 'min_shape' in x and 'max_shape' in x: | |||
| out['min_shape'] = x['min_shape'] | |||
| out['max_shape'] = x['max_shape'] | |||
| return out | |||
| @@ -445,6 +453,27 @@ class Shape(PrimitiveWithInfer): | |||
| return out | |||
| class DynamicShape(Primitive): | |||
| """ | |||
| Returns the shape of input tensor. | |||
| Inputs: | |||
| - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. | |||
| Outputs: | |||
| Tensor[int], 1-dim Tensor of type int32 | |||
| Examples: | |||
| >>> input_tensor = Tensor(np.ones(shape=[3, 2, 1]), mindspore.float32) | |||
| >>> shape = P.DynamicShape() | |||
| >>> output = shape(input_tensor) | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| """init Shape""" | |||
| class Squeeze(PrimitiveWithInfer): | |||
| """ | |||
| Returns a tensor with the same type but dimensions of 1 being removed based on axis. | |||
| @@ -578,7 +607,7 @@ class Unique(Primitive): | |||
| self.init_prim_io_names(inputs=['x'], outputs=['output']) | |||
| class GatherV2(PrimitiveWithInfer): | |||
| class GatherV2(PrimitiveWithCheck): | |||
| """ | |||
| Returns a slice of input tensor based on the specified indices and axis. | |||
| @@ -605,7 +634,7 @@ class GatherV2(PrimitiveWithInfer): | |||
| """init index_select""" | |||
| self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output']) | |||
| def __infer__(self, params, indices, axis): | |||
| def __check__(self, params, indices, axis): | |||
| validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) | |||
| validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name) | |||
| validator.check_subclass("axis", axis['dtype'], mstype.int_, self.name) | |||
| @@ -613,13 +642,6 @@ class GatherV2(PrimitiveWithInfer): | |||
| params_shp = params['shape'] | |||
| rank = len(params_shp) | |||
| validator.check_int_range("axis", axis_v, -rank, rank, Rel.INC_LEFT, self.name) | |||
| if axis_v < 0: | |||
| axis_v += rank | |||
| out_shape = params_shp[:axis_v] + indices['shape'] + params_shp[axis_v + 1:] | |||
| out = {'shape': out_shape, | |||
| 'dtype': params['dtype'], | |||
| 'value': None} | |||
| return out | |||
| class SparseGatherV2(GatherV2): | |||
| @@ -26,7 +26,7 @@ from ..._checkparam import Rel | |||
| from ...common import dtype as mstype | |||
| from ...common.tensor import Tensor | |||
| from .._utils import get_broadcast_shape | |||
| from ..primitive import PrimitiveWithInfer, prim_attr_register, _run_op | |||
| from ..primitive import PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op | |||
| def _infer_shape_reduce(x, axis, keep_dims, prim_name): | |||
| @@ -1257,7 +1257,7 @@ class Rsqrt(PrimitiveWithInfer): | |||
| return None | |||
| class Sqrt(PrimitiveWithInfer): | |||
| class Sqrt(PrimitiveWithCheck): | |||
| """ | |||
| Returns square root of a tensor element-wise. | |||
| @@ -1279,12 +1279,8 @@ class Sqrt(PrimitiveWithInfer): | |||
| """init Sqrt""" | |||
| self.init_prim_io_names(inputs=['x'], outputs=['output']) | |||
| def infer_shape(self, x_shape): | |||
| return x_shape | |||
| def infer_dtype(self, x_type): | |||
| def check_dtype(self, x_type): | |||
| validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.name) | |||
| return x_type | |||
| def infer_value(self, x): | |||
| if x is not None: | |||
| @@ -28,7 +28,7 @@ from ..._c_expression import signature_dtype as sig_dtype | |||
| from ..._checkparam import Validator as validator | |||
| from ..._checkparam import Rel | |||
| from ...common import dtype as mstype | |||
| from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register | |||
| from ..primitive import Primitive, PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register | |||
| from ..operations.math_ops import _infer_shape_reduce | |||
| @@ -4354,7 +4354,7 @@ class ApplyProximalAdagrad(PrimitiveWithInfer): | |||
| return var_dtype, accum_dtype | |||
| class SparseApplyProximalAdagrad(PrimitiveWithInfer): | |||
| class SparseApplyProximalAdagrad(PrimitiveWithCheck): | |||
| r""" | |||
| Update relevant entries according to the proximal adagrad algorithm. Compared with ApplyProximalAdagrad, | |||
| an additional index tensor is input. | |||
| @@ -4433,11 +4433,10 @@ class SparseApplyProximalAdagrad(PrimitiveWithInfer): | |||
| outputs=['var', 'accum']) | |||
| self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) | |||
| def infer_shape(self, var_shape, accum_shape, lr_shape, l1_shape, l2_shape, grad_shape, indices_shape): | |||
| def check_shape(self, var_shape, accum_shape, lr_shape, l1_shape, l2_shape, grad_shape, indices_shape): | |||
| validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name) | |||
| return var_shape, accum_shape | |||
| def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype, indices_dtype): | |||
| def check_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype, indices_dtype): | |||
| args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype} | |||
| validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name) | |||
| validator.check_scalar_or_tensor_type_same({"lr": lr_dtype}, [mstype.float16, mstype.float32], self.name) | |||
| @@ -4446,7 +4445,6 @@ class SparseApplyProximalAdagrad(PrimitiveWithInfer): | |||
| valid_types = [mstype.int16, mstype.int32, mstype.int64, | |||
| mstype.uint16, mstype.uint32, mstype.uint64] | |||
| validator.check_tensor_type_same({'indices': indices_dtype}, valid_types, self.name) | |||
| return var_dtype, accum_dtype | |||
| class ApplyAddSign(PrimitiveWithInfer): | |||
| @@ -4978,7 +4976,7 @@ class ApplyFtrl(PrimitiveWithInfer): | |||
| return var_type | |||
| class SparseApplyFtrl(PrimitiveWithInfer): | |||
| class SparseApplyFtrl(PrimitiveWithCheck): | |||
| """ | |||
| Update relevant entries according to the FTRL-proximal scheme. | |||
| @@ -5053,21 +5051,19 @@ class SparseApplyFtrl(PrimitiveWithInfer): | |||
| self.lr_power = validator.check_number("lr_power", lr_power, 0, Rel.LE, self.name) | |||
| self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) | |||
| def infer_shape(self, var_shape, accum_shape, linear_shape, grad_shape, indices_shape): | |||
| def check_shape(self, var_shape, accum_shape, linear_shape, grad_shape, indices_shape): | |||
| validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name) | |||
| validator.check('var shape', var_shape, 'linear shape', linear_shape, Rel.EQ, self.name) | |||
| if len(var_shape) > 1: | |||
| validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name) | |||
| validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name) | |||
| validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name) | |||
| return var_shape, accum_shape, linear_shape | |||
| def infer_dtype(self, var_dtype, accum_dtype, linear_dtype, grad_dtype, indices_dtype): | |||
| def check_dtype(self, var_dtype, accum_dtype, linear_dtype, grad_dtype, indices_dtype): | |||
| args = {"var_dtype": var_dtype, "accum_dtype": accum_dtype, | |||
| "linear_dtype": linear_dtype, "grad_dtype": grad_dtype} | |||
| validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name) | |||
| validator.check_tensor_type_same({"indices_dtype": indices_dtype}, [mstype.int32], self.name) | |||
| return var_dtype, accum_dtype, linear_dtype | |||
| class SparseApplyFtrlV2(PrimitiveWithInfer): | |||
| @@ -200,6 +200,84 @@ class Primitive(Primitive_): | |||
| return self._update_parameter | |||
| class PrimitiveWithCheck(Primitive): | |||
| """ | |||
| PrimitiveWithCheck is the base class of primitives in python defines functions for checking operator input arguments | |||
| but used the infer method registed in c++ source codes. | |||
| There are three methods can be overide to define the check logic of the primitive: __check__(), check_shape(), | |||
| check_dtype(). If __check__() is defined in primitive, the __check__() has highest priority to be called. | |||
| If __check__() is not defined, infer_shape() and infer_dtype() can be defined to describe the check logic of | |||
| the shape and type. | |||
| Args: | |||
| name (str): Name of the current Primitive. | |||
| Examples: | |||
| >>> # init a Primitive class with check | |||
| >>> class Flatten(PrimitiveWithCheck): | |||
| >>> @prim_attr_register | |||
| >>> def __init__(self): | |||
| >>> pass | |||
| >>> def check_shape(self, input_x): | |||
| >>> validator.check_integer('input_x rank', len(input_x), 1, Rel.GE, self.name) | |||
| >>> | |||
| >>> def check_dtype(self, input_x): | |||
| >>> validator.check_subclass("input_x", input_x, mstype.tensor, self.name) | |||
| >>> | |||
| >>> # init a Primitive obj | |||
| >>> add = Flatten() | |||
| """ | |||
| def __init__(self, name): | |||
| Primitive.__init__(self, name) | |||
| self.set_prim_type(prim_type.py_infer_check) | |||
| def _clone(self): | |||
| """ | |||
| Deeply clones the primitive object. | |||
| Calls the __init__() method with the same arguments. This method is called in parser if the | |||
| flag self.__setattr_flag__ is True. | |||
| """ | |||
| cloned_prim = Primitive._clone(self) | |||
| return cloned_prim | |||
| def check_shape(self, *args): | |||
| """ | |||
| Check shapes of input args. | |||
| Note: | |||
| The shape of scalar is an empty tuple. | |||
| Args: | |||
| args (tuple(int)): shapes of input tensors. | |||
| Return: | |||
| None. | |||
| """ | |||
| return None | |||
| def check_dtype(self, *args): | |||
| """ | |||
| Check data types of input args. | |||
| Args: | |||
| args (:class:`mindspore.dtype`): data type of inputs. | |||
| Return: | |||
| None. | |||
| """ | |||
| return None | |||
| def __check__(self, *args): | |||
| """Check shape, type, and value at the same time by using dictionary as arguments.""" | |||
| tracks = ['dtype', 'shape'] | |||
| for track in tracks: | |||
| fn = getattr(self, 'check_' + track) | |||
| fn(*(x[track] for x in args)) | |||
| class PrimitiveWithInfer(Primitive): | |||
| """ | |||
| PrimitiveWithInfer is the base class of primitives in python defines functions for tracking inference in python. | |||
| @@ -306,6 +384,18 @@ class PrimitiveWithInfer(Primitive): | |||
| if not is_graph_mode: | |||
| return out | |||
| # output does not contain dynamic shape, no need to calculate min/max shape | |||
| def has_dynamic_shape(shp): | |||
| if isinstance(shp, int): | |||
| return shp < 0 | |||
| if isinstance(shp, (list, tuple)): | |||
| return any(has_dynamic_shape(e) for e in shp) | |||
| return False | |||
| if not has_dynamic_shape(out['shape']): | |||
| return out | |||
| # calculate min/max shape for output | |||
| def get_specified_shape(elems, attr): | |||
| has_specified_shape = False | |||
| ret_vals = [] | |||
| @@ -345,6 +435,8 @@ def prim_attr_register(fn): | |||
| def deco(self, *args, **kwargs): | |||
| if isinstance(self, PrimitiveWithInfer): | |||
| PrimitiveWithInfer.__init__(self, self.__class__.__name__) | |||
| elif isinstance(self, PrimitiveWithCheck): | |||
| PrimitiveWithCheck.__init__(self, self.__class__.__name__) | |||
| else: | |||
| Primitive.__init__(self, self.__class__.__name__) | |||
| bound_args = inspect.signature(fn).bind(self, *args, **kwargs) | |||
| @@ -27,7 +27,7 @@ from mindspore.ops import composite as C | |||
| from mindspore.ops import functional as F | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like | |||
| from mindspore.ops.primitive import constexpr | |||
| from mindspore.ops.primitive import constexpr, PrimitiveWithInfer, prim_attr_register | |||
| from mindspore.ops._grad.grad_base import bprop_getters | |||
| from mindspore import Tensor, RowTensor, context | |||
| from mindspore.common.parameter import Parameter, ParameterTuple | |||
| @@ -105,10 +105,31 @@ def _generate_inverse_index(x_shape, axis): | |||
| perm = index[1:1 + axis] + (0,) + index[1 + axis:] | |||
| return perm | |||
| class MySparseGatherV2(P.GatherV2): | |||
| # pylint: disable=W0231 | |||
| class MySparseGatherV2(PrimitiveWithInfer): | |||
| """ | |||
| For test | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| """init index_select""" | |||
| self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output']) | |||
| def __infer__(self, params, indices, axis): | |||
| validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) | |||
| validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name) | |||
| validator.check_subclass("axis", axis['dtype'], mstype.int_, self.name) | |||
| axis_v = axis['value'] | |||
| params_shp = params['shape'] | |||
| rank = len(params_shp) | |||
| validator.check_int_range("axis", axis_v, -rank, rank, Rel.INC_LEFT, self.name) | |||
| if axis_v < 0: | |||
| axis_v += rank | |||
| out_shape = params_shp[:axis_v] + indices['shape'] + params_shp[axis_v + 1:] | |||
| out = {'shape': out_shape, | |||
| 'dtype': params['dtype'], | |||
| 'value': None} | |||
| return out | |||
| @bprop_getters.register(MySparseGatherV2) | |||
| def get_bprop_sparse_gather_v2(self): | |||
| @@ -0,0 +1,109 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ test dynamic shape """ | |||
| from mindspore import Tensor, context, nn, Parameter | |||
| from mindspore.ops import operations as P | |||
| from mindspore import dtype as mstype | |||
| import numpy as np | |||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=False) | |||
| def test_sparse_apply_proximal_ada_grad(): | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.sparse_apply_proximal_adagrad = P.SparseApplyProximalAdagrad() | |||
| self.var = Parameter(Tensor(np.random.rand(7800, 80).astype(np.float32)), name="var") | |||
| self.accum = Parameter(Tensor(np.random.rand(7800, 80).astype(np.float32)), name="accum") | |||
| self.lr = 0.01 | |||
| self.l1 = 0.0 | |||
| self.l2 = 0.0 | |||
| def construct(self, grad, indices): | |||
| out = self.sparse_apply_proximal_adagrad(self.var, self.accum, self.lr, self.l1, self.l2, grad, indices) | |||
| return out[0] | |||
| class NetWrapper(nn.Cell): | |||
| def __init__(self): | |||
| super(NetWrapper, self).__init__() | |||
| self.unq = P.Unique() | |||
| self.add = P.TensorAdd() | |||
| self.expand_dims = P.ExpandDims() | |||
| self.cast = P.Cast() | |||
| self.net = Net() | |||
| def construct(self, grad, inp): | |||
| ids, _ = self.unq(inp) | |||
| new_grad = self.expand_dims(ids, 1) | |||
| new_grad = self.cast(new_grad, mstype.float32) + grad | |||
| return self.net(new_grad, ids) | |||
| net = NetWrapper() | |||
| grad = Tensor(np.random.rand(1, 80).astype(np.float32)) | |||
| indices = Tensor(np.ones([7800]), mstype.int32) | |||
| net(grad, indices) | |||
| def test_sparse_apply_ftrl(): | |||
| class SparseApplyFtrlNet(nn.Cell): | |||
| def __init__(self): | |||
| super(SparseApplyFtrlNet, self).__init__() | |||
| self.sparse_apply_ftrl = P.SparseApplyFtrl(lr=0.01, l1=0.0, l2=0.0, lr_power=-0.5) | |||
| self.var = Parameter(Tensor(np.random.rand(7800, 80).astype(np.float32)), name="var") | |||
| self.accum = Parameter(Tensor(np.random.rand(7800, 80).astype(np.float32)), name="accum") | |||
| self.linear = Parameter(Tensor(np.random.rand(7800, 80).astype(np.float32)), name="linear") | |||
| def construct(self, grad, indices): | |||
| out = self.sparse_apply_ftrl(self.var, self.accum, self.linear, grad, indices) | |||
| return out[0] | |||
| class NetWrapper(nn.Cell): | |||
| def __init__(self): | |||
| super(NetWrapper, self).__init__() | |||
| self.unq = P.Unique() | |||
| self.add = P.TensorAdd() | |||
| self.expand_dims = P.ExpandDims() | |||
| self.cast = P.Cast() | |||
| self.net = SparseApplyFtrlNet() | |||
| def construct(self, grad, inp): | |||
| ids, _ = self.unq(inp) | |||
| new_grad = self.expand_dims(ids, 1) | |||
| new_grad = self.cast(new_grad, mstype.float32) + grad | |||
| return self.net(new_grad, ids) | |||
| net = NetWrapper() | |||
| grad = Tensor(np.random.rand(1, 80).astype(np.float32)) | |||
| indices = Tensor(np.ones([7800]), mstype.int32) | |||
| net(grad, indices) | |||
| def test_gatherv2(): | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.unq = P.Unique() | |||
| self.gather = P.GatherV2() | |||
| def construct(self, x, y): | |||
| u, _ = self.unq(y) | |||
| z = self.gather(x, u, 0) | |||
| return z | |||
| x = Tensor(np.ones([20, 12], dtype=np.float32)) | |||
| y = Tensor(np.ones([8], dtype=np.int32)) | |||
| net = Net() | |||
| net(x, y) | |||