Merge pull request !5373 from zhangbuxue/fix_bug_the_const_input_is_broadened_in_PyNative_modetags/v1.0.0
| @@ -147,7 +147,7 @@ def resolve_symbol(namespace, symbol): | |||||
| resolve_ = namespace[symbol] | resolve_ = namespace[symbol] | ||||
| # list and dict is not hashable ,it can not be key for the map, just return the result | # list and dict is not hashable ,it can not be key for the map, just return the result | ||||
| if isinstance(resolve_, (list, dict)): | |||||
| if isinstance(resolve_, (tuple, list, dict)): | |||||
| return resolve_ | return resolve_ | ||||
| # dataclass may not be hashable | # dataclass may not be hashable | ||||
| @@ -642,6 +642,9 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v | |||||
| inputs.push_back(NewValueNode(prim)); | inputs.push_back(NewValueNode(prim)); | ||||
| size_t size = op_exec_info->op_inputs.size(); | size_t size = op_exec_info->op_inputs.size(); | ||||
| auto const_input_index = prim->get_const_input_indexes(); | |||||
| bool have_const_input = !const_input_index.empty(); | |||||
| bool is_const_prim = prim->is_const_prim(); | |||||
| for (size_t i = 0; i < size; i++) { | for (size_t i = 0; i < size; i++) { | ||||
| auto obj = op_exec_info->op_inputs[i]; | auto obj = op_exec_info->op_inputs[i]; | ||||
| bool op_mask = false; | bool op_mask = false; | ||||
| @@ -669,12 +672,13 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v | |||||
| abs = node->abstract(); | abs = node->abstract(); | ||||
| } | } | ||||
| MS_LOG(DEBUG) << prim->ToString() << " abs is nullptr " << (abs == nullptr) << " is_const_value " | MS_LOG(DEBUG) << prim->ToString() << " abs is nullptr " << (abs == nullptr) << " is_const_value " | ||||
| << prim->is_const_value(); | |||||
| if (abs == nullptr || prim->is_const_value()) { | |||||
| << prim->is_const_prim(); | |||||
| bool is_const_input = have_const_input && std::count(const_input_index.begin(), const_input_index.end(), i); | |||||
| if (abs == nullptr || is_const_prim || is_const_input) { | |||||
| MS_LOG(DEBUG) << "MakeCnode get node no in map" << id; | MS_LOG(DEBUG) << "MakeCnode get node no in map" << id; | ||||
| ValuePtr input_value = PyAttrValue(obj); | ValuePtr input_value = PyAttrValue(obj); | ||||
| abs = input_value->ToAbstract(); | abs = input_value->ToAbstract(); | ||||
| if (!prim->is_const_value()) { | |||||
| if (!is_const_prim && !is_const_input) { | |||||
| auto config = abstract::AbstractBase::kBroadenTensorOnly; | auto config = abstract::AbstractBase::kBroadenTensorOnly; | ||||
| abs = abs->Broaden(config); | abs = abs->Broaden(config); | ||||
| MS_LOG(DEBUG) << "broaden for " << prim->ToString() << " " << config; | MS_LOG(DEBUG) << "broaden for " << prim->ToString() << " " << config; | ||||
| @@ -885,7 +889,7 @@ py::tuple PynativeExecutor::RunOpInner(const py::args &args) { | |||||
| value_ret[0] = output["value"]; | value_ret[0] = output["value"]; | ||||
| return value_ret; | return value_ret; | ||||
| } | } | ||||
| if (op_exec_info->py_primitive->is_const_value()) { | |||||
| if (op_exec_info->py_primitive->is_const_prim()) { | |||||
| py::tuple value_ret(1); | py::tuple value_ret(1); | ||||
| value_ret[0] = ""; | value_ret[0] = ""; | ||||
| return value_ret; | return value_ret; | ||||
| @@ -1044,7 +1048,7 @@ AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, bool op_mask) { | |||||
| auto tuple = obj.cast<py::tuple>(); | auto tuple = obj.cast<py::tuple>(); | ||||
| // cell((1,2)): support not mix (scalar, tensor) | // cell((1,2)): support not mix (scalar, tensor) | ||||
| if (tuple.size() > 0 && !py::isinstance<tensor::Tensor>(tuple[0])) { | |||||
| if (!tuple.empty() && !py::isinstance<tensor::Tensor>(tuple[0])) { | |||||
| return MakeValueNode(obj, obj_id); | return MakeValueNode(obj, obj_id); | ||||
| } | } | ||||
| @@ -98,22 +98,22 @@ py::tuple check_bprop_out(const py::object &grads_obj, const py::tuple &py_args) | |||||
| << ", and the value is " << py::cast<py::str>(grads[i]) << "."; | << ", and the value is " << py::cast<py::str>(grads[i]) << "."; | ||||
| } | } | ||||
| py::tuple grad_shape = grads[i].attr("shape"); | |||||
| py::object arg_dtype = py_args[i].attr("dtype"); | |||||
| py::object grad_dtype = grads[i].attr("dtype"); | py::object grad_dtype = grads[i].attr("dtype"); | ||||
| py::tuple arg_shape = py_args[i].attr("shape"); | py::tuple arg_shape = py_args[i].attr("shape"); | ||||
| py::object arg_dtype = py_args[i].attr("dtype"); | |||||
| py::tuple grad_shape = grads[i].attr("shape"); | |||||
| if (!grad_dtype.equal(arg_dtype)) { | |||||
| MS_EXCEPTION(TypeError) << "When user defines the net bprop, the gradient of the " << i | |||||
| << "th arg should have the same dtype as the " << i << "th arg, but the " << i | |||||
| << "th arg dtype is: " << py::cast<py::str>(arg_dtype) | |||||
| << ", the gradient dtype is: " << py::cast<py::str>(grad_dtype) << "."; | |||||
| } | |||||
| if (!grad_shape.equal(arg_shape)) { | if (!grad_shape.equal(arg_shape)) { | ||||
| MS_EXCEPTION(ValueError) << "When user defines the net bprop, the gradient of the " << i | MS_EXCEPTION(ValueError) << "When user defines the net bprop, the gradient of the " << i | ||||
| << "th arg should have the same shape as the " << i << "th arg, but the " << i | << "th arg should have the same shape as the " << i << "th arg, but the " << i | ||||
| << "th arg shape is: " << py::cast<py::str>(arg_shape) | << "th arg shape is: " << py::cast<py::str>(arg_shape) | ||||
| << ", the gradient shape is: " << py::cast<py::str>(grad_shape) << "."; | << ", the gradient shape is: " << py::cast<py::str>(grad_shape) << "."; | ||||
| } | } | ||||
| if (!grad_dtype.is(arg_dtype)) { | |||||
| MS_EXCEPTION(TypeError) << "When user defines the net bprop, the gradient of the " << i | |||||
| << "th arg should have the same dtype as the " << i << "th arg, but the " << i | |||||
| << "th arg dtype is: " << py::cast<py::str>(arg_dtype) | |||||
| << ", the gradient dtype is: " << py::cast<py::str>(grad_dtype) << "."; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -239,10 +239,7 @@ py::object PrimitivePy::RunPyComputeFunction(const py::tuple &py_args) const { | |||||
| bool PrimitivePy::HasComputeFunction() const { | bool PrimitivePy::HasComputeFunction() const { | ||||
| auto func = GetComputeFunction(); | auto func = GetComputeFunction(); | ||||
| if (py::isinstance<py::none>(func)) { | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| return !py::isinstance<py::none>(func); | |||||
| } | } | ||||
| PrimitivePtr PrimitivePy::Clone() { | PrimitivePtr PrimitivePy::Clone() { | ||||
| @@ -272,7 +269,9 @@ REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) { | |||||
| .def("add_attr", &PrimitivePy::AddPyAttr, "add primitive attr") | .def("add_attr", &PrimitivePy::AddPyAttr, "add primitive attr") | ||||
| .def("get_attr_dict", &PrimitivePy::GetAttrDict, "get primitive attr") | .def("get_attr_dict", &PrimitivePy::GetAttrDict, "get primitive attr") | ||||
| .def("set_prim_type", &PrimitivePy::set_prim_type, "Set primitive type.") | .def("set_prim_type", &PrimitivePy::set_prim_type, "Set primitive type.") | ||||
| .def("set_is_const_value", &PrimitivePy::set_is_const_value, "Set primitive is const value.") | |||||
| .def("set_const_prim", &PrimitivePy::set_const_prim, "Set primitive is const.") | |||||
| .def("set_const_input_indexes", &PrimitivePy::set_const_input_indexes, | |||||
| "Set primitive const input indexes.") | |||||
| .def("set_signatures", &PrimitivePy::set_signatures, "Set primitive inputs signature.") | .def("set_signatures", &PrimitivePy::set_signatures, "Set primitive inputs signature.") | ||||
| .def("register_hook", &PrimitivePy::set_hook, "Set primitive hook function.") | .def("register_hook", &PrimitivePy::set_hook, "Set primitive hook function.") | ||||
| .def("set_instance_name", &PrimitivePy::set_instance_name, "Set primitive instance name."); | .def("set_instance_name", &PrimitivePy::set_instance_name, "Set primitive instance name."); | ||||
| @@ -32,7 +32,7 @@ Primitive::Primitive(const std::string &name, const bool is_base, const PrimType | |||||
| has_signature_(false), | has_signature_(false), | ||||
| prim_type_(prim_type), | prim_type_(prim_type), | ||||
| record_evaluate_add_attr_(false), | record_evaluate_add_attr_(false), | ||||
| is_const_value_(false), | |||||
| is_const_prim_(false), | |||||
| id_(MakeId()) {} | id_(MakeId()) {} | ||||
| Primitive::Primitive(const Primitive &prim) | Primitive::Primitive(const Primitive &prim) | ||||
| @@ -43,7 +43,7 @@ Primitive::Primitive(const Primitive &prim) | |||||
| has_signature_(prim.has_signature_), | has_signature_(prim.has_signature_), | ||||
| prim_type_(prim.prim_type_), | prim_type_(prim.prim_type_), | ||||
| record_evaluate_add_attr_(false), | record_evaluate_add_attr_(false), | ||||
| is_const_value_(false), | |||||
| is_const_prim_(false), | |||||
| id_(prim.id_) {} | id_(prim.id_) {} | ||||
| abstract::AbstractBasePtr Primitive::ToAbstract() { | abstract::AbstractBasePtr Primitive::ToAbstract() { | ||||
| @@ -109,8 +109,12 @@ class Primitive : public Named { | |||||
| bool is_base() const { return is_base_; } | bool is_base() const { return is_base_; } | ||||
| virtual BaseRef RunHookFunction(const VectorRef &args) const { MS_LOG(EXCEPTION) << "call a empty function!"; } | virtual BaseRef RunHookFunction(const VectorRef &args) const { MS_LOG(EXCEPTION) << "call a empty function!"; } | ||||
| virtual void CopyHookFunction(const PrimitivePtr &primitive) { MS_LOG(EXCEPTION) << "call a empty function!"; } | virtual void CopyHookFunction(const PrimitivePtr &primitive) { MS_LOG(EXCEPTION) << "call a empty function!"; } | ||||
| void set_is_const_value(bool value) { is_const_value_ = value; } | |||||
| bool is_const_value() const { return is_const_value_; } | |||||
| void set_const_prim(bool is_const_prim) { is_const_prim_ = is_const_prim; } | |||||
| bool is_const_prim() const { return is_const_prim_; } | |||||
| void set_const_input_indexes(const std::vector<size_t> &const_input_indexes) { | |||||
| const_input_indexes_ = const_input_indexes; | |||||
| } | |||||
| std::vector<size_t> &get_const_input_indexes() { return const_input_indexes_; } | |||||
| std::string id() const { return id_; } | std::string id() const { return id_; } | ||||
| protected: | protected: | ||||
| @@ -123,7 +127,8 @@ class Primitive : public Named { | |||||
| bool has_signature_; | bool has_signature_; | ||||
| PrimType prim_type_; | PrimType prim_type_; | ||||
| bool record_evaluate_add_attr_; | bool record_evaluate_add_attr_; | ||||
| bool is_const_value_; | |||||
| bool is_const_prim_; | |||||
| std::vector<size_t> const_input_indexes_; | |||||
| std::string id_{""}; | std::string id_{""}; | ||||
| }; | }; | ||||
| @@ -28,7 +28,7 @@ hastype = Primitive('hastype') | |||||
| cast = P.Cast() | cast = P.Cast() | ||||
| dtype = P.DType() | dtype = P.DType() | ||||
| isconstant = Primitive('is_constant') | isconstant = Primitive('is_constant') | ||||
| isconstant.set_is_const_value(True) | |||||
| isconstant.set_const_prim(True) | |||||
| issubclass_ = P.IsSubClass() | issubclass_ = P.IsSubClass() | ||||
| isinstance_ = P.IsInstance() | isinstance_ = P.IsInstance() | ||||
| @@ -1089,7 +1089,7 @@ class InvertPermutation(PrimitiveWithInfer): | |||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self): | def __init__(self): | ||||
| """init InvertPermutation""" | """init InvertPermutation""" | ||||
| self.set_is_const_value(True) | |||||
| self.set_const_prim(True) | |||||
| def __infer__(self, x): | def __infer__(self, x): | ||||
| x_shp = x['shape'] | x_shp = x['shape'] | ||||
| @@ -2873,6 +2873,7 @@ class MirrorPad(PrimitiveWithInfer): | |||||
| """Init Pad""" | """Init Pad""" | ||||
| validator.check_string('mode', mode, ['REFLECT', 'SYMMETRIC'], self.name) | validator.check_string('mode', mode, ['REFLECT', 'SYMMETRIC'], self.name) | ||||
| self.mode = mode | self.mode = mode | ||||
| self.set_const_input_indexes([1]) | |||||
| def __infer__(self, input_x, paddings): | def __infer__(self, input_x, paddings): | ||||
| validator.check_subclass("input_x", input_x['dtype'], mstype.tensor, self.name) | validator.check_subclass("input_x", input_x['dtype'], mstype.tensor, self.name) | ||||
| @@ -390,7 +390,7 @@ def constexpr(fn=None, get_instance=True, name=None): | |||||
| def __init__(self): | def __init__(self): | ||||
| op_name = name if name else fn.__name__ | op_name = name if name else fn.__name__ | ||||
| PrimitiveWithInfer.__init__(self, op_name) | PrimitiveWithInfer.__init__(self, op_name) | ||||
| self.set_is_const_value(True) | |||||
| self.set_const_prim(True) | |||||
| def infer_value(self, *args): | def infer_value(self, *args): | ||||
| return fn(*args) | return fn(*args) | ||||