Browse Source

fix transform bug which high order pritimive is not convert to graph

tags/v0.5.0-beta
zhousiyi 5 years ago
parent
commit
e895f19e80
1 changed files with 10 additions and 4 deletions
  1. +10
    -4
      mindspore/ccsrc/vm/transform.cc

+ 10
- 4
mindspore/ccsrc/vm/transform.cc View File

@@ -499,15 +499,20 @@ void TraverseGraphMap(
for (auto &use : users) { for (auto &use : users) {
CNodePtr node = use.first->cast<CNodePtr>(); CNodePtr node = use.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
if (node->func_graph() != fg) {
continue;
}
int key = use.second; int key = use.second;
if (key != 0) { if (key != 0) {
MS_EXCEPTION_IF_NULL(node->input(0)); MS_EXCEPTION_IF_NULL(node->input(0));
bool key_is_const = node->input(0)->isa<ValueNode>(); bool key_is_const = node->input(0)->isa<ValueNode>();
PrimitivePtr value = GetValueNode<PrimitivePtr>(node->input(0)); PrimitivePtr value = GetValueNode<PrimitivePtr>(node->input(0));
bool is_prim_array_map = !(prim::kPrimArrayMap->name().compare(value->name()));
bool is_prim_array_reduce = !(prim::kPrimArrayReduce->name().compare(value->name()));
if (key == 1 && key_is_const && (is_prim_array_map || is_prim_array_reduce)) {
continue;
if (value != nullptr) {
bool is_prim_array_map = !(prim::kPrimArrayMap->name().compare(value->name()));
bool is_prim_array_reduce = !(prim::kPrimArrayReduce->name().compare(value->name()));
if (key == 1 && key_is_const && (is_prim_array_map || is_prim_array_reduce)) {
continue;
}
} }
FuncGraphPtr g = get_prim_graph(GetValueNode<PrimitivePtr>(const_primitive_node), FuncGraphPtr g = get_prim_graph(GetValueNode<PrimitivePtr>(const_primitive_node),
dyn_cast<AbstractFunction>(const_primitive_node->abstract())); dyn_cast<AbstractFunction>(const_primitive_node->abstract()));
@@ -554,6 +559,7 @@ FuncGraphPtr WrapPrimitives(const FuncGraphPtr &graph) {
FuncGraphTransaction tr = manager_ptr->Transact(); FuncGraphTransaction tr = manager_ptr->Transact();
auto &fgs = manager_ptr->func_graphs(); auto &fgs = manager_ptr->func_graphs();
TraverseGraphMap(manager_ptr, &tr, fgs, get_prim_graph); TraverseGraphMap(manager_ptr, &tr, fgs, get_prim_graph);
tr.Commit();


return graph; return graph;
} }


Loading…
Cancel
Save