Merge pull request !3271 from vlne-v1/ref_demotags/v1.0.0
| @@ -185,14 +185,23 @@ class Validator: | |||||
| raise TypeError(f"{msg_prefix} `{arg_name}` must be float.") | raise TypeError(f"{msg_prefix} `{arg_name}` must be float.") | ||||
| @staticmethod | @staticmethod | ||||
| def check_subclass(arg_name, type_, template_type, prim_name): | |||||
| def check_subclass(arg_name, type_, template_types, prim_name): | |||||
| """Checks whether some type is subclass of another type""" | """Checks whether some type is subclass of another type""" | ||||
| if not isinstance(template_type, Iterable): | |||||
| template_type = (template_type,) | |||||
| if not any([mstype.issubclass_(type_, x) for x in template_type]): | |||||
| if not isinstance(template_types, Iterable): | |||||
| template_types = (template_types,) | |||||
| hit = False | |||||
| for template_type in template_types: | |||||
| if isinstance(template_type, mstype.Type): | |||||
| if mstype.issubclass_(type_, template_type): | |||||
| hit = True | |||||
| break | |||||
| elif type_ is template_type: | |||||
| hit = True | |||||
| break | |||||
| if not hit: | |||||
| type_str = (type(type_).__name__ if isinstance(type_, (tuple, list)) else "") + str(type_) | type_str = (type(type_).__name__ if isinstance(type_, (tuple, list)) else "") + str(type_) | ||||
| raise TypeError(f'For \'{prim_name}\' the type of `{arg_name}` should be subclass' | raise TypeError(f'For \'{prim_name}\' the type of `{arg_name}` should be subclass' | ||||
| f' of {",".join((str(x) for x in template_type))}, but got {type_str}.') | |||||
| f' of {",".join((str(x) for x in template_types))}, but got {type_str}.') | |||||
| @staticmethod | @staticmethod | ||||
| def check_const_input(arg_name, arg_value, prim_name): | def check_const_input(arg_name, arg_value, prim_name): | ||||
| @@ -206,13 +215,7 @@ class Validator: | |||||
| def _check_tensor_type(arg): | def _check_tensor_type(arg): | ||||
| arg_key, arg_val = arg | arg_key, arg_val = arg | ||||
| elem_type = arg_val | elem_type = arg_val | ||||
| if not elem_type in valid_values: | |||||
| type_names = [] | |||||
| for t in valid_values: | |||||
| type_names.append(str(t)) | |||||
| types_info = '[' + ', '.join(type_names) + ']' | |||||
| raise TypeError(f'For \'{prim_name}\' type of `{arg_key}` should be in {types_info},' | |||||
| f' but got {elem_type}.') | |||||
| Validator.check_subclass(arg_key, elem_type, valid_values, prim_name) | |||||
| return (arg_key, elem_type) | return (arg_key, elem_type) | ||||
| def _check_types_same(arg1, arg2): | def _check_types_same(arg1, arg2): | ||||
| @@ -335,12 +338,6 @@ class Validator: | |||||
| class ParamValidator: | class ParamValidator: | ||||
| """Parameter validator. NOTICE: this class will be replaced by `class Validator`""" | """Parameter validator. NOTICE: this class will be replaced by `class Validator`""" | ||||
| @staticmethod | |||||
| def equal(arg_name, arg_value, cond_str, cond): | |||||
| """Judging valid value.""" | |||||
| if not cond: | |||||
| raise ValueError(f'The `{arg_name}` must be {cond_str}, but got {arg_value}.') | |||||
| @staticmethod | @staticmethod | ||||
| def check(arg_name, arg_value, value_name, value, rel=Rel.EQ): | def check(arg_name, arg_value, value_name, value, rel=Rel.EQ): | ||||
| """This method is only used for check int values, since when compare float values, | """This method is only used for check int values, since when compare float values, | ||||
| @@ -360,27 +357,6 @@ class ParamValidator: | |||||
| raise ValueError(f'The `{arg_name}` should be an int and must {rel_str}, but got {arg_value}.') | raise ValueError(f'The `{arg_name}` should be an int and must {rel_str}, but got {arg_value}.') | ||||
| return arg_value | return arg_value | ||||
| @staticmethod | |||||
| def check_shape_length(arg_name, arg_value, value, rel): | |||||
| """Shape length judgment.""" | |||||
| rel_fn = Rel.get_fns(rel) | |||||
| type_mismatch = not isinstance(arg_value, int) | |||||
| if type_mismatch or not rel_fn(arg_value, value): | |||||
| rel_str = Rel.get_strs(rel).format(value) | |||||
| raise ValueError(f'The length of `{arg_name}` should be an int and must {rel_str}, but got {arg_value}') | |||||
| return arg_value | |||||
| @staticmethod | |||||
| def check_int_range(arg_name, arg_value, lower_limit, upper_limit, rel): | |||||
| """This method is only used for check int values, | |||||
| since when compare float values, we need consider float error.""" | |||||
| rel_fn = Rel.get_fns(rel) | |||||
| type_mismatch = not isinstance(arg_value, int) | |||||
| if type_mismatch or not rel_fn(arg_value, lower_limit, upper_limit): | |||||
| rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit) | |||||
| raise ValueError(f'The `{arg_name}` should be an int in range {rel_str}, but got {arg_value}.') | |||||
| return arg_value | |||||
| @staticmethod | @staticmethod | ||||
| def check_isinstance(arg_name, arg_value, classes): | def check_isinstance(arg_name, arg_value, classes): | ||||
| """Check arg isinstance of classes""" | """Check arg isinstance of classes""" | ||||
| @@ -388,33 +364,6 @@ class ParamValidator: | |||||
| raise ValueError(f'The `{arg_name}` should be isinstance of {classes}, but got {arg_value}.') | raise ValueError(f'The `{arg_name}` should be isinstance of {classes}, but got {arg_value}.') | ||||
| return arg_value | return arg_value | ||||
| @staticmethod | |||||
| def check_number_range(arg_name, arg_value, lower_limit, upper_limit, rel): | |||||
| """Is it necessary to consider error when comparing float values.""" | |||||
| rel_fn = Rel.get_fns(rel) | |||||
| if not rel_fn(arg_value, lower_limit, upper_limit): | |||||
| rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit) | |||||
| raise ValueError(f'The `{arg_name}` should be in range {rel_str}, but got {arg_value}.') | |||||
| return arg_value | |||||
| @staticmethod | |||||
| def check_subclass(arg_name, type_, template_type, with_type_of=True): | |||||
| """Check whether some type is subclass of another type""" | |||||
| if not isinstance(template_type, Iterable): | |||||
| template_type = (template_type,) | |||||
| if not any([mstype.issubclass_(type_, x) for x in template_type]): | |||||
| type_str = (type(type_).__name__ if isinstance(type_, (tuple, list)) else "") + str(type_) | |||||
| raise TypeError(f'The {"type of" if with_type_of else ""} `{arg_name}` should be subclass' | |||||
| f' of {",".join((str(x) for x in template_type))}, but got {type_str}.') | |||||
| @staticmethod | |||||
| def check_args_tensor(args): | |||||
| """Check whether args are all tensor.""" | |||||
| if not isinstance(args, dict): | |||||
| raise TypeError("The args should be a dict.") | |||||
| for arg, value in args.items(): | |||||
| ParamValidator.check_subclass(arg, value, mstype.tensor) | |||||
| @staticmethod | @staticmethod | ||||
| def check_bool(arg_name, arg_value): | def check_bool(arg_name, arg_value): | ||||
| """Check arg isinstance of bool""" | """Check arg isinstance of bool""" | ||||
| @@ -442,113 +391,6 @@ class ParamValidator: | |||||
| return arg_value | return arg_value | ||||
| raise_error_msg() | raise_error_msg() | ||||
| @staticmethod | |||||
| def check_typename(arg_name, arg_type, valid_types): | |||||
| """Does it contain the _name_ attribute.""" | |||||
| def get_typename(t): | |||||
| return t.__name__ if hasattr(t, '__name__') else str(t) | |||||
| if isinstance(arg_type, type(mstype.tensor)): | |||||
| arg_type = arg_type.element_type() | |||||
| if arg_type in valid_types: | |||||
| return arg_type | |||||
| type_names = [get_typename(t) for t in valid_types] | |||||
| if len(valid_types) == 1: | |||||
| raise ValueError(f'The type of `{arg_name}` should be {type_names[0]},' | |||||
| f' but got {get_typename(arg_type)}.') | |||||
| raise ValueError(f'The type of `{arg_name}` should be one of {type_names},' | |||||
| f' but got {get_typename(arg_type)}.') | |||||
| @staticmethod | |||||
| def check_string(arg_name, arg_value, valid_values): | |||||
| """String type judgment.""" | |||||
| if isinstance(arg_value, str) and arg_value in valid_values: | |||||
| return arg_value | |||||
| if len(valid_values) == 1: | |||||
| raise ValueError(f'The `{arg_name}` should be str and must be {valid_values[0]},' | |||||
| f' but got {arg_value}.') | |||||
| raise ValueError(f'The `{arg_name}` should be str and must be one of {valid_values},' | |||||
| f' but got {arg_value}.') | |||||
| @staticmethod | |||||
| def check_type_same(args, valid_values): | |||||
| """Determine whether the types are the same.""" | |||||
| name = list(args.keys())[0] | |||||
| value = list(args.values())[0] | |||||
| if isinstance(value, type(mstype.tensor)): | |||||
| value = value.element_type() | |||||
| for arg_name, arg_value in args.items(): | |||||
| if isinstance(arg_value, type(mstype.tensor)): | |||||
| arg_value = arg_value.element_type() | |||||
| if arg_value not in valid_values: | |||||
| raise TypeError(f'The `{arg_name}` should be in {valid_values},' | |||||
| f' but `{arg_name}` is {arg_value}.') | |||||
| if arg_value != value: | |||||
| raise TypeError(f'`{arg_name}` should be same as `{name}`,' | |||||
| f' but `{arg_name}` is {arg_value}, `{name}` is {value}.') | |||||
| @staticmethod | |||||
| def check_two_types_same(arg1_name, arg1_type, arg2_name, arg2_type): | |||||
| """Determine whether the types of two variables are the same.""" | |||||
| if arg1_type != arg2_type: | |||||
| raise TypeError(f'The type of `{arg1_name}` and `{arg2_name}` should be same.') | |||||
| @staticmethod | |||||
| def check_value_on_integer(arg_name, arg_value, value, rel): | |||||
| """Judging integer type.""" | |||||
| rel_fn = Rel.get_fns(rel) | |||||
| type_match = isinstance(arg_value, int) | |||||
| if type_match and (not rel_fn(arg_value, value)): | |||||
| rel_str = Rel.get_strs(rel).format(value) | |||||
| raise ValueError(f'The `{arg_name}` should be an int and must {rel_str}, but got {arg_value}.') | |||||
| return arg_value | |||||
| @staticmethod | |||||
| def check_param_equal(param1_name, param1_value, param2_name, param2_value): | |||||
| """Judging the equality of parameters.""" | |||||
| if param1_value != param2_value: | |||||
| raise ValueError(f"`{param1_name}` must equal `{param2_name}`," | |||||
| f" but got `{param1_name}` = {param1_value}," | |||||
| f" `{param2_name}` = {param2_value}.") | |||||
| @staticmethod | |||||
| def check_const_input(arg_name, arg_value): | |||||
| """Check valid value.""" | |||||
| if arg_value is None: | |||||
| raise ValueError(f'The `{arg_name}` must be a const input, but got {arg_value}.') | |||||
| @staticmethod | |||||
| def check_float_positive(arg_name, arg_value): | |||||
| """Float type judgment.""" | |||||
| if isinstance(arg_value, float): | |||||
| if arg_value > 0: | |||||
| return arg_value | |||||
| raise ValueError(f"The `{arg_name}` must be positive, but got {arg_value}.") | |||||
| raise TypeError(f"`{arg_name}` must be float!") | |||||
| @staticmethod | |||||
| def check_pad_value_by_mode(op_name, pad_mode, padding): | |||||
| """Validate value of padding according to pad_mode""" | |||||
| if pad_mode != 'pad' and padding != 0: | |||||
| raise ValueError(f"For op '{op_name}', padding must be zero when pad_mode is '{pad_mode}'.") | |||||
| return padding | |||||
| @staticmethod | |||||
| def check_empty_shape_input(arg_name, arg_value): | |||||
| """Check zeros value.""" | |||||
| if 0 in arg_value: | |||||
| raise ValueError(f"Input `{arg_name}` cannot be empty.") | |||||
| @staticmethod | |||||
| def check_scalar_shape_input(arg_name, arg_value): | |||||
| """Check scalar shape input.""" | |||||
| if arg_value != []: | |||||
| raise ValueError(f"Input `{arg_name}` shape should be (). got {arg_value}") | |||||
| def check_int(input_param): | def check_int(input_param): | ||||
| """Int type judgment.""" | """Int type judgment.""" | ||||
| @@ -592,11 +592,6 @@ TypeId AnfRuntimeAlgorithm::GetOutputInferDataType(const AnfNodePtr &node, size_ | |||||
| return get_single_type((*tuple_ptr)[output_idx]); | return get_single_type((*tuple_ptr)[output_idx]); | ||||
| }; | }; | ||||
| TypePtr type_ptr = node->Type(); | TypePtr type_ptr = node->Type(); | ||||
| if (type_ptr->isa<RefType>()) { | |||||
| auto ref_type_ptr = type_ptr->cast<RefTypePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(ref_type_ptr); | |||||
| return get_tuple_type(ref_type_ptr->subtype(), output_idx); | |||||
| } | |||||
| return get_tuple_type(type_ptr, output_idx); | return get_tuple_type(type_ptr, output_idx); | ||||
| } | } | ||||
| @@ -20,6 +20,7 @@ | |||||
| #include "abstract/abstract_value.h" | #include "abstract/abstract_value.h" | ||||
| #include "ir/anf.h" | #include "ir/anf.h" | ||||
| #include "ir/dtype.h" | |||||
| #include "abstract/dshape.h" | #include "abstract/dshape.h" | ||||
| #include "abstract/param_validator.h" | #include "abstract/param_validator.h" | ||||
| #include "frontend/operator/cc_implementations.h" | #include "frontend/operator/cc_implementations.h" | ||||
| @@ -43,15 +44,15 @@ const std::vector<Signature> &GetSignature(const ValuePtr &function) { | |||||
| return empty; | return empty; | ||||
| } | } | ||||
| void ProcessDefault(const std::string &func_name, const AbstractBasePtrList &args_spec_list, | |||||
| const std::vector<Signature> &signature, bool has_var, std::vector<AnfNodePtr> *const op_inputs) { | |||||
| void ProcessDefault(const std::string &func_name, size_t actual_param_number, const std::vector<Signature> &signature, | |||||
| bool has_var, std::vector<AnfNodePtr> *const op_inputs) { | |||||
| std::size_t sig_size = signature.size(); | std::size_t sig_size = signature.size(); | ||||
| auto positional_size = sig_size; | auto positional_size = sig_size; | ||||
| if (has_var) { | if (has_var) { | ||||
| positional_size = sig_size - 1; | positional_size = sig_size - 1; | ||||
| } | } | ||||
| if (args_spec_list.size() < positional_size) { | |||||
| for (size_t i = args_spec_list.size(); i < sig_size; ++i) { | |||||
| if (actual_param_number < positional_size) { | |||||
| for (size_t i = actual_param_number; i < sig_size; ++i) { | |||||
| auto default_value = signature[i].default_value; | auto default_value = signature[i].default_value; | ||||
| if (default_value == nullptr) { | if (default_value == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "Function " << func_name << "'s input length is not equal to Signature length."; | MS_LOG(EXCEPTION) << "Function " << func_name << "'s input length is not equal to Signature length."; | ||||
| @@ -67,23 +68,11 @@ void SetMaxType(TypeId *max_type_id, size_t *max_type_number, const TypeId type_ | |||||
| *max_type_number = type_number; | *max_type_number = type_number; | ||||
| } | } | ||||
| bool GetTensorOrScalarTypeInfo(AbstractBasePtr arg_value, bool is_write, TypeId *arg_type_id, | |||||
| bool GetTensorOrScalarTypeInfo(TypePtr arg_type_origin, bool is_write, TypeId *arg_type_id, | |||||
| TypeId *arg_type = nullptr) { | TypeId *arg_type = nullptr) { | ||||
| if (arg_value->isa<abstract::AbstractRef>()) { | |||||
| auto ref = arg_value->cast<abstract::AbstractRefPtr>(); | |||||
| arg_value = ref->ref(); | |||||
| if (!is_write && ref->need_cast()) { | |||||
| auto tensor_type = ref->target_type(); | |||||
| *arg_type_id = tensor_type->type_id(); | |||||
| if (arg_type != nullptr) { | |||||
| *arg_type = kObjectTypeTensorType; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| } | |||||
| if (arg_value->isa<abstract::AbstractTensor>()) { | |||||
| auto tensor = arg_value->cast<abstract::AbstractTensorPtr>(); | |||||
| auto tensor_type = tensor->element()->BuildType(); | |||||
| if (arg_type_origin->isa<TensorType>()) { | |||||
| auto tensor = arg_type_origin->cast<TensorTypePtr>(); | |||||
| auto tensor_type = tensor->element(); | |||||
| MS_EXCEPTION_IF_NULL(tensor_type); | MS_EXCEPTION_IF_NULL(tensor_type); | ||||
| *arg_type_id = tensor_type->type_id(); | *arg_type_id = tensor_type->type_id(); | ||||
| if (arg_type != nullptr) { | if (arg_type != nullptr) { | ||||
| @@ -91,9 +80,8 @@ bool GetTensorOrScalarTypeInfo(AbstractBasePtr arg_value, bool is_write, TypeId | |||||
| } | } | ||||
| return true; | return true; | ||||
| } | } | ||||
| if (arg_value->isa<abstract::AbstractScalar>()) { | |||||
| auto scalar = arg_value->cast<abstract::AbstractScalarPtr>(); | |||||
| auto scalar_type = scalar->BuildType(); | |||||
| if (arg_type_origin->isa<Number>()) { | |||||
| auto scalar_type = arg_type_origin->cast<NumberPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(scalar_type); | MS_EXCEPTION_IF_NULL(scalar_type); | ||||
| *arg_type_id = scalar_type->type_id(); | *arg_type_id = scalar_type->type_id(); | ||||
| if (arg_type != nullptr) { | if (arg_type != nullptr) { | ||||
| @@ -104,7 +92,7 @@ bool GetTensorOrScalarTypeInfo(AbstractBasePtr arg_value, bool is_write, TypeId | |||||
| return false; | return false; | ||||
| } | } | ||||
| TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::vector<size_t> indices, | |||||
| TypeId GetMaxTypeId(const std::vector<TypePtr> &input_types, std::vector<size_t> indices, | |||||
| const std::set<size_t> &write_indices) { | const std::set<size_t> &write_indices) { | ||||
| TypeId max_type_id = kTypeUnknown; | TypeId max_type_id = kTypeUnknown; | ||||
| size_t max_type_number = 0; | size_t max_type_number = 0; | ||||
| @@ -115,7 +103,7 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve | |||||
| TypeId arg_type_id = kTypeUnknown; | TypeId arg_type_id = kTypeUnknown; | ||||
| TypeId arg_type = kTypeUnknown; | TypeId arg_type = kTypeUnknown; | ||||
| auto is_write = (write_indices.find(index) != write_indices.end()); | auto is_write = (write_indices.find(index) != write_indices.end()); | ||||
| if (!GetTensorOrScalarTypeInfo(args_spec_list[index], is_write, &arg_type_id, &arg_type)) { | |||||
| if (!GetTensorOrScalarTypeInfo(input_types[index], is_write, &arg_type_id, &arg_type)) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| if (arg_type != kObjectTypeTensorType) { | if (arg_type != kObjectTypeTensorType) { | ||||
| @@ -161,8 +149,8 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve | |||||
| // Get the largest type of index in the same SignatureEnumDType of arguments. | // Get the largest type of index in the same SignatureEnumDType of arguments. | ||||
| using MaxTypeMap = std::map<SignatureEnumDType, TypeId>; | using MaxTypeMap = std::map<SignatureEnumDType, TypeId>; | ||||
| MaxTypeMap GetMaxDtype(const std::vector<SignatureEnumDType> &dtypes, | |||||
| const abstract::AbstractBasePtrList &args_spec_list, const std::set<size_t> &write_indices) { | |||||
| MaxTypeMap GetMaxDtype(const std::vector<SignatureEnumDType> &dtypes, const std::vector<TypePtr> &input_types, | |||||
| const std::set<size_t> &write_indices) { | |||||
| // record index for signature.dtypes of the same type | // record index for signature.dtypes of the same type | ||||
| // eg. [T, T1, T, T2, T, T1, T3] -> {{T:(0,2,4)}, {T1:(1,5)}, {T2:(3)}, {T3:(6)}} | // eg. [T, T1, T, T2, T, T1, T3] -> {{T:(0,2,4)}, {T1:(1,5)}, {T2:(3)}, {T3:(6)}} | ||||
| std::map<SignatureEnumDType, std::vector<size_t>> type_indices; | std::map<SignatureEnumDType, std::vector<size_t>> type_indices; | ||||
| @@ -184,11 +172,8 @@ MaxTypeMap GetMaxDtype(const std::vector<SignatureEnumDType> &dtypes, | |||||
| } | } | ||||
| bool has_tensor = false; | bool has_tensor = false; | ||||
| for (const auto &index : indices) { | for (const auto &index : indices) { | ||||
| AbstractBasePtr arg_value = args_spec_list[index]; | |||||
| if (arg_value->isa<abstract::AbstractRef>()) { | |||||
| arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref(); | |||||
| } | |||||
| if (arg_value->isa<abstract::AbstractTensor>()) { | |||||
| auto arg_value = input_types[index]; | |||||
| if (arg_value->isa<TensorType>()) { | |||||
| has_tensor = true; | has_tensor = true; | ||||
| break; | break; | ||||
| } | } | ||||
| @@ -197,7 +182,7 @@ MaxTypeMap GetMaxDtype(const std::vector<SignatureEnumDType> &dtypes, | |||||
| (void)dst_type.insert(std::make_pair(type, kTypeUnknown)); | (void)dst_type.insert(std::make_pair(type, kTypeUnknown)); | ||||
| continue; | continue; | ||||
| } | } | ||||
| (void)dst_type.insert(std::make_pair(type, GetMaxTypeId(args_spec_list, indices, write_indices))); | |||||
| (void)dst_type.insert(std::make_pair(type, GetMaxTypeId(input_types, indices, write_indices))); | |||||
| } | } | ||||
| return dst_type; | return dst_type; | ||||
| } | } | ||||
| @@ -211,7 +196,7 @@ AnfNodePtr DoCast(const AnfNodePtr ¶m, const TypeId &type_id, const FuncGrap | |||||
| } | } | ||||
| void DoAutoCast(const std::string &func_name, const std::vector<Signature> &signature, | void DoAutoCast(const std::string &func_name, const std::vector<Signature> &signature, | ||||
| const abstract::AbstractBasePtrList &args_spec_list, const FuncGraphPtr &graph, | |||||
| const std::vector<TypePtr> &input_types, const FuncGraphPtr &graph, | |||||
| std::vector<AnfNodePtr> *const op_inputs, const std::set<size_t> &write_indices) { | std::vector<AnfNodePtr> *const op_inputs, const std::set<size_t> &write_indices) { | ||||
| std::vector<SignatureEnumDType> dtypes; | std::vector<SignatureEnumDType> dtypes; | ||||
| (void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes), | (void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes), | ||||
| @@ -221,9 +206,9 @@ void DoAutoCast(const std::string &func_name, const std::vector<Signature> &sign | |||||
| return; | return; | ||||
| } | } | ||||
| // Stat the index of the arguments with the largest type in the same SignatureEnumDType. | // Stat the index of the arguments with the largest type in the same SignatureEnumDType. | ||||
| std::map<SignatureEnumDType, TypeId> dst_type = GetMaxDtype(dtypes, args_spec_list, write_indices); | |||||
| std::map<SignatureEnumDType, TypeId> dst_type = GetMaxDtype(dtypes, input_types, write_indices); | |||||
| // Identify which arg requires auto cast | // Identify which arg requires auto cast | ||||
| for (size_t i = 0; i < args_spec_list.size(); ++i) { | |||||
| for (size_t i = 0; i < input_types.size(); ++i) { | |||||
| auto it = dst_type.find(dtypes[i]); | auto it = dst_type.find(dtypes[i]); | ||||
| if (it == dst_type.end() || it->second == kTypeUnknown) { | if (it == dst_type.end() || it->second == kTypeUnknown) { | ||||
| continue; | continue; | ||||
| @@ -232,7 +217,7 @@ void DoAutoCast(const std::string &func_name, const std::vector<Signature> &sign | |||||
| auto is_write = (rw_it != write_indices.end()); | auto is_write = (rw_it != write_indices.end()); | ||||
| TypeId arg_type_id = kTypeUnknown; | TypeId arg_type_id = kTypeUnknown; | ||||
| AbstractBasePtr arg_value = args_spec_list[i]; | |||||
| auto arg_value = input_types[i]; | |||||
| (void)GetTensorOrScalarTypeInfo(arg_value, is_write, &arg_type_id); | (void)GetTensorOrScalarTypeInfo(arg_value, is_write, &arg_type_id); | ||||
| auto it_map = type_name_map.find(arg_type_id); | auto it_map = type_name_map.find(arg_type_id); | ||||
| if (it_map == type_name_map.end()) { | if (it_map == type_name_map.end()) { | ||||
| @@ -248,7 +233,7 @@ void DoAutoCast(const std::string &func_name, const std::vector<Signature> &sign | |||||
| } | } | ||||
| continue; | continue; | ||||
| } | } | ||||
| if (arg_value->isa<abstract::AbstractTensor>() && arg_type_id == it->second) { | |||||
| if ((arg_value->isa<TensorType>()) && arg_type_id == it->second) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| MS_LOG(DEBUG) << "do cast for inputs " << i << " " << (*op_inputs)[i + 1]->ToString() << " " << arg_type_id | MS_LOG(DEBUG) << "do cast for inputs " << i << " " << (*op_inputs)[i + 1]->ToString() << " " << arg_type_id | ||||
| @@ -275,6 +260,7 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func | |||||
| } | } | ||||
| std::vector<AnfNodePtr> op_inputs; | std::vector<AnfNodePtr> op_inputs; | ||||
| std::set<size_t> write_indices; | std::set<size_t> write_indices; | ||||
| std::vector<TypePtr> input_types; | |||||
| op_inputs.push_back(NewValueNode(function)); | op_inputs.push_back(NewValueNode(function)); | ||||
| // Assume, the write input of op is always the first input. We check if any write op, | // Assume, the write input of op is always the first input. We check if any write op, | ||||
| // and add cast op on other inputs to keep the same type with assigned parameter. | // and add cast op on other inputs to keep the same type with assigned parameter. | ||||
| @@ -292,30 +278,36 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func | |||||
| sig = signature[sig_size - 1].rw; | sig = signature[sig_size - 1].rw; | ||||
| } | } | ||||
| TypePtr type = args_spec_list[i]->GetTypeTrack(); | |||||
| if (type && type->type_id() == kObjectTypeRef) { | |||||
| auto ref_abs = args_spec_list[i]->cast<abstract::AbstractRefPtr>(); | |||||
| TypePtr type = args_spec_list[i]->BuildType(); | |||||
| if (type && type->isa<RefType>()) { | |||||
| auto cast_type = parse::GetMixedPrecisionTargetType(func_graph); | |||||
| if (sig == SignatureEnumRW::kRWRead) { | if (sig == SignatureEnumRW::kRWRead) { | ||||
| param = NewCNode({NewValueNode(prim::kPrimGetRefValue), param}, func_graph); | |||||
| if (ref_abs && ref_abs->need_cast()) { | |||||
| auto cast = prim::GetPythonOps("cast", "mindspore.ops.functional"); | |||||
| param = NewCNode({NewValueNode(cast), param, NewValueNode(ref_abs->target_type())}, func_graph); | |||||
| auto source_tensor_type = type->cast<TensorTypePtr>(); | |||||
| if (source_tensor_type != nullptr) { | |||||
| auto source_element = source_tensor_type->element(); | |||||
| if (cast_type != nullptr && IsSubType(source_element, kFloat) && *source_element != *cast_type) { | |||||
| auto cast = prim::GetPythonOps("cast", "mindspore.ops.functional"); | |||||
| param = NewCNode({NewValueNode(cast), param, NewValueNode(cast_type)}, func_graph); | |||||
| type = cast_type->type_id() == kNumberTypeFloat16 ? kTensorTypeFP16 : kTensorTypeFP32; | |||||
| } | |||||
| } | } | ||||
| } else if (sig == SignatureEnumRW::kRWWrite) { | } else if (sig == SignatureEnumRW::kRWWrite) { | ||||
| param = NewCNode({NewValueNode(prim::kPrimGetRefValue), param}, func_graph); | |||||
| write_indices.insert(i); | write_indices.insert(i); | ||||
| } | } | ||||
| // If sig is SignatureEnumRW::kRWRef, not do anything. | // If sig is SignatureEnumRW::kRWRef, not do anything. | ||||
| } else if (sig == SignatureEnumRW::kRWWrite && type->type_id() != kObjectTypeRefKey) { | |||||
| MS_EXCEPTION(TypeError) << "Function " << func_name << "'s input " << i << " should be a Parameter."; | |||||
| } else if (sig == SignatureEnumRW::kRWWrite && | |||||
| !((type->type_id() == kObjectTypeRef) || (type->type_id() == kObjectTypeRefKey))) { | |||||
| MS_EXCEPTION(TypeError) << "Function " << func_name << "'s input " << i << " should be a Parameter, but " | |||||
| << type->ToString(); | |||||
| } | } | ||||
| MS_LOG(DEBUG) << "Function " << func_name << "'s input " << i << " " << param->DebugString(2) << " type " | MS_LOG(DEBUG) << "Function " << func_name << "'s input " << i << " " << param->DebugString(2) << " type " | ||||
| << args_spec_list[i]->ToString(); | << args_spec_list[i]->ToString(); | ||||
| input_types.push_back(type); | |||||
| op_inputs.push_back(param); | op_inputs.push_back(param); | ||||
| } | } | ||||
| // process default | // process default | ||||
| ProcessDefault(func_name, args_spec_list, signature, has_var, &op_inputs); | |||||
| DoAutoCast(func_name, signature, args_spec_list, func_graph, &op_inputs, write_indices); | |||||
| ProcessDefault(func_name, args_spec_list.size(), signature, has_var, &op_inputs); | |||||
| DoAutoCast(func_name, signature, input_types, func_graph, &op_inputs, write_indices); | |||||
| return func_graph->NewCNode(op_inputs); | return func_graph->NewCNode(op_inputs); | ||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| @@ -81,12 +81,6 @@ void MultitypeFuncGraph::PyRegister(const py::tuple &tuple, const py::function & | |||||
| } | } | ||||
| Register(types_name, py_fn); | Register(types_name, py_fn); | ||||
| } | } | ||||
| static TypePtr UnwrapRef(const TypePtr &type) { | |||||
| if (type->isa<RefType>()) { | |||||
| return type->cast<RefTypePtr>()->subtype(); | |||||
| } | |||||
| return type; | |||||
| } | |||||
| // Return Exact match if exists, else return non ambiguous sub class match | // Return Exact match if exists, else return non ambiguous sub class match | ||||
| // Return py::none() if matching is ambiguous | // Return py::none() if matching is ambiguous | ||||
| @@ -99,7 +93,7 @@ const py::function MultitypeFuncGraph::SignMatch(const TypePtrList &types) { | |||||
| } | } | ||||
| auto match = true; | auto match = true; | ||||
| for (size_t i = 0; i < sign.size(); ++i) { | for (size_t i = 0; i < sign.size(); ++i) { | ||||
| if (!IsIdentidityOrSubclass(UnwrapRef(types[i]), sign[i])) { | |||||
| if (!IsIdentidityOrSubclass(types[i], sign[i])) { | |||||
| match = false; | match = false; | ||||
| break; | break; | ||||
| } | } | ||||
| @@ -627,6 +627,16 @@ AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePt | |||||
| return std::make_shared<AbstractClass>(cls->tag(), abs_attributes, cls->methods()); | return std::make_shared<AbstractClass>(cls->tag(), abs_attributes, cls->methods()); | ||||
| } | } | ||||
| AbstractBasePtr InferImplAssign(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| // Inputs: a tensor | |||||
| CheckArgsSize(primitive->name(), args_spec_list, 2); | |||||
| MS_LOG(DEBUG) << "InferImplAssign " << args_spec_list[0]; | |||||
| return args_spec_list[0]; | |||||
| } | |||||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(TypeOf, prim::kPrimTypeOf, InferImplTypeof); | REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(TypeOf, prim::kPrimTypeOf, InferImplTypeof); | ||||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(HasType, prim::kPrimHasType, InferImplHasType); | REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(HasType, prim::kPrimHasType, InferImplHasType); | ||||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(MakeRecord, prim::kPrimMakeRecord, InferImplMakeRecord); | REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(MakeRecord, prim::kPrimMakeRecord, InferImplMakeRecord); | ||||
| @@ -648,5 +658,7 @@ REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(FakeBprop, prim::kPrimFakeBprop, InferImpl | |||||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(J, prim::kPrimJ, InferImplJ); | REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(J, prim::kPrimJ, InferImplJ); | ||||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(BroadcastGradientArgs, prim::kPrimBroadcastGradientArgs, | REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(BroadcastGradientArgs, prim::kPrimBroadcastGradientArgs, | ||||
| InferImplBroadcastGradientArgs); | InferImplBroadcastGradientArgs); | ||||
| REGISTER_PRIMITIVE_EVAL_IMPL(Assign, prim::kPrimAssign, InferImplAssign); | |||||
| } // namespace abstract | } // namespace abstract | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -20,6 +20,7 @@ | |||||
| #include "ir/anf.h" | #include "ir/anf.h" | ||||
| #include "ir/param_info.h" | #include "ir/param_info.h" | ||||
| #include "ir/meta_tensor.h" | |||||
| #include "pipeline/jit/parse/python_adapter.h" | #include "pipeline/jit/parse/python_adapter.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -38,8 +39,7 @@ bool ParameterRequireGrad(const AnfNodePtr &node_ptr) { | |||||
| if (!para_ptr->has_default()) { | if (!para_ptr->has_default()) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| auto obj = py::cast(para_ptr->default_param()); | |||||
| auto param_value = py::cast<ParamValuePtr>(obj.attr("_value")); | |||||
| auto param_value = para_ptr->param_info(); | |||||
| if (param_value == nullptr) { | if (param_value == nullptr) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -1356,8 +1356,7 @@ bool ParameterIsCloned(const AnfNodePtr ¶meter_node) { | |||||
| if (!cloned_parameter->has_default()) { | if (!cloned_parameter->has_default()) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| auto obj = py::cast(cloned_parameter->default_param()); | |||||
| auto param_value = py::cast<ParamValuePtr>(obj.attr("_value")); | |||||
| auto param_value = cloned_parameter->param_info(); | |||||
| if (param_value == nullptr) { | if (param_value == nullptr) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -1380,8 +1379,7 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) { | |||||
| if (!ParameterIsCloned(cloned_parameter_node)) { | if (!ParameterIsCloned(cloned_parameter_node)) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| auto obj = py::cast(cloned_parameter->default_param()); | |||||
| auto param_value = py::cast<ParamValuePtr>(obj.attr("_value")); | |||||
| auto param_value = cloned_parameter->param_info(); | |||||
| if (param_value == nullptr) { | if (param_value == nullptr) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -1400,10 +1398,7 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| const auto ¶m_value_cloned = be_cloned_parameter->default_param(); | |||||
| auto obj_in = py::cast(param_value_cloned); | |||||
| auto param_value_in = py::cast<ParamValuePtr>(obj_in.attr("_value")); | |||||
| auto param_value_in = be_cloned_parameter->param_info(); | |||||
| if (param_value_in == nullptr) { | if (param_value_in == nullptr) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -233,13 +233,14 @@ bool AbstractSpecializeAction(const ResourcePtr &res) { | |||||
| for (const auto ¶m : func_graph->parameters()) { | for (const auto ¶m : func_graph->parameters()) { | ||||
| auto param_node = std::static_pointer_cast<Parameter>(param); | auto param_node = std::static_pointer_cast<Parameter>(param); | ||||
| if (param_node->has_default()) { | if (param_node->has_default()) { | ||||
| ValuePtr value = param_node->default_param(); | |||||
| constexpr bool broaden = true; | |||||
| AbstractBasePtr ptr = abstract::FromValue(value, broaden); | |||||
| parallel::ParallelParameterContextRestoreInNoTraining(func_graph, param_node, ptr); | |||||
| args_spec.push_back(ptr); | |||||
| parallel::ParallelParameterContextCkptInTraining(func_graph, param_node, ptr); | |||||
| auto value = param_node->default_param(); | |||||
| auto abs_value = value->ToAbstract()->cast<abstract::AbstractTensorPtr>(); | |||||
| auto ref_key = std::make_shared<RefKey>(param_node->name()); | |||||
| auto abs_ref_key = ref_key->ToAbstract(); | |||||
| auto abs_ref = std::make_shared<abstract::AbstractRef>(abs_ref_key, abs_value); | |||||
| parallel::ParallelParameterContextRestoreInNoTraining(func_graph, param_node, abs_ref); | |||||
| args_spec.push_back(abs_ref); | |||||
| parallel::ParallelParameterContextCkptInTraining(func_graph, param_node, abs_ref); | |||||
| } | } | ||||
| } | } | ||||
| // Analyze | // Analyze | ||||
| @@ -425,9 +425,6 @@ bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature | |||||
| converted = env; | converted = env; | ||||
| } else if (py::hasattr(obj, PYTHON_CLASS_MEMBER_NAMESPACE)) { | } else if (py::hasattr(obj, PYTHON_CLASS_MEMBER_NAMESPACE)) { | ||||
| converted = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, obj); | converted = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, obj); | ||||
| } else if (py::hasattr(obj, "__parameter__")) { | |||||
| auto to_convert = py::cast<py::object>(python_adapter::GetPyObjAttr(obj, "default_input")); | |||||
| ret = ConvertData(to_convert, &converted); | |||||
| } else { | } else { | ||||
| ret = ConvertOtherObj(obj, &converted); | ret = ConvertOtherObj(obj, &converted); | ||||
| } | } | ||||
| @@ -555,9 +552,6 @@ void MakeProperNameToFuncGraph(const FuncGraphPtr &func_graph, std::string name) | |||||
| ValuePtr PyDataToValue(const py::object &obj) { | ValuePtr PyDataToValue(const py::object &obj) { | ||||
| py::object to_convert = obj; | py::object to_convert = obj; | ||||
| if (py::hasattr(obj, "__parameter__")) { | |||||
| to_convert = py::cast<py::object>(python_adapter::GetPyObjAttr(obj, "default_input")); | |||||
| } | |||||
| ValuePtr value = nullptr; | ValuePtr value = nullptr; | ||||
| (void)ConvertData(to_convert, &value); | (void)ConvertData(to_convert, &value); | ||||
| return value; | return value; | ||||
| @@ -306,7 +306,14 @@ void FunctionBlock::ConditionalJump(AnfNodePtr condNode, const FunctionBlockPtr | |||||
| } | } | ||||
| void FunctionBlock::SetStateAssgin(const AnfNodePtr &target, const std::string &readid) { | void FunctionBlock::SetStateAssgin(const AnfNodePtr &target, const std::string &readid) { | ||||
| state_assign_[target] = readid; | |||||
| const std::string primitive_name("assign"); | |||||
| const std::string module_name("mindspore.ops.functional"); | |||||
| ValueNodePtr assign_op = NewValueNode(prim::GetPythonOps(primitive_name, module_name, true)); | |||||
| auto source = ReadVariable(readid); | |||||
| auto assign = func_graph()->NewCNode({assign_op, target, source}); | |||||
| WriteVariable(readid, assign); | |||||
| MS_LOG(INFO) << "SetState read " << target->DebugString() << ", " << readid; | |||||
| AddAutoDepend(assign); | |||||
| } | } | ||||
| void FunctionBlock::AddAutoDepend(const AnfNodePtr &target) { auto_depends_.push_back(target); } | void FunctionBlock::AddAutoDepend(const AnfNodePtr &target) { auto_depends_.push_back(target); } | ||||
| @@ -321,21 +328,13 @@ void FunctionBlock::InsertDependItemsBeforeReturn() { | |||||
| ValueNodePtr make_tuple_op = NewValueNode(prim::kPrimMakeTuple); | ValueNodePtr make_tuple_op = NewValueNode(prim::kPrimMakeTuple); | ||||
| ValueNodePtr depend_op = NewValueNode(prim::kPrimDepend); | ValueNodePtr depend_op = NewValueNode(prim::kPrimDepend); | ||||
| ValueNodePtr stop_gradient_op = NewValueNode(prim::kPrimStopGradient); | ValueNodePtr stop_gradient_op = NewValueNode(prim::kPrimStopGradient); | ||||
| const std::string primitive_name("assign"); | |||||
| const std::string module_name("mindspore.ops.functional"); | |||||
| ValueNodePtr assign_op = NewValueNode(prim::GetPythonOps(primitive_name, module_name, true)); | |||||
| if (state_assign_.size() == 0 && auto_depends_.size() == 0) { | |||||
| if (auto_depends_.size() == 0) { | |||||
| return; | return; | ||||
| } | } | ||||
| AnfNodePtr state = nullptr; | AnfNodePtr state = nullptr; | ||||
| std::vector<AnfNodePtr> vec_states; | std::vector<AnfNodePtr> vec_states; | ||||
| vec_states.emplace_back(make_tuple_op); | vec_states.emplace_back(make_tuple_op); | ||||
| for (auto &item : state_assign_) { | |||||
| auto source = ReadVariable(item.second); | |||||
| auto assign = func_graph()->NewCNode({assign_op, item.first, source}); | |||||
| MS_LOG(INFO) << "SetState read " << item.first->ToString() << ", " << item.second; | |||||
| vec_states.emplace_back(assign); | |||||
| } | |||||
| for (auto &item : auto_depends_) { | for (auto &item : auto_depends_) { | ||||
| MS_LOG(DEBUG) << "auto_depends " << item->ToString(); | MS_LOG(DEBUG) << "auto_depends " << item->ToString(); | ||||
| vec_states.emplace_back(item); | vec_states.emplace_back(item); | ||||
| @@ -361,7 +360,6 @@ void FunctionBlock::InsertDependItemsBeforeReturn() { | |||||
| AnfNodePtr stopped = func_graph()->NewCNode({stop_gradient_op, state}); | AnfNodePtr stopped = func_graph()->NewCNode({stop_gradient_op, state}); | ||||
| AnfNodePtr ret = func_graph()->NewCNode({depend_op, old_ret, stopped}); | AnfNodePtr ret = func_graph()->NewCNode({depend_op, old_ret, stopped}); | ||||
| func_graph()->set_output(ret, true); | func_graph()->set_output(ret, true); | ||||
| state_assign_.clear(); | |||||
| } | } | ||||
| } // namespace parse | } // namespace parse | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -101,9 +101,6 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> { | |||||
| // keeps all removable phis which will be removed in one pass. | // keeps all removable phis which will be removed in one pass. | ||||
| std::unordered_map<ParameterPtr, AnfNodePtr> removable_phis_; | std::unordered_map<ParameterPtr, AnfNodePtr> removable_phis_; | ||||
| // set state nodes need to insert before function return nodes. | |||||
| OrderedMap<AnfNodePtr, std::string> state_assign_; | |||||
| // hold declared global variables in function | // hold declared global variables in function | ||||
| std::set<std::string> global_vars_; | std::set<std::string> global_vars_; | ||||
| @@ -59,14 +59,13 @@ FuncGraphPtr ParsePythonCode(const py::object &obj, const std::string &python_mo | |||||
| return func_graph; | return func_graph; | ||||
| } | } | ||||
| ValuePtr GetMixedPrecisionTargetType(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m) { | |||||
| TypePtr dst_type; | |||||
| TypePtr GetMixedPrecisionTargetType(const FuncGraphPtr &func_graph) { | |||||
| if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP32)) { | if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP32)) { | ||||
| return kFloat32; | return kFloat32; | ||||
| } else if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP16)) { | } else if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP16)) { | ||||
| return kFloat16; | return kFloat16; | ||||
| } else { | } else { | ||||
| return kNone; | |||||
| return nullptr; | |||||
| } | } | ||||
| } | } | ||||
| @@ -364,7 +364,7 @@ class ParseAst { | |||||
| bool UpdateFuncGraphFlags(py::object obj, const FuncGraphPtr &func_graph); | bool UpdateFuncGraphFlags(py::object obj, const FuncGraphPtr &func_graph); | ||||
| AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m); | AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m); | ||||
| ValuePtr GetMixedPrecisionTargetType(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m); | |||||
| TypePtr GetMixedPrecisionTargetType(const FuncGraphPtr &func_graph); | |||||
| } // namespace parse | } // namespace parse | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -105,24 +105,12 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object | |||||
| auto value = py::cast<tensor::MetaTensorPtr>(obj); | auto value = py::cast<tensor::MetaTensorPtr>(obj); | ||||
| node->set_default_param(value); | node->set_default_param(value); | ||||
| // set_abstract for parameter | // set_abstract for parameter | ||||
| constexpr bool broaden = true; | |||||
| node->set_abstract(abstract::FromValue(value, broaden)); | |||||
| auto abs = value->ToAbstract(); | |||||
| node->set_abstract(abs); | |||||
| para_node = node; | para_node = node; | ||||
| } | } | ||||
| auto iter = func_graph->make_ref_params().find(para_node); | |||||
| if (iter == func_graph->make_ref_params().end()) { | |||||
| ValuePtr target_type = GetMixedPrecisionTargetType(func_graph, para_node); | |||||
| AnfNodePtr make_ref = NewValueNode(prim::kPrimMakeRef); | |||||
| AnfNodePtr ref_key = NewValueNode(std::make_shared<RefKey>(param_name)); | |||||
| AnfNodePtr target_type_node = NewValueNode(target_type); | |||||
| AnfNodePtr ref_node = func_graph->NewCNode({make_ref, ref_key, para_node, target_type_node}); | |||||
| func_graph->make_ref_params()[para_node] = ref_node; | |||||
| func_graph->add_parameter_obj_node(ref_node); | |||||
| return ref_node; | |||||
| } else { | |||||
| return iter->second; | |||||
| } | |||||
| return para_node; | |||||
| } | } | ||||
| bool ResolveObjectToNode(const FuncGraphPtr &func_graph, const py::object &obj, AnfNodePtr *const node) { | bool ResolveObjectToNode(const FuncGraphPtr &func_graph, const py::object &obj, AnfNodePtr *const node) { | ||||
| @@ -640,7 +640,14 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v | |||||
| size_t size = op_exec_info->op_inputs.size(); | size_t size = op_exec_info->op_inputs.size(); | ||||
| for (size_t i = 0; i < size; i++) { | for (size_t i = 0; i < size; i++) { | ||||
| auto obj = op_exec_info->op_inputs[i]; | auto obj = op_exec_info->op_inputs[i]; | ||||
| bool op_mask = py::hasattr(obj, "__parameter__"); | |||||
| bool op_mask = false; | |||||
| if (py::isinstance<tensor::MetaTensor>(obj)) { | |||||
| auto meta_tensor = obj.cast<tensor::MetaTensorPtr>(); | |||||
| if (meta_tensor) { | |||||
| op_mask = meta_tensor->is_parameter(); | |||||
| } | |||||
| } | |||||
| (*op_masks).push_back(op_mask); | (*op_masks).push_back(op_mask); | ||||
| MS_LOG(DEBUG) << "gen " << op_exec_info->op_name << " arg " << i << ": op mask " << op_mask << " grad_flag_ " | MS_LOG(DEBUG) << "gen " << op_exec_info->op_name << " arg " << i << ": op mask " << op_mask << " grad_flag_ " | ||||
| << grad_flag_; | << grad_flag_; | ||||
| @@ -990,8 +997,9 @@ AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, bool op_mask) { | |||||
| if (graph_info_map_[df_builder_].param_map.count(obj_id) == 0) { | if (graph_info_map_[df_builder_].param_map.count(obj_id) == 0) { | ||||
| auto free_param = df_builder_->add_parameter(); | auto free_param = df_builder_->add_parameter(); | ||||
| free_param->set_name(param_name); | free_param->set_name(param_name); | ||||
| free_param->set_default_param(py::cast<tensor::TensorPtr>(obj)); | |||||
| free_param->debug_info()->set_name(param_name); | free_param->debug_info()->set_name(param_name); | ||||
| auto value = py::cast<tensor::TensorPtr>(obj); | |||||
| free_param->set_default_param(value); | |||||
| MS_LOG(DEBUG) << "Top graph set free parameter " << obj_id; | MS_LOG(DEBUG) << "Top graph set free parameter " << obj_id; | ||||
| graph_info_map_[df_builder_].param_map[obj_id] = free_param; | graph_info_map_[df_builder_].param_map[obj_id] = free_param; | ||||
| return free_param; | return free_param; | ||||
| @@ -1159,17 +1167,12 @@ std::vector<AnfNodePtr> PynativeExecutor::GetWeightsArgs(const py::object &weigh | |||||
| auto param_name = py::cast<std::string>(name_attr); | auto param_name = py::cast<std::string>(name_attr); | ||||
| auto free_param = df_builder_->add_parameter(); | auto free_param = df_builder_->add_parameter(); | ||||
| free_param->set_name(param_name); | free_param->set_name(param_name); | ||||
| free_param->set_default_param(py::cast<tensor::TensorPtr>(param)); | |||||
| auto value = py::cast<tensor::TensorPtr>(param); | |||||
| free_param->set_default_param(value); | |||||
| free_param->debug_info()->set_name(param_name); | free_param->debug_info()->set_name(param_name); | ||||
| para_node = free_param; | para_node = free_param; | ||||
| } | } | ||||
| ValuePtr target_type = parse::GetMixedPrecisionTargetType(df_builder_, para_node); | |||||
| AnfNodePtr make_ref = NewValueNode(prim::kPrimMakeRef); | |||||
| auto refkey = std::make_shared<RefKey>(para_node->cast<ParameterPtr>()->name()); | |||||
| AnfNodePtr ref_key_node = NewValueNode(refkey); | |||||
| AnfNodePtr target_type_node = NewValueNode(target_type); | |||||
| AnfNodePtr ref_node = df_builder_->NewCNode({make_ref, ref_key_node, para_node, target_type_node}); | |||||
| w_args.push_back(ref_node); | |||||
| w_args.push_back(para_node); | |||||
| } | } | ||||
| } else { | } else { | ||||
| MS_LOG(DEBUG) << "training not paramter_tuple"; | MS_LOG(DEBUG) << "training not paramter_tuple"; | ||||
| @@ -1197,7 +1200,7 @@ abstract::AbstractBasePtrList PynativeExecutor::GetArgsSpec(const py::args &args | |||||
| auto param_node = std::static_pointer_cast<Parameter>(param); | auto param_node = std::static_pointer_cast<Parameter>(param); | ||||
| if (param_node->has_default()) { | if (param_node->has_default()) { | ||||
| ValuePtr value = param_node->default_param(); | ValuePtr value = param_node->default_param(); | ||||
| AbstractBasePtr ptr = abstract::FromValue(value, true); | |||||
| auto ptr = value->ToAbstract(); | |||||
| if (ptr == nullptr) { | if (ptr == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "Args convert error"; | MS_LOG(EXCEPTION) << "Args convert error"; | ||||
| } | } | ||||
| @@ -147,7 +147,7 @@ REGISTER_PYBIND_DEFINE( | |||||
| (void)py::class_<TypeType, Type, std::shared_ptr<TypeType>>(m_sub, "TypeType").def(py::init()); | (void)py::class_<TypeType, Type, std::shared_ptr<TypeType>>(m_sub, "TypeType").def(py::init()); | ||||
| (void)py::class_<String, Type, std::shared_ptr<String>>(m_sub, "String").def(py::init()); | (void)py::class_<String, Type, std::shared_ptr<String>>(m_sub, "String").def(py::init()); | ||||
| (void)py::class_<RefKeyType, Type, std::shared_ptr<RefKeyType>>(m_sub, "RefKeyType").def(py::init()); | (void)py::class_<RefKeyType, Type, std::shared_ptr<RefKeyType>>(m_sub, "RefKeyType").def(py::init()); | ||||
| (void)py::class_<RefType, Type, std::shared_ptr<RefType>>(m_sub, "RefType").def(py::init()); | |||||
| (void)py::class_<RefType, TensorType, Type, std::shared_ptr<RefType>>(m_sub, "RefType").def(py::init()); | |||||
| (void)py::class_<TypeAnything, Type, std::shared_ptr<TypeAnything>>(m_sub, "TypeAnything").def(py::init()); | (void)py::class_<TypeAnything, Type, std::shared_ptr<TypeAnything>>(m_sub, "TypeAnything").def(py::init()); | ||||
| (void)py::class_<Slice, Type, std::shared_ptr<Slice>>(m_sub, "Slice").def(py::init()); | (void)py::class_<Slice, Type, std::shared_ptr<Slice>>(m_sub, "Slice").def(py::init()); | ||||
| (void)py::class_<TypeEllipsis, Type, std::shared_ptr<TypeEllipsis>>(m_sub, "TypeEllipsis").def(py::init()); | (void)py::class_<TypeEllipsis, Type, std::shared_ptr<TypeEllipsis>>(m_sub, "TypeEllipsis").def(py::init()); | ||||
| @@ -21,7 +21,7 @@ namespace mindspore { | |||||
| namespace py = pybind11; | namespace py = pybind11; | ||||
| REGISTER_PYBIND_DEFINE(ParamInfo, ([](const py::module *m) { | REGISTER_PYBIND_DEFINE(ParamInfo, ([](const py::module *m) { | ||||
| (void)py::class_<ParamInfo, ParamValuePtr>(*m, "ParamInfo") | |||||
| (void)py::class_<ParamInfo, ParamInfoPtr>(*m, "ParamInfo") | |||||
| .def(py::init()) | .def(py::init()) | ||||
| .def("clone", &ParamInfo::Clone) | .def("clone", &ParamInfo::Clone) | ||||
| .def_property("name", &ParamInfo::name, &ParamInfo::set_name) | .def_property("name", &ParamInfo::name, &ParamInfo::set_name) | ||||
| @@ -36,7 +36,7 @@ REGISTER_PYBIND_DEFINE(ParamInfo, ([](const py::module *m) { | |||||
| if (t.size() != 6) { | if (t.size() != 6) { | ||||
| std::runtime_error("Invalid state for ParamInfo!"); | std::runtime_error("Invalid state for ParamInfo!"); | ||||
| } | } | ||||
| ParamValuePtr p = std::make_shared<ParamInfo>(); | |||||
| ParamInfoPtr p = std::make_shared<ParamInfo>(); | |||||
| p->set_name(t[1].cast<std::string>()); | p->set_name(t[1].cast<std::string>()); | ||||
| p->set_requires_grad(t[2].cast<bool>()); | p->set_requires_grad(t[2].cast<bool>()); | ||||
| p->set_layerwise_parallel(t[3].cast<bool>()); | p->set_layerwise_parallel(t[3].cast<bool>()); | ||||
| @@ -291,6 +291,7 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) { | |||||
| .def(py::init<TypePtr, const std::vector<int>>(), py::arg("dtype"), py::arg("shape")) | .def(py::init<TypePtr, const std::vector<int>>(), py::arg("dtype"), py::arg("shape")) | ||||
| .def_property_readonly("dtype", &MetaTensor::Dtype, "Get the MetaTensor's dtype.") | .def_property_readonly("dtype", &MetaTensor::Dtype, "Get the MetaTensor's dtype.") | ||||
| .def_property_readonly("shape", &MetaTensor::shape, "Get the MetaTensor's shape.") | .def_property_readonly("shape", &MetaTensor::shape, "Get the MetaTensor's shape.") | ||||
| .def_property("_param_info", &MetaTensor::param_info, &MetaTensor::set_param_info) | |||||
| .def(py::pickle( | .def(py::pickle( | ||||
| [](const MetaTensor &t) { // __getstate__ | [](const MetaTensor &t) { // __getstate__ | ||||
| /* Return a tuple that fully encodes the state of the object */ | /* Return a tuple that fully encodes the state of the object */ | ||||
| @@ -42,7 +42,7 @@ class Parameter(MetaTensor): | |||||
| In auto_parallel mode of "semi_auto_parallel" and "auto_parallel", if init `Parameter` by | In auto_parallel mode of "semi_auto_parallel" and "auto_parallel", if init `Parameter` by | ||||
| an `Initializer`, the type of Parameter will be `MetaTensor` not `Tensor`. `MetaTensor` | an `Initializer`, the type of Parameter will be `MetaTensor` not `Tensor`. `MetaTensor` | ||||
| only saves the shape and type info of a tensor with no memory usage. The shape can be changed while | only saves the shape and type info of a tensor with no memory usage. The shape can be changed while | ||||
| compile for auto-parallel. Call `init_data` will return a Tensor Parameter with initialized data. | |||||
| compiling for auto-parallel. Call `init_data` will return a Tensor Parameter with initialized data. | |||||
| Note: | Note: | ||||
| Each parameter of Cell is represented by Parameter class. | Each parameter of Cell is represented by Parameter class. | ||||
| @@ -108,7 +108,7 @@ class Parameter(MetaTensor): | |||||
| Parameter, (data, self.name, self.requires_grad, self.layerwise_parallel)) | Parameter, (data, self.name, self.requires_grad, self.layerwise_parallel)) | ||||
| def __init__(self, default_input, name, requires_grad=True, layerwise_parallel=False): | def __init__(self, default_input, name, requires_grad=True, layerwise_parallel=False): | ||||
| self._value = ParamInfo() | |||||
| self._param_info = ParamInfo() | |||||
| self.name = name | self.name = name | ||||
| self.requires_grad = requires_grad | self.requires_grad = requires_grad | ||||
| self.layerwise_parallel = layerwise_parallel | self.layerwise_parallel = layerwise_parallel | ||||
| @@ -156,13 +156,13 @@ class Parameter(MetaTensor): | |||||
| value_str = MetaTensor.__str__(self) | value_str = MetaTensor.__str__(self) | ||||
| if isinstance(self, Tensor): | if isinstance(self, Tensor): | ||||
| value_str = Tensor.__str__(self) | value_str = Tensor.__str__(self) | ||||
| return f'Parameter (name={self._value.name}, value={value_str})' | |||||
| return f'Parameter (name={self._param_info.name}, value={value_str})' | |||||
| def __repr__(self): | def __repr__(self): | ||||
| value_str = MetaTensor.__repr__(self) | value_str = MetaTensor.__repr__(self) | ||||
| if isinstance(self, Tensor): | if isinstance(self, Tensor): | ||||
| value_str = Tensor.__repr__(self) | value_str = Tensor.__repr__(self) | ||||
| return f'Parameter (name={self._value.name}, value={value_str})' | |||||
| return f'Parameter (name={self._param_info.name}, value={value_str})' | |||||
| def __parameter__(self): | def __parameter__(self): | ||||
| """For parse check.""" | """For parse check.""" | ||||
| @@ -181,7 +181,7 @@ class Parameter(MetaTensor): | |||||
| @property | @property | ||||
| def name(self): | def name(self): | ||||
| """Get the name of the parameter.""" | """Get the name of the parameter.""" | ||||
| return self._value.name | |||||
| return self._param_info.name | |||||
| @name.setter | @name.setter | ||||
| def name(self, name_): | def name(self, name_): | ||||
| @@ -203,7 +203,7 @@ class Parameter(MetaTensor): | |||||
| format(name_, PARAMETER_NAME_PREFIX_MAX_LEN)) | format(name_, PARAMETER_NAME_PREFIX_MAX_LEN)) | ||||
| else: | else: | ||||
| raise ValueError("The type of the name should be `str` or `None`.") | raise ValueError("The type of the name should be `str` or `None`.") | ||||
| self._value.name = name_ | |||||
| self._param_info.name = name_ | |||||
| @property | @property | ||||
| def cast_type(self): | def cast_type(self): | ||||
| @@ -254,8 +254,8 @@ class Parameter(MetaTensor): | |||||
| _check_str_by_regular(prefix) | _check_str_by_regular(prefix) | ||||
| x = copy(self) | x = copy(self) | ||||
| # pylint: disable=protected-access | # pylint: disable=protected-access | ||||
| x._value = self._value.clone() | |||||
| x._value.name = prefix + '.' + self._value.name | |||||
| x._param_info = self._param_info.clone() | |||||
| x._param_info.name = prefix + '.' + self._param_info.name | |||||
| x.is_init = False | x.is_init = False | ||||
| if init != 'same': | if init != 'same': | ||||
| shape = self.shape | shape = self.shape | ||||
| @@ -265,24 +265,24 @@ class Parameter(MetaTensor): | |||||
| @property | @property | ||||
| def layerwise_parallel(self): | def layerwise_parallel(self): | ||||
| return self._value.layerwise_parallel | |||||
| return self._param_info.layerwise_parallel | |||||
| @layerwise_parallel.setter | @layerwise_parallel.setter | ||||
| def layerwise_parallel(self, value=True): | def layerwise_parallel(self, value=True): | ||||
| if not isinstance(value, bool): | if not isinstance(value, bool): | ||||
| raise TypeError("`layerwise_parallel` parameter must be bool type") | raise TypeError("`layerwise_parallel` parameter must be bool type") | ||||
| self._value.layerwise_parallel = value | |||||
| self._param_info.layerwise_parallel = value | |||||
| @property | @property | ||||
| def requires_grad(self): | def requires_grad(self): | ||||
| """Return whether the parameter requires gradient.""" | """Return whether the parameter requires gradient.""" | ||||
| return self._value.requires_grad | |||||
| return self._param_info.requires_grad | |||||
| @requires_grad.setter | @requires_grad.setter | ||||
| def requires_grad(self, value=True): | def requires_grad(self, value=True): | ||||
| if not isinstance(value, bool): | if not isinstance(value, bool): | ||||
| raise TypeError("`requires_grad` parameter must be bool type") | raise TypeError("`requires_grad` parameter must be bool type") | ||||
| self._value.requires_grad = value | |||||
| self._param_info.requires_grad = value | |||||
| @property | @property | ||||
| def data(self): | def data(self): | ||||
| @@ -459,10 +459,6 @@ AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) { | |||||
| } | } | ||||
| auto other_tensor = dyn_cast<AbstractTensor>(other); | auto other_tensor = dyn_cast<AbstractTensor>(other); | ||||
| if (other_tensor == nullptr) { | if (other_tensor == nullptr) { | ||||
| auto ref_tensor = dyn_cast<AbstractRef>(other); | |||||
| if (ref_tensor != nullptr) { | |||||
| return this->Join(ref_tensor->ref()); | |||||
| } | |||||
| MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString(); | MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString(); | ||||
| } | } | ||||
| if (*this == *other) { | if (*this == *other) { | ||||
| @@ -473,7 +469,7 @@ AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) { | |||||
| return std::make_shared<AbstractTensor>(element, shape); | return std::make_shared<AbstractTensor>(element, shape); | ||||
| } | } | ||||
| bool AbstractTensor::operator==(const AbstractTensor &other) const { | |||||
| bool AbstractTensor::equal_to(const AbstractTensor &other) const { | |||||
| if (&other == this) { | if (&other == this) { | ||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -491,12 +487,14 @@ bool AbstractTensor::operator==(const AbstractTensor &other) const { | |||||
| return (*element_ == *other.element_) && (*shape() == *other.shape()) && is_value_equal; | return (*element_ == *other.element_) && (*shape() == *other.shape()) && is_value_equal; | ||||
| } | } | ||||
| bool AbstractTensor::operator==(const AbstractTensor &other) const { return equal_to(other); } | |||||
| bool AbstractTensor::operator==(const AbstractBase &other) const { | bool AbstractTensor::operator==(const AbstractBase &other) const { | ||||
| if (&other == this) { | if (&other == this) { | ||||
| return true; | return true; | ||||
| } | } | ||||
| if (other.isa<AbstractTensor>()) { | |||||
| if (other.tid() == tid()) { | |||||
| auto other_tensor = static_cast<const AbstractTensor *>(&other); | auto other_tensor = static_cast<const AbstractTensor *>(&other); | ||||
| return *this == *other_tensor; | return *this == *other_tensor; | ||||
| } else { | } else { | ||||
| @@ -822,39 +820,21 @@ std::string AbstractJTagged::ToString() const { | |||||
| return buffer.str(); | return buffer.str(); | ||||
| } | } | ||||
| AbstractRef::AbstractRef(const AbstractBasePtr &ref_key, const AbstractBasePtr &ref_value, bool need_cast, | |||||
| TypePtr cast_target) | |||||
| : ref_key_(ref_key), ref_(ref_value), need_cast_(false), target_type_(nullptr), ref_key_value_(nullptr) { | |||||
| AbstractRef::AbstractRef(const AbstractBasePtr &ref_key, const AbstractTensorPtr &ref_value) | |||||
| : AbstractTensor(*ref_value), ref_key_(ref_key), ref_key_value_(nullptr) { | |||||
| set_type(std::make_shared<RefType>()); | set_type(std::make_shared<RefType>()); | ||||
| auto origin_type = ref_value->BuildType(); | |||||
| if (need_cast && cast_target && origin_type && origin_type->isa<TensorType>()) { | |||||
| auto tensor_dtype = origin_type->cast<TensorTypePtr>()->element(); | |||||
| if (tensor_dtype && IsSubType(tensor_dtype, kFloat)) { | |||||
| if (cast_target != tensor_dtype) { | |||||
| need_cast_ = true; | |||||
| target_type_ = cast_target; | |||||
| } | |||||
| } | |||||
| } | |||||
| if (ref_key && ref_key->isa<AbstractRefKey>()) { | if (ref_key && ref_key->isa<AbstractRefKey>()) { | ||||
| ref_key_value_ = ref_key->cast<AbstractRefKeyPtr>()->ref_key_value(); | ref_key_value_ = ref_key->cast<AbstractRefKeyPtr>()->ref_key_value(); | ||||
| } | } | ||||
| } | } | ||||
| BaseShapePtr AbstractRef::BuildShape() const { return ref_->BuildShape(); } | |||||
| TypePtr AbstractRef::BuildType() const { | TypePtr AbstractRef::BuildType() const { | ||||
| TypePtr subtype = ref_->BuildType(); | |||||
| TypePtr subtype_origin = subtype; | |||||
| if (need_cast_) { | |||||
| subtype_origin = std::make_shared<TensorType>(target_type_); | |||||
| } | |||||
| return std::make_shared<RefType>(subtype, subtype_origin); | |||||
| auto subtype = AbstractTensor::BuildType()->cast<TensorTypePtr>(); | |||||
| return std::make_shared<RefType>(subtype); | |||||
| } | } | ||||
| bool AbstractRef::operator==(const AbstractRef &other) const { | bool AbstractRef::operator==(const AbstractRef &other) const { | ||||
| return (*ref_ == *other.ref_) && (need_cast_ == other.need_cast_) && (*ref_key_ == *other.ref_key_) && | |||||
| (!need_cast_ || (*target_type_ == *other.target_type_)); | |||||
| return AbstractTensor::equal_to(other) && (*ref_key_ == *other.ref_key_); | |||||
| } | } | ||||
| bool AbstractRef::operator==(const AbstractBase &other) const { | bool AbstractRef::operator==(const AbstractBase &other) const { | ||||
| @@ -886,24 +866,20 @@ AbstractBasePtr AbstractRefKey::Join(const AbstractBasePtr &other) { | |||||
| AbstractBasePtr AbstractRef::Join(const AbstractBasePtr &other) { | AbstractBasePtr AbstractRef::Join(const AbstractBasePtr &other) { | ||||
| auto other_ref = other->cast<AbstractRefPtr>(); | auto other_ref = other->cast<AbstractRefPtr>(); | ||||
| if (other_ref == nullptr) { | if (other_ref == nullptr) { | ||||
| auto new_ref = ref_->Join(other); | |||||
| return std::make_shared<AbstractRef>(ref_key_, new_ref); | |||||
| return AbstractTensor::Join(other)->cast<AbstractTensorPtr>(); | |||||
| } | } | ||||
| if ((*this == *other) && (*ref_key_ == *other_ref->ref_key_)) { | if ((*this == *other) && (*ref_key_ == *other_ref->ref_key_)) { | ||||
| return shared_from_base<AbstractBase>(); | return shared_from_base<AbstractBase>(); | ||||
| } | } | ||||
| auto ref_key = ref_key_->Join(other_ref->ref_key_); | auto ref_key = ref_key_->Join(other_ref->ref_key_); | ||||
| auto ref = ref_->Join(other_ref->ref()); | |||||
| auto ref = AbstractTensor::Join(other_ref->ref())->cast<AbstractTensorPtr>(); | |||||
| return std::make_shared<AbstractRef>(ref_key, ref); | return std::make_shared<AbstractRef>(ref_key, ref); | ||||
| } | } | ||||
| std::string AbstractRef::ToString() const { | std::string AbstractRef::ToString() const { | ||||
| std::ostringstream buffer; | std::ostringstream buffer; | ||||
| buffer << type_name() << "(" | buffer << type_name() << "(" | ||||
| << "key: " << ref_key_->ToString() << " ref_value: " << ref_->ToString(); | |||||
| if (need_cast_) { | |||||
| buffer << " cast to: " << target_type_->ToString(); | |||||
| } | |||||
| << "key: " << ref_key_->ToString() << " ref_value: " << AbstractTensor::ToString(); | |||||
| auto value = GetValueTrack(); | auto value = GetValueTrack(); | ||||
| if (value) { | if (value) { | ||||
| buffer << ", value: " << value->ToString(); | buffer << ", value: " << value->ToString(); | ||||
| @@ -284,11 +284,9 @@ class AbstractTensor : public AbstractUndetermined { | |||||
| AbstractBasePtr Clone() const override; | AbstractBasePtr Clone() const override; | ||||
| AbstractBasePtr Broaden(uint8_t config = 0) const override; | AbstractBasePtr Broaden(uint8_t config = 0) const override; | ||||
| AbstractBasePtr BroadenWithShape() const; | AbstractBasePtr BroadenWithShape() const; | ||||
| AbstractBasePtr Join(const AbstractBasePtr &other) final; | |||||
| AbstractBasePtr Join(const AbstractBasePtr &other); | |||||
| bool operator==(const AbstractTensor &other) const; | bool operator==(const AbstractTensor &other) const; | ||||
| bool operator==(const AbstractBase &other) const override; | bool operator==(const AbstractBase &other) const override; | ||||
| std::string ToString() const override; | std::string ToString() const override; | ||||
| std::size_t hash() const override { | std::size_t hash() const override { | ||||
| auto value = GetValueTrack(); | auto value = GetValueTrack(); | ||||
| @@ -301,6 +299,9 @@ class AbstractTensor : public AbstractUndetermined { | |||||
| } | } | ||||
| return hash_sum; | return hash_sum; | ||||
| } | } | ||||
| protected: | |||||
| bool equal_to(const AbstractTensor &other) const; | |||||
| }; | }; | ||||
| using AbstractTensorPtr = std::shared_ptr<AbstractTensor>; | using AbstractTensorPtr = std::shared_ptr<AbstractTensor>; | ||||
| using AbstractTensorPtrList = std::vector<AbstractTensorPtr>; | using AbstractTensorPtrList = std::vector<AbstractTensorPtr>; | ||||
| @@ -575,42 +576,42 @@ class AbstractRefKey : public AbstractBase { | |||||
| }; | }; | ||||
| using AbstractRefKeyPtr = std::shared_ptr<AbstractRefKey>; | using AbstractRefKeyPtr = std::shared_ptr<AbstractRefKey>; | ||||
| class AbstractRef : public AbstractBase { | |||||
| class AbstractRef : public AbstractTensor { | |||||
| public: | public: | ||||
| AbstractRef(const AbstractBasePtr &ref_key, const AbstractBasePtr &ref_value, bool need_cast = false, | |||||
| TypePtr cast_target = nullptr); | |||||
| AbstractRef(const AbstractBasePtr &ref_key, const AbstractTensorPtr &ref_value); | |||||
| ~AbstractRef() override = default; | ~AbstractRef() override = default; | ||||
| MS_DECLARE_PARENT(AbstractRef, AbstractBase) | |||||
| MS_DECLARE_PARENT(AbstractRef, AbstractTensor) | |||||
| TypePtr BuildType() const override; | TypePtr BuildType() const override; | ||||
| BaseShapePtr BuildShape() const override; | |||||
| bool operator==(const AbstractRef &other) const; | bool operator==(const AbstractRef &other) const; | ||||
| bool operator==(const AbstractBase &other) const override; | bool operator==(const AbstractBase &other) const override; | ||||
| AbstractBasePtr Clone() const override { | AbstractBasePtr Clone() const override { | ||||
| return std::make_shared<AbstractRef>(ref_key_->Clone(), ref_->Clone(), need_cast_, target_type_); | |||||
| auto abs_tensor = AbstractTensor::Clone()->cast<AbstractTensorPtr>(); | |||||
| if (abs_tensor == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| return std::make_shared<AbstractRef>(ref_key_->Clone(), abs_tensor); | |||||
| } | } | ||||
| std::string ToString() const override; | std::string ToString() const override; | ||||
| inline AbstractBasePtr ref() const { return ref_; } | |||||
| inline AbstractTensorPtr ref() { return shared_from_base<AbstractTensor>(); } | |||||
| inline AbstractBasePtr ref_key() const { return ref_key_; } | inline AbstractBasePtr ref_key() const { return ref_key_; } | ||||
| inline RefKeyPtr ref_key_value() const { return ref_key_value_; } | inline RefKeyPtr ref_key_value() const { return ref_key_value_; } | ||||
| inline TypePtr target_type() const { return target_type_; } | |||||
| inline bool need_cast() const { return need_cast_; } | |||||
| AbstractBasePtr Broaden(uint8_t config = 0) const override { | AbstractBasePtr Broaden(uint8_t config = 0) const override { | ||||
| // always broaden for ref | // always broaden for ref | ||||
| return std::make_shared<AbstractRef>(ref_key_->Broaden(config), ref_->Broaden(), need_cast_, target_type_); | |||||
| auto abs_tensor = AbstractTensor::Broaden()->cast<AbstractTensorPtr>(); | |||||
| if (abs_tensor == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| return std::make_shared<AbstractRef>(ref_key_->Broaden(config), abs_tensor); | |||||
| } | } | ||||
| AbstractBasePtr Join(const AbstractBasePtr &other) override; | AbstractBasePtr Join(const AbstractBasePtr &other) override; | ||||
| std::size_t hash() const override { | std::size_t hash() const override { | ||||
| return ref_->hash() ^ (std::hash<uint32_t>{}(this->tid()) << 1); // ref_key_->hash() ^ | |||||
| return AbstractTensor::hash() ^ (std::hash<uint32_t>{}(this->tid()) << 1); // ref_key_->hash() ^ | |||||
| } | } | ||||
| private: | private: | ||||
| AbstractBasePtr ref_key_; | AbstractBasePtr ref_key_; | ||||
| AbstractBasePtr ref_; | |||||
| // For mix presicion, only float type need to cast to float16 of float32 | |||||
| bool need_cast_; | |||||
| TypePtr target_type_; | |||||
| // cache for ref_key after build value, when value is null, return nullptr. | // cache for ref_key after build value, when value is null, return nullptr. | ||||
| RefKeyPtr ref_key_value_; | RefKeyPtr ref_key_value_; | ||||
| }; | }; | ||||
| @@ -113,17 +113,8 @@ AbstractBasePtr InferImplMakeRef(const AnalysisEnginePtr &, const PrimitivePtr & | |||||
| MS_LOG(EXCEPTION) << "make_ref evaluator requires 3 parameters, while the input size is " << args_spec_list.size() | MS_LOG(EXCEPTION) << "make_ref evaluator requires 3 parameters, while the input size is " << args_spec_list.size() | ||||
| << "."; | << "."; | ||||
| } | } | ||||
| TypePtr type = args_spec_list[0]->GetTypeTrack(); | |||||
| ValuePtr tensor_target_v = args_spec_list[2]->BuildValue(); | |||||
| if (type->type_id() != kObjectTypeRefKey) { | |||||
| MS_LOG(EXCEPTION) << "First input of make_ref should be a RefKey but a " << type->ToString(); | |||||
| } | |||||
| auto need_cast = !tensor_target_v->isa<None>(); | |||||
| if (need_cast && !tensor_target_v->isa<Type>()) { | |||||
| MS_LOG(EXCEPTION) << "Third input of make_ref should be a Type but a " << tensor_target_v->ToString(); | |||||
| } | |||||
| TypePtr cast_target = tensor_target_v->cast<TypePtr>(); | |||||
| return std::make_shared<AbstractRef>(args_spec_list[0], args_spec_list[1], need_cast, cast_target); | |||||
| auto tensor = args_spec_list[1]->cast<abstract::AbstractTensorPtr>(); | |||||
| return std::make_shared<AbstractRef>(args_spec_list[0], tensor); | |||||
| } | } | ||||
| AbstractBasePtr InferImplGetRefKey(const AnalysisEnginePtr &, const PrimitivePtr &, | AbstractBasePtr InferImplGetRefKey(const AnalysisEnginePtr &, const PrimitivePtr &, | ||||
| @@ -88,6 +88,17 @@ std::string Parameter::DebugString(int recursive_level) const { | |||||
| return buffer.str(); | return buffer.str(); | ||||
| } | } | ||||
| ParamInfoPtr Parameter::param_info() const { | |||||
| if (!has_default()) { | |||||
| return nullptr; | |||||
| } | |||||
| auto tensor = default_param()->cast<tensor::MetaTensorPtr>(); | |||||
| if (tensor == nullptr || !tensor->is_parameter()) { | |||||
| return nullptr; | |||||
| } | |||||
| return tensor->param_info(); | |||||
| } | |||||
| std::string ValueNode::ToString() const { | std::string ValueNode::ToString() const { | ||||
| MS_EXCEPTION_IF_NULL(value_); | MS_EXCEPTION_IF_NULL(value_); | ||||
| if (value_->isa<FuncGraph>()) { | if (value_->isa<FuncGraph>()) { | ||||
| @@ -75,7 +75,7 @@ using VarPtr = std::shared_ptr<Var>; | |||||
| class AnfIrVisitor; | class AnfIrVisitor; | ||||
| class ParamInfo; | class ParamInfo; | ||||
| using ParamValuePtr = std::shared_ptr<ParamInfo>; | |||||
| using ParamInfoPtr = std::shared_ptr<ParamInfo>; | |||||
| // AnfNode is the basic class of the IR definition derived from Base. | // AnfNode is the basic class of the IR definition derived from Base. | ||||
| // Only two types of nodes are derived: CNode and ANode. | // Only two types of nodes are derived: CNode and ANode. | ||||
| @@ -288,6 +288,7 @@ class Parameter : public ANode { | |||||
| has_default_ = true; | has_default_ = true; | ||||
| } | } | ||||
| ValuePtr default_param() const { return default_param_; } | ValuePtr default_param() const { return default_param_; } | ||||
| ParamInfoPtr param_info() const; | |||||
| bool operator==(const AnfNode &other) const override { | bool operator==(const AnfNode &other) const override { | ||||
| if (!other.isa<Parameter>()) { | if (!other.isa<Parameter>()) { | ||||
| @@ -94,175 +94,6 @@ bool Slice::operator==(const Type &other) const { | |||||
| std::string Slice::DumpText() const { return ToString(); } | std::string Slice::DumpText() const { return ToString(); } | ||||
| TypePtr UndeterminedType::DeepCopy() const { | |||||
| MS_EXCEPTION_IF_NULL(element_type_); | |||||
| if (IsGeneric()) { | |||||
| return std::make_shared<UndeterminedType>(); | |||||
| } | |||||
| return std::make_shared<UndeterminedType>(element_type_->DeepCopy()); | |||||
| } | |||||
| std::string UndeterminedType::ToReprString() const { | |||||
| if (element_type_ == nullptr) { | |||||
| return "Undetermined"; | |||||
| } | |||||
| return "Undetermined[" + element_type_->ToReprString() + "]"; | |||||
| } | |||||
| std::string UndeterminedType::ToString() const { | |||||
| if (element_type_ == nullptr) { | |||||
| return "Undetermined"; | |||||
| } | |||||
| return "Undetermined[" + element_type_->ToString() + "]"; | |||||
| } | |||||
| std::string UndeterminedType::DumpText() const { | |||||
| if (element_type_ == nullptr) { | |||||
| return "Undetermined"; | |||||
| } | |||||
| return "Undetermined[" + element_type_->DumpText() + "]"; | |||||
| } | |||||
| bool UndeterminedType::operator==(const Type &other) const { | |||||
| if (!IsSameObjectType(*this, other)) { | |||||
| return false; | |||||
| } | |||||
| auto other_elem_type = static_cast<const UndeterminedType &>(other).element_type_; | |||||
| if (element_type_ == nullptr && other_elem_type == nullptr) { | |||||
| return true; | |||||
| } else if (element_type_ == nullptr || other_elem_type == nullptr) { | |||||
| return false; | |||||
| } | |||||
| return *element_type_ == *other_elem_type; | |||||
| } | |||||
| TypePtr TensorType::DeepCopy() const { | |||||
| MS_EXCEPTION_IF_NULL(element_type_); | |||||
| if (IsGeneric()) { | |||||
| return std::make_shared<TensorType>(); | |||||
| } | |||||
| return std::make_shared<TensorType>(element_type_->DeepCopy()); | |||||
| } | |||||
| std::string TensorType::ToReprString() const { | |||||
| if (element_type_ == nullptr) { | |||||
| return "tensor"; | |||||
| } | |||||
| return "tensor[" + element_type_->ToReprString() + "]"; | |||||
| } | |||||
| std::string TensorType::ToString() const { | |||||
| if (element_type_ == nullptr) { | |||||
| return "Tensor"; | |||||
| } | |||||
| return "Tensor[" + element_type_->ToString() + "]"; | |||||
| } | |||||
| std::string TensorType::DumpText() const { | |||||
| if (element_type_ == nullptr) { | |||||
| return "Tensor"; | |||||
| } | |||||
| return "Tensor(" + element_type_->DumpText() + ")"; | |||||
| } | |||||
| bool TensorType::operator==(const Type &other) const { | |||||
| if (!IsSameObjectType(*this, other)) { | |||||
| return false; | |||||
| } | |||||
| auto other_elem_type = static_cast<const TensorType &>(other).element_type_; | |||||
| // When element_type_ = nullptr, which means any type of Array. | |||||
| if (element_type_ == nullptr && other_elem_type == nullptr) { | |||||
| return true; | |||||
| } else if (element_type_ == nullptr || other_elem_type == nullptr) { | |||||
| return false; | |||||
| } | |||||
| return *element_type_ == *other_elem_type; | |||||
| } | |||||
| TypePtr RowTensorType::DeepCopy() const { | |||||
| MS_EXCEPTION_IF_NULL(element_type_); | |||||
| if (IsGeneric()) { | |||||
| return std::make_shared<RowTensorType>(); | |||||
| } | |||||
| return std::make_shared<RowTensorType>(element_type_->DeepCopy()); | |||||
| } | |||||
| std::string RowTensorType::ToReprString() const { | |||||
| if (element_type_ == nullptr) { | |||||
| return "RowTensor"; | |||||
| } | |||||
| return "RowTensor[" + element_type_->ToReprString() + "]"; | |||||
| } | |||||
| std::string RowTensorType::ToString() const { | |||||
| if (element_type_ == nullptr) { | |||||
| return "RowTensor"; | |||||
| } | |||||
| return "RowTensor[" + element_type_->ToString() + "]"; | |||||
| } | |||||
| std::string RowTensorType::DumpText() const { | |||||
| if (element_type_ == nullptr) { | |||||
| return "RowTensor"; | |||||
| } | |||||
| return "RowTensor[" + element_type_->DumpText() + "]"; | |||||
| } | |||||
| bool RowTensorType::operator==(const Type &other) const { | |||||
| if (!IsSameObjectType(*this, other)) { | |||||
| return false; | |||||
| } | |||||
| auto other_elem_type = static_cast<const RowTensorType &>(other).element_type_; | |||||
| if (element_type_ == nullptr && other_elem_type == nullptr) { | |||||
| return true; | |||||
| } else if (element_type_ == nullptr || other_elem_type == nullptr) { | |||||
| return false; | |||||
| } | |||||
| return *element_type_ == *other_elem_type; | |||||
| } | |||||
| TypePtr SparseTensorType::DeepCopy() const { | |||||
| MS_EXCEPTION_IF_NULL(element_type_); | |||||
| if (IsGeneric()) { | |||||
| return std::make_shared<SparseTensorType>(); | |||||
| } | |||||
| return std::make_shared<SparseTensorType>(element_type_->DeepCopy()); | |||||
| } | |||||
| std::string SparseTensorType::ToReprString() const { | |||||
| if (element_type_ == nullptr) { | |||||
| return "SparseTensor"; | |||||
| } | |||||
| return "SparseTensor[" + element_type_->ToReprString() + "]"; | |||||
| } | |||||
| std::string SparseTensorType::ToString() const { | |||||
| if (element_type_ == nullptr) { | |||||
| return "SparseTensor"; | |||||
| } | |||||
| return "SparseTensor[" + element_type_->ToString() + "]"; | |||||
| } | |||||
| std::string SparseTensorType::DumpText() const { | |||||
| if (element_type_ == nullptr) { | |||||
| return "SparseTensor"; | |||||
| } | |||||
| return "SparseTensor[" + element_type_->DumpText() + "]"; | |||||
| } | |||||
| bool SparseTensorType::operator==(const Type &other) const { | |||||
| if (!IsSameObjectType(*this, other)) { | |||||
| return false; | |||||
| } | |||||
| auto other_elem_type = static_cast<const SparseTensorType &>(other).element_type_; | |||||
| if (element_type_ == nullptr && other_elem_type == nullptr) { | |||||
| return true; | |||||
| } else if (element_type_ == nullptr || other_elem_type == nullptr) { | |||||
| return false; | |||||
| } | |||||
| return *element_type_ == *other_elem_type; | |||||
| } | |||||
| Function::Function() : Object(kObjectTypeFunction) { | Function::Function() : Object(kObjectTypeFunction) { | ||||
| args_ = std::vector<TypePtr>(); | args_ = std::vector<TypePtr>(); | ||||
| retval_ = nullptr; | retval_ = nullptr; | ||||
| @@ -372,4 +203,8 @@ std::ostream &operator<<(std::ostream &os, const std::shared_ptr<Problem> proble | |||||
| os << problem->ToString(); | os << problem->ToString(); | ||||
| return os; | return os; | ||||
| } | } | ||||
| const TypePtr kTensorTypeFP16 = std::make_shared<TensorType>(std::make_shared<Float>(16)); | |||||
| const TypePtr kTensorTypeFP32 = std::make_shared<TensorType>(std::make_shared<Float>(32)); | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -32,10 +32,11 @@ | |||||
| #include "ir/named.h" | #include "ir/named.h" | ||||
| #include "ir/dtype/type.h" | #include "ir/dtype/type.h" | ||||
| #include "ir/dtype/ref.h" | |||||
| #include "ir/dtype/number.h" | #include "ir/dtype/number.h" | ||||
| #include "ir/dtype/container.h" | #include "ir/dtype/container.h" | ||||
| #include "ir/dtype/empty.h" | #include "ir/dtype/empty.h" | ||||
| #include "ir/dtype/tensor_type.h" | |||||
| #include "ir/dtype/ref.h" | |||||
| /* namespace to support intermediate representation definition */ | /* namespace to support intermediate representation definition */ | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -108,98 +109,6 @@ class Slice : public Object { | |||||
| }; | }; | ||||
| using SlicePtr = std::shared_ptr<Slice>; | using SlicePtr = std::shared_ptr<Slice>; | ||||
| class UndeterminedType : public Object { | |||||
| public: | |||||
| UndeterminedType() : Object(kObjectTypeUndeterminedType) {} | |||||
| explicit UndeterminedType(const TypePtr &ele) | |||||
| : Object(kObjectTypeUndeterminedType, kMetaTypeObject, false), element_type_(ele) {} | |||||
| ~UndeterminedType() override = default; | |||||
| MS_DECLARE_PARENT(UndeterminedType, Object) | |||||
| TypeId generic_type_id() const override { return kObjectTypeUndeterminedType; } | |||||
| const TypePtr element() const { return element_type_; } | |||||
| void set_element(const TypePtr &element_type) { element_type_ = element_type; } | |||||
| TypePtr DeepCopy() const override; | |||||
| std::string ToString() const override; | |||||
| std::string ToReprString() const override; | |||||
| std::string DumpText() const override; | |||||
| bool operator==(const Type &other) const override; | |||||
| protected: | |||||
| TypePtr element_type_; | |||||
| }; | |||||
| using MetaTensorTypePtr = std::shared_ptr<UndeterminedType>; | |||||
| class TensorType : public Object { | |||||
| public: | |||||
| TensorType() : Object(kObjectTypeTensorType, kObjectTypeUndeterminedType) {} | |||||
| explicit TensorType(const TypePtr &ele) | |||||
| : Object(kObjectTypeTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {} | |||||
| ~TensorType() override = default; | |||||
| MS_DECLARE_PARENT(TensorType, Object) | |||||
| TypeId generic_type_id() const override { return kObjectTypeTensorType; } | |||||
| const TypePtr element() const { return element_type_; } | |||||
| void set_element(const TypePtr &element_type) { element_type_ = element_type; } | |||||
| TypePtr DeepCopy() const override; | |||||
| std::string ToString() const override; | |||||
| std::string ToReprString() const override; | |||||
| std::string DumpText() const override; | |||||
| bool operator==(const Type &other) const override; | |||||
| private: | |||||
| TypePtr element_type_; | |||||
| }; | |||||
| using TensorTypePtr = std::shared_ptr<TensorType>; | |||||
| class RowTensorType : public Object { | |||||
| public: | |||||
| RowTensorType() : Object(kObjectTypeRowTensorType, kObjectTypeUndeterminedType) {} | |||||
| explicit RowTensorType(const TypePtr &ele) | |||||
| : Object(kObjectTypeRowTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {} | |||||
| ~RowTensorType() override = default; | |||||
| MS_DECLARE_PARENT(RowTensorType, Object) | |||||
| TypeId generic_type_id() const override { return kObjectTypeRowTensorType; } | |||||
| const TypePtr element() const { return element_type_; } | |||||
| void set_element(const TypePtr &element_type) { element_type_ = element_type; } | |||||
| TypePtr DeepCopy() const override; | |||||
| std::string ToString() const override; | |||||
| std::string ToReprString() const override; | |||||
| std::string DumpText() const override; | |||||
| bool operator==(const Type &other) const override; | |||||
| private: | |||||
| TypePtr element_type_; | |||||
| }; | |||||
| using RowTensorTypePtr = std::shared_ptr<RowTensorType>; | |||||
| class SparseTensorType : public Object { | |||||
| public: | |||||
| SparseTensorType() : Object(kObjectTypeSparseTensorType, kObjectTypeUndeterminedType) {} | |||||
| explicit SparseTensorType(const TypePtr &ele) | |||||
| : Object(kObjectTypeSparseTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {} | |||||
| ~SparseTensorType() override = default; | |||||
| MS_DECLARE_PARENT(SparseTensorType, Object) | |||||
| TypeId generic_type_id() const override { return kObjectTypeSparseTensorType; } | |||||
| const TypePtr element() const { return element_type_; } | |||||
| void set_element(const TypePtr &element_type) { element_type_ = element_type; } | |||||
| TypePtr DeepCopy() const override; | |||||
| std::string ToString() const override; | |||||
| std::string ToReprString() const override; | |||||
| std::string DumpText() const override; | |||||
| bool operator==(const Type &other) const override; | |||||
| private: | |||||
| TypePtr element_type_; | |||||
| }; | |||||
| using SparseTensorTypePtr = std::shared_ptr<SparseTensorType>; | |||||
| class Function : public Object { | class Function : public Object { | ||||
| public: | public: | ||||
| Function(); | Function(); | ||||
| @@ -353,6 +262,9 @@ extern const TypePtr kDict; | |||||
| extern const TypePtr kSlice; | extern const TypePtr kSlice; | ||||
| extern const TypePtr kKeyword; | extern const TypePtr kKeyword; | ||||
| extern const TypePtr kTensorType; | extern const TypePtr kTensorType; | ||||
| extern const TypePtr kTensorTypeFP16; | |||||
| extern const TypePtr kTensorTypeFP32; | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CORE_IR_DTYPE_H_ | #endif // MINDSPORE_CORE_IR_DTYPE_H_ | ||||
| @@ -68,6 +68,8 @@ class Number : public Object { | |||||
| const int nbits_; | const int nbits_; | ||||
| }; | }; | ||||
| using NumberPtr = std::shared_ptr<Number>; | |||||
| // Bool | // Bool | ||||
| class Bool : public Number { | class Bool : public Number { | ||||
| public: | public: | ||||
| @@ -19,15 +19,15 @@ | |||||
| #include <cstdlib> | #include <cstdlib> | ||||
| #include <algorithm> | #include <algorithm> | ||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| #include "ir/dtype/tensor_type.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| TypePtr RefType::DeepCopy() const { | TypePtr RefType::DeepCopy() const { | ||||
| if (IsGeneric()) { | if (IsGeneric()) { | ||||
| return std::make_shared<RefType>(); | return std::make_shared<RefType>(); | ||||
| } else { | } else { | ||||
| auto subtype = subtype_->DeepCopy(); | |||||
| auto subtype_origin = subtype_origin_->DeepCopy(); | |||||
| return std::make_shared<RefType>(subtype, subtype_origin); | |||||
| auto subtype = TensorType::DeepCopy()->cast<TensorTypePtr>(); | |||||
| return std::make_shared<RefType>(subtype); | |||||
| } | } | ||||
| } | } | ||||
| @@ -39,7 +39,7 @@ std::string RefType::DumpText() const { | |||||
| buffer << "Ref"; | buffer << "Ref"; | ||||
| } else { | } else { | ||||
| buffer << "Ref["; | buffer << "Ref["; | ||||
| buffer << subtype_->DumpText() << "]"; | |||||
| buffer << TensorType::DumpText() << "]"; | |||||
| } | } | ||||
| return buffer.str(); | return buffer.str(); | ||||
| } | } | ||||
| @@ -17,21 +17,13 @@ | |||||
| #ifndef MINDSPORE_CORE_IR_DTYPE_REF_H_ | #ifndef MINDSPORE_CORE_IR_DTYPE_REF_H_ | ||||
| #define MINDSPORE_CORE_IR_DTYPE_REF_H_ | #define MINDSPORE_CORE_IR_DTYPE_REF_H_ | ||||
| #include <cstddef> | |||||
| #include <iostream> | |||||
| #include <initializer_list> | |||||
| #include <map> | |||||
| #include <memory> | #include <memory> | ||||
| #include <utility> | |||||
| #include <sstream> | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | |||||
| #include <type_traits> | |||||
| #include <unordered_map> | |||||
| #include <algorithm> | |||||
| #include "base/base.h" | #include "base/base.h" | ||||
| #include "ir/named.h" | #include "ir/named.h" | ||||
| #include "ir/dtype/type.h" | #include "ir/dtype/type.h" | ||||
| #include "ir/dtype/tensor_type.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| // TypeRefKey type | // TypeRefKey type | ||||
| @@ -48,23 +40,16 @@ class RefKeyType : public Object { | |||||
| }; | }; | ||||
| // TypeRef type | // TypeRef type | ||||
| class RefType : public Object { | |||||
| class RefType : public TensorType { | |||||
| public: | public: | ||||
| RefType() : Object(kObjectTypeRef) {} | |||||
| RefType(const TypePtr &subtype, const TypePtr &subtype_origin) | |||||
| : Object(kObjectTypeRef, false), subtype_(subtype), subtype_origin_(subtype_origin) {} | |||||
| RefType() : TensorType() {} | |||||
| explicit RefType(const TensorTypePtr &subtype) : TensorType(subtype->element()) {} | |||||
| ~RefType() override {} | ~RefType() override {} | ||||
| MS_DECLARE_PARENT(RefType, Object) | |||||
| MS_DECLARE_PARENT(RefType, TensorType) | |||||
| TypePtr subtype() const { return subtype_; } | |||||
| TypeId generic_type_id() const override { return kObjectTypeRef; } | |||||
| TypePtr DeepCopy() const override; | TypePtr DeepCopy() const override; | ||||
| std::string ToString() const override; | std::string ToString() const override; | ||||
| std::string DumpText() const override; | std::string DumpText() const override; | ||||
| private: | |||||
| TypePtr subtype_; | |||||
| TypePtr subtype_origin_; | |||||
| }; | }; | ||||
| using RefTypePtr = std::shared_ptr<RefType>; | using RefTypePtr = std::shared_ptr<RefType>; | ||||
| @@ -0,0 +1,194 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ir/dtype/tensor_type.h" | |||||
| #include <string> | |||||
| #include <cstdlib> | |||||
| #include <algorithm> | |||||
| #include "utils/log_adapter.h" | |||||
| namespace mindspore { | |||||
| TypePtr UndeterminedType::DeepCopy() const { | |||||
| MS_EXCEPTION_IF_NULL(element_type_); | |||||
| if (IsGeneric()) { | |||||
| return std::make_shared<UndeterminedType>(); | |||||
| } | |||||
| return std::make_shared<UndeterminedType>(element_type_->DeepCopy()); | |||||
| } | |||||
| std::string UndeterminedType::ToReprString() const { | |||||
| if (element_type_ == nullptr) { | |||||
| return "Undetermined"; | |||||
| } | |||||
| return "Undetermined[" + element_type_->ToReprString() + "]"; | |||||
| } | |||||
| std::string UndeterminedType::ToString() const { | |||||
| if (element_type_ == nullptr) { | |||||
| return "Undetermined"; | |||||
| } | |||||
| return "Undetermined[" + element_type_->ToString() + "]"; | |||||
| } | |||||
| std::string UndeterminedType::DumpText() const { | |||||
| if (element_type_ == nullptr) { | |||||
| return "Undetermined"; | |||||
| } | |||||
| return "Undetermined[" + element_type_->DumpText() + "]"; | |||||
| } | |||||
| bool UndeterminedType::operator==(const Type &other) const { | |||||
| if (!IsSameObjectType(*this, other)) { | |||||
| return false; | |||||
| } | |||||
| auto other_elem_type = static_cast<const UndeterminedType &>(other).element_type_; | |||||
| if (element_type_ == nullptr && other_elem_type == nullptr) { | |||||
| return true; | |||||
| } else if (element_type_ == nullptr || other_elem_type == nullptr) { | |||||
| return false; | |||||
| } | |||||
| return *element_type_ == *other_elem_type; | |||||
| } | |||||
| TypePtr TensorType::DeepCopy() const { | |||||
| MS_EXCEPTION_IF_NULL(element_type_); | |||||
| if (IsGeneric()) { | |||||
| return std::make_shared<TensorType>(); | |||||
| } | |||||
| return std::make_shared<TensorType>(element_type_->DeepCopy()); | |||||
| } | |||||
| std::string TensorType::ToReprString() const { | |||||
| if (element_type_ == nullptr) { | |||||
| return "tensor"; | |||||
| } | |||||
| return "tensor[" + element_type_->ToReprString() + "]"; | |||||
| } | |||||
| std::string TensorType::ToString() const { | |||||
| if (element_type_ == nullptr) { | |||||
| return "Tensor"; | |||||
| } | |||||
| return "Tensor[" + element_type_->ToString() + "]"; | |||||
| } | |||||
| std::string TensorType::DumpText() const { | |||||
| if (element_type_ == nullptr) { | |||||
| return "Tensor"; | |||||
| } | |||||
| return "Tensor(" + element_type_->DumpText() + ")"; | |||||
| } | |||||
| bool TensorType::operator==(const Type &other) const { | |||||
| if (!IsSameObjectType(*this, other)) { | |||||
| return false; | |||||
| } | |||||
| auto other_elem_type = static_cast<const TensorType &>(other).element_type_; | |||||
| // When element_type_ = nullptr, which means any type of Array. | |||||
| if (element_type_ == nullptr && other_elem_type == nullptr) { | |||||
| return true; | |||||
| } else if (element_type_ == nullptr || other_elem_type == nullptr) { | |||||
| return false; | |||||
| } | |||||
| return *element_type_ == *other_elem_type; | |||||
| } | |||||
| TypePtr RowTensorType::DeepCopy() const { | |||||
| MS_EXCEPTION_IF_NULL(element_type_); | |||||
| if (IsGeneric()) { | |||||
| return std::make_shared<RowTensorType>(); | |||||
| } | |||||
| return std::make_shared<RowTensorType>(element_type_->DeepCopy()); | |||||
| } | |||||
| std::string RowTensorType::ToReprString() const { | |||||
| if (element_type_ == nullptr) { | |||||
| return "RowTensor"; | |||||
| } | |||||
| return "RowTensor[" + element_type_->ToReprString() + "]"; | |||||
| } | |||||
| std::string RowTensorType::ToString() const { | |||||
| if (element_type_ == nullptr) { | |||||
| return "RowTensor"; | |||||
| } | |||||
| return "RowTensor[" + element_type_->ToString() + "]"; | |||||
| } | |||||
| std::string RowTensorType::DumpText() const { | |||||
| if (element_type_ == nullptr) { | |||||
| return "RowTensor"; | |||||
| } | |||||
| return "RowTensor[" + element_type_->DumpText() + "]"; | |||||
| } | |||||
| bool RowTensorType::operator==(const Type &other) const { | |||||
| if (!IsSameObjectType(*this, other)) { | |||||
| return false; | |||||
| } | |||||
| auto other_elem_type = static_cast<const RowTensorType &>(other).element_type_; | |||||
| if (element_type_ == nullptr && other_elem_type == nullptr) { | |||||
| return true; | |||||
| } else if (element_type_ == nullptr || other_elem_type == nullptr) { | |||||
| return false; | |||||
| } | |||||
| return *element_type_ == *other_elem_type; | |||||
| } | |||||
| TypePtr SparseTensorType::DeepCopy() const { | |||||
| MS_EXCEPTION_IF_NULL(element_type_); | |||||
| if (IsGeneric()) { | |||||
| return std::make_shared<SparseTensorType>(); | |||||
| } | |||||
| return std::make_shared<SparseTensorType>(element_type_->DeepCopy()); | |||||
| } | |||||
| std::string SparseTensorType::ToReprString() const { | |||||
| if (element_type_ == nullptr) { | |||||
| return "SparseTensor"; | |||||
| } | |||||
| return "SparseTensor[" + element_type_->ToReprString() + "]"; | |||||
| } | |||||
| std::string SparseTensorType::ToString() const { | |||||
| if (element_type_ == nullptr) { | |||||
| return "SparseTensor"; | |||||
| } | |||||
| return "SparseTensor[" + element_type_->ToString() + "]"; | |||||
| } | |||||
| std::string SparseTensorType::DumpText() const { | |||||
| if (element_type_ == nullptr) { | |||||
| return "SparseTensor"; | |||||
| } | |||||
| return "SparseTensor[" + element_type_->DumpText() + "]"; | |||||
| } | |||||
| bool SparseTensorType::operator==(const Type &other) const { | |||||
| if (!IsSameObjectType(*this, other)) { | |||||
| return false; | |||||
| } | |||||
| auto other_elem_type = static_cast<const SparseTensorType &>(other).element_type_; | |||||
| if (element_type_ == nullptr && other_elem_type == nullptr) { | |||||
| return true; | |||||
| } else if (element_type_ == nullptr || other_elem_type == nullptr) { | |||||
| return false; | |||||
| } | |||||
| return *element_type_ == *other_elem_type; | |||||
| } | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,132 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CORE_IR_DTYPE_TENSORTYPE_H_ | |||||
| #define MINDSPORE_CORE_IR_DTYPE_TENSORTYPE_H_ | |||||
| #include <cstddef> | |||||
| #include <iostream> | |||||
| #include <initializer_list> | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <utility> | |||||
| #include <sstream> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include <type_traits> | |||||
| #include <unordered_map> | |||||
| #include <algorithm> | |||||
| #include "base/base.h" | |||||
| #include "ir/named.h" | |||||
| #include "ir/dtype/type.h" | |||||
| namespace mindspore { | |||||
| class UndeterminedType : public Object { | |||||
| public: | |||||
| UndeterminedType() : Object(kObjectTypeUndeterminedType) {} | |||||
| explicit UndeterminedType(const TypePtr &ele) | |||||
| : Object(kObjectTypeUndeterminedType, kMetaTypeObject, false), element_type_(ele) {} | |||||
| ~UndeterminedType() override = default; | |||||
| MS_DECLARE_PARENT(UndeterminedType, Object) | |||||
| TypeId generic_type_id() const override { return kObjectTypeUndeterminedType; } | |||||
| const TypePtr element() const { return element_type_; } | |||||
| void set_element(const TypePtr &element_type) { element_type_ = element_type; } | |||||
| TypePtr DeepCopy() const override; | |||||
| std::string ToString() const override; | |||||
| std::string ToReprString() const override; | |||||
| std::string DumpText() const override; | |||||
| bool operator==(const Type &other) const override; | |||||
| protected: | |||||
| TypePtr element_type_; | |||||
| }; | |||||
| using MetaTensorTypePtr = std::shared_ptr<UndeterminedType>; | |||||
| class TensorType : public Object { | |||||
| public: | |||||
| TensorType() : Object(kObjectTypeTensorType, kObjectTypeUndeterminedType) {} | |||||
| explicit TensorType(const TypePtr &ele) | |||||
| : Object(kObjectTypeTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {} | |||||
| ~TensorType() override = default; | |||||
| MS_DECLARE_PARENT(TensorType, Object) | |||||
| TypeId generic_type_id() const override { return kObjectTypeTensorType; } | |||||
| const TypePtr element() const { return element_type_; } | |||||
| void set_element(const TypePtr &element_type) { element_type_ = element_type; } | |||||
| TypePtr DeepCopy() const override; | |||||
| std::string ToString() const override; | |||||
| std::string ToReprString() const override; | |||||
| std::string DumpText() const override; | |||||
| bool operator==(const Type &other) const override; | |||||
| private: | |||||
| TypePtr element_type_; | |||||
| }; | |||||
| using TensorTypePtr = std::shared_ptr<TensorType>; | |||||
| class RowTensorType : public Object { | |||||
| public: | |||||
| RowTensorType() : Object(kObjectTypeRowTensorType, kObjectTypeUndeterminedType) {} | |||||
| explicit RowTensorType(const TypePtr &ele) | |||||
| : Object(kObjectTypeRowTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {} | |||||
| ~RowTensorType() override = default; | |||||
| MS_DECLARE_PARENT(RowTensorType, Object) | |||||
| TypeId generic_type_id() const override { return kObjectTypeRowTensorType; } | |||||
| const TypePtr element() const { return element_type_; } | |||||
| void set_element(const TypePtr &element_type) { element_type_ = element_type; } | |||||
| TypePtr DeepCopy() const override; | |||||
| std::string ToString() const override; | |||||
| std::string ToReprString() const override; | |||||
| std::string DumpText() const override; | |||||
| bool operator==(const Type &other) const override; | |||||
| private: | |||||
| TypePtr element_type_; | |||||
| }; | |||||
| using RowTensorTypePtr = std::shared_ptr<RowTensorType>; | |||||
| class SparseTensorType : public Object { | |||||
| public: | |||||
| SparseTensorType() : Object(kObjectTypeSparseTensorType, kObjectTypeUndeterminedType) {} | |||||
| explicit SparseTensorType(const TypePtr &ele) | |||||
| : Object(kObjectTypeSparseTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {} | |||||
| ~SparseTensorType() override = default; | |||||
| MS_DECLARE_PARENT(SparseTensorType, Object) | |||||
| TypeId generic_type_id() const override { return kObjectTypeSparseTensorType; } | |||||
| const TypePtr element() const { return element_type_; } | |||||
| void set_element(const TypePtr &element_type) { element_type_ = element_type; } | |||||
| TypePtr DeepCopy() const override; | |||||
| std::string ToString() const override; | |||||
| std::string ToReprString() const override; | |||||
| std::string DumpText() const override; | |||||
| bool operator==(const Type &other) const override; | |||||
| private: | |||||
| TypePtr element_type_; | |||||
| }; | |||||
| using SparseTensorTypePtr = std::shared_ptr<SparseTensorType>; | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CORE_IR_DTYPE_TENSORTYPE_H_ | |||||
| @@ -332,14 +332,11 @@ class FuncGraph : public FuncGraphBase { | |||||
| const std::vector<AnfNodePtr> ¶mter_obj_nodes() const { return paramter_obj_nodes_; } | const std::vector<AnfNodePtr> ¶mter_obj_nodes() const { return paramter_obj_nodes_; } | ||||
| void add_parameter_obj_node(const AnfNodePtr &p); | void add_parameter_obj_node(const AnfNodePtr &p); | ||||
| std::unordered_map<AnfNodePtr, AnfNodePtr> &make_ref_params() { return make_ref_params_; } | |||||
| std::unordered_map<std::string, ValuePtr> attrs_; | std::unordered_map<std::string, ValuePtr> attrs_; | ||||
| std::vector<BaseShapePtr> joined_shapes_; | std::vector<BaseShapePtr> joined_shapes_; | ||||
| std::unordered_map<std::string, FuncGraphTransform> transforms_; | std::unordered_map<std::string, FuncGraphTransform> transforms_; | ||||
| // parameter default value | // parameter default value | ||||
| std::map<std::string, AnfNodePtr> parameter_default_value_; | std::map<std::string, AnfNodePtr> parameter_default_value_; | ||||
| std::unordered_map<AnfNodePtr, AnfNodePtr> make_ref_params_; | |||||
| size_t seen_; | size_t seen_; | ||||
| std::list<CNodePtr> GetOrderedCnodes(); | std::list<CNodePtr> GetOrderedCnodes(); | ||||
| @@ -23,6 +23,7 @@ | |||||
| #include <string> | #include <string> | ||||
| #include "base/base.h" | #include "base/base.h" | ||||
| #include "ir/param_info.h" | |||||
| #include "ir/dtype.h" | #include "ir/dtype.h" | ||||
| #include "utils/convert_utils_base.h" | #include "utils/convert_utils_base.h" | ||||
| #include "utils/hashing.h" | #include "utils/hashing.h" | ||||
| @@ -163,6 +164,15 @@ class MetaTensor : public Value { | |||||
| return false; | return false; | ||||
| } | } | ||||
| } | } | ||||
| // Get tensor's param_info info. | |||||
| ParamInfoPtr param_info() const { return param_info_; } | |||||
| bool is_parameter() const { return is_parameter_; } | |||||
| // Set tensor's param_info info. | |||||
| void set_param_info(const ParamInfoPtr ¶m_info) { | |||||
| is_parameter_ = true; | |||||
| param_info_ = param_info; | |||||
| } | |||||
| protected: | protected: | ||||
| // brief Data type of the tensor. | // brief Data type of the tensor. | ||||
| @@ -184,6 +194,9 @@ class MetaTensor : public Value { | |||||
| // | // | ||||
| // Includes the format and data type of a tensor on device. | // Includes the format and data type of a tensor on device. | ||||
| DeviceInfo device_info_; | DeviceInfo device_info_; | ||||
| bool is_parameter_{false}; | |||||
| ParamInfoPtr param_info_{nullptr}; | |||||
| }; | }; | ||||
| using MetaTensorPtr = std::shared_ptr<MetaTensor>; | using MetaTensorPtr = std::shared_ptr<MetaTensor>; | ||||
| @@ -34,7 +34,16 @@ abstract::AbstractBasePtr MetaTensor::ToAbstract() { | |||||
| } | } | ||||
| auto tensor_shape = tens->shape(); | auto tensor_shape = tens->shape(); | ||||
| auto abs_tensor = std::make_shared<abstract::AbstractTensor>(dtype, tensor_shape); | auto abs_tensor = std::make_shared<abstract::AbstractTensor>(dtype, tensor_shape); | ||||
| abs_tensor->set_value(shared_from_base<MetaTensor>()); | |||||
| // if is parameter always no value. | |||||
| if (is_parameter()) { | |||||
| auto param_name = param_info()->name(); | |||||
| auto ref_key = std::make_shared<RefKey>(param_name); | |||||
| auto abs_ref_key = ref_key->ToAbstract(); | |||||
| abs_tensor = std::make_shared<abstract::AbstractRef>(abs_ref_key, abs_tensor); | |||||
| } else { | |||||
| abs_tensor->set_value(shared_from_base<MetaTensor>()); | |||||
| } | |||||
| return abs_tensor; | return abs_tensor; | ||||
| } | } | ||||
| @@ -62,6 +62,21 @@ class Named : public Value { | |||||
| }; | }; | ||||
| using NamedPtr = std::shared_ptr<Named>; | using NamedPtr = std::shared_ptr<Named>; | ||||
| struct NamedHasher { | |||||
| std::size_t operator()(NamedPtr const &name) const { | |||||
| std::size_t hash = name->Hash(); | |||||
| return hash; | |||||
| } | |||||
| }; | |||||
| struct NamedEqual { | |||||
| bool operator()(NamedPtr const &t1, NamedPtr const &t2) const { | |||||
| MS_EXCEPTION_IF_NULL(t1); | |||||
| MS_EXCEPTION_IF_NULL(t2); | |||||
| return *t1 == *t2; | |||||
| } | |||||
| }; | |||||
| class None : public Named { | class None : public Named { | ||||
| public: | public: | ||||
| None() : Named("None") {} | None() : Named("None") {} | ||||
| @@ -21,10 +21,13 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include "ir/anf.h" | |||||
| #include "ir/tensor.h" | |||||
| #include "ir/dtype.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| class ParamInfo; | |||||
| using ParamInfoPtr = std::shared_ptr<ParamInfo>; | |||||
| class ParamInfo { | class ParamInfo { | ||||
| public: | public: | ||||
| ParamInfo() {} | ParamInfo() {} | ||||
| @@ -55,7 +58,7 @@ class ParamInfo { | |||||
| int32_t cloned_index() const { return cloned_index_; } | int32_t cloned_index() const { return cloned_index_; } | ||||
| // Make a cloned parameter and update clone info. | // Make a cloned parameter and update clone info. | ||||
| ParamValuePtr Clone() { | |||||
| ParamInfoPtr Clone() { | |||||
| static std::atomic<int32_t> parameter_cloned_index{1}; | static std::atomic<int32_t> parameter_cloned_index{1}; | ||||
| int32_t index = parameter_cloned_index.fetch_add(1, std::memory_order_relaxed); | int32_t index = parameter_cloned_index.fetch_add(1, std::memory_order_relaxed); | ||||
| auto clone = std::make_shared<ParamInfo>(*this); | auto clone = std::make_shared<ParamInfo>(*this); | ||||
| @@ -467,6 +467,7 @@ Tensor &Tensor::AssignValue(const Tensor &tensor) { | |||||
| } | } | ||||
| return *this; | return *this; | ||||
| } | } | ||||
| abstract::AbstractBasePtr Tensor::ToAbstract() { | abstract::AbstractBasePtr Tensor::ToAbstract() { | ||||
| auto tens = shared_from_base<Tensor>(); | auto tens = shared_from_base<Tensor>(); | ||||
| auto dtype = tens->Dtype(); | auto dtype = tens->Dtype(); | ||||
| @@ -475,7 +476,15 @@ abstract::AbstractBasePtr Tensor::ToAbstract() { | |||||
| } | } | ||||
| auto tensor_shape = tens->shape(); | auto tensor_shape = tens->shape(); | ||||
| auto abs_tensor = std::make_shared<abstract::AbstractTensor>(dtype, tensor_shape); | auto abs_tensor = std::make_shared<abstract::AbstractTensor>(dtype, tensor_shape); | ||||
| abs_tensor->set_value(shared_from_base<Tensor>()); | |||||
| // if is parameter always no value. | |||||
| if (is_parameter()) { | |||||
| auto param_name = param_info()->name(); | |||||
| auto ref_key = std::make_shared<RefKey>(param_name); | |||||
| auto abs_ref_key = ref_key->ToAbstract(); | |||||
| abs_tensor = std::make_shared<abstract::AbstractRef>(abs_ref_key, abs_tensor); | |||||
| } else { | |||||
| abs_tensor->set_value(shared_from_base<Tensor>()); | |||||
| } | |||||
| return abs_tensor; | return abs_tensor; | ||||
| } | } | ||||
| @@ -200,16 +200,6 @@ bool StringImm::operator==(const Value &other) const { | |||||
| } | } | ||||
| bool StringImm::operator==(const StringImm &other) const { return str_ == other.str_; } | bool StringImm::operator==(const StringImm &other) const { return str_ == other.str_; } | ||||
| bool RefKey::operator==(const Value &other) const { | |||||
| if (other.isa<RefKey>()) { | |||||
| auto other_ = static_cast<const RefKey &>(other); | |||||
| return *this == other_; | |||||
| } else { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| bool RefKey::operator==(const RefKey &other) const { return tag_ == other.tag_; } | |||||
| bool AnyValue::operator==(const Value &other) const { | bool AnyValue::operator==(const Value &other) const { | ||||
| if (other.isa<AnyValue>()) { | if (other.isa<AnyValue>()) { | ||||
| return true; | return true; | ||||
| @@ -224,28 +224,21 @@ using StringImmPtr = std::shared_ptr<StringImm>; | |||||
| IMM_TRAITS(StringImmPtr, std::string) | IMM_TRAITS(StringImmPtr, std::string) | ||||
| IMM_TRAITS(StringImmPtr, const char *) | IMM_TRAITS(StringImmPtr, const char *) | ||||
| class RefKey : public Value { | |||||
| class RefKey : public Named { | |||||
| public: | public: | ||||
| explicit RefKey(const std::string &tag) : Value(kRefKeyType), tag_(tag), hash_(std::hash<std::string>{}(tag)) {} | |||||
| explicit RefKey(const std::string &tag) : Named(tag) {} | |||||
| ~RefKey() override = default; | ~RefKey() override = default; | ||||
| MS_DECLARE_PARENT(RefKey, Value) | |||||
| std::size_t hash() const override { return hash_; } | |||||
| const std::string &tag() const { return tag_; } | |||||
| bool operator==(const Value &other) const override; | |||||
| bool operator==(const RefKey &other) const; | |||||
| MS_DECLARE_PARENT(RefKey, Named) | |||||
| const std::string &tag() const { return name(); } | |||||
| abstract::AbstractBasePtr ToAbstract() override; | abstract::AbstractBasePtr ToAbstract() override; | ||||
| std::string ToString() const override { return "RefKey[" + tag_ + "]"; } | |||||
| std::string ToString() const override { return "RefKey[" + name() + "]"; } | |||||
| std::string DumpText() const override { | std::string DumpText() const override { | ||||
| std::ostringstream oss; | std::ostringstream oss; | ||||
| oss << "RefKey[\"" << tag_ << "\"]"; | |||||
| oss << "RefKey[\"" << name() << "\"]"; | |||||
| return oss.str(); | return oss.str(); | ||||
| } | } | ||||
| private: | |||||
| std::string tag_; | |||||
| std::size_t hash_ = 0; | |||||
| }; | }; | ||||
| using RefKeyPtr = std::shared_ptr<RefKey>; | using RefKeyPtr = std::shared_ptr<RefKey>; | ||||
| @@ -43,6 +43,8 @@ if(BUILD_CONVERTER) | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/scope.cc | ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/scope.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/value.cc | ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/value.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/value_extends.cc | ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/value_extends.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/dtype/ref.cc | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/dtype/tensor_type.cc | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/dtype/container.cc | ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/dtype/container.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/dtype/empty.cc | ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/dtype/empty.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/dtype/number.cc | ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/dtype/number.cc | ||||
| @@ -29,6 +29,8 @@ set(ANF_SRC | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/scope.cc | ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/scope.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/value.cc | ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/value.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/value_extends.cc | ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/value_extends.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/dtype/ref.cc | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/dtype/tensor_type.cc | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/dtype/container.cc | ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/dtype/container.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/dtype/empty.cc | ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/dtype/empty.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/dtype/number.cc | ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/dtype/number.cc | ||||
| @@ -23,7 +23,7 @@ from ...common import dtype as mstype | |||||
| from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register | from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register | ||||
| class Assign(PrimitiveWithInfer): | |||||
| class Assign(Primitive): | |||||
| """ | """ | ||||
| Assign `Parameter` with a value. | Assign `Parameter` with a value. | ||||
| @@ -18,7 +18,6 @@ | |||||
| import inspect | import inspect | ||||
| import copy | import copy | ||||
| from mindspore.common.api import _wrap_func | from mindspore.common.api import _wrap_func | ||||
| from mindspore.common import Parameter | |||||
| from mindspore.common._register_for_tensor import tensor_operator_registry | from mindspore.common._register_for_tensor import tensor_operator_registry | ||||
| from mindspore import context | from mindspore import context | ||||
| from .._c_expression import Primitive_, real_run_op, prim_type | from .._c_expression import Primitive_, real_run_op, prim_type | ||||
| @@ -410,16 +409,12 @@ def _run_op(obj, op_name, args): | |||||
| if op_name == "Cast" or obj.update_parameter: | if op_name == "Cast" or obj.update_parameter: | ||||
| cast_args = args | cast_args = args | ||||
| else: | else: | ||||
| cast_args = list() | |||||
| for arg in args: | |||||
| if isinstance(arg, Parameter): | |||||
| if arg.cast_type: | |||||
| cast_args.append(cast(arg, arg.cast_type)) | |||||
| else: | |||||
| cast_args.append(arg) | |||||
| else: | |||||
| cast_args.append(arg) | |||||
| output = real_run_op(obj, op_name, tuple(cast_args)) | |||||
| cast_args = args | |||||
| for idx, arg in enumerate(args): | |||||
| cast_type = getattr(arg, "cast_type", None) | |||||
| if cast_type: | |||||
| cast_args[idx] = cast(arg, cast_type) | |||||
| output = real_run_op(obj, op_name, cast_args) | |||||
| if not output: | if not output: | ||||
| raise RuntimeError("Pynative run op %s failed!" % op_name) | raise RuntimeError("Pynative run op %s failed!" % op_name) | ||||
| if len(output) == 1: | if len(output) == 1: | ||||
| @@ -118,26 +118,31 @@ class ControlMixedWhileIf(nn.Cell): | |||||
| self.var = Parameter(initializer(1, (1), mstype.float32), name="var") | self.var = Parameter(initializer(1, (1), mstype.float32), name="var") | ||||
| def construct(self, x, y, z, c2, c4): | def construct(self, x, y, z, c2, c4): | ||||
| out = self.assign(self.var, c4) | |||||
| out = c4 | |||||
| self.assign(self.var, c4) | |||||
| while x < c2: | while x < c2: | ||||
| y = self.assign(self.var, c4) | |||||
| y = c4 | |||||
| self.assign(self.var, c4) | |||||
| while y < c2 and x < c2: | while y < c2 and x < c2: | ||||
| if 2 * y < c2: | if 2 * y < c2: | ||||
| y = y + 2 | y = y + 2 | ||||
| else: | else: | ||||
| y = y + 1 | y = y + 1 | ||||
| out = out + y | out = out + y | ||||
| z = self.assign(self.var, c4) | |||||
| z = c4 | |||||
| self.assign(self.var, c4) | |||||
| while z < c2: | while z < c2: | ||||
| z = z + 1 | z = z + 1 | ||||
| out = out + z | out = out + z | ||||
| x = x + 1 | x = x + 1 | ||||
| out = out + x | out = out + x | ||||
| while x < 2 * c2: | while x < 2 * c2: | ||||
| y = self.assign(self.var, c4) | |||||
| y = c4 | |||||
| self.assign(self.var, c4) | |||||
| x = x + 1 | x = x + 1 | ||||
| while y < c2: | while y < c2: | ||||
| z = self.assign(self.var, c4) | |||||
| z = c4 | |||||
| self.assign(self.var, c4) | |||||
| while z < c2: | while z < c2: | ||||
| z = z + 1 | z = z + 1 | ||||
| if x < c2: | if x < c2: | ||||
| @@ -27,6 +27,7 @@ import mindspore.nn as nn | |||||
| from mindspore import Tensor | from mindspore import Tensor | ||||
| from mindspore import context | from mindspore import context | ||||
| from mindspore.ops import composite as C | from mindspore.ops import composite as C | ||||
| from mindspore.ops import operations as P | |||||
| from mindspore.common.api import ms_function, _executor | from mindspore.common.api import ms_function, _executor | ||||
| from mindspore.ops._grad.grad_base import bprop_getters | from mindspore.ops._grad.grad_base import bprop_getters | ||||
| from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer | from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer | ||||
| @@ -254,3 +255,60 @@ def test_bprop_with_wrong_output_shape(): | |||||
| net = BpropWithWrongOutputShapeCell() | net = BpropWithWrongOutputShapeCell() | ||||
| net.set_grad() | net.set_grad() | ||||
| grad_all(net)(Tensor(np.ones([64, 10]).astype(np.int32))) | grad_all(net)(Tensor(np.ones([64, 10]).astype(np.int32))) | ||||
| class AssignWhenInsertGrad(nn.Cell): | |||||
| """ NetWithNDarray definition """ | |||||
| def __init__(self): | |||||
| super(AssignWhenInsertGrad, self).__init__() | |||||
| self.gather = P.GatherV2() | |||||
| self.damping = Tensor(np.array([0.03, 0.03]).astype(np.float32)) | |||||
| self.cov_step = ms.Parameter(0, name="cov_step", requires_grad=False) | |||||
| self.freq = Tensor(278, ms.int32) | |||||
| self.getG = P.InsertGradientOf(self.save_gradient) | |||||
| def save_gradient(self, dout): | |||||
| self.cov_step = self.cov_step + self.freq | |||||
| return dout | |||||
| def construct(self, x): | |||||
| self.gather(self.damping, self.cov_step, 0) | |||||
| out = P.ReLU()(x) | |||||
| out = self.getG(out) | |||||
| return out | |||||
| grad_all = C.GradOperation('get_all', get_all=True) | |||||
| class GradNet(nn.Cell): | |||||
| def __init__(self, net): | |||||
| super(GradNet, self).__init__() | |||||
| self.net = net | |||||
| def construct(self, *inputs): | |||||
| out = self.net(*inputs) | |||||
| return out, grad_all(self.net)(*inputs) | |||||
| def test_assign_in_insert_grad(): | |||||
| context.set_context(mode=context.GRAPH_MODE) | |||||
| net = AssignWhenInsertGrad().to_float(ms.float16) | |||||
| input_data = np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32') | |||||
| net_back = GradNet(net) | |||||
| net_back(ms.Tensor(input_data)) | |||||
| class Assign(nn.Cell): | |||||
| """ NetWithNDarray definition """ | |||||
| def __init__(self): | |||||
| super(Assign, self).__init__() | |||||
| self.cov_step = ms.Parameter(0.0, name="cov_step", requires_grad=False) | |||||
| def construct(self, x): | |||||
| self.cov_step = self.cov_step + x | |||||
| return self.cov_step | |||||
| def test_assign(): | |||||
| context.set_context(mode=context.GRAPH_MODE) | |||||
| net = Assign() | |||||
| input_data = ms.Tensor(np.array(1).astype(np.int32)) | |||||
| net_back = GradNet(net) | |||||
| net_back(input_data) | |||||
| @@ -0,0 +1,144 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ test_cont_break """ | |||||
| import numpy as np | |||||
| import mindspore as ms | |||||
| from mindspore import Tensor, context, nn, ms_function | |||||
| from mindspore.nn import Cell | |||||
| from mindspore.ops import operations as P | |||||
| class WhileSubGraphParam(Cell): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.update = ms.Parameter(Tensor(1, ms.float32), "update") | |||||
| def construct(self, x, y, z): | |||||
| out1 = z | |||||
| while x < y: | |||||
| self.update = self.update + 1 | |||||
| out1 = out1 + 1 | |||||
| x = x + 1 | |||||
| return out1, self.update | |||||
| def test_while_loop_phi(): | |||||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | |||||
| x = Tensor(0, ms.float32) | |||||
| y = Tensor(10, ms.float32) | |||||
| z = Tensor(100, ms.float32) | |||||
| net = WhileSubGraphParam() | |||||
| net(x, y, z) | |||||
| class WhileSubGraphParam2(Cell): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.update = ms.Parameter(Tensor(1, ms.float32), "update") | |||||
| def construct(self, x, y, z): | |||||
| out1 = z | |||||
| i = self.update | |||||
| while x < y: | |||||
| i = i + 1 | |||||
| out1 = out1 + 1 | |||||
| x = x + 1 | |||||
| return out1, self.update | |||||
| def test_while_loop_phi_2(): | |||||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | |||||
| x = Tensor(0, ms.float32) | |||||
| y = Tensor(10, ms.float32) | |||||
| z = Tensor(100, ms.float32) | |||||
| net = WhileSubGraphParam2() | |||||
| net(x, y, z) | |||||
| class WhileSubGraphParam3(Cell): | |||||
| def __init__(self, initial_input_x): | |||||
| super().__init__() | |||||
| self.initial_input_x = initial_input_x | |||||
| self.X = ms.Parameter(initial_input_x, name="parameter_x") | |||||
| self.Y = ms.Parameter(self.initial_input_x, name="parameter_y") | |||||
| def construct(self): | |||||
| a = 0 | |||||
| while a < 3: | |||||
| self.X = self.X + self.Y | |||||
| a += 1 | |||||
| return self.X | |||||
| def test_while_loop_phi_3(): | |||||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | |||||
| x = Tensor(0, ms.float32) | |||||
| net = WhileSubGraphParam3(x) | |||||
| net() | |||||
| class ControlMixedWhileIf(nn.Cell): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.assign = P.Assign() | |||||
| self.var = ms.Parameter(ms.Tensor([1], ms.float32), name="var") | |||||
| @ms_function | |||||
| def construct(self, x, y, z, c2, c4): | |||||
| out = self.assign(self.var, c4) | |||||
| while x < c2: | |||||
| y = self.assign(self.var, c4) | |||||
| while y < c2 and x < c2: | |||||
| if 2 * y < c2: | |||||
| y = y + 2 | |||||
| else: | |||||
| y = y + 1 | |||||
| out = out + y | |||||
| z = self.assign(self.var, c4) | |||||
| while z < c2: | |||||
| z = z + 1 | |||||
| out = out + z | |||||
| x = x + 1 | |||||
| out = out + x | |||||
| while x < 2 * c2: | |||||
| y = self.assign(self.var, c4) | |||||
| x = x + 1 | |||||
| while y < c2: | |||||
| z = self.assign(self.var, c4) | |||||
| while z < c2: | |||||
| z = z + 1 | |||||
| if x < c2: | |||||
| y = y - 1 | |||||
| else: | |||||
| y = y + 1 | |||||
| out = out + z | |||||
| out = out + y | |||||
| out = out + x | |||||
| return out | |||||
| def test_mixed_while_if(): | |||||
| context.set_context(mode=context.PYNATIVE_MODE) | |||||
| x = np.array(2).astype(np.int32) | |||||
| y = np.array(14).astype(np.int32) | |||||
| z = np.array(1).astype(np.int32) | |||||
| c2 = Tensor([14], ms.int32) | |||||
| c4 = Tensor([0], ms.int32) | |||||
| net = ControlMixedWhileIf() | |||||
| output = net(Tensor(x), Tensor(y), Tensor(z), c2, c4) | |||||
| expect = np.array(3318).astype(np.int32) | |||||
| assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001) | |||||
| context.set_context(mode=context.GRAPH_MODE) | |||||
| @@ -22,7 +22,13 @@ from mindspore.ops.vm_impl_registry import vm_impl_registry as vm_impl_getters | |||||
| from .vm_interface import vm | from .vm_interface import vm | ||||
| # pylint: disable=unused-argument | # pylint: disable=unused-argument | ||||
| @vm_impl_getters.register(P.Assign) | |||||
| def vm_impl_assign(self): | |||||
| """Generate vm_impl function for Assign""" | |||||
| def vm_impl(x, value): | |||||
| x.assign_value(value) | |||||
| return x | |||||
| return vm_impl | |||||
| @vm_impl_getters.register(P.ExpandDims) | @vm_impl_getters.register(P.ExpandDims) | ||||
| def vm_impl_expand_dims(self): | def vm_impl_expand_dims(self): | ||||