Merge pull request !4204 from fary86/adapt_primitive_dynamic_shapetags/v1.0.0
| @@ -49,22 +49,6 @@ using mindspore::parse::PyObjectWrapper; | |||||
| std::unordered_set<std::string> prims_to_skip_undetermined_infer{"make_tuple", "make_list", "switch", "env_setitem", | std::unordered_set<std::string> prims_to_skip_undetermined_infer{"make_tuple", "make_list", "switch", "env_setitem", | ||||
| "env_getitem"}; | "env_getitem"}; | ||||
| EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) { | |||||
| if (prims_to_skip_undetermined_infer.find(prim_->name()) == prims_to_skip_undetermined_infer.end()) { | |||||
| auto ret_abstract = AbstractEval(args); | |||||
| if (ret_abstract != nullptr) { | |||||
| MS_LOG(DEBUG) << "StandardPrimEvaluator eval Undetermined"; | |||||
| return ret_abstract; | |||||
| } | |||||
| } | |||||
| prim_->BeginRecordAddAttr(); | |||||
| AbstractBasePtr abs_base = eval_impl_(engine, prim_, args); | |||||
| prim_->EndRecordAddAttr(); | |||||
| auto added_attrs = prim_->evaluate_added_attrs(); | |||||
| auto infer_result = std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs)); | |||||
| return infer_result; | |||||
| } | |||||
| EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, | EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, | ||||
| AnfNodeConfigPtr out_conf) { | AnfNodeConfigPtr out_conf) { | ||||
| AbstractBasePtrList args_spec_list; | AbstractBasePtrList args_spec_list; | ||||
| @@ -289,45 +273,45 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) { | |||||
| py::dict dic; | py::dict dic; | ||||
| if (abs_base->isa<AbstractTensor>()) { | if (abs_base->isa<AbstractTensor>()) { | ||||
| auto arg_tensor = dyn_cast<AbstractTensor>(abs_base); | auto arg_tensor = dyn_cast<AbstractTensor>(abs_base); | ||||
| dic["shape"] = arg_tensor->shape()->shape(); | |||||
| dic[ATTR_SHAPE] = arg_tensor->shape()->shape(); | |||||
| if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) { | if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) { | ||||
| const auto &min_shape = arg_tensor->shape()->min_shape(); | const auto &min_shape = arg_tensor->shape()->min_shape(); | ||||
| const auto &max_shape = arg_tensor->shape()->max_shape(); | const auto &max_shape = arg_tensor->shape()->max_shape(); | ||||
| if (!min_shape.empty() && !max_shape.empty()) { | if (!min_shape.empty() && !max_shape.empty()) { | ||||
| dic["min_shape"] = min_shape; | |||||
| dic["max_shape"] = max_shape; | |||||
| dic[ATTR_MIN_SHAPE] = min_shape; | |||||
| dic[ATTR_MAX_SHAPE] = max_shape; | |||||
| } | } | ||||
| } | } | ||||
| dic["dtype"] = arg_tensor->BuildType(); | |||||
| dic["value"] = BuildValue(arg_tensor->BuildValue()); | |||||
| dic[ATTR_DTYPE] = arg_tensor->BuildType(); | |||||
| dic[ATTR_VALUE] = BuildValue(arg_tensor->BuildValue()); | |||||
| } else if (abs_base->isa<AbstractRowTensor>()) { | } else if (abs_base->isa<AbstractRowTensor>()) { | ||||
| auto arg = dyn_cast<AbstractRowTensor>(abs_base); | auto arg = dyn_cast<AbstractRowTensor>(abs_base); | ||||
| dic["shape"] = arg->shape()->shape(); | |||||
| dic["dtype"] = arg->BuildType(); | |||||
| dic["value"] = BuildValue(arg->BuildValue()); | |||||
| dic[ATTR_SHAPE] = arg->shape()->shape(); | |||||
| dic[ATTR_DTYPE] = arg->BuildType(); | |||||
| dic[ATTR_VALUE] = BuildValue(arg->BuildValue()); | |||||
| } else if (abs_base->isa<AbstractSparseTensor>()) { | } else if (abs_base->isa<AbstractSparseTensor>()) { | ||||
| auto arg = dyn_cast<AbstractSparseTensor>(abs_base); | auto arg = dyn_cast<AbstractSparseTensor>(abs_base); | ||||
| dic["shape"] = arg->shape()->shape(); | |||||
| dic["dtype"] = arg->BuildType(); | |||||
| dic["value"] = BuildValue(arg->BuildValue()); | |||||
| dic[ATTR_SHAPE] = arg->shape()->shape(); | |||||
| dic[ATTR_DTYPE] = arg->BuildType(); | |||||
| dic[ATTR_VALUE] = BuildValue(arg->BuildValue()); | |||||
| } else if (abs_base->isa<AbstractScalar>() || abs_base->isa<AbstractType>() || abs_base->isa<AbstractRefKey>()) { | } else if (abs_base->isa<AbstractScalar>() || abs_base->isa<AbstractType>() || abs_base->isa<AbstractRefKey>()) { | ||||
| ShapeVector shape; | ShapeVector shape; | ||||
| dic["shape"] = shape; | |||||
| dic["dtype"] = abs_base->BuildType(); | |||||
| dic["value"] = BuildValue(abs_base->BuildValue()); | |||||
| dic[ATTR_SHAPE] = shape; | |||||
| dic[ATTR_DTYPE] = abs_base->BuildType(); | |||||
| dic[ATTR_VALUE] = BuildValue(abs_base->BuildValue()); | |||||
| } else if (abs_base->isa<AbstractSlice>()) { | } else if (abs_base->isa<AbstractSlice>()) { | ||||
| auto arg_slice = dyn_cast<AbstractSlice>(abs_base); | auto arg_slice = dyn_cast<AbstractSlice>(abs_base); | ||||
| ShapeVector shape; | ShapeVector shape; | ||||
| dic["shape"] = shape; | |||||
| dic["dtype"] = arg_slice->BuildType(); | |||||
| dic["value"] = BuildValue(arg_slice->BuildValue()); | |||||
| dic[ATTR_SHAPE] = shape; | |||||
| dic[ATTR_DTYPE] = arg_slice->BuildType(); | |||||
| dic[ATTR_VALUE] = BuildValue(arg_slice->BuildValue()); | |||||
| } else if (abs_base->isa<AbstractRef>()) { | } else if (abs_base->isa<AbstractRef>()) { | ||||
| auto value = abs_base->cast<AbstractRefPtr>()->ref(); | auto value = abs_base->cast<AbstractRefPtr>()->ref(); | ||||
| dic = ConvertAbstractToPython(value); | dic = ConvertAbstractToPython(value); | ||||
| } else if (abs_base->isa<AbstractEllipsis>()) { | } else if (abs_base->isa<AbstractEllipsis>()) { | ||||
| dic["shape"] = py::none(); | |||||
| dic["dtype"] = py::ellipsis(); | |||||
| dic["value"] = py::ellipsis(); | |||||
| dic[ATTR_SHAPE] = py::none(); | |||||
| dic[ATTR_DTYPE] = py::ellipsis(); | |||||
| dic[ATTR_VALUE] = py::ellipsis(); | |||||
| } else if (abs_base->isa<AbstractTuple>()) { | } else if (abs_base->isa<AbstractTuple>()) { | ||||
| auto arg_tuple = dyn_cast<AbstractTuple>(abs_base); | auto arg_tuple = dyn_cast<AbstractTuple>(abs_base); | ||||
| size_t len = arg_tuple->size(); | size_t len = arg_tuple->size(); | ||||
| @@ -336,12 +320,12 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) { | |||||
| for (size_t i = 0; i < len; i++) { | for (size_t i = 0; i < len; i++) { | ||||
| py::dict out = ConvertAbstractToPython(arg_tuple->elements()[i]); | py::dict out = ConvertAbstractToPython(arg_tuple->elements()[i]); | ||||
| shape_tuple[i] = out["shape"]; | |||||
| dtype_tuple[i] = out["dtype"]; | |||||
| shape_tuple[i] = out[ATTR_SHAPE]; | |||||
| dtype_tuple[i] = out[ATTR_DTYPE]; | |||||
| } | } | ||||
| dic["shape"] = shape_tuple; | |||||
| dic["dtype"] = dtype_tuple; | |||||
| dic["value"] = BuildValue(arg_tuple->BuildValue()); | |||||
| dic[ATTR_SHAPE] = shape_tuple; | |||||
| dic[ATTR_DTYPE] = dtype_tuple; | |||||
| dic[ATTR_VALUE] = BuildValue(arg_tuple->BuildValue()); | |||||
| } else if (abs_base->isa<AbstractList>()) { | } else if (abs_base->isa<AbstractList>()) { | ||||
| auto arg_list = dyn_cast<AbstractList>(abs_base); | auto arg_list = dyn_cast<AbstractList>(abs_base); | ||||
| size_t len = arg_list->size(); | size_t len = arg_list->size(); | ||||
| @@ -350,25 +334,25 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) { | |||||
| for (size_t i = 0; i < len; i++) { | for (size_t i = 0; i < len; i++) { | ||||
| py::dict out = ConvertAbstractToPython(arg_list->elements()[i]); | py::dict out = ConvertAbstractToPython(arg_list->elements()[i]); | ||||
| shape_list[i] = out["shape"]; | |||||
| dtype_list[i] = out["dtype"]; | |||||
| shape_list[i] = out[ATTR_SHAPE]; | |||||
| dtype_list[i] = out[ATTR_DTYPE]; | |||||
| } | } | ||||
| dic["shape"] = shape_list; | |||||
| dic["dtype"] = dtype_list; | |||||
| dic["value"] = BuildValue(arg_list->BuildValue()); | |||||
| dic[ATTR_SHAPE] = shape_list; | |||||
| dic[ATTR_DTYPE] = dtype_list; | |||||
| dic[ATTR_VALUE] = BuildValue(arg_list->BuildValue()); | |||||
| } else if (abs_base->isa<AbstractNone>()) { | } else if (abs_base->isa<AbstractNone>()) { | ||||
| dic["shape"] = py::none(); | |||||
| dic["dtype"] = py::none(); | |||||
| dic["value"] = py::none(); | |||||
| dic[ATTR_SHAPE] = py::none(); | |||||
| dic[ATTR_DTYPE] = py::none(); | |||||
| dic[ATTR_VALUE] = py::none(); | |||||
| } else if (abs_base->isa<AbstractFunction>()) { | } else if (abs_base->isa<AbstractFunction>()) { | ||||
| dic["shape"] = py::none(); | |||||
| dic["dtype"] = abs_base->BuildType(); | |||||
| dic["value"] = py::none(); | |||||
| dic[ATTR_SHAPE] = py::none(); | |||||
| dic[ATTR_DTYPE] = abs_base->BuildType(); | |||||
| dic[ATTR_VALUE] = py::none(); | |||||
| } else if (abs_base->isa<AbstractUndetermined>()) { | } else if (abs_base->isa<AbstractUndetermined>()) { | ||||
| auto arg = dyn_cast<AbstractUndetermined>(abs_base); | auto arg = dyn_cast<AbstractUndetermined>(abs_base); | ||||
| dic["shape"] = py::none(); | |||||
| dic["dtype"] = arg->BuildType(); | |||||
| dic["value"] = py::none(); | |||||
| dic[ATTR_SHAPE] = py::none(); | |||||
| dic[ATTR_DTYPE] = arg->BuildType(); | |||||
| dic[ATTR_VALUE] = py::none(); | |||||
| } else { | } else { | ||||
| auto value = abs_base->BuildValue(); | auto value = abs_base->BuildValue(); | ||||
| if ((*value == *kAnyValue)) { | if ((*value == *kAnyValue)) { | ||||
| @@ -409,18 +393,20 @@ py::tuple PreparePyInputs(const PrimitivePyPtr &prim_py, const AbstractBasePtrLi | |||||
| AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dict &output) { | AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dict &output) { | ||||
| // Convert to AbstractValue based on type and shape | // Convert to AbstractValue based on type and shape | ||||
| auto out_dtype = output["dtype"]; | |||||
| if (output["value"].is_none()) { | |||||
| auto out_shape = output["shape"]; | |||||
| py::object min_shape = output.contains("min_shape") ? (py::object)output["min_shape"] : (py::object)py::none(); | |||||
| py::object max_shape = output.contains("max_shape") ? (py::object)output["max_shape"] : (py::object)py::none(); | |||||
| auto out_dtype = output[ATTR_DTYPE]; | |||||
| if (output[ATTR_VALUE].is_none()) { | |||||
| auto out_shape = output[ATTR_SHAPE]; | |||||
| py::object min_shape = | |||||
| output.contains(py::str(ATTR_MIN_SHAPE)) ? (py::object)output[ATTR_MIN_SHAPE] : (py::object)py::none(); | |||||
| py::object max_shape = | |||||
| output.contains(py::str(ATTR_MAX_SHAPE)) ? (py::object)output[ATTR_MAX_SHAPE] : (py::object)py::none(); | |||||
| return PyListDtype2AbstractTensor(out_shape, out_dtype, min_shape, max_shape); | return PyListDtype2AbstractTensor(out_shape, out_dtype, min_shape, max_shape); | ||||
| } | } | ||||
| // Convert pyobject to Value, then to AbstractValue | // Convert pyobject to Value, then to AbstractValue | ||||
| ValuePtr converted_ret = nullptr; | ValuePtr converted_ret = nullptr; | ||||
| TypePtr dtype = py::isinstance<Type>(out_dtype) ? out_dtype.cast<TypePtr>() : nullptr; | TypePtr dtype = py::isinstance<Type>(out_dtype) ? out_dtype.cast<TypePtr>() : nullptr; | ||||
| bool converted = parse::ConvertData(output["value"], &converted_ret, false, dtype); | |||||
| bool converted = parse::ConvertData(output[ATTR_VALUE], &converted_ret, false, dtype); | |||||
| if (!converted) { | if (!converted) { | ||||
| MS_LOG(EXCEPTION) << "Convert data failed"; | MS_LOG(EXCEPTION) << "Convert data failed"; | ||||
| } | } | ||||
| @@ -447,6 +433,73 @@ AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dic | |||||
| } | } | ||||
| } // end anonymous namespace | } // end anonymous namespace | ||||
| EvalResultPtr StandardPrimEvaluator::EvalPyCheckPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) { | |||||
| auto prim_py = dyn_cast<PrimitivePy>(prim_); | |||||
| if (prim_py == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "The primitive with type 'kPrimTypePyInferCheck' should be a python primitive."; | |||||
| } | |||||
| // Call checking method '__check__' for subclass of 'PrimitiveWithCheck' | |||||
| MS_LOG(DEBUG) << "Begin input args checking for: " << prim_py->ToString(); | |||||
| auto py_args = PreparePyInputs(prim_py, args); | |||||
| prim_py->RunCheck(py_args); | |||||
| prim_->BeginRecordAddAttr(); | |||||
| AbstractBasePtr abs_base = eval_impl_(engine, prim_, args); | |||||
| prim_->EndRecordAddAttr(); | |||||
| auto added_attrs = prim_->evaluate_added_attrs(); | |||||
| if (!py::hasattr(prim_py->GetPyObj(), PY_PRIM_METHOD_INFER_VALUE)) { | |||||
| return std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs)); | |||||
| } | |||||
| // Call method 'infer_value' for primitive with this method for constant propagation | |||||
| py::tuple py_vals(py_args.size()); | |||||
| for (size_t i = 0; i < py_args.size(); ++i) { | |||||
| py_vals[i] = py_args[i][ATTR_VALUE]; | |||||
| } | |||||
| py::object py_ret = prim_py->RunInferValue(py_vals); | |||||
| if (py::isinstance<py::none>(py_ret)) { | |||||
| return std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs)); | |||||
| } | |||||
| // Convert pyobject to Value, then to AbstractValue | |||||
| ValuePtr converted_ret = nullptr; | |||||
| TypePtr dtype = abs_base->BuildType(); | |||||
| bool converted = parse::ConvertData(py_ret, &converted_ret, false, dtype); | |||||
| if (!converted) { | |||||
| MS_LOG(EXCEPTION) << "Convert data failed"; | |||||
| } | |||||
| auto res_spec = FromValue(converted_ret); | |||||
| MS_EXCEPTION_IF_NULL(res_spec); | |||||
| if (res_spec->isa<AbstractTensor>()) { | |||||
| // Replace to tensor constant node in specialize | |||||
| auto res_tensor = res_spec->cast<AbstractTensorPtr>(); | |||||
| res_tensor->set_value(converted_ret); | |||||
| } | |||||
| return std::make_shared<EvalResult>(res_spec, std::make_shared<AttrValueMap>(added_attrs)); | |||||
| } | |||||
| EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) { | |||||
| if (prims_to_skip_undetermined_infer.find(prim_->name()) == prims_to_skip_undetermined_infer.end()) { | |||||
| auto ret_abstract = AbstractEval(args); | |||||
| if (ret_abstract != nullptr) { | |||||
| MS_LOG(DEBUG) << "StandardPrimEvaluator eval Undetermined"; | |||||
| return ret_abstract; | |||||
| } | |||||
| } | |||||
| if (prim_->prim_type() == PrimType::kPrimTypePyInferCheck) { | |||||
| return EvalPyCheckPrim(engine, args); | |||||
| } | |||||
| prim_->BeginRecordAddAttr(); | |||||
| AbstractBasePtr abs_base = eval_impl_(engine, prim_, args); | |||||
| prim_->EndRecordAddAttr(); | |||||
| auto added_attrs = prim_->evaluate_added_attrs(); | |||||
| return std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs)); | |||||
| } | |||||
| EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) { | EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) { | ||||
| auto ret_abstract = AbstractEval(args); | auto ret_abstract = AbstractEval(args); | ||||
| if (ret_abstract != nullptr) { | if (ret_abstract != nullptr) { | ||||
| @@ -42,6 +42,8 @@ class StandardPrimEvaluator : public TrivialPrimEvaluator { | |||||
| std::string ToString() const override { return identifier_ + prim_->name(); } | std::string ToString() const override { return identifier_ + prim_->name(); } | ||||
| private: | private: | ||||
| EvalResultPtr EvalPyCheckPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args); | |||||
| PrimitivePtr prim_; | PrimitivePtr prim_; | ||||
| const StandardPrimitiveEvalImpl eval_impl_; | const StandardPrimitiveEvalImpl eval_impl_; | ||||
| }; | }; | ||||
| @@ -308,20 +308,18 @@ void AnalysisEngine::Clear() { | |||||
| namespace { | namespace { | ||||
| EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr &engine) { | EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr &engine) { | ||||
| // Custom Primitive with python infer_shape, infer_type | // Custom Primitive with python infer_shape, infer_type | ||||
| EvaluatorPtr evaluator = nullptr; | |||||
| MS_EXCEPTION_IF_NULL(prim); | MS_EXCEPTION_IF_NULL(prim); | ||||
| if (prim->isa<prim::DoSignaturePrimitive>()) { | if (prim->isa<prim::DoSignaturePrimitive>()) { | ||||
| evaluator = std::make_shared<DoSignatureEvaluator>(prim); | |||||
| return evaluator; | |||||
| return std::make_shared<DoSignatureEvaluator>(prim); | |||||
| } | } | ||||
| if (prim->isa<prim::UnpackGraphPrimitive>()) { | if (prim->isa<prim::UnpackGraphPrimitive>()) { | ||||
| evaluator = std::make_shared<UnpackGraphEvaluator>(prim); | |||||
| return evaluator; | |||||
| return std::make_shared<UnpackGraphEvaluator>(prim); | |||||
| } | } | ||||
| if (prim->Hash() == prim::kPrimMixedPrecisionCast->Hash() && prim->name() == prim::kPrimMixedPrecisionCast->name()) { | if (prim->Hash() == prim::kPrimMixedPrecisionCast->Hash() && prim->name() == prim::kPrimMixedPrecisionCast->name()) { | ||||
| evaluator = std::make_shared<MixedPrecisionCastEvaluator>(prim); | |||||
| return evaluator; | |||||
| return std::make_shared<MixedPrecisionCastEvaluator>(prim); | |||||
| } | } | ||||
| EvaluatorPtr evaluator = nullptr; | |||||
| if (prim->HasPyEvaluator()) { | if (prim->HasPyEvaluator()) { | ||||
| auto prim_py = dyn_cast<PrimitivePy>(prim); | auto prim_py = dyn_cast<PrimitivePy>(prim); | ||||
| if (prim_py != nullptr) { | if (prim_py != nullptr) { | ||||
| @@ -55,6 +55,10 @@ void ValidateOperation(const AnfNodePtr &node) { | |||||
| MS_LOG(DEBUG) << "Primitive " << prim->name() << " has python evaluator."; | MS_LOG(DEBUG) << "Primitive " << prim->name() << " has python evaluator."; | ||||
| return; | return; | ||||
| } | } | ||||
| if (prim->prim_type() == PrimType::kPrimTypePyInferCheck) { | |||||
| MS_LOG(DEBUG) << "Primitive " << prim->name() << " has python inference checking method."; | |||||
| return; | |||||
| } | |||||
| if (prim->name() == "fake_bprop") { | if (prim->name() == "fake_bprop") { | ||||
| MS_LOG(EXCEPTION) << "Illegal primitive: " << GetValue<std::string>(prim->GetAttr("info")); | MS_LOG(EXCEPTION) << "Illegal primitive: " << GetValue<std::string>(prim->GetAttr("info")); | ||||
| } | } | ||||
| @@ -254,16 +254,33 @@ py::dict PrimitivePy::RunInfer(const py::tuple &args) { | |||||
| if (!HasPyObj()) { | if (!HasPyObj()) { | ||||
| MS_LOG(EXCEPTION) << "[" << this->ToString() << "]: pyobj is empty"; | MS_LOG(EXCEPTION) << "[" << this->ToString() << "]: pyobj is empty"; | ||||
| } | } | ||||
| auto infer_fuc = python_obj_.attr("__infer__"); | |||||
| auto infer_fuc = python_obj_.attr(PY_PRIM_METHOD_INFER); | |||||
| return infer_fuc(*args); | return infer_fuc(*args); | ||||
| } | } | ||||
| void PrimitivePy::RunCheck(const py::tuple &args) { | |||||
| if (!HasPyObj()) { | |||||
| MS_LOG(EXCEPTION) << "[" << this->ToString() << "]: pyobj is empty"; | |||||
| } | |||||
| auto check_func = python_obj_.attr(PY_PRIM_METHOD_CHECK); | |||||
| (void)check_func(*args); | |||||
| } | |||||
| py::object PrimitivePy::RunInferValue(const py::tuple &args) { | |||||
| if (!HasPyObj()) { | |||||
| MS_LOG(EXCEPTION) << "[" << this->ToString() << "]: pyobj is empty"; | |||||
| } | |||||
| auto infer_value = python_obj_.attr(PY_PRIM_METHOD_INFER_VALUE); | |||||
| return infer_value(*args); | |||||
| } | |||||
| REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) { | REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) { | ||||
| (void)py::enum_<PrimType>(*m, "prim_type", py::arithmetic()) | (void)py::enum_<PrimType>(*m, "prim_type", py::arithmetic()) | ||||
| .value("unknown", PrimType::kPrimTypeUnknown) | .value("unknown", PrimType::kPrimTypeUnknown) | ||||
| .value("builtin", PrimType::kPrimTypeBuiltIn) | .value("builtin", PrimType::kPrimTypeBuiltIn) | ||||
| .value("py_infer_shape", PrimType::kPrimTypePyInferShape) | .value("py_infer_shape", PrimType::kPrimTypePyInferShape) | ||||
| .value("user_custom", PrimType::kPrimTypeUserCustom); | |||||
| .value("user_custom", PrimType::kPrimTypeUserCustom) | |||||
| .value("py_infer_check", PrimType::kPrimTypePyInferCheck); | |||||
| (void)py::class_<PrimitivePy, std::shared_ptr<PrimitivePy>>(*m, "Primitive_") | (void)py::class_<PrimitivePy, std::shared_ptr<PrimitivePy>>(*m, "Primitive_") | ||||
| .def_readonly(PYTHON_PRIMITIVE_FLAG, &PrimitivePy::parse_info_) | .def_readonly(PYTHON_PRIMITIVE_FLAG, &PrimitivePy::parse_info_) | ||||
| .def(py::init<py::str &, py::object>()) | .def(py::init<py::str &, py::object>()) | ||||
| @@ -62,6 +62,8 @@ class PrimitivePy : public Primitive { | |||||
| const bool parse_info_ = true; | const bool parse_info_ = true; | ||||
| const py::object &GetPyObj() const { return python_obj_; } | const py::object &GetPyObj() const { return python_obj_; } | ||||
| py::dict RunInfer(const py::tuple &args); | py::dict RunInfer(const py::tuple &args); | ||||
| void RunCheck(const py::tuple &args); | |||||
| py::object RunInferValue(const py::tuple &args); | |||||
| bool ObjHasAttr(const char *attr_name) { return py::hasattr(python_obj_, attr_name); } | bool ObjHasAttr(const char *attr_name) { return py::hasattr(python_obj_, attr_name); } | ||||
| bool HasPyObj() { return python_obj_.operator bool(); } | bool HasPyObj() { return python_obj_.operator bool(); } | ||||
| PrimitivePtr Clone() override; | PrimitivePtr Clone() override; | ||||
| @@ -81,6 +81,9 @@ AbstractBasePtr InferImplDropoutGenMask(const AnalysisEnginePtr &, const Primiti | |||||
| AbstractBasePtr InferImplMinOrMaxGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplMinOrMaxGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplSqrt(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| AbstractBasePtr InferImplScalarToArray(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplScalarToArray(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplArrayToScalar(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplArrayToScalar(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| @@ -176,6 +179,14 @@ AbstractBasePtr InferImplUniqueGrad(const AnalysisEnginePtr &, const PrimitivePt | |||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplGatherV2(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| AbstractBasePtr InferImplDynamicShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| AbstractBasePtr InferImplSparseApplyFtrl(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| AbstractBasePtr InferImplSparseApplyProximalAdagrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| template <typename T> | template <typename T> | ||||
| AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { | AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { | ||||
| @@ -14,6 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include <algorithm> | |||||
| #include <iterator> | |||||
| #include "abstract/infer_functions.h" | #include "abstract/infer_functions.h" | ||||
| #include "abstract/utils.h" | #include "abstract/utils.h" | ||||
| #include "abstract/param_validator.h" | #include "abstract/param_validator.h" | ||||
| @@ -226,5 +228,60 @@ AbstractBasePtr InferImplUniqueGrad(const AnalysisEnginePtr &, const PrimitivePt | |||||
| // outputs: dx | // outputs: dx | ||||
| return std::make_shared<AbstractTensor>(ids->element(), ids_idx->shape()); | return std::make_shared<AbstractTensor>(ids->element(), ids_idx->shape()); | ||||
| } | } | ||||
| AbstractBasePtr InferImplGatherV2(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| const std::string &op_name = primitive->name(); | |||||
| CheckArgsSize(op_name, args_spec_list, 3); | |||||
| AbstractTensorPtr params = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||||
| AbstractTensorPtr indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 1); | |||||
| AbstractScalarPtr axis = CheckArg<AbstractScalar>(op_name, args_spec_list, 2); | |||||
| auto params_shp = params->shape()->shape(); | |||||
| auto indices_shp = indices->shape()->shape(); | |||||
| auto axis_val = GetValue<int>(axis->BuildValue()); | |||||
| auto params_rank = static_cast<int>(params_shp.size()); | |||||
| if (axis_val < 0) { | |||||
| axis_val += params_rank; | |||||
| } | |||||
| auto calc_shape = [axis_val, ¶ms_shp](const ShapeVector &inp_vec) -> ShapeVector { | |||||
| ShapeVector out_vec; | |||||
| std::copy(params_shp.begin(), params_shp.begin() + axis_val, std::back_inserter(out_vec)); | |||||
| copy(inp_vec.begin(), inp_vec.end(), std::back_inserter(out_vec)); | |||||
| copy(params_shp.begin() + axis_val + 1, params_shp.end(), std::back_inserter(out_vec)); | |||||
| return out_vec; | |||||
| }; | |||||
| ShapeVector out_shape = calc_shape(indices_shp); | |||||
| if (!indices->shape()->min_shape().empty() && !indices->shape()->max_shape().empty()) { | |||||
| ShapeVector min_shape = calc_shape(indices->shape()->min_shape()); | |||||
| ShapeVector max_shape = calc_shape(indices->shape()->max_shape()); | |||||
| return std::make_shared<AbstractTensor>(params->element(), | |||||
| std::make_shared<Shape>(out_shape, min_shape, max_shape)); | |||||
| } | |||||
| return std::make_shared<AbstractTensor>(params->element(), std::make_shared<Shape>(out_shape)); | |||||
| } | |||||
| AbstractBasePtr InferImplDynamicShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| const std::string &op_name = primitive->name(); | |||||
| CheckArgsSize(op_name, args_spec_list, 1); | |||||
| AbstractTensorPtr input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||||
| auto shape = input->shape()->shape(); | |||||
| bool has_dyn_shape = std::any_of(shape.begin(), shape.end(), [](int dim) { return dim == Shape::SHP_ANY; }); | |||||
| std::vector<int> tensor_shp({static_cast<int>(shape.size())}); | |||||
| if (has_dyn_shape) { | |||||
| auto elem = std::make_shared<AbstractScalar>(std::make_shared<AnyValue>(), std::make_shared<Int>(32)); | |||||
| return std::make_shared<AbstractTensor>(elem, std::make_shared<Shape>(tensor_shp)); | |||||
| } | |||||
| auto shp_buf_size = sizeof(int) * shape.size(); | |||||
| auto tensor = std::make_shared<tensor::Tensor>(kNumberTypeInt32, tensor_shp, shape.data(), shp_buf_size); | |||||
| return tensor->ToAbstract(); | |||||
| } | |||||
| } // namespace abstract | } // namespace abstract | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -37,5 +37,14 @@ AbstractBasePtr InferImplMinOrMaxGrad(const AnalysisEnginePtr &, const Primitive | |||||
| return std::make_shared<AbstractTuple>(AbstractBasePtrList({dx, dy})); | return std::make_shared<AbstractTuple>(AbstractBasePtrList({dx, dy})); | ||||
| } | } | ||||
| AbstractBasePtr InferImplSqrt(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| // Inputs: three tensors. | |||||
| const std::string op_name = primitive->name(); | |||||
| CheckArgsSize(op_name, args_spec_list, 1); | |||||
| auto inp = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||||
| return inp->Clone()->Broaden(); | |||||
| } | |||||
| } // namespace abstract | } // namespace abstract | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -445,5 +445,25 @@ AbstractBasePtr InferImplDropoutGenMask(const AnalysisEnginePtr &, const Primiti | |||||
| return std::make_shared<AbstractTensor>(std::make_shared<AbstractScalar>(kAnyValue, kUInt8), | return std::make_shared<AbstractTensor>(std::make_shared<AbstractScalar>(kAnyValue, kUInt8), | ||||
| std::make_shared<Shape>(std::vector<int64_t>{shape_y})); | std::make_shared<Shape>(std::vector<int64_t>{shape_y})); | ||||
| } | } | ||||
| AbstractBasePtr InferImplSparseApplyFtrl(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| CheckArgsSize(primitive->name(), args_spec_list, 5); | |||||
| AbstractBasePtrList elements; | |||||
| for (size_t i = 0; i < 3; ++i) { | |||||
| elements.push_back(args_spec_list[i]->Clone()->Broaden()); | |||||
| } | |||||
| return std::make_shared<AbstractTuple>(elements); | |||||
| } | |||||
| AbstractBasePtr InferImplSparseApplyProximalAdagrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| CheckArgsSize(primitive->name(), args_spec_list, 7); | |||||
| AbstractBasePtrList elements; | |||||
| for (size_t i = 0; i < 2; ++i) { | |||||
| elements.push_back(args_spec_list[i]->Clone()->Broaden()); | |||||
| } | |||||
| return std::make_shared<AbstractTuple>(elements); | |||||
| } | |||||
| } // namespace abstract | } // namespace abstract | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -37,6 +37,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||||
| // Maths | // Maths | ||||
| {prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, true}}, | {prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, true}}, | ||||
| {prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, true}}, | {prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, true}}, | ||||
| {prim::kPrimSqrt, {InferImplSqrt, true}}, | |||||
| // Array | // Array | ||||
| {prim::kPrimScalarToArray, {InferImplScalarToArray, true}}, | {prim::kPrimScalarToArray, {InferImplScalarToArray, true}}, | ||||
| {prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}}, | {prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}}, | ||||
| @@ -44,6 +45,9 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||||
| {prim::kPrimPack, {InferImplPack, true}}, | {prim::kPrimPack, {InferImplPack, true}}, | ||||
| {prim::kPrimUnique, {InferImplUnique, true}}, | {prim::kPrimUnique, {InferImplUnique, true}}, | ||||
| {prim::kPrimUniqueGrad, {InferImplUniqueGrad, true}}, | {prim::kPrimUniqueGrad, {InferImplUniqueGrad, true}}, | ||||
| {prim::kPrimGatherV2, {InferImplGatherV2, true}}, | |||||
| {prim::kPrimSparseGatherV2, {InferImplGatherV2, true}}, | |||||
| {prim::kPrimDynamicShape, {InferImplDynamicShape, true}}, | |||||
| // Structure | // Structure | ||||
| {prim::kPrimMakeTuple, {InferImplMakeTuple, true}}, | {prim::kPrimMakeTuple, {InferImplMakeTuple, true}}, | ||||
| {prim::kPrimMakeList, {InferImplMakeList, true}}, | {prim::kPrimMakeList, {InferImplMakeList, true}}, | ||||
| @@ -77,6 +81,8 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||||
| {prim::kPrimLayerNorm, {InferImplLayerNorm, true}}, | {prim::kPrimLayerNorm, {InferImplLayerNorm, true}}, | ||||
| {prim::kPrimLayerNormGrad, {InferImplLayerNormGrad, true}}, | {prim::kPrimLayerNormGrad, {InferImplLayerNormGrad, true}}, | ||||
| {prim::kPrimDropoutGenMask, {InferImplDropoutGenMask, true}}, | {prim::kPrimDropoutGenMask, {InferImplDropoutGenMask, true}}, | ||||
| {prim::kPrimSparseApplyFtrl, {InferImplSparseApplyFtrl, true}}, | |||||
| {prim::kPrimSparseApplyProximalAdagrad, {InferImplSparseApplyProximalAdagrad, true}}, | |||||
| // Others | // Others | ||||
| {prim::kPrimIdentity, {InferImplIdentity, true}}, | {prim::kPrimIdentity, {InferImplIdentity, true}}, | ||||
| // Set impl to null as it will use PartialEvaluator; | // Set impl to null as it will use PartialEvaluator; | ||||
| @@ -84,6 +84,9 @@ inline const PrimitivePtr kPrimConcat = std::make_shared<Primitive>("Concat"); | |||||
| inline const PrimitivePtr kPrimSqueeze = std::make_shared<Primitive>("Squeeze"); | inline const PrimitivePtr kPrimSqueeze = std::make_shared<Primitive>("Squeeze"); | ||||
| inline const PrimitivePtr kPrimTranspose = std::make_shared<Primitive>("Transpose"); | inline const PrimitivePtr kPrimTranspose = std::make_shared<Primitive>("Transpose"); | ||||
| inline const PrimitivePtr kPrimGatherV2 = std::make_shared<Primitive>("GatherV2"); | inline const PrimitivePtr kPrimGatherV2 = std::make_shared<Primitive>("GatherV2"); | ||||
| inline const PrimitivePtr kPrimSparseGatherV2 = std::make_shared<Primitive>("SparseGatherV2"); | |||||
| inline const PrimitivePtr kPrimShape = std::make_shared<Primitive>("Shape"); | |||||
| inline const PrimitivePtr kPrimDynamicShape = std::make_shared<Primitive>("DynamicShape"); | |||||
| inline const PrimitivePtr kPrimEmbeddingLookup = std::make_shared<Primitive>("EmbeddingLookup"); | inline const PrimitivePtr kPrimEmbeddingLookup = std::make_shared<Primitive>("EmbeddingLookup"); | ||||
| inline const PrimitivePtr kPrimEmbeddingLookupCommGrad = std::make_shared<Primitive>("EmbeddingLookupCommGrad"); | inline const PrimitivePtr kPrimEmbeddingLookupCommGrad = std::make_shared<Primitive>("EmbeddingLookupCommGrad"); | ||||
| inline const PrimitivePtr kPrimSize = std::make_shared<Primitive>("Size"); | inline const PrimitivePtr kPrimSize = std::make_shared<Primitive>("Size"); | ||||
| @@ -154,6 +157,8 @@ inline const PrimitivePtr kPrimBpropCut = std::make_shared<Primitive>("bprop_cut | |||||
| inline const PrimitivePtr kPrimFakeQuantPerLayer = std::make_shared<Primitive>("FakeQuantPerLayer"); | inline const PrimitivePtr kPrimFakeQuantPerLayer = std::make_shared<Primitive>("FakeQuantPerLayer"); | ||||
| inline const PrimitivePtr kPrimFakeQuantPerChannel = std::make_shared<Primitive>("FakeQuantPerChannel"); | inline const PrimitivePtr kPrimFakeQuantPerChannel = std::make_shared<Primitive>("FakeQuantPerChannel"); | ||||
| inline const PrimitivePtr kPrimApplyRMSProp = std::make_shared<Primitive>("ApplyRMSProp"); | inline const PrimitivePtr kPrimApplyRMSProp = std::make_shared<Primitive>("ApplyRMSProp"); | ||||
| inline const PrimitivePtr kPrimSparseApplyFtrl = std::make_shared<Primitive>("SparseApplyFtrl"); | |||||
| inline const PrimitivePtr kPrimSparseApplyProximalAdagrad = std::make_shared<Primitive>("SparseApplyProximalAdagrad"); | |||||
| // Comm ops | // Comm ops | ||||
| inline const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator"); | inline const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator"); | ||||
| @@ -35,7 +35,8 @@ enum PrimType { | |||||
| kPrimTypeBuiltIn, // Built-in primitive operator | kPrimTypeBuiltIn, // Built-in primitive operator | ||||
| kPrimTypePyInferShape, // Primitive operator defined by custom | kPrimTypePyInferShape, // Primitive operator defined by custom | ||||
| kPrimTypePyInferTensor, // Primitive operator defined by custom | kPrimTypePyInferTensor, // Primitive operator defined by custom | ||||
| kPrimTypeUserCustom | |||||
| kPrimTypeUserCustom, | |||||
| kPrimTypePyInferCheck // Primitive operator with input args checking method | |||||
| }; | }; | ||||
| class Primitive : public Named { | class Primitive : public Named { | ||||
| @@ -23,4 +23,19 @@ const char GRAPH_FLAG_HAS_EFFECT[] = "has_effect"; | |||||
| const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[] = "_effect_patial_order"; | const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[] = "_effect_patial_order"; | ||||
| const char GRAPH_FLAG_RANDOM_EFFECT[] = "_random_effect"; | const char GRAPH_FLAG_RANDOM_EFFECT[] = "_random_effect"; | ||||
| const char GRAPH_FLAG_SIDE_EFFECT[] = "_side_effect"; | const char GRAPH_FLAG_SIDE_EFFECT[] = "_side_effect"; | ||||
| // method names of python primitive called from c++ source code | |||||
| // 1. infer method name of class 'PrimitiveWithInfer' | |||||
| const char PY_PRIM_METHOD_INFER[] = "__infer__"; | |||||
| // 2. check method name of class 'PrimitiveWithCheck' | |||||
| const char PY_PRIM_METHOD_CHECK[] = "__check__"; | |||||
| // 3. method name of class 'PrimitivePy' for constant propagation | |||||
| const char PY_PRIM_METHOD_INFER_VALUE[] = "infer_value"; | |||||
| // type inference related attributes | |||||
| const char ATTR_VALUE[] = "value"; | |||||
| const char ATTR_DTYPE[] = "dtype"; | |||||
| const char ATTR_SHAPE[] = "shape"; | |||||
| const char ATTR_MIN_SHAPE[] = "min_shape"; | |||||
| const char ATTR_MAX_SHAPE[] = "max_shape"; | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -23,6 +23,16 @@ extern const char GRAPH_FLAG_HAS_EFFECT[]; | |||||
| extern const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[]; | extern const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[]; | ||||
| extern const char GRAPH_FLAG_RANDOM_EFFECT[]; | extern const char GRAPH_FLAG_RANDOM_EFFECT[]; | ||||
| extern const char GRAPH_FLAG_SIDE_EFFECT[]; | extern const char GRAPH_FLAG_SIDE_EFFECT[]; | ||||
| extern const char PY_PRIM_METHOD_INFER[]; | |||||
| extern const char PY_PRIM_METHOD_CHECK[]; | |||||
| extern const char PY_PRIM_METHOD_INFER_VALUE[]; | |||||
| extern const char ATTR_VALUE[]; | |||||
| extern const char ATTR_DTYPE[]; | |||||
| extern const char ATTR_SHAPE[]; | |||||
| extern const char ATTR_MIN_SHAPE[]; | |||||
| extern const char ATTR_MAX_SHAPE[]; | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CORE_UTILS_FLAGS_H | #endif // MINDSPORE_CORE_UTILS_FLAGS_H | ||||
| @@ -27,7 +27,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, | |||||
| Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue, | Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue, | ||||
| SameTypeShape, ScatterAdd, ScatterSub, ScatterMul, ScatterDiv, ScatterMax, ScatterMin, | SameTypeShape, ScatterAdd, ScatterSub, ScatterMul, ScatterDiv, ScatterMax, ScatterMin, | ||||
| ScatterUpdate, ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select, | ScatterUpdate, ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select, | ||||
| Shape, Size, Slice, Split, TransShape, ParallelConcat, Padding, | |||||
| Shape, DynamicShape, Size, Slice, Split, TransShape, ParallelConcat, Padding, | |||||
| ScatterNdAdd, ScatterNdSub, ScatterNonAliasingAdd, ReverseV2, Rint, | ScatterNdAdd, ScatterNdSub, ScatterNonAliasingAdd, ReverseV2, Rint, | ||||
| Squeeze, StridedSlice, Tile, TensorScatterUpdate, EditDistance, | Squeeze, StridedSlice, Tile, TensorScatterUpdate, EditDistance, | ||||
| Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentProd, | Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentProd, | ||||
| @@ -206,6 +206,7 @@ __all__ = [ | |||||
| 'HookBackward', | 'HookBackward', | ||||
| 'InvertPermutation', | 'InvertPermutation', | ||||
| 'Shape', | 'Shape', | ||||
| 'DynamicShape', | |||||
| 'DropoutDoMask', | 'DropoutDoMask', | ||||
| 'DropoutGenMask', | 'DropoutGenMask', | ||||
| 'DropoutGrad', | 'DropoutGrad', | ||||
| @@ -27,7 +27,7 @@ import numpy as np | |||||
| from .._utils import get_concat_offset | from .._utils import get_concat_offset | ||||
| from ..operations.math_ops import _infer_shape_reduce | from ..operations.math_ops import _infer_shape_reduce | ||||
| from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register, _run_op | |||||
| from ..primitive import Primitive, PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op | |||||
| from ..._c_expression import signature_dtype as sig_dtype | from ..._c_expression import signature_dtype as sig_dtype | ||||
| from ..._c_expression import signature_kind as sig_kind | from ..._c_expression import signature_kind as sig_kind | ||||
| from ..._c_expression import signature_rw as sig_rw | from ..._c_expression import signature_rw as sig_rw | ||||
| @@ -142,6 +142,11 @@ class ExpandDims(PrimitiveWithInfer): | |||||
| out = {'shape': x_shape, | out = {'shape': x_shape, | ||||
| 'dtype': x['dtype'], | 'dtype': x['dtype'], | ||||
| 'value': value} | 'value': value} | ||||
| if 'min_shape' in x and 'max_shape' in x: | |||||
| out['min_shape'] = x['min_shape'] | |||||
| out['min_shape'].insert(axis_v, 1) | |||||
| out['max_shape'] = x['max_shape'] | |||||
| out['max_shape'].insert(axis_v, 1) | |||||
| return out | return out | ||||
| @@ -277,6 +282,9 @@ class Cast(PrimitiveWithInfer): | |||||
| out = {'shape': x['shape'], | out = {'shape': x['shape'], | ||||
| 'dtype': mstype.tensor_type(t['value']), | 'dtype': mstype.tensor_type(t['value']), | ||||
| 'value': value} | 'value': value} | ||||
| if 'min_shape' in x and 'max_shape' in x: | |||||
| out['min_shape'] = x['min_shape'] | |||||
| out['max_shape'] = x['max_shape'] | |||||
| return out | return out | ||||
| @@ -445,6 +453,27 @@ class Shape(PrimitiveWithInfer): | |||||
| return out | return out | ||||
| class DynamicShape(Primitive): | |||||
| """ | |||||
| Returns the shape of input tensor. | |||||
| Inputs: | |||||
| - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. | |||||
| Outputs: | |||||
| Tensor[int], 1-dim Tensor of type int32 | |||||
| Examples: | |||||
| >>> input_tensor = Tensor(np.ones(shape=[3, 2, 1]), mindspore.float32) | |||||
| >>> shape = P.DynamicShape() | |||||
| >>> output = shape(input_tensor) | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self): | |||||
| """init Shape""" | |||||
| class Squeeze(PrimitiveWithInfer): | class Squeeze(PrimitiveWithInfer): | ||||
| """ | """ | ||||
| Returns a tensor with the same type but dimensions of 1 being removed based on axis. | Returns a tensor with the same type but dimensions of 1 being removed based on axis. | ||||
| @@ -578,7 +607,7 @@ class Unique(Primitive): | |||||
| self.init_prim_io_names(inputs=['x'], outputs=['output']) | self.init_prim_io_names(inputs=['x'], outputs=['output']) | ||||
| class GatherV2(PrimitiveWithInfer): | |||||
| class GatherV2(PrimitiveWithCheck): | |||||
| """ | """ | ||||
| Returns a slice of input tensor based on the specified indices and axis. | Returns a slice of input tensor based on the specified indices and axis. | ||||
| @@ -605,7 +634,7 @@ class GatherV2(PrimitiveWithInfer): | |||||
| """init index_select""" | """init index_select""" | ||||
| self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output']) | self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output']) | ||||
| def __infer__(self, params, indices, axis): | |||||
| def __check__(self, params, indices, axis): | |||||
| validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) | validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) | ||||
| validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name) | validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name) | ||||
| validator.check_subclass("axis", axis['dtype'], mstype.int_, self.name) | validator.check_subclass("axis", axis['dtype'], mstype.int_, self.name) | ||||
| @@ -613,13 +642,6 @@ class GatherV2(PrimitiveWithInfer): | |||||
| params_shp = params['shape'] | params_shp = params['shape'] | ||||
| rank = len(params_shp) | rank = len(params_shp) | ||||
| validator.check_int_range("axis", axis_v, -rank, rank, Rel.INC_LEFT, self.name) | validator.check_int_range("axis", axis_v, -rank, rank, Rel.INC_LEFT, self.name) | ||||
| if axis_v < 0: | |||||
| axis_v += rank | |||||
| out_shape = params_shp[:axis_v] + indices['shape'] + params_shp[axis_v + 1:] | |||||
| out = {'shape': out_shape, | |||||
| 'dtype': params['dtype'], | |||||
| 'value': None} | |||||
| return out | |||||
| class SparseGatherV2(GatherV2): | class SparseGatherV2(GatherV2): | ||||
| @@ -26,7 +26,7 @@ from ..._checkparam import Rel | |||||
| from ...common import dtype as mstype | from ...common import dtype as mstype | ||||
| from ...common.tensor import Tensor | from ...common.tensor import Tensor | ||||
| from .._utils import get_broadcast_shape | from .._utils import get_broadcast_shape | ||||
| from ..primitive import PrimitiveWithInfer, prim_attr_register, _run_op | |||||
| from ..primitive import PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op | |||||
| def _infer_shape_reduce(x, axis, keep_dims, prim_name): | def _infer_shape_reduce(x, axis, keep_dims, prim_name): | ||||
| @@ -1257,7 +1257,7 @@ class Rsqrt(PrimitiveWithInfer): | |||||
| return None | return None | ||||
| class Sqrt(PrimitiveWithInfer): | |||||
| class Sqrt(PrimitiveWithCheck): | |||||
| """ | """ | ||||
| Returns square root of a tensor element-wise. | Returns square root of a tensor element-wise. | ||||
| @@ -1279,12 +1279,8 @@ class Sqrt(PrimitiveWithInfer): | |||||
| """init Sqrt""" | """init Sqrt""" | ||||
| self.init_prim_io_names(inputs=['x'], outputs=['output']) | self.init_prim_io_names(inputs=['x'], outputs=['output']) | ||||
| def infer_shape(self, x_shape): | |||||
| return x_shape | |||||
| def infer_dtype(self, x_type): | |||||
| def check_dtype(self, x_type): | |||||
| validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.name) | validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.name) | ||||
| return x_type | |||||
| def infer_value(self, x): | def infer_value(self, x): | ||||
| if x is not None: | if x is not None: | ||||
| @@ -28,7 +28,7 @@ from ..._c_expression import signature_dtype as sig_dtype | |||||
| from ..._checkparam import Validator as validator | from ..._checkparam import Validator as validator | ||||
| from ..._checkparam import Rel | from ..._checkparam import Rel | ||||
| from ...common import dtype as mstype | from ...common import dtype as mstype | ||||
| from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register | |||||
| from ..primitive import Primitive, PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register | |||||
| from ..operations.math_ops import _infer_shape_reduce | from ..operations.math_ops import _infer_shape_reduce | ||||
| @@ -4354,7 +4354,7 @@ class ApplyProximalAdagrad(PrimitiveWithInfer): | |||||
| return var_dtype, accum_dtype | return var_dtype, accum_dtype | ||||
| class SparseApplyProximalAdagrad(PrimitiveWithInfer): | |||||
| class SparseApplyProximalAdagrad(PrimitiveWithCheck): | |||||
| r""" | r""" | ||||
| Update relevant entries according to the proximal adagrad algorithm. Compared with ApplyProximalAdagrad, | Update relevant entries according to the proximal adagrad algorithm. Compared with ApplyProximalAdagrad, | ||||
| an additional index tensor is input. | an additional index tensor is input. | ||||
| @@ -4433,11 +4433,10 @@ class SparseApplyProximalAdagrad(PrimitiveWithInfer): | |||||
| outputs=['var', 'accum']) | outputs=['var', 'accum']) | ||||
| self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) | self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) | ||||
| def infer_shape(self, var_shape, accum_shape, lr_shape, l1_shape, l2_shape, grad_shape, indices_shape): | |||||
| def check_shape(self, var_shape, accum_shape, lr_shape, l1_shape, l2_shape, grad_shape, indices_shape): | |||||
| validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name) | validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name) | ||||
| return var_shape, accum_shape | |||||
| def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype, indices_dtype): | |||||
| def check_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype, indices_dtype): | |||||
| args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype} | args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype} | ||||
| validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name) | validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name) | ||||
| validator.check_scalar_or_tensor_type_same({"lr": lr_dtype}, [mstype.float16, mstype.float32], self.name) | validator.check_scalar_or_tensor_type_same({"lr": lr_dtype}, [mstype.float16, mstype.float32], self.name) | ||||
| @@ -4446,7 +4445,6 @@ class SparseApplyProximalAdagrad(PrimitiveWithInfer): | |||||
| valid_types = [mstype.int16, mstype.int32, mstype.int64, | valid_types = [mstype.int16, mstype.int32, mstype.int64, | ||||
| mstype.uint16, mstype.uint32, mstype.uint64] | mstype.uint16, mstype.uint32, mstype.uint64] | ||||
| validator.check_tensor_type_same({'indices': indices_dtype}, valid_types, self.name) | validator.check_tensor_type_same({'indices': indices_dtype}, valid_types, self.name) | ||||
| return var_dtype, accum_dtype | |||||
| class ApplyAddSign(PrimitiveWithInfer): | class ApplyAddSign(PrimitiveWithInfer): | ||||
| @@ -4978,7 +4976,7 @@ class ApplyFtrl(PrimitiveWithInfer): | |||||
| return var_type | return var_type | ||||
| class SparseApplyFtrl(PrimitiveWithInfer): | |||||
| class SparseApplyFtrl(PrimitiveWithCheck): | |||||
| """ | """ | ||||
| Update relevant entries according to the FTRL-proximal scheme. | Update relevant entries according to the FTRL-proximal scheme. | ||||
| @@ -5053,21 +5051,19 @@ class SparseApplyFtrl(PrimitiveWithInfer): | |||||
| self.lr_power = validator.check_number("lr_power", lr_power, 0, Rel.LE, self.name) | self.lr_power = validator.check_number("lr_power", lr_power, 0, Rel.LE, self.name) | ||||
| self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) | self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) | ||||
| def infer_shape(self, var_shape, accum_shape, linear_shape, grad_shape, indices_shape): | |||||
| def check_shape(self, var_shape, accum_shape, linear_shape, grad_shape, indices_shape): | |||||
| validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name) | validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name) | ||||
| validator.check('var shape', var_shape, 'linear shape', linear_shape, Rel.EQ, self.name) | validator.check('var shape', var_shape, 'linear shape', linear_shape, Rel.EQ, self.name) | ||||
| if len(var_shape) > 1: | if len(var_shape) > 1: | ||||
| validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name) | validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name) | ||||
| validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name) | validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name) | ||||
| validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name) | validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name) | ||||
| return var_shape, accum_shape, linear_shape | |||||
| def infer_dtype(self, var_dtype, accum_dtype, linear_dtype, grad_dtype, indices_dtype): | |||||
| def check_dtype(self, var_dtype, accum_dtype, linear_dtype, grad_dtype, indices_dtype): | |||||
| args = {"var_dtype": var_dtype, "accum_dtype": accum_dtype, | args = {"var_dtype": var_dtype, "accum_dtype": accum_dtype, | ||||
| "linear_dtype": linear_dtype, "grad_dtype": grad_dtype} | "linear_dtype": linear_dtype, "grad_dtype": grad_dtype} | ||||
| validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name) | validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name) | ||||
| validator.check_tensor_type_same({"indices_dtype": indices_dtype}, [mstype.int32], self.name) | validator.check_tensor_type_same({"indices_dtype": indices_dtype}, [mstype.int32], self.name) | ||||
| return var_dtype, accum_dtype, linear_dtype | |||||
| class SparseApplyFtrlV2(PrimitiveWithInfer): | class SparseApplyFtrlV2(PrimitiveWithInfer): | ||||
| @@ -200,6 +200,84 @@ class Primitive(Primitive_): | |||||
| return self._update_parameter | return self._update_parameter | ||||
| class PrimitiveWithCheck(Primitive): | |||||
| """ | |||||
| PrimitiveWithCheck is the base class of primitives in python defines functions for checking operator input arguments | |||||
| but used the infer method registed in c++ source codes. | |||||
| There are three methods can be overide to define the check logic of the primitive: __check__(), check_shape(), | |||||
| check_dtype(). If __check__() is defined in primitive, the __check__() has highest priority to be called. | |||||
| If __check__() is not defined, infer_shape() and infer_dtype() can be defined to describe the check logic of | |||||
| the shape and type. | |||||
| Args: | |||||
| name (str): Name of the current Primitive. | |||||
| Examples: | |||||
| >>> # init a Primitive class with check | |||||
| >>> class Flatten(PrimitiveWithCheck): | |||||
| >>> @prim_attr_register | |||||
| >>> def __init__(self): | |||||
| >>> pass | |||||
| >>> def check_shape(self, input_x): | |||||
| >>> validator.check_integer('input_x rank', len(input_x), 1, Rel.GE, self.name) | |||||
| >>> | |||||
| >>> def check_dtype(self, input_x): | |||||
| >>> validator.check_subclass("input_x", input_x, mstype.tensor, self.name) | |||||
| >>> | |||||
| >>> # init a Primitive obj | |||||
| >>> add = Flatten() | |||||
| """ | |||||
| def __init__(self, name): | |||||
| Primitive.__init__(self, name) | |||||
| self.set_prim_type(prim_type.py_infer_check) | |||||
| def _clone(self): | |||||
| """ | |||||
| Deeply clones the primitive object. | |||||
| Calls the __init__() method with the same arguments. This method is called in parser if the | |||||
| flag self.__setattr_flag__ is True. | |||||
| """ | |||||
| cloned_prim = Primitive._clone(self) | |||||
| return cloned_prim | |||||
| def check_shape(self, *args): | |||||
| """ | |||||
| Check shapes of input args. | |||||
| Note: | |||||
| The shape of scalar is an empty tuple. | |||||
| Args: | |||||
| args (tuple(int)): shapes of input tensors. | |||||
| Return: | |||||
| None. | |||||
| """ | |||||
| return None | |||||
| def check_dtype(self, *args): | |||||
| """ | |||||
| Check data types of input args. | |||||
| Args: | |||||
| args (:class:`mindspore.dtype`): data type of inputs. | |||||
| Return: | |||||
| None. | |||||
| """ | |||||
| return None | |||||
| def __check__(self, *args): | |||||
| """Check shape, type, and value at the same time by using dictionary as arguments.""" | |||||
| tracks = ['dtype', 'shape'] | |||||
| for track in tracks: | |||||
| fn = getattr(self, 'check_' + track) | |||||
| fn(*(x[track] for x in args)) | |||||
| class PrimitiveWithInfer(Primitive): | class PrimitiveWithInfer(Primitive): | ||||
| """ | """ | ||||
| PrimitiveWithInfer is the base class of primitives in python defines functions for tracking inference in python. | PrimitiveWithInfer is the base class of primitives in python defines functions for tracking inference in python. | ||||
| @@ -306,6 +384,18 @@ class PrimitiveWithInfer(Primitive): | |||||
| if not is_graph_mode: | if not is_graph_mode: | ||||
| return out | return out | ||||
| # output does not contain dynamic shape, no need to calculate min/max shape | |||||
| def has_dynamic_shape(shp): | |||||
| if isinstance(shp, int): | |||||
| return shp < 0 | |||||
| if isinstance(shp, (list, tuple)): | |||||
| return any(has_dynamic_shape(e) for e in shp) | |||||
| return False | |||||
| if not has_dynamic_shape(out['shape']): | |||||
| return out | |||||
| # calculate min/max shape for output | |||||
| def get_specified_shape(elems, attr): | def get_specified_shape(elems, attr): | ||||
| has_specified_shape = False | has_specified_shape = False | ||||
| ret_vals = [] | ret_vals = [] | ||||
| @@ -345,6 +435,8 @@ def prim_attr_register(fn): | |||||
| def deco(self, *args, **kwargs): | def deco(self, *args, **kwargs): | ||||
| if isinstance(self, PrimitiveWithInfer): | if isinstance(self, PrimitiveWithInfer): | ||||
| PrimitiveWithInfer.__init__(self, self.__class__.__name__) | PrimitiveWithInfer.__init__(self, self.__class__.__name__) | ||||
| elif isinstance(self, PrimitiveWithCheck): | |||||
| PrimitiveWithCheck.__init__(self, self.__class__.__name__) | |||||
| else: | else: | ||||
| Primitive.__init__(self, self.__class__.__name__) | Primitive.__init__(self, self.__class__.__name__) | ||||
| bound_args = inspect.signature(fn).bind(self, *args, **kwargs) | bound_args = inspect.signature(fn).bind(self, *args, **kwargs) | ||||
| @@ -27,7 +27,7 @@ from mindspore.ops import composite as C | |||||
| from mindspore.ops import functional as F | from mindspore.ops import functional as F | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like | from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like | ||||
| from mindspore.ops.primitive import constexpr | |||||
| from mindspore.ops.primitive import constexpr, PrimitiveWithInfer, prim_attr_register | |||||
| from mindspore.ops._grad.grad_base import bprop_getters | from mindspore.ops._grad.grad_base import bprop_getters | ||||
| from mindspore import Tensor, RowTensor, context | from mindspore import Tensor, RowTensor, context | ||||
| from mindspore.common.parameter import Parameter, ParameterTuple | from mindspore.common.parameter import Parameter, ParameterTuple | ||||
| @@ -105,10 +105,31 @@ def _generate_inverse_index(x_shape, axis): | |||||
| perm = index[1:1 + axis] + (0,) + index[1 + axis:] | perm = index[1:1 + axis] + (0,) + index[1 + axis:] | ||||
| return perm | return perm | ||||
| class MySparseGatherV2(P.GatherV2): | |||||
| # pylint: disable=W0231 | |||||
| class MySparseGatherV2(PrimitiveWithInfer): | |||||
| """ | """ | ||||
| For test | For test | ||||
| """ | """ | ||||
| @prim_attr_register | |||||
| def __init__(self): | |||||
| """init index_select""" | |||||
| self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output']) | |||||
| def __infer__(self, params, indices, axis): | |||||
| validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) | |||||
| validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name) | |||||
| validator.check_subclass("axis", axis['dtype'], mstype.int_, self.name) | |||||
| axis_v = axis['value'] | |||||
| params_shp = params['shape'] | |||||
| rank = len(params_shp) | |||||
| validator.check_int_range("axis", axis_v, -rank, rank, Rel.INC_LEFT, self.name) | |||||
| if axis_v < 0: | |||||
| axis_v += rank | |||||
| out_shape = params_shp[:axis_v] + indices['shape'] + params_shp[axis_v + 1:] | |||||
| out = {'shape': out_shape, | |||||
| 'dtype': params['dtype'], | |||||
| 'value': None} | |||||
| return out | |||||
| @bprop_getters.register(MySparseGatherV2) | @bprop_getters.register(MySparseGatherV2) | ||||
| def get_bprop_sparse_gather_v2(self): | def get_bprop_sparse_gather_v2(self): | ||||
| @@ -0,0 +1,109 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ test dynamic shape """ | |||||
| from mindspore import Tensor, context, nn, Parameter | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore import dtype as mstype | |||||
| import numpy as np | |||||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=False) | |||||
| def test_sparse_apply_proximal_ada_grad(): | |||||
| class Net(nn.Cell): | |||||
| def __init__(self): | |||||
| super(Net, self).__init__() | |||||
| self.sparse_apply_proximal_adagrad = P.SparseApplyProximalAdagrad() | |||||
| self.var = Parameter(Tensor(np.random.rand(7800, 80).astype(np.float32)), name="var") | |||||
| self.accum = Parameter(Tensor(np.random.rand(7800, 80).astype(np.float32)), name="accum") | |||||
| self.lr = 0.01 | |||||
| self.l1 = 0.0 | |||||
| self.l2 = 0.0 | |||||
| def construct(self, grad, indices): | |||||
| out = self.sparse_apply_proximal_adagrad(self.var, self.accum, self.lr, self.l1, self.l2, grad, indices) | |||||
| return out[0] | |||||
| class NetWrapper(nn.Cell): | |||||
| def __init__(self): | |||||
| super(NetWrapper, self).__init__() | |||||
| self.unq = P.Unique() | |||||
| self.add = P.TensorAdd() | |||||
| self.expand_dims = P.ExpandDims() | |||||
| self.cast = P.Cast() | |||||
| self.net = Net() | |||||
| def construct(self, grad, inp): | |||||
| ids, _ = self.unq(inp) | |||||
| new_grad = self.expand_dims(ids, 1) | |||||
| new_grad = self.cast(new_grad, mstype.float32) + grad | |||||
| return self.net(new_grad, ids) | |||||
| net = NetWrapper() | |||||
| grad = Tensor(np.random.rand(1, 80).astype(np.float32)) | |||||
| indices = Tensor(np.ones([7800]), mstype.int32) | |||||
| net(grad, indices) | |||||
| def test_sparse_apply_ftrl(): | |||||
| class SparseApplyFtrlNet(nn.Cell): | |||||
| def __init__(self): | |||||
| super(SparseApplyFtrlNet, self).__init__() | |||||
| self.sparse_apply_ftrl = P.SparseApplyFtrl(lr=0.01, l1=0.0, l2=0.0, lr_power=-0.5) | |||||
| self.var = Parameter(Tensor(np.random.rand(7800, 80).astype(np.float32)), name="var") | |||||
| self.accum = Parameter(Tensor(np.random.rand(7800, 80).astype(np.float32)), name="accum") | |||||
| self.linear = Parameter(Tensor(np.random.rand(7800, 80).astype(np.float32)), name="linear") | |||||
| def construct(self, grad, indices): | |||||
| out = self.sparse_apply_ftrl(self.var, self.accum, self.linear, grad, indices) | |||||
| return out[0] | |||||
| class NetWrapper(nn.Cell): | |||||
| def __init__(self): | |||||
| super(NetWrapper, self).__init__() | |||||
| self.unq = P.Unique() | |||||
| self.add = P.TensorAdd() | |||||
| self.expand_dims = P.ExpandDims() | |||||
| self.cast = P.Cast() | |||||
| self.net = SparseApplyFtrlNet() | |||||
| def construct(self, grad, inp): | |||||
| ids, _ = self.unq(inp) | |||||
| new_grad = self.expand_dims(ids, 1) | |||||
| new_grad = self.cast(new_grad, mstype.float32) + grad | |||||
| return self.net(new_grad, ids) | |||||
| net = NetWrapper() | |||||
| grad = Tensor(np.random.rand(1, 80).astype(np.float32)) | |||||
| indices = Tensor(np.ones([7800]), mstype.int32) | |||||
| net(grad, indices) | |||||
| def test_gatherv2(): | |||||
| class Net(nn.Cell): | |||||
| def __init__(self): | |||||
| super(Net, self).__init__() | |||||
| self.unq = P.Unique() | |||||
| self.gather = P.GatherV2() | |||||
| def construct(self, x, y): | |||||
| u, _ = self.unq(y) | |||||
| z = self.gather(x, u, 0) | |||||
| return z | |||||
| x = Tensor(np.ones([20, 12], dtype=np.float32)) | |||||
| y = Tensor(np.ones([8], dtype=np.int32)) | |||||
| net = Net() | |||||
| net(x, y) | |||||