| @@ -102,6 +102,51 @@ bool HasFraczGroupAttrAndSet(const AnfNodePtr &node, size_t index, int64_t group | |||
| return true; | |||
| } | |||
| std::vector<KernelWithIndex> GetCNodeNeighborFraczNodes(const FuncGraphManagerPtr &manager, const CNodePtr &cnode, | |||
| size_t index, int64_t groups) { | |||
| auto node_name = AnfAlgo::GetCNodeName(cnode); | |||
| auto input_num = AnfAlgo::GetInputTensorNum(cnode); | |||
| auto output_num = AnfAlgo::GetOutputTensorNum(cnode); | |||
| auto node_user = manager->node_users(); | |||
| std::vector<KernelWithIndex> ret; | |||
| if (node_name == kDependName || node_name == kLoadName) { | |||
| if (index != 0) { | |||
| return ret; | |||
| } | |||
| input_num = 1; | |||
| output_num = 1; | |||
| } | |||
| for (size_t i = 0; i < input_num; ++i) { | |||
| if (AnfAlgo::GetInputFormat(cnode, i) == kOpFormat_FRAC_Z) { | |||
| auto input = cnode->input(i + 1); | |||
| if (node_name == kTupleGetItemName) { | |||
| auto item_index = AnfAlgo::GetTupleGetItemOutIndex(cnode); | |||
| while (input->isa<CNode>() && AnfAlgo::GetCNodeName(input) == kDependName) { | |||
| AnfAlgo::SetNodeAttr(kAttrFracZGroup, MakeValue(groups), input); | |||
| input = input->cast<CNodePtr>()->input(1); | |||
| } | |||
| (void)ret.emplace_back(input, item_index); | |||
| } else { | |||
| (void)ret.emplace_back(input, 0); | |||
| } | |||
| } | |||
| } | |||
| if (kOptOperatorSet.find(node_name) == kOptOperatorSet.end()) { | |||
| for (size_t i = 0; i < output_num; ++i) { | |||
| if (AnfAlgo::GetOutputFormat(cnode, i) == kOpFormat_FRAC_Z) { | |||
| auto output = GetOutputItem(manager, cnode, groups, i); | |||
| if (output != nullptr) { | |||
| std::transform(node_user[output].begin(), node_user[output].end(), std::back_inserter(ret), | |||
| [](KernelWithIndex node_index) { | |||
| return KernelWithIndex{node_index.first, node_index.second - 1}; | |||
| }); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| return ret; | |||
| } | |||
| std::vector<KernelWithIndex> GetNeighborFraczNodes(const FuncGraphManagerPtr &manager, const AnfNodePtr &node, | |||
| size_t index, int64_t groups) { | |||
| std::vector<KernelWithIndex> ret; | |||
| @@ -129,43 +174,7 @@ std::vector<KernelWithIndex> GetNeighborFraczNodes(const FuncGraphManagerPtr &ma | |||
| }); | |||
| } | |||
| } else { | |||
| auto input_num = AnfAlgo::GetInputTensorNum(cnode); | |||
| auto output_num = AnfAlgo::GetOutputTensorNum(cnode); | |||
| if (node_name == kDependName || node_name == kLoadName) { | |||
| if (index != 0) { | |||
| return ret; | |||
| } | |||
| input_num = 1; | |||
| output_num = 1; | |||
| } | |||
| for (size_t i = 0; i < input_num; ++i) { | |||
| if (AnfAlgo::GetInputFormat(cnode, i) == kOpFormat_FRAC_Z) { | |||
| auto input = cnode->input(i + 1); | |||
| if (node_name == kTupleGetItemName) { | |||
| auto item_index = AnfAlgo::GetTupleGetItemOutIndex(cnode); | |||
| while (input->isa<CNode>() && AnfAlgo::GetCNodeName(input) == kDependName) { | |||
| AnfAlgo::SetNodeAttr(kAttrFracZGroup, MakeValue(groups), input); | |||
| input = input->cast<CNodePtr>()->input(1); | |||
| } | |||
| (void)ret.emplace_back(input, item_index); | |||
| } else { | |||
| (void)ret.emplace_back(input, 0); | |||
| } | |||
| } | |||
| } | |||
| if (kOptOperatorSet.find(node_name) == kOptOperatorSet.end()) { | |||
| for (size_t i = 0; i < output_num; ++i) { | |||
| if (AnfAlgo::GetOutputFormat(cnode, i) == kOpFormat_FRAC_Z) { | |||
| auto output = GetOutputItem(manager, cnode, groups, i); | |||
| if (output != nullptr) { | |||
| std::transform(node_user[output].begin(), node_user[output].end(), std::back_inserter(ret), | |||
| [](KernelWithIndex node_index) { | |||
| return KernelWithIndex{node_index.first, node_index.second - 1}; | |||
| }); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| ret = GetCNodeNeighborFraczNodes(manager, cnode, index, groups); | |||
| } | |||
| return ret; | |||
| } | |||
| @@ -28,8 +28,8 @@ constexpr int64_t kInputXDimP = -1; | |||
| constexpr int64_t kInputYDimR = -2; | |||
| std::vector<size_t> CalCdistBroadCastShape(std::vector<size_t> x_shape, std::vector<size_t> y_shape) { | |||
| x_shape.insert(x_shape.end() + kInputXDimP, 1); | |||
| y_shape.insert(y_shape.end() + kInputYDimR, 1); | |||
| (void)x_shape.insert(x_shape.end() + kInputXDimP, 1); | |||
| (void)y_shape.insert(y_shape.end() + kInputYDimR, 1); | |||
| if (x_shape.size() != y_shape.size()) { | |||
| MS_EXCEPTION(ValueError) << "For Cdist, input_x and input_y should have the same rank."; | |||
| } | |||
| @@ -39,13 +39,13 @@ std::vector<size_t> CalCdistBroadCastShape(std::vector<size_t> x_shape, std::vec | |||
| auto length = x_shape.size(); | |||
| std::vector<size_t> broadcast_shape; | |||
| std::copy(x_shape.begin(), x_shape.end() - SizeToLong(length), std::back_inserter(broadcast_shape)); | |||
| for (int64_t i = -length; i < 0; i++) { | |||
| if (x_shape[length + i] == 1) { | |||
| broadcast_shape.push_back(y_shape[length + i]); | |||
| } else if (y_shape[length + i] == 1) { | |||
| broadcast_shape.push_back(x_shape[length + i]); | |||
| } else if (x_shape[length + i] == y_shape[length + i]) { | |||
| broadcast_shape.push_back(x_shape[length + i]); | |||
| for (size_t i = length; i > 0; --i) { | |||
| if (x_shape[length - i] == 1) { | |||
| broadcast_shape.push_back(y_shape[length - i]); | |||
| } else if (y_shape[length - i] == 1) { | |||
| broadcast_shape.push_back(x_shape[length - i]); | |||
| } else if (x_shape[length - i] == y_shape[length - i]) { | |||
| broadcast_shape.push_back(x_shape[length - i]); | |||
| } else { | |||
| MS_EXCEPTION(ValueError) << "The two input shape can not broadcast, x_shape: " << x_shape << ", y_shape" | |||
| << y_shape; | |||
| @@ -64,7 +64,7 @@ AnfNodePtr AddBroadCastToNode(const FuncGraphPtr &func_graph, const AnfNodePtr & | |||
| auto expand_dims = func_graph->NewCNode(expand_dims_inputs); | |||
| auto dtype = AnfAlgo::GetOutputInferDataType(input_node, 0); | |||
| auto expand_shape = AnfAlgo::GetOutputInferShape(input_node, 0); | |||
| expand_shape.insert(expand_shape.end() + dim, 1); | |||
| (void)expand_shape.insert(expand_shape.end() + dim, 1); | |||
| AnfAlgo::SetOutputInferTypeAndShape({dtype}, {expand_shape}, expand_dims.get()); | |||
| AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(dim), expand_dims); | |||
| AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), expand_dims); | |||
| @@ -27,10 +27,9 @@ constexpr size_t kDiagInputNum = 1; | |||
| constexpr size_t kDiagInputMaxDim = 4; | |||
| template <typename T> | |||
| void SetAssistTensorData(void *data, T value, size_t dims_size) { | |||
| void SetAssistTensorData(void *data, const T &value, size_t dims_size) { | |||
| MS_EXCEPTION_IF_NULL(data); | |||
| auto tensor_data = reinterpret_cast<T *>(data); | |||
| MS_EXCEPTION_IF_NULL(tensor_data); | |||
| for (size_t i = 0; i < dims_size; ++i) { | |||
| tensor_data[(1 + dims_size) * i] = value; | |||
| } | |||
| @@ -46,7 +45,7 @@ ValueNodePtr DiagFission::CreateAssistNode(const FuncGraphPtr &func_graph, const | |||
| for (size_t i = 0; i < ori_shape.size(); i++) { | |||
| dims = dims * ori_shape[i]; | |||
| } | |||
| output_shape.insert(output_shape.end(), ori_shape.begin(), ori_shape.end()); | |||
| (void)output_shape.insert(output_shape.end(), ori_shape.begin(), ori_shape.end()); | |||
| auto type = AnfAlgo::GetOutputInferDataType(node, 0); | |||
| std::vector<int64_t> assist_shape; | |||
| std::transform(output_shape.begin(), output_shape.end(), std::back_inserter(assist_shape), SizeToLong); | |||
| @@ -95,7 +94,7 @@ const AnfNodePtr DiagFission::Process(const FuncGraphPtr &graph, const AnfNodePt | |||
| } | |||
| std::vector<AnfNodePtr> new_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimDiag->name()))}; | |||
| auto assist_const = CreateAssistNode(graph, diag_cnode, input_shape); | |||
| new_inputs.insert(new_inputs.end(), diag_cnode->inputs().begin() + 1, diag_cnode->inputs().end()); | |||
| (void)new_inputs.insert(new_inputs.end(), diag_cnode->inputs().begin() + 1, diag_cnode->inputs().end()); | |||
| new_inputs.push_back(assist_const); | |||
| CNodePtr new_cnode = graph->NewCNode(new_inputs); | |||
| MS_EXCEPTION_IF_NULL(new_cnode); | |||
| @@ -45,7 +45,8 @@ const AnfNodePtr DiagPartFission::Process(const FuncGraphPtr &func_graph, const | |||
| } | |||
| std::vector<AnfNodePtr> new_node_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimDiagPart->name()))}; | |||
| auto assist_node = CreateAssistNode(func_graph, diag_part_cnode, out_shape); | |||
| new_node_inputs.insert(new_node_inputs.end(), diag_part_cnode->inputs().begin() + 1, diag_part_cnode->inputs().end()); | |||
| (void)new_node_inputs.insert(new_node_inputs.end(), diag_part_cnode->inputs().begin() + 1, | |||
| diag_part_cnode->inputs().end()); | |||
| new_node_inputs.push_back(assist_node); | |||
| CNodePtr new_cnode = func_graph->NewCNode(new_node_inputs); | |||
| MS_EXCEPTION_IF_NULL(new_cnode); | |||
| @@ -143,7 +143,7 @@ AnfNodePtr ConstructFilter(const FuncGraphPtr &func_graph, const std::vector<int | |||
| // assist tensor 1 | |||
| int64_t c1 = (fc + kC0 - 1) / kC0; | |||
| std::vector<int64_t> assist_shape = {c1 * kd * kh * kw, 1, kC0, kC0}; // frac_z_3d | |||
| auto infer_shape = {IntToSize(1), LongToSize(fc), LongToSize(kd), LongToSize(kh), LongToSize(kw)}; | |||
| std::vector<size_t> infer_shape = {IntToSize(1), LongToSize(fc), LongToSize(kd), LongToSize(kh), LongToSize(kw)}; | |||
| float val = 1.0 / (kd * kh * kw); | |||
| if (divisor_override) { | |||
| val = 1.0 / divisor_override; | |||
| @@ -151,30 +151,8 @@ AnfNodePtr ConstructFilter(const FuncGraphPtr &func_graph, const std::vector<int | |||
| val = 1.0; | |||
| } | |||
| // create value node | |||
| tensor::TensorPtr assist_tensor = std::make_shared<tensor::Tensor>(kNumberTypeFloat16, assist_shape); | |||
| MS_EXCEPTION_IF_NULL(assist_tensor); | |||
| TensorTypePtr tensor_type = std::make_shared<TensorType>(kFloat16); | |||
| tensor::DeviceInfo device_info{kOpFormat_FRACTAL_Z_3D, tensor_type, kOpFormat_FRACTAL_Z_3D}; | |||
| assist_tensor->set_device_info(device_info); | |||
| auto tensor_data = reinterpret_cast<float16 *>(assist_tensor->data_c()); | |||
| int64_t cnt = c1 * kd * kh * kw; | |||
| for (int64_t i = 0; i < cnt; ++i) { | |||
| for (int64_t j = 0; j < kC0; ++j) { | |||
| for (int64_t k = 0; k < kC0; ++k) { | |||
| float t = j == k ? val : 0; | |||
| *tensor_data = float16(t); | |||
| ++tensor_data; | |||
| } | |||
| } | |||
| } | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat16, assist_shape); | |||
| auto kernel_graph = func_graph->cast<KernelGraphPtr>(); | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| auto value_node = kernel_graph->NewValueNode(x_abstract, assist_tensor); | |||
| kernel_graph->AddValueNodeToGraph(value_node); | |||
| AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat16}, {infer_shape}, value_node.get()); | |||
| return value_node; | |||
| return ConstructFilterValueNode(func_graph, val, assist_shape, infer_shape, cnt); | |||
| } | |||
| AnfNodePtr ConstructMultiplier(const FuncGraphPtr &func_graph, int64_t fn, int64_t fc, int64_t fd, int64_t fh, | |||
| @@ -235,6 +213,33 @@ AnfNodePtr ConstructMultiplier(const FuncGraphPtr &func_graph, int64_t fn, int64 | |||
| } | |||
| } // namespace | |||
| AnfNodePtr ConstructFilterValueNode(const FuncGraphPtr &func_graph, float val, const std::vector<int64_t> &assist_shape, | |||
| const std::vector<size_t> &infer_shape, int64_t cnt) { | |||
| tensor::TensorPtr assist_tensor = std::make_shared<tensor::Tensor>(kNumberTypeFloat16, assist_shape); | |||
| MS_EXCEPTION_IF_NULL(assist_tensor); | |||
| TensorTypePtr tensor_type = std::make_shared<TensorType>(kFloat16); | |||
| tensor::DeviceInfo device_info{kOpFormat_FRACTAL_Z_3D, tensor_type, kOpFormat_FRACTAL_Z_3D}; | |||
| assist_tensor->set_device_info(device_info); | |||
| auto tensor_data = reinterpret_cast<float16 *>(assist_tensor->data_c()); | |||
| for (int64_t i = 0; i < cnt; ++i) { | |||
| for (int64_t j = 0; j < kC0; ++j) { | |||
| for (int64_t k = 0; k < kC0; ++k) { | |||
| float t = j == k ? val : 0; | |||
| *tensor_data = float16(t); | |||
| ++tensor_data; | |||
| } | |||
| } | |||
| } | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat16, assist_shape); | |||
| auto kernel_graph = func_graph->cast<KernelGraphPtr>(); | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| auto value_node = kernel_graph->NewValueNode(x_abstract, assist_tensor); | |||
| kernel_graph->AddValueNodeToGraph(value_node); | |||
| AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat16}, {infer_shape}, value_node.get()); | |||
| return value_node; | |||
| } | |||
| const BaseRef AvgPool3DFusion::DefinePattern() const { | |||
| VarPtr Xs = std::make_shared<SeqVar>(); | |||
| return VectorRef({prim::kPrimAvgPool3D, Xs}); | |||
| @@ -31,6 +31,9 @@ class AvgPool3DFusion : public PatternProcessPass { | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| }; | |||
| AnfNodePtr ConstructFilterValueNode(const FuncGraphPtr &func_graph, float val, const std::vector<int64_t> &assist_shape, | |||
| const std::vector<size_t> &infer_shape, int64_t cnt); | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -19,6 +19,7 @@ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <algorithm> | |||
| #include "backend/optimizer/ascend/ir_fusion/avgpool_3d_fusion.h" | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "backend/optimizer/common/helper.h" | |||
| #include "base/core_ops.h" | |||
| @@ -105,7 +106,7 @@ AnfNodePtr ConstructFilter(const FuncGraphPtr &func_graph, const std::vector<int | |||
| // assist tensor 1 | |||
| int64_t c1 = (fc + kC0 - 1) / kC0; | |||
| std::vector<int64_t> assist_shape = {c1 * kd * kh * kw, 1, kC0, kC0}; // frac_z_3d | |||
| auto infer_shape = {IntToSize(1), LongToSize(fc), LongToSize(kd), LongToSize(kh), LongToSize(kw)}; | |||
| std::vector<size_t> infer_shape = {IntToSize(1), LongToSize(fc), LongToSize(kd), LongToSize(kh), LongToSize(kw)}; | |||
| float val = 1.0; | |||
| if (divisor_override) { | |||
| val = 1.0 / divisor_override; | |||
| @@ -113,29 +114,8 @@ AnfNodePtr ConstructFilter(const FuncGraphPtr &func_graph, const std::vector<int | |||
| val = 1.0 / (kd * kh * kw); | |||
| } | |||
| // create value node | |||
| tensor::TensorPtr assist_tensor = std::make_shared<tensor::Tensor>(kNumberTypeFloat16, assist_shape); | |||
| TensorTypePtr tensor_type = std::make_shared<TensorType>(kFloat16); | |||
| tensor::DeviceInfo device_info{kOpFormat_FRACTAL_Z_3D, tensor_type, kOpFormat_FRACTAL_Z_3D}; | |||
| assist_tensor->set_device_info(device_info); | |||
| auto tensor_data = reinterpret_cast<float16 *>(assist_tensor->data_c()); | |||
| int64_t cnt = c1 * kd * kh * kw; | |||
| for (int64_t i = 0; i < cnt; ++i) { | |||
| for (int64_t j = 0; j < kC0; ++j) { | |||
| for (int64_t k = 0; k < kC0; ++k) { | |||
| float t = j == k ? val : 0; | |||
| *tensor_data = float16(t); | |||
| ++tensor_data; | |||
| } | |||
| } | |||
| } | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat16, assist_shape); | |||
| auto kernel_graph = func_graph->cast<KernelGraphPtr>(); | |||
| auto value_node = kernel_graph->NewValueNode(x_abstract, assist_tensor); | |||
| MS_EXCEPTION_IF_NULL(value_node); | |||
| kernel_graph->AddValueNodeToGraph(value_node); | |||
| AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat16}, {infer_shape}, value_node.get()); | |||
| return value_node; | |||
| return ConstructFilterValueNode(func_graph, val, assist_shape, infer_shape, cnt); | |||
| } | |||
| AnfNodePtr ConstructMultiplier(const FuncGraphPtr &func_graph, const std::vector<size_t> &ori_shape, | |||
| @@ -102,7 +102,7 @@ CNodePtr CreateAllToAllvNode(const FuncGraphPtr &graph, const CNodePtr &all_to_a | |||
| MS_LOG(EXCEPTION) << "The node " << split->DebugString() << " should have at least one output, but got 0."; | |||
| } | |||
| std::vector<AnfNodePtr> all_to_all_v_input = {NewValueNode(std::make_shared<Primitive>(kAllToAllVOpName))}; | |||
| all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs.begin(), split_outputs.end()); | |||
| (void)all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs.begin(), split_outputs.end()); | |||
| auto all_to_all_v = graph->NewCNode(all_to_all_v_input); | |||
| MS_EXCEPTION_IF_NULL(all_to_all_v); | |||
| auto single_shape = AnfAlgo::GetOutputInferShape(split_outputs[0], 0); | |||
| @@ -135,7 +135,7 @@ CNodePtr CreateConcatNode(const FuncGraphPtr &graph, const CNodePtr &all_to_all, | |||
| MS_LOG(EXCEPTION) << "The node " << all_to_all_v->DebugString() << " should have at least one output, but got 0."; | |||
| } | |||
| std::vector<AnfNodePtr> concat_input = {NewValueNode(std::make_shared<Primitive>(kConcatOpName))}; | |||
| concat_input.insert(concat_input.end(), all_to_all_v_outputs.begin(), all_to_all_v_outputs.end()); | |||
| (void)concat_input.insert(concat_input.end(), all_to_all_v_outputs.begin(), all_to_all_v_outputs.end()); | |||
| auto concat = graph->NewCNode(concat_input); | |||
| MS_EXCEPTION_IF_NULL(concat); | |||
| auto single_shape = AnfAlgo::GetOutputInferShape(all_to_all_v_outputs[0], 0); | |||
| @@ -85,7 +85,6 @@ void GraphOptimizer::AddPassManager(const PassManagerPtr &pass_manager) { | |||
| FuncGraphPtr GraphOptimizer::Optimize(const FuncGraphPtr &func_graph, bool run_only_once) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| run_only_once_ = (pass_managers_.size() == 1) ? true : run_only_once; | |||
| // Performance risk by creating new manager each time | |||
| // cppcheck-suppress * | |||
| auto manager = Manage(func_graph, true); | |||
| @@ -269,7 +269,7 @@ EquivPtr PatternEngine::Match(const BaseRef &pattern, const BaseRef &expr, const | |||
| } | |||
| // 2. check equal | |||
| if (PatternEngine::AnfNodeEqual(pattern_ref, expr_ref)) { | |||
| if (opt::AnfEqual(pattern_ref, expr_ref)) { | |||
| return equiv; | |||
| } | |||
| @@ -301,57 +301,6 @@ EquivPtr PatternEngine::Match(const BaseRef &pattern, const BaseRef &expr, const | |||
| return equiv; | |||
| } | |||
| bool PatternEngine::AnfNodeEqual(const BaseRef &a, const BaseRef &b) { | |||
| if (utils::isa<AnfNodePtr>(a) && utils::isa<AnfNodePtr>(b)) { | |||
| auto a_node = utils::cast<AnfNodePtr>(a); | |||
| auto b_node = utils::cast<AnfNodePtr>(b); | |||
| MS_EXCEPTION_IF_NULL(a_node); | |||
| MS_EXCEPTION_IF_NULL(b_node); | |||
| if (IsValueNode<Primitive>(a_node) && IsValueNode<Primitive>(b_node)) { | |||
| auto a_value_node = a_node->cast<ValueNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(a_value_node); | |||
| auto a_value = a_value_node->value(); | |||
| MS_EXCEPTION_IF_NULL(a_value); | |||
| auto a_prim = a_value->cast<PrimitivePtr>(); | |||
| MS_EXCEPTION_IF_NULL(a_prim); | |||
| auto b_value_node = b_node->cast<ValueNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(b_value_node); | |||
| auto b_value = b_value_node->value(); | |||
| MS_EXCEPTION_IF_NULL(b_value); | |||
| auto b_prim = b_value->cast<PrimitivePtr>(); | |||
| MS_EXCEPTION_IF_NULL(b_prim); | |||
| return a_prim->name() == b_prim->name(); | |||
| } else if (a_node->isa<ValueNode>() && b_node->isa<ValueNode>()) { | |||
| auto a_value_node_ptr = a_node->cast<ValueNodePtr>(); | |||
| if (a_value_node_ptr == nullptr) { | |||
| MS_LOG(EXCEPTION) << "cast value node ptr fail"; | |||
| } | |||
| auto a_value_ptr = a_value_node_ptr->value(); | |||
| if (a_value_ptr == nullptr) { | |||
| MS_LOG(EXCEPTION) << "value ptr is nullptr"; | |||
| } | |||
| auto b_value_node_ptr = b_node->cast<ValueNodePtr>(); | |||
| if (b_value_node_ptr == nullptr) { | |||
| MS_LOG(EXCEPTION) << "cast value node ptr fail"; | |||
| } | |||
| auto b_value_ptr = b_value_node_ptr->value(); | |||
| if (b_value_ptr == nullptr) { | |||
| MS_LOG(EXCEPTION) << "value ptr is nullptr"; | |||
| } | |||
| return (*a_value_ptr) == (*b_value_ptr); | |||
| } | |||
| MS_LOG(DEBUG) << "check AnfNodePtr equal"; | |||
| } | |||
| if (utils::isa<FuncGraphPtr>(a) && utils::isa<FuncGraphPtr>(b)) { | |||
| MS_LOG(DEBUG) << "check GraphPtr equal"; | |||
| } | |||
| return a == b; | |||
| } | |||
| bool PatternEngine::CNodeTypeEqual(const BaseRef &a, const BaseRef &b) { | |||
| // To matchCNode and Kernel's type | |||
| if (utils::isa<CNode>(a) && utils::isa<CNode>(b)) { | |||
| @@ -179,7 +179,6 @@ class PatternEngine { | |||
| VectorRef *const values_expr) const; | |||
| bool ToVector(const VectorRef &pattern_ref, const VectorRef &expr_ref, VectorRef *const values_pattern, | |||
| VectorRef *const values_expr) const; | |||
| static bool AnfNodeEqual(const BaseRef &a, const BaseRef &b); | |||
| static bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b); | |||
| std::shared_ptr<Visitor> visitor_; | |||
| }; | |||