Browse Source

code check clean

tags/v1.6.0
yuchaojie 4 years ago
parent
commit
ed25dd2d21
11 changed files with 99 additions and 155 deletions
  1. +46
    -37
      mindspore/ccsrc/backend/optimizer/ascend/format_type/set_fracz_group_attr.cc
  2. +10
    -10
      mindspore/ccsrc/backend/optimizer/ascend/ir_fission/cdist_fission.cc
  3. +3
    -4
      mindspore/ccsrc/backend/optimizer/ascend/ir_fission/diag_fission.cc
  4. +2
    -1
      mindspore/ccsrc/backend/optimizer/ascend/ir_fission/diag_part_fission.cc
  5. +29
    -24
      mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/avgpool_3d_fusion.cc
  6. +3
    -0
      mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/avgpool_3d_fusion.h
  7. +3
    -23
      mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/avgpool_3d_grad_fusion.cc
  8. +2
    -2
      mindspore/ccsrc/backend/optimizer/ascend/mindir/all_to_all_unify_mindir.cc
  9. +0
    -1
      mindspore/ccsrc/backend/optimizer/common/optimizer.cc
  10. +1
    -52
      mindspore/ccsrc/backend/optimizer/common/pattern_engine.cc
  11. +0
    -1
      mindspore/ccsrc/backend/optimizer/common/pattern_engine.h

+ 46
- 37
mindspore/ccsrc/backend/optimizer/ascend/format_type/set_fracz_group_attr.cc View File

@@ -102,6 +102,51 @@ bool HasFraczGroupAttrAndSet(const AnfNodePtr &node, size_t index, int64_t group
return true;
}

std::vector<KernelWithIndex> GetCNodeNeighborFraczNodes(const FuncGraphManagerPtr &manager, const CNodePtr &cnode,
size_t index, int64_t groups) {
auto node_name = AnfAlgo::GetCNodeName(cnode);
auto input_num = AnfAlgo::GetInputTensorNum(cnode);
auto output_num = AnfAlgo::GetOutputTensorNum(cnode);
auto node_user = manager->node_users();
std::vector<KernelWithIndex> ret;
if (node_name == kDependName || node_name == kLoadName) {
if (index != 0) {
return ret;
}
input_num = 1;
output_num = 1;
}
for (size_t i = 0; i < input_num; ++i) {
if (AnfAlgo::GetInputFormat(cnode, i) == kOpFormat_FRAC_Z) {
auto input = cnode->input(i + 1);
if (node_name == kTupleGetItemName) {
auto item_index = AnfAlgo::GetTupleGetItemOutIndex(cnode);
while (input->isa<CNode>() && AnfAlgo::GetCNodeName(input) == kDependName) {
AnfAlgo::SetNodeAttr(kAttrFracZGroup, MakeValue(groups), input);
input = input->cast<CNodePtr>()->input(1);
}
(void)ret.emplace_back(input, item_index);
} else {
(void)ret.emplace_back(input, 0);
}
}
}
if (kOptOperatorSet.find(node_name) == kOptOperatorSet.end()) {
for (size_t i = 0; i < output_num; ++i) {
if (AnfAlgo::GetOutputFormat(cnode, i) == kOpFormat_FRAC_Z) {
auto output = GetOutputItem(manager, cnode, groups, i);
if (output != nullptr) {
std::transform(node_user[output].begin(), node_user[output].end(), std::back_inserter(ret),
[](KernelWithIndex node_index) {
return KernelWithIndex{node_index.first, node_index.second - 1};
});
}
}
}
}
return ret;
}

std::vector<KernelWithIndex> GetNeighborFraczNodes(const FuncGraphManagerPtr &manager, const AnfNodePtr &node,
size_t index, int64_t groups) {
std::vector<KernelWithIndex> ret;
@@ -129,43 +174,7 @@ std::vector<KernelWithIndex> GetNeighborFraczNodes(const FuncGraphManagerPtr &ma
});
}
} else {
auto input_num = AnfAlgo::GetInputTensorNum(cnode);
auto output_num = AnfAlgo::GetOutputTensorNum(cnode);
if (node_name == kDependName || node_name == kLoadName) {
if (index != 0) {
return ret;
}
input_num = 1;
output_num = 1;
}
for (size_t i = 0; i < input_num; ++i) {
if (AnfAlgo::GetInputFormat(cnode, i) == kOpFormat_FRAC_Z) {
auto input = cnode->input(i + 1);
if (node_name == kTupleGetItemName) {
auto item_index = AnfAlgo::GetTupleGetItemOutIndex(cnode);
while (input->isa<CNode>() && AnfAlgo::GetCNodeName(input) == kDependName) {
AnfAlgo::SetNodeAttr(kAttrFracZGroup, MakeValue(groups), input);
input = input->cast<CNodePtr>()->input(1);
}
(void)ret.emplace_back(input, item_index);
} else {
(void)ret.emplace_back(input, 0);
}
}
}
if (kOptOperatorSet.find(node_name) == kOptOperatorSet.end()) {
for (size_t i = 0; i < output_num; ++i) {
if (AnfAlgo::GetOutputFormat(cnode, i) == kOpFormat_FRAC_Z) {
auto output = GetOutputItem(manager, cnode, groups, i);
if (output != nullptr) {
std::transform(node_user[output].begin(), node_user[output].end(), std::back_inserter(ret),
[](KernelWithIndex node_index) {
return KernelWithIndex{node_index.first, node_index.second - 1};
});
}
}
}
}
ret = GetCNodeNeighborFraczNodes(manager, cnode, index, groups);
}
return ret;
}


+ 10
- 10
mindspore/ccsrc/backend/optimizer/ascend/ir_fission/cdist_fission.cc View File

@@ -28,8 +28,8 @@ constexpr int64_t kInputXDimP = -1;
constexpr int64_t kInputYDimR = -2;

std::vector<size_t> CalCdistBroadCastShape(std::vector<size_t> x_shape, std::vector<size_t> y_shape) {
x_shape.insert(x_shape.end() + kInputXDimP, 1);
y_shape.insert(y_shape.end() + kInputYDimR, 1);
(void)x_shape.insert(x_shape.end() + kInputXDimP, 1);
(void)y_shape.insert(y_shape.end() + kInputYDimR, 1);
if (x_shape.size() != y_shape.size()) {
MS_EXCEPTION(ValueError) << "For Cdist, input_x and input_y should have the same rank.";
}
@@ -39,13 +39,13 @@ std::vector<size_t> CalCdistBroadCastShape(std::vector<size_t> x_shape, std::vec
auto length = x_shape.size();
std::vector<size_t> broadcast_shape;
std::copy(x_shape.begin(), x_shape.end() - SizeToLong(length), std::back_inserter(broadcast_shape));
for (int64_t i = -length; i < 0; i++) {
if (x_shape[length + i] == 1) {
broadcast_shape.push_back(y_shape[length + i]);
} else if (y_shape[length + i] == 1) {
broadcast_shape.push_back(x_shape[length + i]);
} else if (x_shape[length + i] == y_shape[length + i]) {
broadcast_shape.push_back(x_shape[length + i]);
for (size_t i = length; i > 0; --i) {
if (x_shape[length - i] == 1) {
broadcast_shape.push_back(y_shape[length - i]);
} else if (y_shape[length - i] == 1) {
broadcast_shape.push_back(x_shape[length - i]);
} else if (x_shape[length - i] == y_shape[length - i]) {
broadcast_shape.push_back(x_shape[length - i]);
} else {
MS_EXCEPTION(ValueError) << "The two input shape can not broadcast, x_shape: " << x_shape << ", y_shape"
<< y_shape;
@@ -64,7 +64,7 @@ AnfNodePtr AddBroadCastToNode(const FuncGraphPtr &func_graph, const AnfNodePtr &
auto expand_dims = func_graph->NewCNode(expand_dims_inputs);
auto dtype = AnfAlgo::GetOutputInferDataType(input_node, 0);
auto expand_shape = AnfAlgo::GetOutputInferShape(input_node, 0);
expand_shape.insert(expand_shape.end() + dim, 1);
(void)expand_shape.insert(expand_shape.end() + dim, 1);
AnfAlgo::SetOutputInferTypeAndShape({dtype}, {expand_shape}, expand_dims.get());
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(dim), expand_dims);
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), expand_dims);


+ 3
- 4
mindspore/ccsrc/backend/optimizer/ascend/ir_fission/diag_fission.cc View File

@@ -27,10 +27,9 @@ constexpr size_t kDiagInputNum = 1;
constexpr size_t kDiagInputMaxDim = 4;

template <typename T>
void SetAssistTensorData(void *data, T value, size_t dims_size) {
void SetAssistTensorData(void *data, const T &value, size_t dims_size) {
MS_EXCEPTION_IF_NULL(data);
auto tensor_data = reinterpret_cast<T *>(data);
MS_EXCEPTION_IF_NULL(tensor_data);
for (size_t i = 0; i < dims_size; ++i) {
tensor_data[(1 + dims_size) * i] = value;
}
@@ -46,7 +45,7 @@ ValueNodePtr DiagFission::CreateAssistNode(const FuncGraphPtr &func_graph, const
for (size_t i = 0; i < ori_shape.size(); i++) {
dims = dims * ori_shape[i];
}
output_shape.insert(output_shape.end(), ori_shape.begin(), ori_shape.end());
(void)output_shape.insert(output_shape.end(), ori_shape.begin(), ori_shape.end());
auto type = AnfAlgo::GetOutputInferDataType(node, 0);
std::vector<int64_t> assist_shape;
std::transform(output_shape.begin(), output_shape.end(), std::back_inserter(assist_shape), SizeToLong);
@@ -95,7 +94,7 @@ const AnfNodePtr DiagFission::Process(const FuncGraphPtr &graph, const AnfNodePt
}
std::vector<AnfNodePtr> new_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimDiag->name()))};
auto assist_const = CreateAssistNode(graph, diag_cnode, input_shape);
new_inputs.insert(new_inputs.end(), diag_cnode->inputs().begin() + 1, diag_cnode->inputs().end());
(void)new_inputs.insert(new_inputs.end(), diag_cnode->inputs().begin() + 1, diag_cnode->inputs().end());
new_inputs.push_back(assist_const);
CNodePtr new_cnode = graph->NewCNode(new_inputs);
MS_EXCEPTION_IF_NULL(new_cnode);


+ 2
- 1
mindspore/ccsrc/backend/optimizer/ascend/ir_fission/diag_part_fission.cc View File

@@ -45,7 +45,8 @@ const AnfNodePtr DiagPartFission::Process(const FuncGraphPtr &func_graph, const
}
std::vector<AnfNodePtr> new_node_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimDiagPart->name()))};
auto assist_node = CreateAssistNode(func_graph, diag_part_cnode, out_shape);
new_node_inputs.insert(new_node_inputs.end(), diag_part_cnode->inputs().begin() + 1, diag_part_cnode->inputs().end());
(void)new_node_inputs.insert(new_node_inputs.end(), diag_part_cnode->inputs().begin() + 1,
diag_part_cnode->inputs().end());
new_node_inputs.push_back(assist_node);
CNodePtr new_cnode = func_graph->NewCNode(new_node_inputs);
MS_EXCEPTION_IF_NULL(new_cnode);


+ 29
- 24
mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/avgpool_3d_fusion.cc View File

@@ -143,7 +143,7 @@ AnfNodePtr ConstructFilter(const FuncGraphPtr &func_graph, const std::vector<int
// assist tensor 1
int64_t c1 = (fc + kC0 - 1) / kC0;
std::vector<int64_t> assist_shape = {c1 * kd * kh * kw, 1, kC0, kC0}; // frac_z_3d
auto infer_shape = {IntToSize(1), LongToSize(fc), LongToSize(kd), LongToSize(kh), LongToSize(kw)};
std::vector<size_t> infer_shape = {IntToSize(1), LongToSize(fc), LongToSize(kd), LongToSize(kh), LongToSize(kw)};
float val = 1.0 / (kd * kh * kw);
if (divisor_override) {
val = 1.0 / divisor_override;
@@ -151,30 +151,8 @@ AnfNodePtr ConstructFilter(const FuncGraphPtr &func_graph, const std::vector<int
val = 1.0;
}
// create value node
tensor::TensorPtr assist_tensor = std::make_shared<tensor::Tensor>(kNumberTypeFloat16, assist_shape);
MS_EXCEPTION_IF_NULL(assist_tensor);
TensorTypePtr tensor_type = std::make_shared<TensorType>(kFloat16);
tensor::DeviceInfo device_info{kOpFormat_FRACTAL_Z_3D, tensor_type, kOpFormat_FRACTAL_Z_3D};
assist_tensor->set_device_info(device_info);
auto tensor_data = reinterpret_cast<float16 *>(assist_tensor->data_c());
int64_t cnt = c1 * kd * kh * kw;
for (int64_t i = 0; i < cnt; ++i) {
for (int64_t j = 0; j < kC0; ++j) {
for (int64_t k = 0; k < kC0; ++k) {
float t = j == k ? val : 0;
*tensor_data = float16(t);
++tensor_data;
}
}
}

auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat16, assist_shape);
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
MS_EXCEPTION_IF_NULL(kernel_graph);
auto value_node = kernel_graph->NewValueNode(x_abstract, assist_tensor);
kernel_graph->AddValueNodeToGraph(value_node);
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat16}, {infer_shape}, value_node.get());
return value_node;
return ConstructFilterValueNode(func_graph, val, assist_shape, infer_shape, cnt);
}

AnfNodePtr ConstructMultiplier(const FuncGraphPtr &func_graph, int64_t fn, int64_t fc, int64_t fd, int64_t fh,
@@ -235,6 +213,33 @@ AnfNodePtr ConstructMultiplier(const FuncGraphPtr &func_graph, int64_t fn, int64
}
} // namespace

AnfNodePtr ConstructFilterValueNode(const FuncGraphPtr &func_graph, float val, const std::vector<int64_t> &assist_shape,
const std::vector<size_t> &infer_shape, int64_t cnt) {
tensor::TensorPtr assist_tensor = std::make_shared<tensor::Tensor>(kNumberTypeFloat16, assist_shape);
MS_EXCEPTION_IF_NULL(assist_tensor);
TensorTypePtr tensor_type = std::make_shared<TensorType>(kFloat16);
tensor::DeviceInfo device_info{kOpFormat_FRACTAL_Z_3D, tensor_type, kOpFormat_FRACTAL_Z_3D};
assist_tensor->set_device_info(device_info);
auto tensor_data = reinterpret_cast<float16 *>(assist_tensor->data_c());
for (int64_t i = 0; i < cnt; ++i) {
for (int64_t j = 0; j < kC0; ++j) {
for (int64_t k = 0; k < kC0; ++k) {
float t = j == k ? val : 0;
*tensor_data = float16(t);
++tensor_data;
}
}
}

auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat16, assist_shape);
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
MS_EXCEPTION_IF_NULL(kernel_graph);
auto value_node = kernel_graph->NewValueNode(x_abstract, assist_tensor);
kernel_graph->AddValueNodeToGraph(value_node);
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat16}, {infer_shape}, value_node.get());
return value_node;
}

const BaseRef AvgPool3DFusion::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();
return VectorRef({prim::kPrimAvgPool3D, Xs});


+ 3
- 0
mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/avgpool_3d_fusion.h View File

@@ -31,6 +31,9 @@ class AvgPool3DFusion : public PatternProcessPass {
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};

AnfNodePtr ConstructFilterValueNode(const FuncGraphPtr &func_graph, float val, const std::vector<int64_t> &assist_shape,
const std::vector<size_t> &infer_shape, int64_t cnt);
} // namespace opt
} // namespace mindspore



+ 3
- 23
mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/avgpool_3d_grad_fusion.cc View File

@@ -19,6 +19,7 @@
#include <memory>
#include <string>
#include <algorithm>
#include "backend/optimizer/ascend/ir_fusion/avgpool_3d_fusion.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/optimizer/common/helper.h"
#include "base/core_ops.h"
@@ -105,7 +106,7 @@ AnfNodePtr ConstructFilter(const FuncGraphPtr &func_graph, const std::vector<int
// assist tensor 1
int64_t c1 = (fc + kC0 - 1) / kC0;
std::vector<int64_t> assist_shape = {c1 * kd * kh * kw, 1, kC0, kC0}; // frac_z_3d
auto infer_shape = {IntToSize(1), LongToSize(fc), LongToSize(kd), LongToSize(kh), LongToSize(kw)};
std::vector<size_t> infer_shape = {IntToSize(1), LongToSize(fc), LongToSize(kd), LongToSize(kh), LongToSize(kw)};
float val = 1.0;
if (divisor_override) {
val = 1.0 / divisor_override;
@@ -113,29 +114,8 @@ AnfNodePtr ConstructFilter(const FuncGraphPtr &func_graph, const std::vector<int
val = 1.0 / (kd * kh * kw);
}
// create value node
tensor::TensorPtr assist_tensor = std::make_shared<tensor::Tensor>(kNumberTypeFloat16, assist_shape);
TensorTypePtr tensor_type = std::make_shared<TensorType>(kFloat16);
tensor::DeviceInfo device_info{kOpFormat_FRACTAL_Z_3D, tensor_type, kOpFormat_FRACTAL_Z_3D};
assist_tensor->set_device_info(device_info);
auto tensor_data = reinterpret_cast<float16 *>(assist_tensor->data_c());
int64_t cnt = c1 * kd * kh * kw;
for (int64_t i = 0; i < cnt; ++i) {
for (int64_t j = 0; j < kC0; ++j) {
for (int64_t k = 0; k < kC0; ++k) {
float t = j == k ? val : 0;
*tensor_data = float16(t);
++tensor_data;
}
}
}

auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat16, assist_shape);
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
auto value_node = kernel_graph->NewValueNode(x_abstract, assist_tensor);
MS_EXCEPTION_IF_NULL(value_node);
kernel_graph->AddValueNodeToGraph(value_node);
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat16}, {infer_shape}, value_node.get());
return value_node;
return ConstructFilterValueNode(func_graph, val, assist_shape, infer_shape, cnt);
}

AnfNodePtr ConstructMultiplier(const FuncGraphPtr &func_graph, const std::vector<size_t> &ori_shape,


+ 2
- 2
mindspore/ccsrc/backend/optimizer/ascend/mindir/all_to_all_unify_mindir.cc View File

@@ -102,7 +102,7 @@ CNodePtr CreateAllToAllvNode(const FuncGraphPtr &graph, const CNodePtr &all_to_a
MS_LOG(EXCEPTION) << "The node " << split->DebugString() << " should have at least one output, but got 0.";
}
std::vector<AnfNodePtr> all_to_all_v_input = {NewValueNode(std::make_shared<Primitive>(kAllToAllVOpName))};
all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs.begin(), split_outputs.end());
(void)all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs.begin(), split_outputs.end());
auto all_to_all_v = graph->NewCNode(all_to_all_v_input);
MS_EXCEPTION_IF_NULL(all_to_all_v);
auto single_shape = AnfAlgo::GetOutputInferShape(split_outputs[0], 0);
@@ -135,7 +135,7 @@ CNodePtr CreateConcatNode(const FuncGraphPtr &graph, const CNodePtr &all_to_all,
MS_LOG(EXCEPTION) << "The node " << all_to_all_v->DebugString() << " should have at least one output, but got 0.";
}
std::vector<AnfNodePtr> concat_input = {NewValueNode(std::make_shared<Primitive>(kConcatOpName))};
concat_input.insert(concat_input.end(), all_to_all_v_outputs.begin(), all_to_all_v_outputs.end());
(void)concat_input.insert(concat_input.end(), all_to_all_v_outputs.begin(), all_to_all_v_outputs.end());
auto concat = graph->NewCNode(concat_input);
MS_EXCEPTION_IF_NULL(concat);
auto single_shape = AnfAlgo::GetOutputInferShape(all_to_all_v_outputs[0], 0);


+ 0
- 1
mindspore/ccsrc/backend/optimizer/common/optimizer.cc View File

@@ -85,7 +85,6 @@ void GraphOptimizer::AddPassManager(const PassManagerPtr &pass_manager) {
FuncGraphPtr GraphOptimizer::Optimize(const FuncGraphPtr &func_graph, bool run_only_once) {
MS_EXCEPTION_IF_NULL(func_graph);
run_only_once_ = (pass_managers_.size() == 1) ? true : run_only_once;
// Performance risk by creating new manager each time
// cppcheck-suppress *
auto manager = Manage(func_graph, true);



+ 1
- 52
mindspore/ccsrc/backend/optimizer/common/pattern_engine.cc View File

@@ -269,7 +269,7 @@ EquivPtr PatternEngine::Match(const BaseRef &pattern, const BaseRef &expr, const
}

// 2. check equal
if (PatternEngine::AnfNodeEqual(pattern_ref, expr_ref)) {
if (opt::AnfEqual(pattern_ref, expr_ref)) {
return equiv;
}

@@ -301,57 +301,6 @@ EquivPtr PatternEngine::Match(const BaseRef &pattern, const BaseRef &expr, const
return equiv;
}

bool PatternEngine::AnfNodeEqual(const BaseRef &a, const BaseRef &b) {
if (utils::isa<AnfNodePtr>(a) && utils::isa<AnfNodePtr>(b)) {
auto a_node = utils::cast<AnfNodePtr>(a);
auto b_node = utils::cast<AnfNodePtr>(b);
MS_EXCEPTION_IF_NULL(a_node);
MS_EXCEPTION_IF_NULL(b_node);
if (IsValueNode<Primitive>(a_node) && IsValueNode<Primitive>(b_node)) {
auto a_value_node = a_node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(a_value_node);
auto a_value = a_value_node->value();
MS_EXCEPTION_IF_NULL(a_value);
auto a_prim = a_value->cast<PrimitivePtr>();
MS_EXCEPTION_IF_NULL(a_prim);

auto b_value_node = b_node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(b_value_node);
auto b_value = b_value_node->value();
MS_EXCEPTION_IF_NULL(b_value);
auto b_prim = b_value->cast<PrimitivePtr>();
MS_EXCEPTION_IF_NULL(b_prim);

return a_prim->name() == b_prim->name();
} else if (a_node->isa<ValueNode>() && b_node->isa<ValueNode>()) {
auto a_value_node_ptr = a_node->cast<ValueNodePtr>();
if (a_value_node_ptr == nullptr) {
MS_LOG(EXCEPTION) << "cast value node ptr fail";
}
auto a_value_ptr = a_value_node_ptr->value();
if (a_value_ptr == nullptr) {
MS_LOG(EXCEPTION) << "value ptr is nullptr";
}

auto b_value_node_ptr = b_node->cast<ValueNodePtr>();
if (b_value_node_ptr == nullptr) {
MS_LOG(EXCEPTION) << "cast value node ptr fail";
}
auto b_value_ptr = b_value_node_ptr->value();
if (b_value_ptr == nullptr) {
MS_LOG(EXCEPTION) << "value ptr is nullptr";
}

return (*a_value_ptr) == (*b_value_ptr);
}
MS_LOG(DEBUG) << "check AnfNodePtr equal";
}
if (utils::isa<FuncGraphPtr>(a) && utils::isa<FuncGraphPtr>(b)) {
MS_LOG(DEBUG) << "check GraphPtr equal";
}
return a == b;
}

bool PatternEngine::CNodeTypeEqual(const BaseRef &a, const BaseRef &b) {
// To matchCNode and Kernel's type
if (utils::isa<CNode>(a) && utils::isa<CNode>(b)) {


+ 0
- 1
mindspore/ccsrc/backend/optimizer/common/pattern_engine.h View File

@@ -179,7 +179,6 @@ class PatternEngine {
VectorRef *const values_expr) const;
bool ToVector(const VectorRef &pattern_ref, const VectorRef &expr_ref, VectorRef *const values_pattern,
VectorRef *const values_expr) const;
static bool AnfNodeEqual(const BaseRef &a, const BaseRef &b);
static bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b);
std::shared_ptr<Visitor> visitor_;
};


Loading…
Cancel
Save