| @@ -46,12 +46,12 @@ std::optional<std::string> Common::GetRealPath(const std::string &input_path) { | |||
| return std::nullopt; | |||
| } | |||
| #if defined(SYSTEM_ENV_POSIX) | |||
| if (nullptr == realpath(prefix_path.c_str(), real_path)) { | |||
| if (realpath(prefix_path.c_str(), real_path) == nullptr) { | |||
| MS_LOG(ERROR) << "dir " << prefix_path << " does not exist."; | |||
| return std::nullopt; | |||
| } | |||
| #elif defined(SYSTEM_ENV_WINDOWS) | |||
| if (nullptr == _fullpath(real_path, prefix_path.c_str(), PATH_MAX)) { | |||
| if (_fullpath(real_path, prefix_path.c_str(), PATH_MAX) == nullptr) { | |||
| MS_LOG(ERROR) << "dir " << prefix_path << " does not exist."; | |||
| return std::nullopt; | |||
| } | |||
| @@ -273,12 +273,13 @@ std::string Common::AddId(const std::string &filename, const std::string &suffix | |||
| static size_t g_id = 0; | |||
| std::ostringstream s; | |||
| auto i = filename.rfind(suffix); | |||
| int spaces = 4; | |||
| if (i >= filename.size()) { | |||
| s << filename; | |||
| s << "_" << std::setfill('0') << std::setw(4) << g_id; | |||
| s << "_" << std::setfill('0') << std::setw(spaces) << g_id; | |||
| } else { | |||
| s << filename.substr(0, i); | |||
| s << "_" << std::setfill('0') << std::setw(4) << g_id; | |||
| s << "_" << std::setfill('0') << std::setw(spaces) << g_id; | |||
| if (i + 1 < filename.size()) { | |||
| s << filename.substr(i); | |||
| } | |||
| @@ -48,7 +48,28 @@ DataType InferType(const AnyPtrList &list) { | |||
| return DataType::kUnknown; | |||
| } | |||
| enum OpType { ADD, SUB, MUL, DIV, MOD }; | |||
| template <typename T> | |||
| bool IsAddOverflow(const T &x, const T &y, const T &max, const T &min) { | |||
| return (y > 0 && (max - y) < x) || (y < 0 && (min - y) > x); | |||
| } | |||
| template <typename T> | |||
| bool IsSubOverflow(const T &x, const T &y, const T &max, const T &min) { | |||
| return (y < 0 && (max + y) < x) || (y > 0 && (min + y) > x); | |||
| } | |||
| template <typename T> | |||
| bool IsMulOverflow(const T &x, const T &y, const T &max, const T &min) { | |||
| return (x > 0 && y > 0 && (max / y) < x) || (x < 0 && y < 0 && (max / y) > x) || (x > 0 && y < 0 && (min / y) < x) || | |||
| (x < 0 && y > 0 && (min / y) > x); | |||
| } | |||
| template <typename T> | |||
| bool IsDivOverflow(const T &x, const T &y, const T &max, const T &min) { | |||
| return (x == min && static_cast<int64_t>(y) == -1); | |||
| } | |||
| enum class OpType { ADD, SUB, MUL, DIV, MOD }; | |||
| template <typename T> | |||
| bool IsSignedIntOverflow(T x, T y, OpType opType) { | |||
| @@ -56,20 +77,19 @@ bool IsSignedIntOverflow(T x, T y, OpType opType) { | |||
| auto min = std::numeric_limits<T>::min(); | |||
| if (opType == OpType::ADD) { | |||
| return (y > 0 && (max - y) < x) || (y < 0 && (min - y) > x); | |||
| return IsAddOverflow<T>(x, y, max, min); | |||
| } | |||
| if (opType == OpType::SUB) { | |||
| return (y < 0 && (max + y) < x) || (y > 0 && (min + y) > x); | |||
| return IsSubOverflow<T>(x, y, max, min); | |||
| } | |||
| if (opType == OpType::MUL) { | |||
| return (x > 0 && y > 0 && (max / y) < x) || (x < 0 && y < 0 && (max / y) > x) || | |||
| (x > 0 && y < 0 && (min / y) < x) || (x < 0 && y > 0 && (min / y) > x); | |||
| return IsMulOverflow<T>(x, y, max, min); | |||
| } | |||
| if (opType == OpType::DIV || opType == OpType::MOD) { | |||
| return x == min && static_cast<int64_t>(y) == -1; | |||
| return IsDivOverflow<T>(x, y, max, min); | |||
| } | |||
| MS_LOG(EXCEPTION) << "Unsupported operation type."; | |||
| @@ -199,7 +199,7 @@ AnfNodePtr DoCast(const AnfNodePtr ¶m, const TypeId &type_id, const FuncGrap | |||
| void DoAutoCast(const std::string &func_name, const std::vector<Signature> &signature, | |||
| const std::vector<TypePtr> &input_types, const FuncGraphPtr &graph, | |||
| std::vector<AnfNodePtr> *const op_inputs, const std::set<size_t> &write_indices) { | |||
| const std::set<size_t> &write_indices, std::vector<AnfNodePtr> *const op_inputs) { | |||
| std::vector<SignatureEnumDType> dtypes; | |||
| (void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes), | |||
| [](const Signature &sig) { return sig.dtype; }); | |||
| @@ -244,12 +244,8 @@ void DoAutoCast(const std::string &func_name, const std::vector<Signature> &sign | |||
| } | |||
| } | |||
| AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func_name, const ValuePtr &function, | |||
| const AbstractBasePtrList &args_spec_list, const std::vector<AnfNodePtr> ¶ms_list) { | |||
| // args: original inputs | |||
| auto &signature = GetSignature(function); | |||
| std::size_t sig_size = signature.size(); | |||
| auto has_var = (sig_size > 0 && signature[sig_size - 1].kind == SignatureEnumKind::kKindVarPositional); | |||
| void CheckSigSize(const size_t &sig_size, const bool &has_var, const AbstractBasePtrList &args_spec_list, | |||
| const std::string &func_name) { | |||
| if (sig_size > 0) { | |||
| if (has_var) { | |||
| if (sig_size - 1 > args_spec_list.size()) { | |||
| @@ -260,6 +256,15 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func | |||
| MS_LOG(EXCEPTION) << "Function " << func_name << "'s input length is not equal to Signature length."; | |||
| } | |||
| } | |||
| } | |||
| AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func_name, const ValuePtr &function, | |||
| const AbstractBasePtrList &args_spec_list, const std::vector<AnfNodePtr> ¶ms_list) { | |||
| // args: original inputs | |||
| auto &signature = GetSignature(function); | |||
| std::size_t sig_size = signature.size(); | |||
| auto has_var = (sig_size > 0 && signature[sig_size - 1].kind == SignatureEnumKind::kKindVarPositional); | |||
| CheckSigSize(sig_size, has_var, args_spec_list, func_name); | |||
| std::vector<AnfNodePtr> op_inputs; | |||
| std::set<size_t> write_indices; | |||
| std::vector<TypePtr> input_types; | |||
| @@ -308,7 +313,7 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func | |||
| } | |||
| // process default | |||
| ProcessDefault(func_name, args_spec_list.size(), signature, has_var, &op_inputs); | |||
| DoAutoCast(func_name, signature, input_types, func_graph, &op_inputs, write_indices); | |||
| DoAutoCast(func_name, signature, input_types, func_graph, write_indices, &op_inputs); | |||
| return func_graph->NewCNodeInOrder(op_inputs); | |||
| } | |||
| } // namespace | |||
| @@ -22,7 +22,7 @@ namespace irpass { | |||
| #define UPPER_FLT_LIMIT (FLT_MAX / 2.0) | |||
| #define LOWER_FLT_LIMIT (-FLT_MAX / 2.0) | |||
| // Define the checking mode | |||
| enum ScalarCheckingMode : int64_t { GREATER_EQUAL = 0, LESS }; | |||
| enum class ScalarCheckingMode : int64_t { GREATER_EQUAL = 0, LESS }; | |||
| bool IsNodeScalarTrueWith(const AnfNodePtr &node, const ScalarCheckingMode &checking_mode, const float &check_value) { | |||
| auto value_node = node->cast<ValueNodePtr>(); | |||
| @@ -38,7 +38,7 @@ bool IsNodeScalarTrueWith(const AnfNodePtr &node, const ScalarCheckingMode &chec | |||
| auto scalar = value->cast<ScalarPtr>(); | |||
| if (scalar != nullptr) { | |||
| if (scalar->isa<FloatImm>()) { | |||
| if (checking_mode == GREATER_EQUAL) { | |||
| if (checking_mode == ScalarCheckingMode::GREATER_EQUAL) { | |||
| return GetValue<float>(scalar) >= check_value; | |||
| } | |||
| return GetValue<float>(scalar) < check_value; | |||
| @@ -56,7 +56,7 @@ bool IsNodeScalarTrueWith(const AnfNodePtr &node, const ScalarCheckingMode &chec | |||
| TypeId tensor_type = tensor_ptr->Dtype()->type_id(); | |||
| if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat)) { | |||
| float *data = reinterpret_cast<float *>(tensor_ptr->data_c()); | |||
| if (checking_mode == GREATER_EQUAL) { | |||
| if (checking_mode == ScalarCheckingMode::GREATER_EQUAL) { | |||
| return data[0] >= check_value; | |||
| } | |||
| return data[0] < check_value; | |||
| @@ -66,7 +66,9 @@ bool IsNodeScalarTrueWith(const AnfNodePtr &node, const ScalarCheckingMode &chec | |||
| } | |||
| // check if a value is greater or equal 0.0 | |||
| bool IsNodeScalarPositive(const AnfNodePtr &node) { return IsNodeScalarTrueWith(node, GREATER_EQUAL, 0.0); } | |||
| bool IsNodeScalarPositive(const AnfNodePtr &node) { | |||
| return IsNodeScalarTrueWith(node, ScalarCheckingMode::GREATER_EQUAL, 0.0); | |||
| } | |||
| bool IsCNodePositive(const AnfNodePtr &node) { | |||
| if (IsPrimitiveCNode(node, prim::kPrimReduceSum) || IsPrimitiveCNode(node, prim::kPrimSqueeze)) { | |||
| @@ -87,10 +89,14 @@ bool IsCNodePositive(const AnfNodePtr &node) { | |||
| } | |||
| // check if a value is greater or equal UPPER_FLT_LIMIT | |||
| bool IsNodeScalarMaxFLT(const AnfNodePtr &node) { return IsNodeScalarTrueWith(node, GREATER_EQUAL, UPPER_FLT_LIMIT); } | |||
| bool IsNodeScalarMaxFLT(const AnfNodePtr &node) { | |||
| return IsNodeScalarTrueWith(node, ScalarCheckingMode::GREATER_EQUAL, UPPER_FLT_LIMIT); | |||
| } | |||
| // check if a value is smaller than LOWER_FLT_LIMIT | |||
| bool IsNodeScalarMinFLT(const AnfNodePtr &node) { return IsNodeScalarTrueWith(node, LESS, LOWER_FLT_LIMIT); } | |||
| bool IsNodeScalarMinFLT(const AnfNodePtr &node) { | |||
| return IsNodeScalarTrueWith(node, ScalarCheckingMode::LESS, LOWER_FLT_LIMIT); | |||
| } | |||
| AnfNodePtr ValueBasedEliminate::operator()(const OptimizerPtr &, const AnfNodePtr &node) { | |||
| PatternNode x, y, z; | |||
| @@ -57,13 +57,8 @@ using CompileGraphs = compile::CompileGraphs; | |||
| using abstract::AnalysisResult; | |||
| using mindspore::abstract::AnalysisContextPtr; | |||
| using mindspore::validator::Validate; | |||
| bool SimplifyDataStructuresPass(const ResourcePtr &res) { | |||
| MS_EXCEPTION_IF_NULL(res->func_graph()); | |||
| FuncGraphPtr func_graph = res->func_graph(); | |||
| bool changed = opt::SimplifyDataStructures(func_graph, res->manager()); | |||
| namespace { | |||
| void DoRenormalize(const bool &changed, const FuncGraphPtr &func_graph, const ResourcePtr &res) { | |||
| abstract::AbstractBasePtrList args_spec; | |||
| auto parameters = func_graph->parameters(); | |||
| (void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_spec), | |||
| @@ -73,6 +68,15 @@ bool SimplifyDataStructuresPass(const ResourcePtr &res) { | |||
| res->set_func_graph(new_fg); | |||
| } | |||
| res->set_args_spec(args_spec); | |||
| } | |||
| } // namespace | |||
| bool SimplifyDataStructuresPass(const ResourcePtr &res) { | |||
| MS_EXCEPTION_IF_NULL(res->func_graph()); | |||
| FuncGraphPtr func_graph = res->func_graph(); | |||
| bool changed = opt::SimplifyDataStructures(func_graph, res->manager()); | |||
| DoRenormalize(changed, func_graph, res); | |||
| return true; | |||
| } | |||
| @@ -99,16 +103,7 @@ bool CleanAfterOptAPass(const ResourcePtr &res) { | |||
| FuncGraphPtr func_graph = res->func_graph(); | |||
| bool changed = opt::CleanAfterOptA(func_graph, res->manager()); | |||
| abstract::AbstractBasePtrList args_spec; | |||
| auto parameters = func_graph->parameters(); | |||
| (void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_spec), | |||
| [](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); }); | |||
| if (changed) { | |||
| FuncGraphPtr new_fg = Renormalize(res, func_graph, args_spec); | |||
| res->set_func_graph(new_fg); | |||
| } | |||
| res->set_args_spec(args_spec); | |||
| DoRenormalize(changed, func_graph, res); | |||
| return true; | |||
| } | |||
| @@ -101,7 +101,8 @@ std::unordered_map<abstract::AbstractBasePtrList, int64_t, abstract::AbstractBas | |||
| namespace { | |||
| std::string GetBaseNameForIR(int64_t stage_idx, const std::string &action_name) { | |||
| std::ostringstream oss; | |||
| oss << std::setfill('0') << std::setw(2) << stage_idx << "_" << action_name; | |||
| int spaces = 2; | |||
| oss << std::setfill('0') << std::setw(spaces) << stage_idx << "_" << action_name; | |||
| return oss.str(); | |||
| } | |||
| @@ -138,8 +138,8 @@ class NoneOf(NoneOf_): | |||
| def __init__(self, patterns=None): | |||
| r""" | |||
| Args: | |||
| patterns(Union[list[:class:`mindspore.graph_utils.graph_pattern`]]: list of forbidden patterns, each element | |||
| should be one of the exposed Pattern instance. | |||
| patterns(Union[list[:class:`mindspore.graph_utils.graph_pattern`]]: list of forbidden patterns, each | |||
| element should be one of the exposed Pattern instance. | |||
| Raises: | |||
| TypeError: raise type error for invalid argument. | |||