Merge pull request !3271 from vlne-v1/ref_demotags/v1.0.0
| @@ -185,14 +185,23 @@ class Validator: | |||
| raise TypeError(f"{msg_prefix} `{arg_name}` must be float.") | |||
| @staticmethod | |||
| def check_subclass(arg_name, type_, template_type, prim_name): | |||
| def check_subclass(arg_name, type_, template_types, prim_name): | |||
| """Checks whether some type is subclass of another type""" | |||
| if not isinstance(template_type, Iterable): | |||
| template_type = (template_type,) | |||
| if not any([mstype.issubclass_(type_, x) for x in template_type]): | |||
| if not isinstance(template_types, Iterable): | |||
| template_types = (template_types,) | |||
| hit = False | |||
| for template_type in template_types: | |||
| if isinstance(template_type, mstype.Type): | |||
| if mstype.issubclass_(type_, template_type): | |||
| hit = True | |||
| break | |||
| elif type_ is template_type: | |||
| hit = True | |||
| break | |||
| if not hit: | |||
| type_str = (type(type_).__name__ if isinstance(type_, (tuple, list)) else "") + str(type_) | |||
| raise TypeError(f'For \'{prim_name}\' the type of `{arg_name}` should be subclass' | |||
| f' of {",".join((str(x) for x in template_type))}, but got {type_str}.') | |||
| f' of {",".join((str(x) for x in template_types))}, but got {type_str}.') | |||
| @staticmethod | |||
| def check_const_input(arg_name, arg_value, prim_name): | |||
| @@ -206,13 +215,7 @@ class Validator: | |||
| def _check_tensor_type(arg): | |||
| arg_key, arg_val = arg | |||
| elem_type = arg_val | |||
| if not elem_type in valid_values: | |||
| type_names = [] | |||
| for t in valid_values: | |||
| type_names.append(str(t)) | |||
| types_info = '[' + ', '.join(type_names) + ']' | |||
| raise TypeError(f'For \'{prim_name}\' type of `{arg_key}` should be in {types_info},' | |||
| f' but got {elem_type}.') | |||
| Validator.check_subclass(arg_key, elem_type, valid_values, prim_name) | |||
| return (arg_key, elem_type) | |||
| def _check_types_same(arg1, arg2): | |||
| @@ -335,12 +338,6 @@ class Validator: | |||
| class ParamValidator: | |||
| """Parameter validator. NOTICE: this class will be replaced by `class Validator`""" | |||
| @staticmethod | |||
| def equal(arg_name, arg_value, cond_str, cond): | |||
| """Judging valid value.""" | |||
| if not cond: | |||
| raise ValueError(f'The `{arg_name}` must be {cond_str}, but got {arg_value}.') | |||
| @staticmethod | |||
| def check(arg_name, arg_value, value_name, value, rel=Rel.EQ): | |||
| """This method is only used for check int values, since when compare float values, | |||
| @@ -360,27 +357,6 @@ class ParamValidator: | |||
| raise ValueError(f'The `{arg_name}` should be an int and must {rel_str}, but got {arg_value}.') | |||
| return arg_value | |||
| @staticmethod | |||
| def check_shape_length(arg_name, arg_value, value, rel): | |||
| """Shape length judgment.""" | |||
| rel_fn = Rel.get_fns(rel) | |||
| type_mismatch = not isinstance(arg_value, int) | |||
| if type_mismatch or not rel_fn(arg_value, value): | |||
| rel_str = Rel.get_strs(rel).format(value) | |||
| raise ValueError(f'The length of `{arg_name}` should be an int and must {rel_str}, but got {arg_value}') | |||
| return arg_value | |||
| @staticmethod | |||
| def check_int_range(arg_name, arg_value, lower_limit, upper_limit, rel): | |||
| """This method is only used for check int values, | |||
| since when compare float values, we need consider float error.""" | |||
| rel_fn = Rel.get_fns(rel) | |||
| type_mismatch = not isinstance(arg_value, int) | |||
| if type_mismatch or not rel_fn(arg_value, lower_limit, upper_limit): | |||
| rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit) | |||
| raise ValueError(f'The `{arg_name}` should be an int in range {rel_str}, but got {arg_value}.') | |||
| return arg_value | |||
| @staticmethod | |||
| def check_isinstance(arg_name, arg_value, classes): | |||
| """Check arg isinstance of classes""" | |||
| @@ -388,33 +364,6 @@ class ParamValidator: | |||
| raise ValueError(f'The `{arg_name}` should be isinstance of {classes}, but got {arg_value}.') | |||
| return arg_value | |||
| @staticmethod | |||
| def check_number_range(arg_name, arg_value, lower_limit, upper_limit, rel): | |||
| """Is it necessary to consider error when comparing float values.""" | |||
| rel_fn = Rel.get_fns(rel) | |||
| if not rel_fn(arg_value, lower_limit, upper_limit): | |||
| rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit) | |||
| raise ValueError(f'The `{arg_name}` should be in range {rel_str}, but got {arg_value}.') | |||
| return arg_value | |||
| @staticmethod | |||
| def check_subclass(arg_name, type_, template_type, with_type_of=True): | |||
| """Check whether some type is subclass of another type""" | |||
| if not isinstance(template_type, Iterable): | |||
| template_type = (template_type,) | |||
| if not any([mstype.issubclass_(type_, x) for x in template_type]): | |||
| type_str = (type(type_).__name__ if isinstance(type_, (tuple, list)) else "") + str(type_) | |||
| raise TypeError(f'The {"type of" if with_type_of else ""} `{arg_name}` should be subclass' | |||
| f' of {",".join((str(x) for x in template_type))}, but got {type_str}.') | |||
| @staticmethod | |||
| def check_args_tensor(args): | |||
| """Check whether args are all tensor.""" | |||
| if not isinstance(args, dict): | |||
| raise TypeError("The args should be a dict.") | |||
| for arg, value in args.items(): | |||
| ParamValidator.check_subclass(arg, value, mstype.tensor) | |||
| @staticmethod | |||
| def check_bool(arg_name, arg_value): | |||
| """Check arg isinstance of bool""" | |||
| @@ -442,113 +391,6 @@ class ParamValidator: | |||
| return arg_value | |||
| raise_error_msg() | |||
| @staticmethod | |||
| def check_typename(arg_name, arg_type, valid_types): | |||
| """Does it contain the _name_ attribute.""" | |||
| def get_typename(t): | |||
| return t.__name__ if hasattr(t, '__name__') else str(t) | |||
| if isinstance(arg_type, type(mstype.tensor)): | |||
| arg_type = arg_type.element_type() | |||
| if arg_type in valid_types: | |||
| return arg_type | |||
| type_names = [get_typename(t) for t in valid_types] | |||
| if len(valid_types) == 1: | |||
| raise ValueError(f'The type of `{arg_name}` should be {type_names[0]},' | |||
| f' but got {get_typename(arg_type)}.') | |||
| raise ValueError(f'The type of `{arg_name}` should be one of {type_names},' | |||
| f' but got {get_typename(arg_type)}.') | |||
| @staticmethod | |||
| def check_string(arg_name, arg_value, valid_values): | |||
| """String type judgment.""" | |||
| if isinstance(arg_value, str) and arg_value in valid_values: | |||
| return arg_value | |||
| if len(valid_values) == 1: | |||
| raise ValueError(f'The `{arg_name}` should be str and must be {valid_values[0]},' | |||
| f' but got {arg_value}.') | |||
| raise ValueError(f'The `{arg_name}` should be str and must be one of {valid_values},' | |||
| f' but got {arg_value}.') | |||
| @staticmethod | |||
| def check_type_same(args, valid_values): | |||
| """Determine whether the types are the same.""" | |||
| name = list(args.keys())[0] | |||
| value = list(args.values())[0] | |||
| if isinstance(value, type(mstype.tensor)): | |||
| value = value.element_type() | |||
| for arg_name, arg_value in args.items(): | |||
| if isinstance(arg_value, type(mstype.tensor)): | |||
| arg_value = arg_value.element_type() | |||
| if arg_value not in valid_values: | |||
| raise TypeError(f'The `{arg_name}` should be in {valid_values},' | |||
| f' but `{arg_name}` is {arg_value}.') | |||
| if arg_value != value: | |||
| raise TypeError(f'`{arg_name}` should be same as `{name}`,' | |||
| f' but `{arg_name}` is {arg_value}, `{name}` is {value}.') | |||
| @staticmethod | |||
| def check_two_types_same(arg1_name, arg1_type, arg2_name, arg2_type): | |||
| """Determine whether the types of two variables are the same.""" | |||
| if arg1_type != arg2_type: | |||
| raise TypeError(f'The type of `{arg1_name}` and `{arg2_name}` should be same.') | |||
| @staticmethod | |||
| def check_value_on_integer(arg_name, arg_value, value, rel): | |||
| """Judging integer type.""" | |||
| rel_fn = Rel.get_fns(rel) | |||
| type_match = isinstance(arg_value, int) | |||
| if type_match and (not rel_fn(arg_value, value)): | |||
| rel_str = Rel.get_strs(rel).format(value) | |||
| raise ValueError(f'The `{arg_name}` should be an int and must {rel_str}, but got {arg_value}.') | |||
| return arg_value | |||
| @staticmethod | |||
| def check_param_equal(param1_name, param1_value, param2_name, param2_value): | |||
| """Judging the equality of parameters.""" | |||
| if param1_value != param2_value: | |||
| raise ValueError(f"`{param1_name}` must equal `{param2_name}`," | |||
| f" but got `{param1_name}` = {param1_value}," | |||
| f" `{param2_name}` = {param2_value}.") | |||
| @staticmethod | |||
| def check_const_input(arg_name, arg_value): | |||
| """Check valid value.""" | |||
| if arg_value is None: | |||
| raise ValueError(f'The `{arg_name}` must be a const input, but got {arg_value}.') | |||
| @staticmethod | |||
| def check_float_positive(arg_name, arg_value): | |||
| """Float type judgment.""" | |||
| if isinstance(arg_value, float): | |||
| if arg_value > 0: | |||
| return arg_value | |||
| raise ValueError(f"The `{arg_name}` must be positive, but got {arg_value}.") | |||
| raise TypeError(f"`{arg_name}` must be float!") | |||
| @staticmethod | |||
| def check_pad_value_by_mode(op_name, pad_mode, padding): | |||
| """Validate value of padding according to pad_mode""" | |||
| if pad_mode != 'pad' and padding != 0: | |||
| raise ValueError(f"For op '{op_name}', padding must be zero when pad_mode is '{pad_mode}'.") | |||
| return padding | |||
| @staticmethod | |||
| def check_empty_shape_input(arg_name, arg_value): | |||
| """Check zeros value.""" | |||
| if 0 in arg_value: | |||
| raise ValueError(f"Input `{arg_name}` cannot be empty.") | |||
| @staticmethod | |||
| def check_scalar_shape_input(arg_name, arg_value): | |||
| """Check scalar shape input.""" | |||
| if arg_value != []: | |||
| raise ValueError(f"Input `{arg_name}` shape should be (). got {arg_value}") | |||
| def check_int(input_param): | |||
| """Int type judgment.""" | |||
| @@ -592,11 +592,6 @@ TypeId AnfRuntimeAlgorithm::GetOutputInferDataType(const AnfNodePtr &node, size_ | |||
| return get_single_type((*tuple_ptr)[output_idx]); | |||
| }; | |||
| TypePtr type_ptr = node->Type(); | |||
| if (type_ptr->isa<RefType>()) { | |||
| auto ref_type_ptr = type_ptr->cast<RefTypePtr>(); | |||
| MS_EXCEPTION_IF_NULL(ref_type_ptr); | |||
| return get_tuple_type(ref_type_ptr->subtype(), output_idx); | |||
| } | |||
| return get_tuple_type(type_ptr, output_idx); | |||
| } | |||
| @@ -20,6 +20,7 @@ | |||
| #include "abstract/abstract_value.h" | |||
| #include "ir/anf.h" | |||
| #include "ir/dtype.h" | |||
| #include "abstract/dshape.h" | |||
| #include "abstract/param_validator.h" | |||
| #include "frontend/operator/cc_implementations.h" | |||
| @@ -43,15 +44,15 @@ const std::vector<Signature> &GetSignature(const ValuePtr &function) { | |||
| return empty; | |||
| } | |||
| void ProcessDefault(const std::string &func_name, const AbstractBasePtrList &args_spec_list, | |||
| const std::vector<Signature> &signature, bool has_var, std::vector<AnfNodePtr> *const op_inputs) { | |||
| void ProcessDefault(const std::string &func_name, size_t actual_param_number, const std::vector<Signature> &signature, | |||
| bool has_var, std::vector<AnfNodePtr> *const op_inputs) { | |||
| std::size_t sig_size = signature.size(); | |||
| auto positional_size = sig_size; | |||
| if (has_var) { | |||
| positional_size = sig_size - 1; | |||
| } | |||
| if (args_spec_list.size() < positional_size) { | |||
| for (size_t i = args_spec_list.size(); i < sig_size; ++i) { | |||
| if (actual_param_number < positional_size) { | |||
| for (size_t i = actual_param_number; i < sig_size; ++i) { | |||
| auto default_value = signature[i].default_value; | |||
| if (default_value == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Function " << func_name << "'s input length is not equal to Signature length."; | |||
| @@ -67,23 +68,11 @@ void SetMaxType(TypeId *max_type_id, size_t *max_type_number, const TypeId type_ | |||
| *max_type_number = type_number; | |||
| } | |||
| bool GetTensorOrScalarTypeInfo(AbstractBasePtr arg_value, bool is_write, TypeId *arg_type_id, | |||
| bool GetTensorOrScalarTypeInfo(TypePtr arg_type_origin, bool is_write, TypeId *arg_type_id, | |||
| TypeId *arg_type = nullptr) { | |||
| if (arg_value->isa<abstract::AbstractRef>()) { | |||
| auto ref = arg_value->cast<abstract::AbstractRefPtr>(); | |||
| arg_value = ref->ref(); | |||
| if (!is_write && ref->need_cast()) { | |||
| auto tensor_type = ref->target_type(); | |||
| *arg_type_id = tensor_type->type_id(); | |||
| if (arg_type != nullptr) { | |||
| *arg_type = kObjectTypeTensorType; | |||
| } | |||
| return true; | |||
| } | |||
| } | |||
| if (arg_value->isa<abstract::AbstractTensor>()) { | |||
| auto tensor = arg_value->cast<abstract::AbstractTensorPtr>(); | |||
| auto tensor_type = tensor->element()->BuildType(); | |||
| if (arg_type_origin->isa<TensorType>()) { | |||
| auto tensor = arg_type_origin->cast<TensorTypePtr>(); | |||
| auto tensor_type = tensor->element(); | |||
| MS_EXCEPTION_IF_NULL(tensor_type); | |||
| *arg_type_id = tensor_type->type_id(); | |||
| if (arg_type != nullptr) { | |||
| @@ -91,9 +80,8 @@ bool GetTensorOrScalarTypeInfo(AbstractBasePtr arg_value, bool is_write, TypeId | |||
| } | |||
| return true; | |||
| } | |||
| if (arg_value->isa<abstract::AbstractScalar>()) { | |||
| auto scalar = arg_value->cast<abstract::AbstractScalarPtr>(); | |||
| auto scalar_type = scalar->BuildType(); | |||
| if (arg_type_origin->isa<Number>()) { | |||
| auto scalar_type = arg_type_origin->cast<NumberPtr>(); | |||
| MS_EXCEPTION_IF_NULL(scalar_type); | |||
| *arg_type_id = scalar_type->type_id(); | |||
| if (arg_type != nullptr) { | |||
| @@ -104,7 +92,7 @@ bool GetTensorOrScalarTypeInfo(AbstractBasePtr arg_value, bool is_write, TypeId | |||
| return false; | |||
| } | |||
| TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::vector<size_t> indices, | |||
| TypeId GetMaxTypeId(const std::vector<TypePtr> &input_types, std::vector<size_t> indices, | |||
| const std::set<size_t> &write_indices) { | |||
| TypeId max_type_id = kTypeUnknown; | |||
| size_t max_type_number = 0; | |||
| @@ -115,7 +103,7 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve | |||
| TypeId arg_type_id = kTypeUnknown; | |||
| TypeId arg_type = kTypeUnknown; | |||
| auto is_write = (write_indices.find(index) != write_indices.end()); | |||
| if (!GetTensorOrScalarTypeInfo(args_spec_list[index], is_write, &arg_type_id, &arg_type)) { | |||
| if (!GetTensorOrScalarTypeInfo(input_types[index], is_write, &arg_type_id, &arg_type)) { | |||
| continue; | |||
| } | |||
| if (arg_type != kObjectTypeTensorType) { | |||
| @@ -161,8 +149,8 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve | |||
| // Get the largest type of index in the same SignatureEnumDType of arguments. | |||
| using MaxTypeMap = std::map<SignatureEnumDType, TypeId>; | |||
| MaxTypeMap GetMaxDtype(const std::vector<SignatureEnumDType> &dtypes, | |||
| const abstract::AbstractBasePtrList &args_spec_list, const std::set<size_t> &write_indices) { | |||
| MaxTypeMap GetMaxDtype(const std::vector<SignatureEnumDType> &dtypes, const std::vector<TypePtr> &input_types, | |||
| const std::set<size_t> &write_indices) { | |||
| // record index for signature.dtypes of the same type | |||
| // eg. [T, T1, T, T2, T, T1, T3] -> {{T:(0,2,4)}, {T1:(1,5)}, {T2:(3)}, {T3:(6)}} | |||
| std::map<SignatureEnumDType, std::vector<size_t>> type_indices; | |||
| @@ -184,11 +172,8 @@ MaxTypeMap GetMaxDtype(const std::vector<SignatureEnumDType> &dtypes, | |||
| } | |||
| bool has_tensor = false; | |||
| for (const auto &index : indices) { | |||
| AbstractBasePtr arg_value = args_spec_list[index]; | |||
| if (arg_value->isa<abstract::AbstractRef>()) { | |||
| arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref(); | |||
| } | |||
| if (arg_value->isa<abstract::AbstractTensor>()) { | |||
| auto arg_value = input_types[index]; | |||
| if (arg_value->isa<TensorType>()) { | |||
| has_tensor = true; | |||
| break; | |||
| } | |||
| @@ -197,7 +182,7 @@ MaxTypeMap GetMaxDtype(const std::vector<SignatureEnumDType> &dtypes, | |||
| (void)dst_type.insert(std::make_pair(type, kTypeUnknown)); | |||
| continue; | |||
| } | |||
| (void)dst_type.insert(std::make_pair(type, GetMaxTypeId(args_spec_list, indices, write_indices))); | |||
| (void)dst_type.insert(std::make_pair(type, GetMaxTypeId(input_types, indices, write_indices))); | |||
| } | |||
| return dst_type; | |||
| } | |||
| @@ -211,7 +196,7 @@ AnfNodePtr DoCast(const AnfNodePtr ¶m, const TypeId &type_id, const FuncGrap | |||
| } | |||
| void DoAutoCast(const std::string &func_name, const std::vector<Signature> &signature, | |||
| const abstract::AbstractBasePtrList &args_spec_list, const FuncGraphPtr &graph, | |||
| const std::vector<TypePtr> &input_types, const FuncGraphPtr &graph, | |||
| std::vector<AnfNodePtr> *const op_inputs, const std::set<size_t> &write_indices) { | |||
| std::vector<SignatureEnumDType> dtypes; | |||
| (void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes), | |||
| @@ -221,9 +206,9 @@ void DoAutoCast(const std::string &func_name, const std::vector<Signature> &sign | |||
| return; | |||
| } | |||
| // Stat the index of the arguments with the largest type in the same SignatureEnumDType. | |||
| std::map<SignatureEnumDType, TypeId> dst_type = GetMaxDtype(dtypes, args_spec_list, write_indices); | |||
| std::map<SignatureEnumDType, TypeId> dst_type = GetMaxDtype(dtypes, input_types, write_indices); | |||
| // Identify which arg requires auto cast | |||
| for (size_t i = 0; i < args_spec_list.size(); ++i) { | |||
| for (size_t i = 0; i < input_types.size(); ++i) { | |||
| auto it = dst_type.find(dtypes[i]); | |||
| if (it == dst_type.end() || it->second == kTypeUnknown) { | |||
| continue; | |||
| @@ -232,7 +217,7 @@ void DoAutoCast(const std::string &func_name, const std::vector<Signature> &sign | |||
| auto is_write = (rw_it != write_indices.end()); | |||
| TypeId arg_type_id = kTypeUnknown; | |||
| AbstractBasePtr arg_value = args_spec_list[i]; | |||
| auto arg_value = input_types[i]; | |||
| (void)GetTensorOrScalarTypeInfo(arg_value, is_write, &arg_type_id); | |||
| auto it_map = type_name_map.find(arg_type_id); | |||
| if (it_map == type_name_map.end()) { | |||
| @@ -248,7 +233,7 @@ void DoAutoCast(const std::string &func_name, const std::vector<Signature> &sign | |||
| } | |||
| continue; | |||
| } | |||
| if (arg_value->isa<abstract::AbstractTensor>() && arg_type_id == it->second) { | |||
| if ((arg_value->isa<TensorType>()) && arg_type_id == it->second) { | |||
| continue; | |||
| } | |||
| MS_LOG(DEBUG) << "do cast for inputs " << i << " " << (*op_inputs)[i + 1]->ToString() << " " << arg_type_id | |||
| @@ -275,6 +260,7 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func | |||
| } | |||
| std::vector<AnfNodePtr> op_inputs; | |||
| std::set<size_t> write_indices; | |||
| std::vector<TypePtr> input_types; | |||
| op_inputs.push_back(NewValueNode(function)); | |||
| // Assume, the write input of op is always the first input. We check if any write op, | |||
| // and add cast op on other inputs to keep the same type with assigned parameter. | |||
| @@ -292,30 +278,36 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func | |||
| sig = signature[sig_size - 1].rw; | |||
| } | |||
| TypePtr type = args_spec_list[i]->GetTypeTrack(); | |||
| if (type && type->type_id() == kObjectTypeRef) { | |||
| auto ref_abs = args_spec_list[i]->cast<abstract::AbstractRefPtr>(); | |||
| TypePtr type = args_spec_list[i]->BuildType(); | |||
| if (type && type->isa<RefType>()) { | |||
| auto cast_type = parse::GetMixedPrecisionTargetType(func_graph); | |||
| if (sig == SignatureEnumRW::kRWRead) { | |||
| param = NewCNode({NewValueNode(prim::kPrimGetRefValue), param}, func_graph); | |||
| if (ref_abs && ref_abs->need_cast()) { | |||
| auto cast = prim::GetPythonOps("cast", "mindspore.ops.functional"); | |||
| param = NewCNode({NewValueNode(cast), param, NewValueNode(ref_abs->target_type())}, func_graph); | |||
| auto source_tensor_type = type->cast<TensorTypePtr>(); | |||
| if (source_tensor_type != nullptr) { | |||
| auto source_element = source_tensor_type->element(); | |||
| if (cast_type != nullptr && IsSubType(source_element, kFloat) && *source_element != *cast_type) { | |||
| auto cast = prim::GetPythonOps("cast", "mindspore.ops.functional"); | |||
| param = NewCNode({NewValueNode(cast), param, NewValueNode(cast_type)}, func_graph); | |||
| type = cast_type->type_id() == kNumberTypeFloat16 ? kTensorTypeFP16 : kTensorTypeFP32; | |||
| } | |||
| } | |||
| } else if (sig == SignatureEnumRW::kRWWrite) { | |||
| param = NewCNode({NewValueNode(prim::kPrimGetRefValue), param}, func_graph); | |||
| write_indices.insert(i); | |||
| } | |||
| // If sig is SignatureEnumRW::kRWRef, not do anything. | |||
| } else if (sig == SignatureEnumRW::kRWWrite && type->type_id() != kObjectTypeRefKey) { | |||
| MS_EXCEPTION(TypeError) << "Function " << func_name << "'s input " << i << " should be a Parameter."; | |||
| } else if (sig == SignatureEnumRW::kRWWrite && | |||
| !((type->type_id() == kObjectTypeRef) || (type->type_id() == kObjectTypeRefKey))) { | |||
| MS_EXCEPTION(TypeError) << "Function " << func_name << "'s input " << i << " should be a Parameter, but " | |||
| << type->ToString(); | |||
| } | |||
| MS_LOG(DEBUG) << "Function " << func_name << "'s input " << i << " " << param->DebugString(2) << " type " | |||
| << args_spec_list[i]->ToString(); | |||
| input_types.push_back(type); | |||
| op_inputs.push_back(param); | |||
| } | |||
| // process default | |||
| ProcessDefault(func_name, args_spec_list, signature, has_var, &op_inputs); | |||
| DoAutoCast(func_name, signature, args_spec_list, func_graph, &op_inputs, write_indices); | |||
| ProcessDefault(func_name, args_spec_list.size(), signature, has_var, &op_inputs); | |||
| DoAutoCast(func_name, signature, input_types, func_graph, &op_inputs, write_indices); | |||
| return func_graph->NewCNode(op_inputs); | |||
| } | |||
| } // namespace | |||
| @@ -81,12 +81,6 @@ void MultitypeFuncGraph::PyRegister(const py::tuple &tuple, const py::function & | |||
| } | |||
| Register(types_name, py_fn); | |||
| } | |||
| static TypePtr UnwrapRef(const TypePtr &type) { | |||
| if (type->isa<RefType>()) { | |||
| return type->cast<RefTypePtr>()->subtype(); | |||
| } | |||
| return type; | |||
| } | |||
| // Return Exact match if exists, else return non ambiguous sub class match | |||
| // Return py::none() if matching is ambiguous | |||
| @@ -99,7 +93,7 @@ const py::function MultitypeFuncGraph::SignMatch(const TypePtrList &types) { | |||
| } | |||
| auto match = true; | |||
| for (size_t i = 0; i < sign.size(); ++i) { | |||
| if (!IsIdentidityOrSubclass(UnwrapRef(types[i]), sign[i])) { | |||
| if (!IsIdentidityOrSubclass(types[i], sign[i])) { | |||
| match = false; | |||
| break; | |||
| } | |||
| @@ -627,6 +627,16 @@ AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePt | |||
| return std::make_shared<AbstractClass>(cls->tag(), abs_attributes, cls->methods()); | |||
| } | |||
| AbstractBasePtr InferImplAssign(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: a tensor | |||
| CheckArgsSize(primitive->name(), args_spec_list, 2); | |||
| MS_LOG(DEBUG) << "InferImplAssign " << args_spec_list[0]; | |||
| return args_spec_list[0]; | |||
| } | |||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(TypeOf, prim::kPrimTypeOf, InferImplTypeof); | |||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(HasType, prim::kPrimHasType, InferImplHasType); | |||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(MakeRecord, prim::kPrimMakeRecord, InferImplMakeRecord); | |||
| @@ -648,5 +658,7 @@ REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(FakeBprop, prim::kPrimFakeBprop, InferImpl | |||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(J, prim::kPrimJ, InferImplJ); | |||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(BroadcastGradientArgs, prim::kPrimBroadcastGradientArgs, | |||
| InferImplBroadcastGradientArgs); | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(Assign, prim::kPrimAssign, InferImplAssign); | |||
| } // namespace abstract | |||
| } // namespace mindspore | |||
| @@ -20,6 +20,7 @@ | |||
| #include "ir/anf.h" | |||
| #include "ir/param_info.h" | |||
| #include "ir/meta_tensor.h" | |||
| #include "pipeline/jit/parse/python_adapter.h" | |||
| namespace mindspore { | |||
| @@ -38,8 +39,7 @@ bool ParameterRequireGrad(const AnfNodePtr &node_ptr) { | |||
| if (!para_ptr->has_default()) { | |||
| return false; | |||
| } | |||
| auto obj = py::cast(para_ptr->default_param()); | |||
| auto param_value = py::cast<ParamValuePtr>(obj.attr("_value")); | |||
| auto param_value = para_ptr->param_info(); | |||
| if (param_value == nullptr) { | |||
| return false; | |||
| } | |||
| @@ -1356,8 +1356,7 @@ bool ParameterIsCloned(const AnfNodePtr ¶meter_node) { | |||
| if (!cloned_parameter->has_default()) { | |||
| return false; | |||
| } | |||
| auto obj = py::cast(cloned_parameter->default_param()); | |||
| auto param_value = py::cast<ParamValuePtr>(obj.attr("_value")); | |||
| auto param_value = cloned_parameter->param_info(); | |||
| if (param_value == nullptr) { | |||
| return false; | |||
| } | |||
| @@ -1380,8 +1379,7 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) { | |||
| if (!ParameterIsCloned(cloned_parameter_node)) { | |||
| continue; | |||
| } | |||
| auto obj = py::cast(cloned_parameter->default_param()); | |||
| auto param_value = py::cast<ParamValuePtr>(obj.attr("_value")); | |||
| auto param_value = cloned_parameter->param_info(); | |||
| if (param_value == nullptr) { | |||
| continue; | |||
| } | |||
| @@ -1400,10 +1398,7 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) { | |||
| continue; | |||
| } | |||
| const auto ¶m_value_cloned = be_cloned_parameter->default_param(); | |||
| auto obj_in = py::cast(param_value_cloned); | |||
| auto param_value_in = py::cast<ParamValuePtr>(obj_in.attr("_value")); | |||
| auto param_value_in = be_cloned_parameter->param_info(); | |||
| if (param_value_in == nullptr) { | |||
| continue; | |||
| } | |||
| @@ -233,13 +233,14 @@ bool AbstractSpecializeAction(const ResourcePtr &res) { | |||
| for (const auto ¶m : func_graph->parameters()) { | |||
| auto param_node = std::static_pointer_cast<Parameter>(param); | |||
| if (param_node->has_default()) { | |||
| ValuePtr value = param_node->default_param(); | |||
| constexpr bool broaden = true; | |||
| AbstractBasePtr ptr = abstract::FromValue(value, broaden); | |||
| parallel::ParallelParameterContextRestoreInNoTraining(func_graph, param_node, ptr); | |||
| args_spec.push_back(ptr); | |||
| parallel::ParallelParameterContextCkptInTraining(func_graph, param_node, ptr); | |||
| auto value = param_node->default_param(); | |||
| auto abs_value = value->ToAbstract()->cast<abstract::AbstractTensorPtr>(); | |||
| auto ref_key = std::make_shared<RefKey>(param_node->name()); | |||
| auto abs_ref_key = ref_key->ToAbstract(); | |||
| auto abs_ref = std::make_shared<abstract::AbstractRef>(abs_ref_key, abs_value); | |||
| parallel::ParallelParameterContextRestoreInNoTraining(func_graph, param_node, abs_ref); | |||
| args_spec.push_back(abs_ref); | |||
| parallel::ParallelParameterContextCkptInTraining(func_graph, param_node, abs_ref); | |||
| } | |||
| } | |||
| // Analyze | |||
| @@ -425,9 +425,6 @@ bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature | |||
| converted = env; | |||
| } else if (py::hasattr(obj, PYTHON_CLASS_MEMBER_NAMESPACE)) { | |||
| converted = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, obj); | |||
| } else if (py::hasattr(obj, "__parameter__")) { | |||
| auto to_convert = py::cast<py::object>(python_adapter::GetPyObjAttr(obj, "default_input")); | |||
| ret = ConvertData(to_convert, &converted); | |||
| } else { | |||
| ret = ConvertOtherObj(obj, &converted); | |||
| } | |||
| @@ -555,9 +552,6 @@ void MakeProperNameToFuncGraph(const FuncGraphPtr &func_graph, std::string name) | |||
| ValuePtr PyDataToValue(const py::object &obj) { | |||
| py::object to_convert = obj; | |||
| if (py::hasattr(obj, "__parameter__")) { | |||
| to_convert = py::cast<py::object>(python_adapter::GetPyObjAttr(obj, "default_input")); | |||
| } | |||
| ValuePtr value = nullptr; | |||
| (void)ConvertData(to_convert, &value); | |||
| return value; | |||
| @@ -306,7 +306,14 @@ void FunctionBlock::ConditionalJump(AnfNodePtr condNode, const FunctionBlockPtr | |||
| } | |||
| void FunctionBlock::SetStateAssgin(const AnfNodePtr &target, const std::string &readid) { | |||
| state_assign_[target] = readid; | |||
| const std::string primitive_name("assign"); | |||
| const std::string module_name("mindspore.ops.functional"); | |||
| ValueNodePtr assign_op = NewValueNode(prim::GetPythonOps(primitive_name, module_name, true)); | |||
| auto source = ReadVariable(readid); | |||
| auto assign = func_graph()->NewCNode({assign_op, target, source}); | |||
| WriteVariable(readid, assign); | |||
| MS_LOG(INFO) << "SetState read " << target->DebugString() << ", " << readid; | |||
| AddAutoDepend(assign); | |||
| } | |||
| void FunctionBlock::AddAutoDepend(const AnfNodePtr &target) { auto_depends_.push_back(target); } | |||
| @@ -321,21 +328,13 @@ void FunctionBlock::InsertDependItemsBeforeReturn() { | |||
| ValueNodePtr make_tuple_op = NewValueNode(prim::kPrimMakeTuple); | |||
| ValueNodePtr depend_op = NewValueNode(prim::kPrimDepend); | |||
| ValueNodePtr stop_gradient_op = NewValueNode(prim::kPrimStopGradient); | |||
| const std::string primitive_name("assign"); | |||
| const std::string module_name("mindspore.ops.functional"); | |||
| ValueNodePtr assign_op = NewValueNode(prim::GetPythonOps(primitive_name, module_name, true)); | |||
| if (state_assign_.size() == 0 && auto_depends_.size() == 0) { | |||
| if (auto_depends_.size() == 0) { | |||
| return; | |||
| } | |||
| AnfNodePtr state = nullptr; | |||
| std::vector<AnfNodePtr> vec_states; | |||
| vec_states.emplace_back(make_tuple_op); | |||
| for (auto &item : state_assign_) { | |||
| auto source = ReadVariable(item.second); | |||
| auto assign = func_graph()->NewCNode({assign_op, item.first, source}); | |||
| MS_LOG(INFO) << "SetState read " << item.first->ToString() << ", " << item.second; | |||
| vec_states.emplace_back(assign); | |||
| } | |||
| for (auto &item : auto_depends_) { | |||
| MS_LOG(DEBUG) << "auto_depends " << item->ToString(); | |||
| vec_states.emplace_back(item); | |||
| @@ -361,7 +360,6 @@ void FunctionBlock::InsertDependItemsBeforeReturn() { | |||
| AnfNodePtr stopped = func_graph()->NewCNode({stop_gradient_op, state}); | |||
| AnfNodePtr ret = func_graph()->NewCNode({depend_op, old_ret, stopped}); | |||
| func_graph()->set_output(ret, true); | |||
| state_assign_.clear(); | |||
| } | |||
| } // namespace parse | |||
| } // namespace mindspore | |||
| @@ -101,9 +101,6 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> { | |||
| // keeps all removable phis which will be removed in one pass. | |||
| std::unordered_map<ParameterPtr, AnfNodePtr> removable_phis_; | |||
| // set state nodes need to insert before function return nodes. | |||
| OrderedMap<AnfNodePtr, std::string> state_assign_; | |||
| // hold declared global variables in function | |||
| std::set<std::string> global_vars_; | |||
| @@ -59,14 +59,13 @@ FuncGraphPtr ParsePythonCode(const py::object &obj, const std::string &python_mo | |||
| return func_graph; | |||
| } | |||
| ValuePtr GetMixedPrecisionTargetType(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m) { | |||
| TypePtr dst_type; | |||
| TypePtr GetMixedPrecisionTargetType(const FuncGraphPtr &func_graph) { | |||
| if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP32)) { | |||
| return kFloat32; | |||
| } else if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP16)) { | |||
| return kFloat16; | |||
| } else { | |||
| return kNone; | |||
| return nullptr; | |||
| } | |||
| } | |||
| @@ -364,7 +364,7 @@ class ParseAst { | |||
| bool UpdateFuncGraphFlags(py::object obj, const FuncGraphPtr &func_graph); | |||
| AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m); | |||
| ValuePtr GetMixedPrecisionTargetType(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m); | |||
| TypePtr GetMixedPrecisionTargetType(const FuncGraphPtr &func_graph); | |||
| } // namespace parse | |||
| } // namespace mindspore | |||
| @@ -105,24 +105,12 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object | |||
| auto value = py::cast<tensor::MetaTensorPtr>(obj); | |||
| node->set_default_param(value); | |||
| // set_abstract for parameter | |||
| constexpr bool broaden = true; | |||
| node->set_abstract(abstract::FromValue(value, broaden)); | |||
| auto abs = value->ToAbstract(); | |||
| node->set_abstract(abs); | |||
| para_node = node; | |||
| } | |||
| auto iter = func_graph->make_ref_params().find(para_node); | |||
| if (iter == func_graph->make_ref_params().end()) { | |||
| ValuePtr target_type = GetMixedPrecisionTargetType(func_graph, para_node); | |||
| AnfNodePtr make_ref = NewValueNode(prim::kPrimMakeRef); | |||
| AnfNodePtr ref_key = NewValueNode(std::make_shared<RefKey>(param_name)); | |||
| AnfNodePtr target_type_node = NewValueNode(target_type); | |||
| AnfNodePtr ref_node = func_graph->NewCNode({make_ref, ref_key, para_node, target_type_node}); | |||
| func_graph->make_ref_params()[para_node] = ref_node; | |||
| func_graph->add_parameter_obj_node(ref_node); | |||
| return ref_node; | |||
| } else { | |||
| return iter->second; | |||
| } | |||
| return para_node; | |||
| } | |||
| bool ResolveObjectToNode(const FuncGraphPtr &func_graph, const py::object &obj, AnfNodePtr *const node) { | |||
| @@ -640,7 +640,14 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v | |||
| size_t size = op_exec_info->op_inputs.size(); | |||
| for (size_t i = 0; i < size; i++) { | |||
| auto obj = op_exec_info->op_inputs[i]; | |||
| bool op_mask = py::hasattr(obj, "__parameter__"); | |||
| bool op_mask = false; | |||
| if (py::isinstance<tensor::MetaTensor>(obj)) { | |||
| auto meta_tensor = obj.cast<tensor::MetaTensorPtr>(); | |||
| if (meta_tensor) { | |||
| op_mask = meta_tensor->is_parameter(); | |||
| } | |||
| } | |||
| (*op_masks).push_back(op_mask); | |||
| MS_LOG(DEBUG) << "gen " << op_exec_info->op_name << " arg " << i << ": op mask " << op_mask << " grad_flag_ " | |||
| << grad_flag_; | |||
| @@ -990,8 +997,9 @@ AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, bool op_mask) { | |||
| if (graph_info_map_[df_builder_].param_map.count(obj_id) == 0) { | |||
| auto free_param = df_builder_->add_parameter(); | |||
| free_param->set_name(param_name); | |||
| free_param->set_default_param(py::cast<tensor::TensorPtr>(obj)); | |||
| free_param->debug_info()->set_name(param_name); | |||
| auto value = py::cast<tensor::TensorPtr>(obj); | |||
| free_param->set_default_param(value); | |||
| MS_LOG(DEBUG) << "Top graph set free parameter " << obj_id; | |||
| graph_info_map_[df_builder_].param_map[obj_id] = free_param; | |||
| return free_param; | |||
| @@ -1159,17 +1167,12 @@ std::vector<AnfNodePtr> PynativeExecutor::GetWeightsArgs(const py::object &weigh | |||
| auto param_name = py::cast<std::string>(name_attr); | |||
| auto free_param = df_builder_->add_parameter(); | |||
| free_param->set_name(param_name); | |||
| free_param->set_default_param(py::cast<tensor::TensorPtr>(param)); | |||
| auto value = py::cast<tensor::TensorPtr>(param); | |||
| free_param->set_default_param(value); | |||
| free_param->debug_info()->set_name(param_name); | |||
| para_node = free_param; | |||
| } | |||
| ValuePtr target_type = parse::GetMixedPrecisionTargetType(df_builder_, para_node); | |||
| AnfNodePtr make_ref = NewValueNode(prim::kPrimMakeRef); | |||
| auto refkey = std::make_shared<RefKey>(para_node->cast<ParameterPtr>()->name()); | |||
| AnfNodePtr ref_key_node = NewValueNode(refkey); | |||
| AnfNodePtr target_type_node = NewValueNode(target_type); | |||
| AnfNodePtr ref_node = df_builder_->NewCNode({make_ref, ref_key_node, para_node, target_type_node}); | |||
| w_args.push_back(ref_node); | |||
| w_args.push_back(para_node); | |||
| } | |||
| } else { | |||
| MS_LOG(DEBUG) << "training not paramter_tuple"; | |||
| @@ -1197,7 +1200,7 @@ abstract::AbstractBasePtrList PynativeExecutor::GetArgsSpec(const py::args &args | |||
| auto param_node = std::static_pointer_cast<Parameter>(param); | |||
| if (param_node->has_default()) { | |||
| ValuePtr value = param_node->default_param(); | |||
| AbstractBasePtr ptr = abstract::FromValue(value, true); | |||
| auto ptr = value->ToAbstract(); | |||
| if (ptr == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Args convert error"; | |||
| } | |||
| @@ -147,7 +147,7 @@ REGISTER_PYBIND_DEFINE( | |||
| (void)py::class_<TypeType, Type, std::shared_ptr<TypeType>>(m_sub, "TypeType").def(py::init()); | |||
| (void)py::class_<String, Type, std::shared_ptr<String>>(m_sub, "String").def(py::init()); | |||
| (void)py::class_<RefKeyType, Type, std::shared_ptr<RefKeyType>>(m_sub, "RefKeyType").def(py::init()); | |||
| (void)py::class_<RefType, Type, std::shared_ptr<RefType>>(m_sub, "RefType").def(py::init()); | |||
| (void)py::class_<RefType, TensorType, Type, std::shared_ptr<RefType>>(m_sub, "RefType").def(py::init()); | |||
| (void)py::class_<TypeAnything, Type, std::shared_ptr<TypeAnything>>(m_sub, "TypeAnything").def(py::init()); | |||
| (void)py::class_<Slice, Type, std::shared_ptr<Slice>>(m_sub, "Slice").def(py::init()); | |||
| (void)py::class_<TypeEllipsis, Type, std::shared_ptr<TypeEllipsis>>(m_sub, "TypeEllipsis").def(py::init()); | |||
| @@ -21,7 +21,7 @@ namespace mindspore { | |||
| namespace py = pybind11; | |||
| REGISTER_PYBIND_DEFINE(ParamInfo, ([](const py::module *m) { | |||
| (void)py::class_<ParamInfo, ParamValuePtr>(*m, "ParamInfo") | |||
| (void)py::class_<ParamInfo, ParamInfoPtr>(*m, "ParamInfo") | |||
| .def(py::init()) | |||
| .def("clone", &ParamInfo::Clone) | |||
| .def_property("name", &ParamInfo::name, &ParamInfo::set_name) | |||
| @@ -36,7 +36,7 @@ REGISTER_PYBIND_DEFINE(ParamInfo, ([](const py::module *m) { | |||
| if (t.size() != 6) { | |||
| std::runtime_error("Invalid state for ParamInfo!"); | |||
| } | |||
| ParamValuePtr p = std::make_shared<ParamInfo>(); | |||
| ParamInfoPtr p = std::make_shared<ParamInfo>(); | |||
| p->set_name(t[1].cast<std::string>()); | |||
| p->set_requires_grad(t[2].cast<bool>()); | |||
| p->set_layerwise_parallel(t[3].cast<bool>()); | |||
| @@ -291,6 +291,7 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) { | |||
| .def(py::init<TypePtr, const std::vector<int>>(), py::arg("dtype"), py::arg("shape")) | |||
| .def_property_readonly("dtype", &MetaTensor::Dtype, "Get the MetaTensor's dtype.") | |||
| .def_property_readonly("shape", &MetaTensor::shape, "Get the MetaTensor's shape.") | |||
| .def_property("_param_info", &MetaTensor::param_info, &MetaTensor::set_param_info) | |||
| .def(py::pickle( | |||
| [](const MetaTensor &t) { // __getstate__ | |||
| /* Return a tuple that fully encodes the state of the object */ | |||
| @@ -42,7 +42,7 @@ class Parameter(MetaTensor): | |||
| In auto_parallel mode of "semi_auto_parallel" and "auto_parallel", if init `Parameter` by | |||
| an `Initializer`, the type of Parameter will be `MetaTensor` not `Tensor`. `MetaTensor` | |||
| only saves the shape and type info of a tensor with no memory usage. The shape can be changed while | |||
| compile for auto-parallel. Call `init_data` will return a Tensor Parameter with initialized data. | |||
| compiling for auto-parallel. Call `init_data` will return a Tensor Parameter with initialized data. | |||
| Note: | |||
| Each parameter of Cell is represented by Parameter class. | |||
| @@ -108,7 +108,7 @@ class Parameter(MetaTensor): | |||
| Parameter, (data, self.name, self.requires_grad, self.layerwise_parallel)) | |||
| def __init__(self, default_input, name, requires_grad=True, layerwise_parallel=False): | |||
| self._value = ParamInfo() | |||
| self._param_info = ParamInfo() | |||
| self.name = name | |||
| self.requires_grad = requires_grad | |||
| self.layerwise_parallel = layerwise_parallel | |||
| @@ -156,13 +156,13 @@ class Parameter(MetaTensor): | |||
| value_str = MetaTensor.__str__(self) | |||
| if isinstance(self, Tensor): | |||
| value_str = Tensor.__str__(self) | |||
| return f'Parameter (name={self._value.name}, value={value_str})' | |||
| return f'Parameter (name={self._param_info.name}, value={value_str})' | |||
| def __repr__(self): | |||
| value_str = MetaTensor.__repr__(self) | |||
| if isinstance(self, Tensor): | |||
| value_str = Tensor.__repr__(self) | |||
| return f'Parameter (name={self._value.name}, value={value_str})' | |||
| return f'Parameter (name={self._param_info.name}, value={value_str})' | |||
| def __parameter__(self): | |||
| """For parse check.""" | |||
| @@ -181,7 +181,7 @@ class Parameter(MetaTensor): | |||
| @property | |||
| def name(self): | |||
| """Get the name of the parameter.""" | |||
| return self._value.name | |||
| return self._param_info.name | |||
| @name.setter | |||
| def name(self, name_): | |||
| @@ -203,7 +203,7 @@ class Parameter(MetaTensor): | |||
| format(name_, PARAMETER_NAME_PREFIX_MAX_LEN)) | |||
| else: | |||
| raise ValueError("The type of the name should be `str` or `None`.") | |||
| self._value.name = name_ | |||
| self._param_info.name = name_ | |||
| @property | |||
| def cast_type(self): | |||
| @@ -254,8 +254,8 @@ class Parameter(MetaTensor): | |||
| _check_str_by_regular(prefix) | |||
| x = copy(self) | |||
| # pylint: disable=protected-access | |||
| x._value = self._value.clone() | |||
| x._value.name = prefix + '.' + self._value.name | |||
| x._param_info = self._param_info.clone() | |||
| x._param_info.name = prefix + '.' + self._param_info.name | |||
| x.is_init = False | |||
| if init != 'same': | |||
| shape = self.shape | |||
| @@ -265,24 +265,24 @@ class Parameter(MetaTensor): | |||
| @property | |||
| def layerwise_parallel(self): | |||
| return self._value.layerwise_parallel | |||
| return self._param_info.layerwise_parallel | |||
| @layerwise_parallel.setter | |||
| def layerwise_parallel(self, value=True): | |||
| if not isinstance(value, bool): | |||
| raise TypeError("`layerwise_parallel` parameter must be bool type") | |||
| self._value.layerwise_parallel = value | |||
| self._param_info.layerwise_parallel = value | |||
| @property | |||
| def requires_grad(self): | |||
| """Return whether the parameter requires gradient.""" | |||
| return self._value.requires_grad | |||
| return self._param_info.requires_grad | |||
| @requires_grad.setter | |||
| def requires_grad(self, value=True): | |||
| if not isinstance(value, bool): | |||
| raise TypeError("`requires_grad` parameter must be bool type") | |||
| self._value.requires_grad = value | |||
| self._param_info.requires_grad = value | |||
| @property | |||
| def data(self): | |||
| @@ -459,10 +459,6 @@ AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) { | |||
| } | |||
| auto other_tensor = dyn_cast<AbstractTensor>(other); | |||
| if (other_tensor == nullptr) { | |||
| auto ref_tensor = dyn_cast<AbstractRef>(other); | |||
| if (ref_tensor != nullptr) { | |||
| return this->Join(ref_tensor->ref()); | |||
| } | |||
| MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString(); | |||
| } | |||
| if (*this == *other) { | |||
| @@ -473,7 +469,7 @@ AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) { | |||
| return std::make_shared<AbstractTensor>(element, shape); | |||
| } | |||
| bool AbstractTensor::operator==(const AbstractTensor &other) const { | |||
| bool AbstractTensor::equal_to(const AbstractTensor &other) const { | |||
| if (&other == this) { | |||
| return true; | |||
| } | |||
| @@ -491,12 +487,14 @@ bool AbstractTensor::operator==(const AbstractTensor &other) const { | |||
| return (*element_ == *other.element_) && (*shape() == *other.shape()) && is_value_equal; | |||
| } | |||
| bool AbstractTensor::operator==(const AbstractTensor &other) const { return equal_to(other); } | |||
| bool AbstractTensor::operator==(const AbstractBase &other) const { | |||
| if (&other == this) { | |||
| return true; | |||
| } | |||
| if (other.isa<AbstractTensor>()) { | |||
| if (other.tid() == tid()) { | |||
| auto other_tensor = static_cast<const AbstractTensor *>(&other); | |||
| return *this == *other_tensor; | |||
| } else { | |||
| @@ -822,39 +820,21 @@ std::string AbstractJTagged::ToString() const { | |||
| return buffer.str(); | |||
| } | |||
| AbstractRef::AbstractRef(const AbstractBasePtr &ref_key, const AbstractBasePtr &ref_value, bool need_cast, | |||
| TypePtr cast_target) | |||
| : ref_key_(ref_key), ref_(ref_value), need_cast_(false), target_type_(nullptr), ref_key_value_(nullptr) { | |||
| AbstractRef::AbstractRef(const AbstractBasePtr &ref_key, const AbstractTensorPtr &ref_value) | |||
| : AbstractTensor(*ref_value), ref_key_(ref_key), ref_key_value_(nullptr) { | |||
| set_type(std::make_shared<RefType>()); | |||
| auto origin_type = ref_value->BuildType(); | |||
| if (need_cast && cast_target && origin_type && origin_type->isa<TensorType>()) { | |||
| auto tensor_dtype = origin_type->cast<TensorTypePtr>()->element(); | |||
| if (tensor_dtype && IsSubType(tensor_dtype, kFloat)) { | |||
| if (cast_target != tensor_dtype) { | |||
| need_cast_ = true; | |||
| target_type_ = cast_target; | |||
| } | |||
| } | |||
| } | |||
| if (ref_key && ref_key->isa<AbstractRefKey>()) { | |||
| ref_key_value_ = ref_key->cast<AbstractRefKeyPtr>()->ref_key_value(); | |||
| } | |||
| } | |||
| BaseShapePtr AbstractRef::BuildShape() const { return ref_->BuildShape(); } | |||
| TypePtr AbstractRef::BuildType() const { | |||
| TypePtr subtype = ref_->BuildType(); | |||
| TypePtr subtype_origin = subtype; | |||
| if (need_cast_) { | |||
| subtype_origin = std::make_shared<TensorType>(target_type_); | |||
| } | |||
| return std::make_shared<RefType>(subtype, subtype_origin); | |||
| auto subtype = AbstractTensor::BuildType()->cast<TensorTypePtr>(); | |||
| return std::make_shared<RefType>(subtype); | |||
| } | |||
| bool AbstractRef::operator==(const AbstractRef &other) const { | |||
| return (*ref_ == *other.ref_) && (need_cast_ == other.need_cast_) && (*ref_key_ == *other.ref_key_) && | |||
| (!need_cast_ || (*target_type_ == *other.target_type_)); | |||
| return AbstractTensor::equal_to(other) && (*ref_key_ == *other.ref_key_); | |||
| } | |||
| bool AbstractRef::operator==(const AbstractBase &other) const { | |||
| @@ -886,24 +866,20 @@ AbstractBasePtr AbstractRefKey::Join(const AbstractBasePtr &other) { | |||
| AbstractBasePtr AbstractRef::Join(const AbstractBasePtr &other) { | |||
| auto other_ref = other->cast<AbstractRefPtr>(); | |||
| if (other_ref == nullptr) { | |||
| auto new_ref = ref_->Join(other); | |||
| return std::make_shared<AbstractRef>(ref_key_, new_ref); | |||
| return AbstractTensor::Join(other)->cast<AbstractTensorPtr>(); | |||
| } | |||
| if ((*this == *other) && (*ref_key_ == *other_ref->ref_key_)) { | |||
| return shared_from_base<AbstractBase>(); | |||
| } | |||
| auto ref_key = ref_key_->Join(other_ref->ref_key_); | |||
| auto ref = ref_->Join(other_ref->ref()); | |||
| auto ref = AbstractTensor::Join(other_ref->ref())->cast<AbstractTensorPtr>(); | |||
| return std::make_shared<AbstractRef>(ref_key, ref); | |||
| } | |||
| std::string AbstractRef::ToString() const { | |||
| std::ostringstream buffer; | |||
| buffer << type_name() << "(" | |||
| << "key: " << ref_key_->ToString() << " ref_value: " << ref_->ToString(); | |||
| if (need_cast_) { | |||
| buffer << " cast to: " << target_type_->ToString(); | |||
| } | |||
| << "key: " << ref_key_->ToString() << " ref_value: " << AbstractTensor::ToString(); | |||
| auto value = GetValueTrack(); | |||
| if (value) { | |||
| buffer << ", value: " << value->ToString(); | |||
| @@ -284,11 +284,9 @@ class AbstractTensor : public AbstractUndetermined { | |||
| AbstractBasePtr Clone() const override; | |||
| AbstractBasePtr Broaden(uint8_t config = 0) const override; | |||
| AbstractBasePtr BroadenWithShape() const; | |||
| AbstractBasePtr Join(const AbstractBasePtr &other) final; | |||
| AbstractBasePtr Join(const AbstractBasePtr &other); | |||
| bool operator==(const AbstractTensor &other) const; | |||
| bool operator==(const AbstractBase &other) const override; | |||
| std::string ToString() const override; | |||
| std::size_t hash() const override { | |||
| auto value = GetValueTrack(); | |||
| @@ -301,6 +299,9 @@ class AbstractTensor : public AbstractUndetermined { | |||
| } | |||
| return hash_sum; | |||
| } | |||
| protected: | |||
| bool equal_to(const AbstractTensor &other) const; | |||
| }; | |||
| using AbstractTensorPtr = std::shared_ptr<AbstractTensor>; | |||
| using AbstractTensorPtrList = std::vector<AbstractTensorPtr>; | |||
| @@ -575,42 +576,42 @@ class AbstractRefKey : public AbstractBase { | |||
| }; | |||
| using AbstractRefKeyPtr = std::shared_ptr<AbstractRefKey>; | |||
| class AbstractRef : public AbstractBase { | |||
| class AbstractRef : public AbstractTensor { | |||
| public: | |||
| AbstractRef(const AbstractBasePtr &ref_key, const AbstractBasePtr &ref_value, bool need_cast = false, | |||
| TypePtr cast_target = nullptr); | |||
| AbstractRef(const AbstractBasePtr &ref_key, const AbstractTensorPtr &ref_value); | |||
| ~AbstractRef() override = default; | |||
| MS_DECLARE_PARENT(AbstractRef, AbstractBase) | |||
| MS_DECLARE_PARENT(AbstractRef, AbstractTensor) | |||
| TypePtr BuildType() const override; | |||
| BaseShapePtr BuildShape() const override; | |||
| bool operator==(const AbstractRef &other) const; | |||
| bool operator==(const AbstractBase &other) const override; | |||
| AbstractBasePtr Clone() const override { | |||
| return std::make_shared<AbstractRef>(ref_key_->Clone(), ref_->Clone(), need_cast_, target_type_); | |||
| auto abs_tensor = AbstractTensor::Clone()->cast<AbstractTensorPtr>(); | |||
| if (abs_tensor == nullptr) { | |||
| return nullptr; | |||
| } | |||
| return std::make_shared<AbstractRef>(ref_key_->Clone(), abs_tensor); | |||
| } | |||
| std::string ToString() const override; | |||
| inline AbstractBasePtr ref() const { return ref_; } | |||
| inline AbstractTensorPtr ref() { return shared_from_base<AbstractTensor>(); } | |||
| inline AbstractBasePtr ref_key() const { return ref_key_; } | |||
| inline RefKeyPtr ref_key_value() const { return ref_key_value_; } | |||
| inline TypePtr target_type() const { return target_type_; } | |||
| inline bool need_cast() const { return need_cast_; } | |||
| AbstractBasePtr Broaden(uint8_t config = 0) const override { | |||
| // always broaden for ref | |||
| return std::make_shared<AbstractRef>(ref_key_->Broaden(config), ref_->Broaden(), need_cast_, target_type_); | |||
| auto abs_tensor = AbstractTensor::Broaden()->cast<AbstractTensorPtr>(); | |||
| if (abs_tensor == nullptr) { | |||
| return nullptr; | |||
| } | |||
| return std::make_shared<AbstractRef>(ref_key_->Broaden(config), abs_tensor); | |||
| } | |||
| AbstractBasePtr Join(const AbstractBasePtr &other) override; | |||
| std::size_t hash() const override { | |||
| return ref_->hash() ^ (std::hash<uint32_t>{}(this->tid()) << 1); // ref_key_->hash() ^ | |||
| return AbstractTensor::hash() ^ (std::hash<uint32_t>{}(this->tid()) << 1); // ref_key_->hash() ^ | |||
| } | |||
| private: | |||
| AbstractBasePtr ref_key_; | |||
| AbstractBasePtr ref_; | |||
| // For mix presicion, only float type need to cast to float16 of float32 | |||
| bool need_cast_; | |||
| TypePtr target_type_; | |||
| // cache for ref_key after build value, when value is null, return nullptr. | |||
| RefKeyPtr ref_key_value_; | |||
| }; | |||
| @@ -113,17 +113,8 @@ AbstractBasePtr InferImplMakeRef(const AnalysisEnginePtr &, const PrimitivePtr & | |||
| MS_LOG(EXCEPTION) << "make_ref evaluator requires 3 parameters, while the input size is " << args_spec_list.size() | |||
| << "."; | |||
| } | |||
| TypePtr type = args_spec_list[0]->GetTypeTrack(); | |||
| ValuePtr tensor_target_v = args_spec_list[2]->BuildValue(); | |||
| if (type->type_id() != kObjectTypeRefKey) { | |||
| MS_LOG(EXCEPTION) << "First input of make_ref should be a RefKey but a " << type->ToString(); | |||
| } | |||
| auto need_cast = !tensor_target_v->isa<None>(); | |||
| if (need_cast && !tensor_target_v->isa<Type>()) { | |||
| MS_LOG(EXCEPTION) << "Third input of make_ref should be a Type but a " << tensor_target_v->ToString(); | |||
| } | |||
| TypePtr cast_target = tensor_target_v->cast<TypePtr>(); | |||
| return std::make_shared<AbstractRef>(args_spec_list[0], args_spec_list[1], need_cast, cast_target); | |||
| auto tensor = args_spec_list[1]->cast<abstract::AbstractTensorPtr>(); | |||
| return std::make_shared<AbstractRef>(args_spec_list[0], tensor); | |||
| } | |||
| AbstractBasePtr InferImplGetRefKey(const AnalysisEnginePtr &, const PrimitivePtr &, | |||
| @@ -88,6 +88,17 @@ std::string Parameter::DebugString(int recursive_level) const { | |||
| return buffer.str(); | |||
| } | |||
| ParamInfoPtr Parameter::param_info() const { | |||
| if (!has_default()) { | |||
| return nullptr; | |||
| } | |||
| auto tensor = default_param()->cast<tensor::MetaTensorPtr>(); | |||
| if (tensor == nullptr || !tensor->is_parameter()) { | |||
| return nullptr; | |||
| } | |||
| return tensor->param_info(); | |||
| } | |||
| std::string ValueNode::ToString() const { | |||
| MS_EXCEPTION_IF_NULL(value_); | |||
| if (value_->isa<FuncGraph>()) { | |||
| @@ -75,7 +75,7 @@ using VarPtr = std::shared_ptr<Var>; | |||
| class AnfIrVisitor; | |||
| class ParamInfo; | |||
| using ParamValuePtr = std::shared_ptr<ParamInfo>; | |||
| using ParamInfoPtr = std::shared_ptr<ParamInfo>; | |||
| // AnfNode is the basic class of the IR definition derived from Base. | |||
| // Only two types of nodes are derived: CNode and ANode. | |||
| @@ -288,6 +288,7 @@ class Parameter : public ANode { | |||
| has_default_ = true; | |||
| } | |||
| ValuePtr default_param() const { return default_param_; } | |||
| ParamInfoPtr param_info() const; | |||
| bool operator==(const AnfNode &other) const override { | |||
| if (!other.isa<Parameter>()) { | |||
| @@ -94,175 +94,6 @@ bool Slice::operator==(const Type &other) const { | |||
| std::string Slice::DumpText() const { return ToString(); } | |||
| TypePtr UndeterminedType::DeepCopy() const { | |||
| MS_EXCEPTION_IF_NULL(element_type_); | |||
| if (IsGeneric()) { | |||
| return std::make_shared<UndeterminedType>(); | |||
| } | |||
| return std::make_shared<UndeterminedType>(element_type_->DeepCopy()); | |||
| } | |||
| std::string UndeterminedType::ToReprString() const { | |||
| if (element_type_ == nullptr) { | |||
| return "Undetermined"; | |||
| } | |||
| return "Undetermined[" + element_type_->ToReprString() + "]"; | |||
| } | |||
| std::string UndeterminedType::ToString() const { | |||
| if (element_type_ == nullptr) { | |||
| return "Undetermined"; | |||
| } | |||
| return "Undetermined[" + element_type_->ToString() + "]"; | |||
| } | |||
| std::string UndeterminedType::DumpText() const { | |||
| if (element_type_ == nullptr) { | |||
| return "Undetermined"; | |||
| } | |||
| return "Undetermined[" + element_type_->DumpText() + "]"; | |||
| } | |||
| bool UndeterminedType::operator==(const Type &other) const { | |||
| if (!IsSameObjectType(*this, other)) { | |||
| return false; | |||
| } | |||
| auto other_elem_type = static_cast<const UndeterminedType &>(other).element_type_; | |||
| if (element_type_ == nullptr && other_elem_type == nullptr) { | |||
| return true; | |||
| } else if (element_type_ == nullptr || other_elem_type == nullptr) { | |||
| return false; | |||
| } | |||
| return *element_type_ == *other_elem_type; | |||
| } | |||
| TypePtr TensorType::DeepCopy() const { | |||
| MS_EXCEPTION_IF_NULL(element_type_); | |||
| if (IsGeneric()) { | |||
| return std::make_shared<TensorType>(); | |||
| } | |||
| return std::make_shared<TensorType>(element_type_->DeepCopy()); | |||
| } | |||
| std::string TensorType::ToReprString() const { | |||
| if (element_type_ == nullptr) { | |||
| return "tensor"; | |||
| } | |||
| return "tensor[" + element_type_->ToReprString() + "]"; | |||
| } | |||
| std::string TensorType::ToString() const { | |||
| if (element_type_ == nullptr) { | |||
| return "Tensor"; | |||
| } | |||
| return "Tensor[" + element_type_->ToString() + "]"; | |||
| } | |||
| std::string TensorType::DumpText() const { | |||
| if (element_type_ == nullptr) { | |||
| return "Tensor"; | |||
| } | |||
| return "Tensor(" + element_type_->DumpText() + ")"; | |||
| } | |||
| bool TensorType::operator==(const Type &other) const { | |||
| if (!IsSameObjectType(*this, other)) { | |||
| return false; | |||
| } | |||
| auto other_elem_type = static_cast<const TensorType &>(other).element_type_; | |||
| // When element_type_ = nullptr, which means any type of Array. | |||
| if (element_type_ == nullptr && other_elem_type == nullptr) { | |||
| return true; | |||
| } else if (element_type_ == nullptr || other_elem_type == nullptr) { | |||
| return false; | |||
| } | |||
| return *element_type_ == *other_elem_type; | |||
| } | |||
| TypePtr RowTensorType::DeepCopy() const { | |||
| MS_EXCEPTION_IF_NULL(element_type_); | |||
| if (IsGeneric()) { | |||
| return std::make_shared<RowTensorType>(); | |||
| } | |||
| return std::make_shared<RowTensorType>(element_type_->DeepCopy()); | |||
| } | |||
| std::string RowTensorType::ToReprString() const { | |||
| if (element_type_ == nullptr) { | |||
| return "RowTensor"; | |||
| } | |||
| return "RowTensor[" + element_type_->ToReprString() + "]"; | |||
| } | |||
| std::string RowTensorType::ToString() const { | |||
| if (element_type_ == nullptr) { | |||
| return "RowTensor"; | |||
| } | |||
| return "RowTensor[" + element_type_->ToString() + "]"; | |||
| } | |||
| std::string RowTensorType::DumpText() const { | |||
| if (element_type_ == nullptr) { | |||
| return "RowTensor"; | |||
| } | |||
| return "RowTensor[" + element_type_->DumpText() + "]"; | |||
| } | |||
| bool RowTensorType::operator==(const Type &other) const { | |||
| if (!IsSameObjectType(*this, other)) { | |||
| return false; | |||
| } | |||
| auto other_elem_type = static_cast<const RowTensorType &>(other).element_type_; | |||
| if (element_type_ == nullptr && other_elem_type == nullptr) { | |||
| return true; | |||
| } else if (element_type_ == nullptr || other_elem_type == nullptr) { | |||
| return false; | |||
| } | |||
| return *element_type_ == *other_elem_type; | |||
| } | |||
| TypePtr SparseTensorType::DeepCopy() const { | |||
| MS_EXCEPTION_IF_NULL(element_type_); | |||
| if (IsGeneric()) { | |||
| return std::make_shared<SparseTensorType>(); | |||
| } | |||
| return std::make_shared<SparseTensorType>(element_type_->DeepCopy()); | |||
| } | |||
| std::string SparseTensorType::ToReprString() const { | |||
| if (element_type_ == nullptr) { | |||
| return "SparseTensor"; | |||
| } | |||
| return "SparseTensor[" + element_type_->ToReprString() + "]"; | |||
| } | |||
| std::string SparseTensorType::ToString() const { | |||
| if (element_type_ == nullptr) { | |||
| return "SparseTensor"; | |||
| } | |||
| return "SparseTensor[" + element_type_->ToString() + "]"; | |||
| } | |||
| std::string SparseTensorType::DumpText() const { | |||
| if (element_type_ == nullptr) { | |||
| return "SparseTensor"; | |||
| } | |||
| return "SparseTensor[" + element_type_->DumpText() + "]"; | |||
| } | |||
| bool SparseTensorType::operator==(const Type &other) const { | |||
| if (!IsSameObjectType(*this, other)) { | |||
| return false; | |||
| } | |||
| auto other_elem_type = static_cast<const SparseTensorType &>(other).element_type_; | |||
| if (element_type_ == nullptr && other_elem_type == nullptr) { | |||
| return true; | |||
| } else if (element_type_ == nullptr || other_elem_type == nullptr) { | |||
| return false; | |||
| } | |||
| return *element_type_ == *other_elem_type; | |||
| } | |||
| Function::Function() : Object(kObjectTypeFunction) { | |||
| args_ = std::vector<TypePtr>(); | |||
| retval_ = nullptr; | |||
| @@ -372,4 +203,8 @@ std::ostream &operator<<(std::ostream &os, const std::shared_ptr<Problem> proble | |||
| os << problem->ToString(); | |||
| return os; | |||
| } | |||
| const TypePtr kTensorTypeFP16 = std::make_shared<TensorType>(std::make_shared<Float>(16)); | |||
| const TypePtr kTensorTypeFP32 = std::make_shared<TensorType>(std::make_shared<Float>(32)); | |||
| } // namespace mindspore | |||
| @@ -32,10 +32,11 @@ | |||
| #include "ir/named.h" | |||
| #include "ir/dtype/type.h" | |||
| #include "ir/dtype/ref.h" | |||
| #include "ir/dtype/number.h" | |||
| #include "ir/dtype/container.h" | |||
| #include "ir/dtype/empty.h" | |||
| #include "ir/dtype/tensor_type.h" | |||
| #include "ir/dtype/ref.h" | |||
| /* namespace to support intermediate representation definition */ | |||
| namespace mindspore { | |||
| @@ -108,98 +109,6 @@ class Slice : public Object { | |||
| }; | |||
| using SlicePtr = std::shared_ptr<Slice>; | |||
| class UndeterminedType : public Object { | |||
| public: | |||
| UndeterminedType() : Object(kObjectTypeUndeterminedType) {} | |||
| explicit UndeterminedType(const TypePtr &ele) | |||
| : Object(kObjectTypeUndeterminedType, kMetaTypeObject, false), element_type_(ele) {} | |||
| ~UndeterminedType() override = default; | |||
| MS_DECLARE_PARENT(UndeterminedType, Object) | |||
| TypeId generic_type_id() const override { return kObjectTypeUndeterminedType; } | |||
| const TypePtr element() const { return element_type_; } | |||
| void set_element(const TypePtr &element_type) { element_type_ = element_type; } | |||
| TypePtr DeepCopy() const override; | |||
| std::string ToString() const override; | |||
| std::string ToReprString() const override; | |||
| std::string DumpText() const override; | |||
| bool operator==(const Type &other) const override; | |||
| protected: | |||
| TypePtr element_type_; | |||
| }; | |||
| using MetaTensorTypePtr = std::shared_ptr<UndeterminedType>; | |||
| class TensorType : public Object { | |||
| public: | |||
| TensorType() : Object(kObjectTypeTensorType, kObjectTypeUndeterminedType) {} | |||
| explicit TensorType(const TypePtr &ele) | |||
| : Object(kObjectTypeTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {} | |||
| ~TensorType() override = default; | |||
| MS_DECLARE_PARENT(TensorType, Object) | |||
| TypeId generic_type_id() const override { return kObjectTypeTensorType; } | |||
| const TypePtr element() const { return element_type_; } | |||
| void set_element(const TypePtr &element_type) { element_type_ = element_type; } | |||
| TypePtr DeepCopy() const override; | |||
| std::string ToString() const override; | |||
| std::string ToReprString() const override; | |||
| std::string DumpText() const override; | |||
| bool operator==(const Type &other) const override; | |||
| private: | |||
| TypePtr element_type_; | |||
| }; | |||
| using TensorTypePtr = std::shared_ptr<TensorType>; | |||
| class RowTensorType : public Object { | |||
| public: | |||
| RowTensorType() : Object(kObjectTypeRowTensorType, kObjectTypeUndeterminedType) {} | |||
| explicit RowTensorType(const TypePtr &ele) | |||
| : Object(kObjectTypeRowTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {} | |||
| ~RowTensorType() override = default; | |||
| MS_DECLARE_PARENT(RowTensorType, Object) | |||
| TypeId generic_type_id() const override { return kObjectTypeRowTensorType; } | |||
| const TypePtr element() const { return element_type_; } | |||
| void set_element(const TypePtr &element_type) { element_type_ = element_type; } | |||
| TypePtr DeepCopy() const override; | |||
| std::string ToString() const override; | |||
| std::string ToReprString() const override; | |||
| std::string DumpText() const override; | |||
| bool operator==(const Type &other) const override; | |||
| private: | |||
| TypePtr element_type_; | |||
| }; | |||
| using RowTensorTypePtr = std::shared_ptr<RowTensorType>; | |||
| class SparseTensorType : public Object { | |||
| public: | |||
| SparseTensorType() : Object(kObjectTypeSparseTensorType, kObjectTypeUndeterminedType) {} | |||
| explicit SparseTensorType(const TypePtr &ele) | |||
| : Object(kObjectTypeSparseTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {} | |||
| ~SparseTensorType() override = default; | |||
| MS_DECLARE_PARENT(SparseTensorType, Object) | |||
| TypeId generic_type_id() const override { return kObjectTypeSparseTensorType; } | |||
| const TypePtr element() const { return element_type_; } | |||
| void set_element(const TypePtr &element_type) { element_type_ = element_type; } | |||
| TypePtr DeepCopy() const override; | |||
| std::string ToString() const override; | |||
| std::string ToReprString() const override; | |||
| std::string DumpText() const override; | |||
| bool operator==(const Type &other) const override; | |||
| private: | |||
| TypePtr element_type_; | |||
| }; | |||
| using SparseTensorTypePtr = std::shared_ptr<SparseTensorType>; | |||
| class Function : public Object { | |||
| public: | |||
| Function(); | |||
| @@ -353,6 +262,9 @@ extern const TypePtr kDict; | |||
| extern const TypePtr kSlice; | |||
| extern const TypePtr kKeyword; | |||
| extern const TypePtr kTensorType; | |||
| extern const TypePtr kTensorTypeFP16; | |||
| extern const TypePtr kTensorTypeFP32; | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_IR_DTYPE_H_ | |||
| @@ -68,6 +68,8 @@ class Number : public Object { | |||
| const int nbits_; | |||
| }; | |||
| using NumberPtr = std::shared_ptr<Number>; | |||
| // Bool | |||
| class Bool : public Number { | |||
| public: | |||
| @@ -19,15 +19,15 @@ | |||
| #include <cstdlib> | |||
| #include <algorithm> | |||
| #include "utils/log_adapter.h" | |||
| #include "ir/dtype/tensor_type.h" | |||
| namespace mindspore { | |||
| TypePtr RefType::DeepCopy() const { | |||
| if (IsGeneric()) { | |||
| return std::make_shared<RefType>(); | |||
| } else { | |||
| auto subtype = subtype_->DeepCopy(); | |||
| auto subtype_origin = subtype_origin_->DeepCopy(); | |||
| return std::make_shared<RefType>(subtype, subtype_origin); | |||
| auto subtype = TensorType::DeepCopy()->cast<TensorTypePtr>(); | |||
| return std::make_shared<RefType>(subtype); | |||
| } | |||
| } | |||
| @@ -39,7 +39,7 @@ std::string RefType::DumpText() const { | |||
| buffer << "Ref"; | |||
| } else { | |||
| buffer << "Ref["; | |||
| buffer << subtype_->DumpText() << "]"; | |||
| buffer << TensorType::DumpText() << "]"; | |||
| } | |||
| return buffer.str(); | |||
| } | |||
| @@ -17,21 +17,13 @@ | |||
| #ifndef MINDSPORE_CORE_IR_DTYPE_REF_H_ | |||
| #define MINDSPORE_CORE_IR_DTYPE_REF_H_ | |||
| #include <cstddef> | |||
| #include <iostream> | |||
| #include <initializer_list> | |||
| #include <map> | |||
| #include <memory> | |||
| #include <utility> | |||
| #include <sstream> | |||
| #include <string> | |||
| #include <vector> | |||
| #include <type_traits> | |||
| #include <unordered_map> | |||
| #include <algorithm> | |||
| #include "base/base.h" | |||
| #include "ir/named.h" | |||
| #include "ir/dtype/type.h" | |||
| #include "ir/dtype/tensor_type.h" | |||
| namespace mindspore { | |||
| // TypeRefKey type | |||
| @@ -48,23 +40,16 @@ class RefKeyType : public Object { | |||
| }; | |||
| // TypeRef type | |||
| class RefType : public Object { | |||
| class RefType : public TensorType { | |||
| public: | |||
| RefType() : Object(kObjectTypeRef) {} | |||
| RefType(const TypePtr &subtype, const TypePtr &subtype_origin) | |||
| : Object(kObjectTypeRef, false), subtype_(subtype), subtype_origin_(subtype_origin) {} | |||
| RefType() : TensorType() {} | |||
| explicit RefType(const TensorTypePtr &subtype) : TensorType(subtype->element()) {} | |||
| ~RefType() override {} | |||
| MS_DECLARE_PARENT(RefType, Object) | |||
| MS_DECLARE_PARENT(RefType, TensorType) | |||
| TypePtr subtype() const { return subtype_; } | |||
| TypeId generic_type_id() const override { return kObjectTypeRef; } | |||
| TypePtr DeepCopy() const override; | |||
| std::string ToString() const override; | |||
| std::string DumpText() const override; | |||
| private: | |||
| TypePtr subtype_; | |||
| TypePtr subtype_origin_; | |||
| }; | |||
| using RefTypePtr = std::shared_ptr<RefType>; | |||
| @@ -0,0 +1,194 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ir/dtype/tensor_type.h" | |||
| #include <string> | |||
| #include <cstdlib> | |||
| #include <algorithm> | |||
| #include "utils/log_adapter.h" | |||
| namespace mindspore { | |||
| TypePtr UndeterminedType::DeepCopy() const { | |||
| MS_EXCEPTION_IF_NULL(element_type_); | |||
| if (IsGeneric()) { | |||
| return std::make_shared<UndeterminedType>(); | |||
| } | |||
| return std::make_shared<UndeterminedType>(element_type_->DeepCopy()); | |||
| } | |||
| std::string UndeterminedType::ToReprString() const { | |||
| if (element_type_ == nullptr) { | |||
| return "Undetermined"; | |||
| } | |||
| return "Undetermined[" + element_type_->ToReprString() + "]"; | |||
| } | |||
| std::string UndeterminedType::ToString() const { | |||
| if (element_type_ == nullptr) { | |||
| return "Undetermined"; | |||
| } | |||
| return "Undetermined[" + element_type_->ToString() + "]"; | |||
| } | |||
| std::string UndeterminedType::DumpText() const { | |||
| if (element_type_ == nullptr) { | |||
| return "Undetermined"; | |||
| } | |||
| return "Undetermined[" + element_type_->DumpText() + "]"; | |||
| } | |||
| bool UndeterminedType::operator==(const Type &other) const { | |||
| if (!IsSameObjectType(*this, other)) { | |||
| return false; | |||
| } | |||
| auto other_elem_type = static_cast<const UndeterminedType &>(other).element_type_; | |||
| if (element_type_ == nullptr && other_elem_type == nullptr) { | |||
| return true; | |||
| } else if (element_type_ == nullptr || other_elem_type == nullptr) { | |||
| return false; | |||
| } | |||
| return *element_type_ == *other_elem_type; | |||
| } | |||
| TypePtr TensorType::DeepCopy() const { | |||
| MS_EXCEPTION_IF_NULL(element_type_); | |||
| if (IsGeneric()) { | |||
| return std::make_shared<TensorType>(); | |||
| } | |||
| return std::make_shared<TensorType>(element_type_->DeepCopy()); | |||
| } | |||
| std::string TensorType::ToReprString() const { | |||
| if (element_type_ == nullptr) { | |||
| return "tensor"; | |||
| } | |||
| return "tensor[" + element_type_->ToReprString() + "]"; | |||
| } | |||
| std::string TensorType::ToString() const { | |||
| if (element_type_ == nullptr) { | |||
| return "Tensor"; | |||
| } | |||
| return "Tensor[" + element_type_->ToString() + "]"; | |||
| } | |||
| std::string TensorType::DumpText() const { | |||
| if (element_type_ == nullptr) { | |||
| return "Tensor"; | |||
| } | |||
| return "Tensor(" + element_type_->DumpText() + ")"; | |||
| } | |||
| bool TensorType::operator==(const Type &other) const { | |||
| if (!IsSameObjectType(*this, other)) { | |||
| return false; | |||
| } | |||
| auto other_elem_type = static_cast<const TensorType &>(other).element_type_; | |||
| // When element_type_ = nullptr, which means any type of Array. | |||
| if (element_type_ == nullptr && other_elem_type == nullptr) { | |||
| return true; | |||
| } else if (element_type_ == nullptr || other_elem_type == nullptr) { | |||
| return false; | |||
| } | |||
| return *element_type_ == *other_elem_type; | |||
| } | |||
| TypePtr RowTensorType::DeepCopy() const { | |||
| MS_EXCEPTION_IF_NULL(element_type_); | |||
| if (IsGeneric()) { | |||
| return std::make_shared<RowTensorType>(); | |||
| } | |||
| return std::make_shared<RowTensorType>(element_type_->DeepCopy()); | |||
| } | |||
| std::string RowTensorType::ToReprString() const { | |||
| if (element_type_ == nullptr) { | |||
| return "RowTensor"; | |||
| } | |||
| return "RowTensor[" + element_type_->ToReprString() + "]"; | |||
| } | |||
| std::string RowTensorType::ToString() const { | |||
| if (element_type_ == nullptr) { | |||
| return "RowTensor"; | |||
| } | |||
| return "RowTensor[" + element_type_->ToString() + "]"; | |||
| } | |||
| std::string RowTensorType::DumpText() const { | |||
| if (element_type_ == nullptr) { | |||
| return "RowTensor"; | |||
| } | |||
| return "RowTensor[" + element_type_->DumpText() + "]"; | |||
| } | |||
| bool RowTensorType::operator==(const Type &other) const { | |||
| if (!IsSameObjectType(*this, other)) { | |||
| return false; | |||
| } | |||
| auto other_elem_type = static_cast<const RowTensorType &>(other).element_type_; | |||
| if (element_type_ == nullptr && other_elem_type == nullptr) { | |||
| return true; | |||
| } else if (element_type_ == nullptr || other_elem_type == nullptr) { | |||
| return false; | |||
| } | |||
| return *element_type_ == *other_elem_type; | |||
| } | |||
| TypePtr SparseTensorType::DeepCopy() const { | |||
| MS_EXCEPTION_IF_NULL(element_type_); | |||
| if (IsGeneric()) { | |||
| return std::make_shared<SparseTensorType>(); | |||
| } | |||
| return std::make_shared<SparseTensorType>(element_type_->DeepCopy()); | |||
| } | |||
| std::string SparseTensorType::ToReprString() const { | |||
| if (element_type_ == nullptr) { | |||
| return "SparseTensor"; | |||
| } | |||
| return "SparseTensor[" + element_type_->ToReprString() + "]"; | |||
| } | |||
| std::string SparseTensorType::ToString() const { | |||
| if (element_type_ == nullptr) { | |||
| return "SparseTensor"; | |||
| } | |||
| return "SparseTensor[" + element_type_->ToString() + "]"; | |||
| } | |||
| std::string SparseTensorType::DumpText() const { | |||
| if (element_type_ == nullptr) { | |||
| return "SparseTensor"; | |||
| } | |||
| return "SparseTensor[" + element_type_->DumpText() + "]"; | |||
| } | |||
| bool SparseTensorType::operator==(const Type &other) const { | |||
| if (!IsSameObjectType(*this, other)) { | |||
| return false; | |||
| } | |||
| auto other_elem_type = static_cast<const SparseTensorType &>(other).element_type_; | |||
| if (element_type_ == nullptr && other_elem_type == nullptr) { | |||
| return true; | |||
| } else if (element_type_ == nullptr || other_elem_type == nullptr) { | |||
| return false; | |||
| } | |||
| return *element_type_ == *other_elem_type; | |||
| } | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,132 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_IR_DTYPE_TENSORTYPE_H_ | |||
| #define MINDSPORE_CORE_IR_DTYPE_TENSORTYPE_H_ | |||
| #include <cstddef> | |||
| #include <iostream> | |||
| #include <initializer_list> | |||
| #include <map> | |||
| #include <memory> | |||
| #include <utility> | |||
| #include <sstream> | |||
| #include <string> | |||
| #include <vector> | |||
| #include <type_traits> | |||
| #include <unordered_map> | |||
| #include <algorithm> | |||
| #include "base/base.h" | |||
| #include "ir/named.h" | |||
| #include "ir/dtype/type.h" | |||
| namespace mindspore { | |||
| class UndeterminedType : public Object { | |||
| public: | |||
| UndeterminedType() : Object(kObjectTypeUndeterminedType) {} | |||
| explicit UndeterminedType(const TypePtr &ele) | |||
| : Object(kObjectTypeUndeterminedType, kMetaTypeObject, false), element_type_(ele) {} | |||
| ~UndeterminedType() override = default; | |||
| MS_DECLARE_PARENT(UndeterminedType, Object) | |||
| TypeId generic_type_id() const override { return kObjectTypeUndeterminedType; } | |||
| const TypePtr element() const { return element_type_; } | |||
| void set_element(const TypePtr &element_type) { element_type_ = element_type; } | |||
| TypePtr DeepCopy() const override; | |||
| std::string ToString() const override; | |||
| std::string ToReprString() const override; | |||
| std::string DumpText() const override; | |||
| bool operator==(const Type &other) const override; | |||
| protected: | |||
| TypePtr element_type_; | |||
| }; | |||
| using MetaTensorTypePtr = std::shared_ptr<UndeterminedType>; | |||
| class TensorType : public Object { | |||
| public: | |||
| TensorType() : Object(kObjectTypeTensorType, kObjectTypeUndeterminedType) {} | |||
| explicit TensorType(const TypePtr &ele) | |||
| : Object(kObjectTypeTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {} | |||
| ~TensorType() override = default; | |||
| MS_DECLARE_PARENT(TensorType, Object) | |||
| TypeId generic_type_id() const override { return kObjectTypeTensorType; } | |||
| const TypePtr element() const { return element_type_; } | |||
| void set_element(const TypePtr &element_type) { element_type_ = element_type; } | |||
| TypePtr DeepCopy() const override; | |||
| std::string ToString() const override; | |||
| std::string ToReprString() const override; | |||
| std::string DumpText() const override; | |||
| bool operator==(const Type &other) const override; | |||
| private: | |||
| TypePtr element_type_; | |||
| }; | |||
| using TensorTypePtr = std::shared_ptr<TensorType>; | |||
| class RowTensorType : public Object { | |||
| public: | |||
| RowTensorType() : Object(kObjectTypeRowTensorType, kObjectTypeUndeterminedType) {} | |||
| explicit RowTensorType(const TypePtr &ele) | |||
| : Object(kObjectTypeRowTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {} | |||
| ~RowTensorType() override = default; | |||
| MS_DECLARE_PARENT(RowTensorType, Object) | |||
| TypeId generic_type_id() const override { return kObjectTypeRowTensorType; } | |||
| const TypePtr element() const { return element_type_; } | |||
| void set_element(const TypePtr &element_type) { element_type_ = element_type; } | |||
| TypePtr DeepCopy() const override; | |||
| std::string ToString() const override; | |||
| std::string ToReprString() const override; | |||
| std::string DumpText() const override; | |||
| bool operator==(const Type &other) const override; | |||
| private: | |||
| TypePtr element_type_; | |||
| }; | |||
| using RowTensorTypePtr = std::shared_ptr<RowTensorType>; | |||
| class SparseTensorType : public Object { | |||
| public: | |||
| SparseTensorType() : Object(kObjectTypeSparseTensorType, kObjectTypeUndeterminedType) {} | |||
| explicit SparseTensorType(const TypePtr &ele) | |||
| : Object(kObjectTypeSparseTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {} | |||
| ~SparseTensorType() override = default; | |||
| MS_DECLARE_PARENT(SparseTensorType, Object) | |||
| TypeId generic_type_id() const override { return kObjectTypeSparseTensorType; } | |||
| const TypePtr element() const { return element_type_; } | |||
| void set_element(const TypePtr &element_type) { element_type_ = element_type; } | |||
| TypePtr DeepCopy() const override; | |||
| std::string ToString() const override; | |||
| std::string ToReprString() const override; | |||
| std::string DumpText() const override; | |||
| bool operator==(const Type &other) const override; | |||
| private: | |||
| TypePtr element_type_; | |||
| }; | |||
| using SparseTensorTypePtr = std::shared_ptr<SparseTensorType>; | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_IR_DTYPE_TENSORTYPE_H_ | |||
| @@ -332,14 +332,11 @@ class FuncGraph : public FuncGraphBase { | |||
| const std::vector<AnfNodePtr> ¶mter_obj_nodes() const { return paramter_obj_nodes_; } | |||
| void add_parameter_obj_node(const AnfNodePtr &p); | |||
| std::unordered_map<AnfNodePtr, AnfNodePtr> &make_ref_params() { return make_ref_params_; } | |||
| std::unordered_map<std::string, ValuePtr> attrs_; | |||
| std::vector<BaseShapePtr> joined_shapes_; | |||
| std::unordered_map<std::string, FuncGraphTransform> transforms_; | |||
| // parameter default value | |||
| std::map<std::string, AnfNodePtr> parameter_default_value_; | |||
| std::unordered_map<AnfNodePtr, AnfNodePtr> make_ref_params_; | |||
| size_t seen_; | |||
| std::list<CNodePtr> GetOrderedCnodes(); | |||
| @@ -23,6 +23,7 @@ | |||
| #include <string> | |||
| #include "base/base.h" | |||
| #include "ir/param_info.h" | |||
| #include "ir/dtype.h" | |||
| #include "utils/convert_utils_base.h" | |||
| #include "utils/hashing.h" | |||
| @@ -163,6 +164,15 @@ class MetaTensor : public Value { | |||
| return false; | |||
| } | |||
| } | |||
| // Get tensor's param_info info. | |||
| ParamInfoPtr param_info() const { return param_info_; } | |||
| bool is_parameter() const { return is_parameter_; } | |||
| // Set tensor's param_info info. | |||
| void set_param_info(const ParamInfoPtr ¶m_info) { | |||
| is_parameter_ = true; | |||
| param_info_ = param_info; | |||
| } | |||
| protected: | |||
| // brief Data type of the tensor. | |||
| @@ -184,6 +194,9 @@ class MetaTensor : public Value { | |||
| // | |||
| // Includes the format and data type of a tensor on device. | |||
| DeviceInfo device_info_; | |||
| bool is_parameter_{false}; | |||
| ParamInfoPtr param_info_{nullptr}; | |||
| }; | |||
| using MetaTensorPtr = std::shared_ptr<MetaTensor>; | |||
| @@ -34,7 +34,16 @@ abstract::AbstractBasePtr MetaTensor::ToAbstract() { | |||
| } | |||
| auto tensor_shape = tens->shape(); | |||
| auto abs_tensor = std::make_shared<abstract::AbstractTensor>(dtype, tensor_shape); | |||
| abs_tensor->set_value(shared_from_base<MetaTensor>()); | |||
| // if is parameter always no value. | |||
| if (is_parameter()) { | |||
| auto param_name = param_info()->name(); | |||
| auto ref_key = std::make_shared<RefKey>(param_name); | |||
| auto abs_ref_key = ref_key->ToAbstract(); | |||
| abs_tensor = std::make_shared<abstract::AbstractRef>(abs_ref_key, abs_tensor); | |||
| } else { | |||
| abs_tensor->set_value(shared_from_base<MetaTensor>()); | |||
| } | |||
| return abs_tensor; | |||
| } | |||
| @@ -62,6 +62,21 @@ class Named : public Value { | |||
| }; | |||
| using NamedPtr = std::shared_ptr<Named>; | |||
| struct NamedHasher { | |||
| std::size_t operator()(NamedPtr const &name) const { | |||
| std::size_t hash = name->Hash(); | |||
| return hash; | |||
| } | |||
| }; | |||
| struct NamedEqual { | |||
| bool operator()(NamedPtr const &t1, NamedPtr const &t2) const { | |||
| MS_EXCEPTION_IF_NULL(t1); | |||
| MS_EXCEPTION_IF_NULL(t2); | |||
| return *t1 == *t2; | |||
| } | |||
| }; | |||
| class None : public Named { | |||
| public: | |||
| None() : Named("None") {} | |||
| @@ -21,10 +21,13 @@ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "ir/anf.h" | |||
| #include "ir/tensor.h" | |||
| #include "ir/dtype.h" | |||
| namespace mindspore { | |||
| class ParamInfo; | |||
| using ParamInfoPtr = std::shared_ptr<ParamInfo>; | |||
| class ParamInfo { | |||
| public: | |||
| ParamInfo() {} | |||
| @@ -55,7 +58,7 @@ class ParamInfo { | |||
| int32_t cloned_index() const { return cloned_index_; } | |||
| // Make a cloned parameter and update clone info. | |||
| ParamValuePtr Clone() { | |||
| ParamInfoPtr Clone() { | |||
| static std::atomic<int32_t> parameter_cloned_index{1}; | |||
| int32_t index = parameter_cloned_index.fetch_add(1, std::memory_order_relaxed); | |||
| auto clone = std::make_shared<ParamInfo>(*this); | |||
| @@ -467,6 +467,7 @@ Tensor &Tensor::AssignValue(const Tensor &tensor) { | |||
| } | |||
| return *this; | |||
| } | |||
| abstract::AbstractBasePtr Tensor::ToAbstract() { | |||
| auto tens = shared_from_base<Tensor>(); | |||
| auto dtype = tens->Dtype(); | |||
| @@ -475,7 +476,15 @@ abstract::AbstractBasePtr Tensor::ToAbstract() { | |||
| } | |||
| auto tensor_shape = tens->shape(); | |||
| auto abs_tensor = std::make_shared<abstract::AbstractTensor>(dtype, tensor_shape); | |||
| abs_tensor->set_value(shared_from_base<Tensor>()); | |||
| // if is parameter always no value. | |||
| if (is_parameter()) { | |||
| auto param_name = param_info()->name(); | |||
| auto ref_key = std::make_shared<RefKey>(param_name); | |||
| auto abs_ref_key = ref_key->ToAbstract(); | |||
| abs_tensor = std::make_shared<abstract::AbstractRef>(abs_ref_key, abs_tensor); | |||
| } else { | |||
| abs_tensor->set_value(shared_from_base<Tensor>()); | |||
| } | |||
| return abs_tensor; | |||
| } | |||
| @@ -200,16 +200,6 @@ bool StringImm::operator==(const Value &other) const { | |||
| } | |||
| bool StringImm::operator==(const StringImm &other) const { return str_ == other.str_; } | |||
| bool RefKey::operator==(const Value &other) const { | |||
| if (other.isa<RefKey>()) { | |||
| auto other_ = static_cast<const RefKey &>(other); | |||
| return *this == other_; | |||
| } else { | |||
| return false; | |||
| } | |||
| } | |||
| bool RefKey::operator==(const RefKey &other) const { return tag_ == other.tag_; } | |||
| bool AnyValue::operator==(const Value &other) const { | |||
| if (other.isa<AnyValue>()) { | |||
| return true; | |||
| @@ -224,28 +224,21 @@ using StringImmPtr = std::shared_ptr<StringImm>; | |||
| IMM_TRAITS(StringImmPtr, std::string) | |||
| IMM_TRAITS(StringImmPtr, const char *) | |||
| class RefKey : public Value { | |||
| class RefKey : public Named { | |||
| public: | |||
| explicit RefKey(const std::string &tag) : Value(kRefKeyType), tag_(tag), hash_(std::hash<std::string>{}(tag)) {} | |||
| explicit RefKey(const std::string &tag) : Named(tag) {} | |||
| ~RefKey() override = default; | |||
| MS_DECLARE_PARENT(RefKey, Value) | |||
| std::size_t hash() const override { return hash_; } | |||
| const std::string &tag() const { return tag_; } | |||
| bool operator==(const Value &other) const override; | |||
| bool operator==(const RefKey &other) const; | |||
| MS_DECLARE_PARENT(RefKey, Named) | |||
| const std::string &tag() const { return name(); } | |||
| abstract::AbstractBasePtr ToAbstract() override; | |||
| std::string ToString() const override { return "RefKey[" + tag_ + "]"; } | |||
| std::string ToString() const override { return "RefKey[" + name() + "]"; } | |||
| std::string DumpText() const override { | |||
| std::ostringstream oss; | |||
| oss << "RefKey[\"" << tag_ << "\"]"; | |||
| oss << "RefKey[\"" << name() << "\"]"; | |||
| return oss.str(); | |||
| } | |||
| private: | |||
| std::string tag_; | |||
| std::size_t hash_ = 0; | |||
| }; | |||
| using RefKeyPtr = std::shared_ptr<RefKey>; | |||
| @@ -43,6 +43,8 @@ if(BUILD_CONVERTER) | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/scope.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/value.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/value_extends.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/dtype/ref.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/dtype/tensor_type.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/dtype/container.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/dtype/empty.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/dtype/number.cc | |||
| @@ -29,6 +29,8 @@ set(ANF_SRC | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/scope.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/value.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/value_extends.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/dtype/ref.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/dtype/tensor_type.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/dtype/container.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/dtype/empty.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/dtype/number.cc | |||
| @@ -23,7 +23,7 @@ from ...common import dtype as mstype | |||
| from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register | |||
| class Assign(PrimitiveWithInfer): | |||
| class Assign(Primitive): | |||
| """ | |||
| Assign `Parameter` with a value. | |||
| @@ -18,7 +18,6 @@ | |||
| import inspect | |||
| import copy | |||
| from mindspore.common.api import _wrap_func | |||
| from mindspore.common import Parameter | |||
| from mindspore.common._register_for_tensor import tensor_operator_registry | |||
| from mindspore import context | |||
| from .._c_expression import Primitive_, real_run_op, prim_type | |||
| @@ -410,16 +409,12 @@ def _run_op(obj, op_name, args): | |||
| if op_name == "Cast" or obj.update_parameter: | |||
| cast_args = args | |||
| else: | |||
| cast_args = list() | |||
| for arg in args: | |||
| if isinstance(arg, Parameter): | |||
| if arg.cast_type: | |||
| cast_args.append(cast(arg, arg.cast_type)) | |||
| else: | |||
| cast_args.append(arg) | |||
| else: | |||
| cast_args.append(arg) | |||
| output = real_run_op(obj, op_name, tuple(cast_args)) | |||
| cast_args = args | |||
| for idx, arg in enumerate(args): | |||
| cast_type = getattr(arg, "cast_type", None) | |||
| if cast_type: | |||
| cast_args[idx] = cast(arg, cast_type) | |||
| output = real_run_op(obj, op_name, cast_args) | |||
| if not output: | |||
| raise RuntimeError("Pynative run op %s failed!" % op_name) | |||
| if len(output) == 1: | |||
| @@ -118,26 +118,31 @@ class ControlMixedWhileIf(nn.Cell): | |||
| self.var = Parameter(initializer(1, (1), mstype.float32), name="var") | |||
| def construct(self, x, y, z, c2, c4): | |||
| out = self.assign(self.var, c4) | |||
| out = c4 | |||
| self.assign(self.var, c4) | |||
| while x < c2: | |||
| y = self.assign(self.var, c4) | |||
| y = c4 | |||
| self.assign(self.var, c4) | |||
| while y < c2 and x < c2: | |||
| if 2 * y < c2: | |||
| y = y + 2 | |||
| else: | |||
| y = y + 1 | |||
| out = out + y | |||
| z = self.assign(self.var, c4) | |||
| z = c4 | |||
| self.assign(self.var, c4) | |||
| while z < c2: | |||
| z = z + 1 | |||
| out = out + z | |||
| x = x + 1 | |||
| out = out + x | |||
| while x < 2 * c2: | |||
| y = self.assign(self.var, c4) | |||
| y = c4 | |||
| self.assign(self.var, c4) | |||
| x = x + 1 | |||
| while y < c2: | |||
| z = self.assign(self.var, c4) | |||
| z = c4 | |||
| self.assign(self.var, c4) | |||
| while z < c2: | |||
| z = z + 1 | |||
| if x < c2: | |||
| @@ -27,6 +27,7 @@ import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore import context | |||
| from mindspore.ops import composite as C | |||
| from mindspore.ops import operations as P | |||
| from mindspore.common.api import ms_function, _executor | |||
| from mindspore.ops._grad.grad_base import bprop_getters | |||
| from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer | |||
| @@ -254,3 +255,60 @@ def test_bprop_with_wrong_output_shape(): | |||
| net = BpropWithWrongOutputShapeCell() | |||
| net.set_grad() | |||
| grad_all(net)(Tensor(np.ones([64, 10]).astype(np.int32))) | |||
| class AssignWhenInsertGrad(nn.Cell): | |||
| """ NetWithNDarray definition """ | |||
| def __init__(self): | |||
| super(AssignWhenInsertGrad, self).__init__() | |||
| self.gather = P.GatherV2() | |||
| self.damping = Tensor(np.array([0.03, 0.03]).astype(np.float32)) | |||
| self.cov_step = ms.Parameter(0, name="cov_step", requires_grad=False) | |||
| self.freq = Tensor(278, ms.int32) | |||
| self.getG = P.InsertGradientOf(self.save_gradient) | |||
| def save_gradient(self, dout): | |||
| self.cov_step = self.cov_step + self.freq | |||
| return dout | |||
| def construct(self, x): | |||
| self.gather(self.damping, self.cov_step, 0) | |||
| out = P.ReLU()(x) | |||
| out = self.getG(out) | |||
| return out | |||
| grad_all = C.GradOperation('get_all', get_all=True) | |||
| class GradNet(nn.Cell): | |||
| def __init__(self, net): | |||
| super(GradNet, self).__init__() | |||
| self.net = net | |||
| def construct(self, *inputs): | |||
| out = self.net(*inputs) | |||
| return out, grad_all(self.net)(*inputs) | |||
| def test_assign_in_insert_grad(): | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| net = AssignWhenInsertGrad().to_float(ms.float16) | |||
| input_data = np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32') | |||
| net_back = GradNet(net) | |||
| net_back(ms.Tensor(input_data)) | |||
| class Assign(nn.Cell): | |||
| """ NetWithNDarray definition """ | |||
| def __init__(self): | |||
| super(Assign, self).__init__() | |||
| self.cov_step = ms.Parameter(0.0, name="cov_step", requires_grad=False) | |||
| def construct(self, x): | |||
| self.cov_step = self.cov_step + x | |||
| return self.cov_step | |||
| def test_assign(): | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| net = Assign() | |||
| input_data = ms.Tensor(np.array(1).astype(np.int32)) | |||
| net_back = GradNet(net) | |||
| net_back(input_data) | |||
| @@ -0,0 +1,144 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ test_cont_break """ | |||
| import numpy as np | |||
| import mindspore as ms | |||
| from mindspore import Tensor, context, nn, ms_function | |||
| from mindspore.nn import Cell | |||
| from mindspore.ops import operations as P | |||
| class WhileSubGraphParam(Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.update = ms.Parameter(Tensor(1, ms.float32), "update") | |||
| def construct(self, x, y, z): | |||
| out1 = z | |||
| while x < y: | |||
| self.update = self.update + 1 | |||
| out1 = out1 + 1 | |||
| x = x + 1 | |||
| return out1, self.update | |||
| def test_while_loop_phi(): | |||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | |||
| x = Tensor(0, ms.float32) | |||
| y = Tensor(10, ms.float32) | |||
| z = Tensor(100, ms.float32) | |||
| net = WhileSubGraphParam() | |||
| net(x, y, z) | |||
| class WhileSubGraphParam2(Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.update = ms.Parameter(Tensor(1, ms.float32), "update") | |||
| def construct(self, x, y, z): | |||
| out1 = z | |||
| i = self.update | |||
| while x < y: | |||
| i = i + 1 | |||
| out1 = out1 + 1 | |||
| x = x + 1 | |||
| return out1, self.update | |||
| def test_while_loop_phi_2(): | |||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | |||
| x = Tensor(0, ms.float32) | |||
| y = Tensor(10, ms.float32) | |||
| z = Tensor(100, ms.float32) | |||
| net = WhileSubGraphParam2() | |||
| net(x, y, z) | |||
| class WhileSubGraphParam3(Cell): | |||
| def __init__(self, initial_input_x): | |||
| super().__init__() | |||
| self.initial_input_x = initial_input_x | |||
| self.X = ms.Parameter(initial_input_x, name="parameter_x") | |||
| self.Y = ms.Parameter(self.initial_input_x, name="parameter_y") | |||
| def construct(self): | |||
| a = 0 | |||
| while a < 3: | |||
| self.X = self.X + self.Y | |||
| a += 1 | |||
| return self.X | |||
| def test_while_loop_phi_3(): | |||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | |||
| x = Tensor(0, ms.float32) | |||
| net = WhileSubGraphParam3(x) | |||
| net() | |||
| class ControlMixedWhileIf(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.assign = P.Assign() | |||
| self.var = ms.Parameter(ms.Tensor([1], ms.float32), name="var") | |||
| @ms_function | |||
| def construct(self, x, y, z, c2, c4): | |||
| out = self.assign(self.var, c4) | |||
| while x < c2: | |||
| y = self.assign(self.var, c4) | |||
| while y < c2 and x < c2: | |||
| if 2 * y < c2: | |||
| y = y + 2 | |||
| else: | |||
| y = y + 1 | |||
| out = out + y | |||
| z = self.assign(self.var, c4) | |||
| while z < c2: | |||
| z = z + 1 | |||
| out = out + z | |||
| x = x + 1 | |||
| out = out + x | |||
| while x < 2 * c2: | |||
| y = self.assign(self.var, c4) | |||
| x = x + 1 | |||
| while y < c2: | |||
| z = self.assign(self.var, c4) | |||
| while z < c2: | |||
| z = z + 1 | |||
| if x < c2: | |||
| y = y - 1 | |||
| else: | |||
| y = y + 1 | |||
| out = out + z | |||
| out = out + y | |||
| out = out + x | |||
| return out | |||
| def test_mixed_while_if(): | |||
| context.set_context(mode=context.PYNATIVE_MODE) | |||
| x = np.array(2).astype(np.int32) | |||
| y = np.array(14).astype(np.int32) | |||
| z = np.array(1).astype(np.int32) | |||
| c2 = Tensor([14], ms.int32) | |||
| c4 = Tensor([0], ms.int32) | |||
| net = ControlMixedWhileIf() | |||
| output = net(Tensor(x), Tensor(y), Tensor(z), c2, c4) | |||
| expect = np.array(3318).astype(np.int32) | |||
| assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001) | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| @@ -22,7 +22,13 @@ from mindspore.ops.vm_impl_registry import vm_impl_registry as vm_impl_getters | |||
| from .vm_interface import vm | |||
| # pylint: disable=unused-argument | |||
| @vm_impl_getters.register(P.Assign) | |||
| def vm_impl_assign(self): | |||
| """Generate vm_impl function for Assign""" | |||
| def vm_impl(x, value): | |||
| x.assign_value(value) | |||
| return x | |||
| return vm_impl | |||
| @vm_impl_getters.register(P.ExpandDims) | |||
| def vm_impl_expand_dims(self): | |||