Browse Source

match side effect in lite

tags/v1.2.0-rc1
xuanyue 4 years ago
parent
commit
46a22f4a7b
7 changed files with 49 additions and 81 deletions
  1. +1
    -1
      mindspore/lite/test/CMakeLists.txt
  2. +0
    -47
      mindspore/lite/tools/anf_exporter/anf_exporter.cc
  3. +0
    -1
      mindspore/lite/tools/anf_exporter/anf_exporter.h
  4. +1
    -1
      mindspore/lite/tools/converter/CMakeLists.txt
  5. +2
    -2
      mindspore/lite/tools/converter/anf_transform.cc
  6. +38
    -22
      mindspore/lite/tools/optimizer/graph/redundant_op_remove_pass.cc
  7. +7
    -7
      mindspore/lite/tools/optimizer/graph/redundant_op_remove_pass.h

+ 1
- 1
mindspore/lite/test/CMakeLists.txt View File

@@ -225,7 +225,7 @@ if(ENABLE_CONVERTER)
${LITE_DIR}/tools/optimizer/graph/update_conv2d_param_pass.cc
${LITE_DIR}/tools/optimizer/graph/unused_cast_node_remove_pass.cc
${LITE_DIR}/tools/optimizer/graph/unused_transpose_node_remove_pass.cc
${LITE_DIR}/tools/optimizer/graph/identity_remove_pass.cc
${LITE_DIR}/tools/optimizer/graph/redundant_op_remove_pass.cc
${LITE_DIR}/tools/optimizer/graph/infershape_pass.cc
${LITE_DIR}/tools/optimizer/graph/slice_prepose_pass.cc
${LITE_DIR}/tools/optimizer/graph/mindir_adjust_pass.cc


+ 0
- 47
mindspore/lite/tools/anf_exporter/anf_exporter.cc View File

@@ -59,41 +59,6 @@ void AnfExporter::RemoveIfMakeTuple(const CNodePtr &cnode) {
}
}

void AnfExporter::RemoveIfDepend(const CNodePtr &cnode) {
bool hasDepend = false;
std::vector<AnfNodePtr> inputs;
inputs.clear();

inputs.emplace_back(cnode->input(0));
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
AnfNodePtr inputNode = cnode->input(i);
if (!inputNode->isa<CNode>()) {
inputs.emplace_back(cnode->input(i));
continue;
}
auto dependNode = utils::cast<CNodePtr>(inputNode);
if (IsPrimitiveCNode(dependNode, schema::PrimitiveType_Depend) ||
IsPrimitiveCNode(dependNode, schema::PrimitiveType_ControlDepend)) {
hasDepend = true;
bool maskOut = (dependNode->inputs().size() == 3);
for (size_t j = 1; j < dependNode->inputs().size(); ++j) {
AnfNodePtr dependInputNode = dependNode->input(j);
if (dependInputNode->isa<CNode>()) {
inputs.emplace_back(dependInputNode);
if (maskOut) {
break;
}
}
}
} else {
inputs.emplace_back(cnode->input(i));
}
}
if (hasDepend) {
cnode->set_inputs(inputs);
}
}

int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &meta_graph,
const std::shared_ptr<PrimitiveC> &primitive,
const std::unique_ptr<schema::CNodeT> &dst_node) {
@@ -286,23 +251,11 @@ int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr<sc
break;
}
}

#ifdef SUPPORT_TRAIN
RemoveIfMakeTuple(cnode);
RemoveIfDepend(cnode);
#endif

if ((primitive_c->Type() == schema::PrimitiveType_TupleGetItem) ||
#ifdef SUPPORT_TRAIN
(primitive_c->Type() == schema::PrimitiveType_Depend) ||
(primitive_c->Type() == schema::PrimitiveType_ControlDepend) ||
#endif
(primitive_c->Type() == schema::PrimitiveType_MakeTuple)) {
continue;
}
#ifndef SUPPORT_TRAIN
RemoveIfMakeTuple(cnode);
#endif
auto primT = primitive_c->primitiveT();
auto node = std::make_unique<schema::CNodeT>();
if (node == nullptr) {


+ 0
- 1
mindspore/lite/tools/anf_exporter/anf_exporter.h View File

@@ -41,7 +41,6 @@ class AnfExporter {
int SetOpInputNode(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
schema::CNodeT *fb_node);
static void RemoveIfMakeTuple(const CNodePtr &cnode);
static void RemoveIfDepend(const CNodePtr &cnode);

protected:
int ConvertInputCNode(const std::shared_ptr<AnfNode> &input_anode, schema::CNodeT *output_cnode);


+ 1
- 1
mindspore/lite/tools/converter/CMakeLists.txt View File

@@ -59,7 +59,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
../optimizer/graph/update_conv2d_param_pass.cc
../optimizer/graph/unused_cast_node_remove_pass.cc
../optimizer/graph/unused_transpose_node_remove_pass.cc
../optimizer/graph/identity_remove_pass.cc
../optimizer/graph/redundant_op_remove_pass.cc
../optimizer/graph/infershape_pass.cc
../optimizer/graph/slice_prepose_pass.cc
../optimizer/graph/mindir_adjust_pass.cc


+ 2
- 2
mindspore/lite/tools/converter/anf_transform.cc View File

@@ -34,7 +34,7 @@
#include "tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.h"
#include "tools/optimizer/graph/mindir_adjust_pass.h"
#include "tools/optimizer/graph/mindir_inputs_adjust_pass.h"
#include "tools/optimizer/graph/identity_remove_pass.h"
#include "tools/optimizer/graph/redundant_op_remove_pass.h"
#include "tools/optimizer/graph/weight_format_hardcode_pass.h"
#include "tools/optimizer/graph/weight_format_transform_pass.h"
#include "tools/optimizer/graph/clip_convert_activation_pass.h"
@@ -144,7 +144,7 @@ int AnfTransform::AddConvertPass(const std::shared_ptr<opt::GraphOptimizer> &opt
int AnfTransform::AddConstFoldPass(const std::shared_ptr<opt::GraphOptimizer> &optimizer,
const converter::Flags *config) {
auto const_fold_pm = std::make_shared<opt::PassManager>("const fold fusion pass manager", false);
const_fold_pm->AddPass(std::make_shared<opt::RemoveIdentityOpPass>());
const_fold_pm->AddPass(std::make_shared<opt::RemoveRedundantOpPass>());
if (!config->trainModel) {
auto inne_context_ptr = std::make_shared<lite::InnerContext>();
inne_context_ptr->Init();


mindspore/lite/tools/optimizer/graph/identity_remove_pass.cc → mindspore/lite/tools/optimizer/graph/redundant_op_remove_pass.cc View File

@@ -13,37 +13,41 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/optimizer/graph/identity_remove_pass.h"
#include "tools/optimizer/graph/redundant_op_remove_pass.h"
#include <memory>
#include "mindspore/lite/include/errorcode.h"
#include "src/ops/primitive_c.h"

namespace mindspore::opt {
int RemoveIdentityOpPass::ReplaceIdentity(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) {
namespace {
constexpr size_t InputDoubleNum = 2;
constexpr size_t InputTripleNum = 3;
constexpr auto kNameLoad = "Load";
constexpr auto kNameUpdateState = "UpdateState";
} // namespace
int RemoveRedundantOpPass::ReplaceOp(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) {
if (!utils::isa<CNodePtr>(anf_node)) {
MS_LOG(DEBUG) << "anf node is node a cnode.";
return lite::RET_NO_CHANGE;
}
auto type = opt::GetCNodeType(anf_node);
if (type != schema::PrimitiveType_Identity) {
MS_LOG(DEBUG) << "anf node is not a identity node.";
return lite::RET_NO_CHANGE;
}
auto identity_cnode = anf_node->cast<CNodePtr>();
if (identity_cnode->inputs().size() != lite::kDoubleNum) {
MS_LOG(DEBUG) << "The node inputs size is bigger than 1";
remove_cnode_.insert(anf_node);
return lite::RET_NO_CHANGE;
} else {
bool replace_succ = manager->Replace(anf_node, identity_cnode->input(1));
if (!replace_succ) {
MS_LOG(ERROR) << "replace identity failed.";
return lite::RET_ERROR;
auto cnode = anf_node->cast<CNodePtr>();
if (type == schema::PrimitiveType_Identity) {
if (cnode->size() != InputDoubleNum) {
MS_LOG(DEBUG) << "The node inputs size is bigger than 1";
remove_cnode_.insert(anf_node);
return lite::RET_NO_CHANGE;
}
}
bool replace_succ = manager->Replace(anf_node, cnode->input(1));
if (!replace_succ) {
MS_LOG(ERROR) << "replace redundant op failed.";
return lite::RET_ERROR;
}
return RET_OK;
}

int RemoveIdentityOpPass::ReplaceTupleGetItem(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) {
int RemoveRedundantOpPass::ReplaceTupleGetItem(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) {
if (!utils::isa<CNodePtr>(anf_node)) {
MS_LOG(DEBUG) << "anf node is node a cnode.";
return lite::RET_NO_CHANGE;
@@ -53,7 +57,7 @@ int RemoveIdentityOpPass::ReplaceTupleGetItem(const AnfNodePtr &anf_node, const
return lite::RET_NO_CHANGE;
}
auto cnode = anf_node->cast<CNodePtr>();
if (cnode->inputs().size() != 3) {
if (cnode->inputs().size() != InputTripleNum) {
MS_LOG(ERROR) << "TupleGetItem should have 3 inputs, got " << cnode->inputs().size();
return RET_ERROR;
}
@@ -81,7 +85,7 @@ int RemoveIdentityOpPass::ReplaceTupleGetItem(const AnfNodePtr &anf_node, const
return lite::RET_OK;
}

bool RemoveIdentityOpPass::Run(const FuncGraphPtr &func_graph) {
bool RemoveRedundantOpPass::Run(const FuncGraphPtr &func_graph) {
MS_ASSERT(func_graph != nullptr);
auto manager = func_graph->manager();
MS_ASSERT(manager != nullptr);
@@ -93,10 +97,22 @@ bool RemoveIdentityOpPass::Run(const FuncGraphPtr &func_graph) {
}
auto type = opt::GetCNodeType(node);
if (type == schema::PrimitiveType_Identity) {
status = ReplaceIdentity(node, manager);
} else if (type == schema::PrimitiveType_TupleGetItem) {
status = ReplaceOp(node, manager);
}
if (CheckPrimitiveType(node, std::make_shared<Primitive>(kNameLoad))) {
status = ReplaceOp(node, manager);
}
if (CheckPrimitiveType(node, std::make_shared<Primitive>(kNameUpdateState))) {
status = ReplaceOp(node, manager);
}
if (type == schema::PrimitiveType_Depend ||
type == schema::PrimitiveType_ControlDepend) { // ControlDepend delete next version.
status = ReplaceOp(node, manager);
}
if (type == schema::PrimitiveType_TupleGetItem) {
status = ReplaceTupleGetItem(node, manager);
} else if (type == schema::PrimitiveType_If || type == schema::PrimitiveType_While) {
}
if (type == schema::PrimitiveType_If || type == schema::PrimitiveType_While) {
auto sub_func_graph = GetValueNode<FuncGraphPtr>(node->cast<CNodePtr>()->input(1));
if (sub_func_graph == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);

mindspore/lite/tools/optimizer/graph/identity_remove_pass.h → mindspore/lite/tools/optimizer/graph/redundant_op_remove_pass.h View File

@@ -14,8 +14,8 @@
* limitations under the License.
*/

#ifndef MINDSPORE_LITE_SRC_PASS_REMOVE_IDENTITY_PASS_H_
#define MINDSPORE_LITE_SRC_PASS_REMOVE_IDENTITY_PASS_H_
#ifndef MINDSPORE_LITE_SRC_PASS_REDUNDANT_OP_REMOVE_PASS_H_
#define MINDSPORE_LITE_SRC_PASS_REDUNDANT_OP_REMOVE_PASS_H_
#include <string>
#include <set>
#include "backend/optimizer/common/pass.h"
@@ -24,11 +24,11 @@

using mindspore::lite::converter::FmkType;
namespace mindspore::opt {
class RemoveIdentityOpPass : public Pass {
class RemoveRedundantOpPass : public Pass {
public:
RemoveIdentityOpPass() : Pass("remove_identity_pass") {}
~RemoveIdentityOpPass() override = default;
int ReplaceIdentity(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager);
RemoveRedundantOpPass() : Pass("remove_redundant_op_pass") {}
~RemoveRedundantOpPass() override = default;
int ReplaceOp(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager);
int ReplaceTupleGetItem(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager);
bool Run(const FuncGraphPtr &graph) override;

@@ -36,4 +36,4 @@ class RemoveIdentityOpPass : public Pass {
std::set<AnfNodePtr> remove_cnode_;
};
} // namespace mindspore::opt
#endif // MINDSPORE_LITE_SRC_PASS_REMOVE_IDENTITY_PASS_H_
#endif // MINDSPORE_LITE_SRC_PASS_REDUNDANT_OP_REMOVE_PASS_H_

Loading…
Cancel
Save