Browse Source

Change pass to enable vmap and J run in same opt

feature/build-system-rewrite
l00591931 4 years ago
parent
commit
1c50572fe0
9 changed files with 200 additions and 50 deletions
  1. +0
    -1
      mindspore/ccsrc/frontend/optimizer/irpass.cc
  2. +3
    -37
      mindspore/ccsrc/frontend/optimizer/irpass/gradient_eliminate.cc
  3. +6
    -8
      mindspore/ccsrc/frontend/optimizer/irpass/gradient_eliminate.h
  4. +43
    -0
      mindspore/ccsrc/frontend/optimizer/irpass/meta_fg_eliminate.cc
  5. +42
    -0
      mindspore/ccsrc/frontend/optimizer/irpass/meta_fg_eliminate.h
  6. +47
    -0
      mindspore/ccsrc/frontend/optimizer/irpass/meta_fg_prim_eliminate.cc
  7. +52
    -0
      mindspore/ccsrc/frontend/optimizer/irpass/meta_fg_prim_eliminate.h
  8. +3
    -3
      mindspore/ccsrc/pipeline/jit/pass.cc
  9. +4
    -1
      mindspore/core/base/core_ops.h

+ 0
- 1
mindspore/ccsrc/frontend/optimizer/irpass.cc View File

@@ -21,7 +21,6 @@
#include "frontend/optimizer/irpass/convert.h"
#include "frontend/optimizer/irpass/environ_eliminate.h"
#include "frontend/optimizer/irpass/grad_var_prepare.h"
#include "frontend/optimizer/irpass/gradient_eliminate.h"
#include "frontend/optimizer/irpass/inline.h"
#include "frontend/optimizer/irpass/updatestate_eliminate.h"
#include "frontend/optimizer/irpass/load_eliminate.h"


+ 3
- 37
mindspore/ccsrc/frontend/optimizer/irpass/gradient_eliminate.cc View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -36,20 +36,6 @@ AnfNodePtr ExpandJPrimitive(const ValueNodePtr &vnode, const pipeline::ResourceB
return nullptr;
}

bool CheckIfEmbedJ(const CNodePtr &j_node) {
auto &value_node = j_node->input(1);
if (IsValueNode<Primitive>(value_node)) {
return false;
}
auto func_graph = GetValueNode<FuncGraphPtr>(value_node);
if (func_graph == nullptr) {
MS_LOG(EXCEPTION) << "Unexpected J node, input func graph should not be null, node: " << j_node->DebugString();
}
auto func_graph_manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(func_graph_manager);
return func_graph_manager->func_graph_j_total(func_graph);
}

bool IsSideEffectOp(const AnfNodePtr &node) {
if (!node->isa<CNode>()) {
return false;
@@ -78,25 +64,17 @@ AnfNodePtr ExpandJ(const ValueNodePtr &vnode, const OptimizerPtr &optimizer) {
} // namespace internal

bool ExpandJPrim::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) {
// Search all j nodes.
GetJPrim(func_graph);
// Get j nodes that don't have embed j nodes.
std::vector<CNodePtr> todo;
// If graph also contains J(FuncGraph) or J(Primitive), then ignore this graph.
// ExpandJ innermost graph or primitive first.
std::copy_if(j_nodes_.begin(), j_nodes_.end(), std::back_inserter(todo),
[](const CNodePtr &j_node) { return !internal::CheckIfEmbedJ(j_node); });
// Check whether need to eliminate forward cnodes in pynative mode.
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
auto pynative_exec = pynative::PynativeExecutor::GetInstance();
auto grad_exec = pynative_exec->grad_executor();
bool eliminate_forward = grad_exec->eliminate_forward();
grad_exec->set_eliminate_forward(eliminate_forward && todo.empty());
grad_exec->set_eliminate_forward(eliminate_forward && prim_nodes_.empty());
}
// Expand j nodes that don't have embed j nodes.
bool change = false;
auto manager = optimizer->manager();
for (auto &j_node : todo) {
for (auto &j_node : prim_nodes_) {
auto expanded_j = internal::ExpandJ(j_node->input(1)->cast<ValueNodePtr>(), optimizer);
manager->Replace(j_node, expanded_j);
if (j_node->func_graph()->has_flag(FUNC_GRAPH_FLAG_K_GRAPH)) {
@@ -107,18 +85,6 @@ bool ExpandJPrim::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr
}
return change;
}

void ExpandJPrim::GetJPrim(const FuncGraphPtr &func_graph) {
j_nodes_.clear();
AnfNodePtr ret = func_graph->get_return();
MS_EXCEPTION_IF_NULL(ret);
std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
for (auto &node : all_nodes) {
if (IsPrimitiveCNode(node, prim::kPrimJ)) {
j_nodes_.push_back(node->cast<CNodePtr>());
}
}
}
} // namespace irpass
} // namespace opt
} // namespace mindspore

+ 6
- 8
mindspore/ccsrc/frontend/optimizer/irpass/gradient_eliminate.h View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -27,21 +27,19 @@
#include "utils/ms_utils.h"
#include "frontend/operator/ops.h"
#include "frontend/optimizer/ad/grad.h"
#include "frontend/optimizer/irpass/meta_fg_prim_eliminate.h"

namespace mindspore {
namespace opt {
namespace irpass {
// {prim::kPrimJ, C}
class ExpandJPrim {
class ExpandJPrim : public ExpandMetaFGPrim {
public:
ExpandJPrim() = default;
ExpandJPrim() { prim_ = prim::kPrimJ; }
virtual ~ExpandJPrim() = default;
bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer);
void GetJPrim(const FuncGraphPtr &func_graph);

private:
std::vector<CNodePtr> j_nodes_;
bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) override;
};
using ExpandJPrimPtr = std::shared_ptr<ExpandJPrim>;
} // namespace irpass
} // namespace opt
} // namespace mindspore


+ 43
- 0
mindspore/ccsrc/frontend/optimizer/irpass/meta_fg_eliminate.cc View File

@@ -0,0 +1,43 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "frontend/optimizer/irpass/meta_fg_eliminate.h"
#include "frontend/optimizer/irpass/gradient_eliminate.h"

namespace mindspore {
namespace opt {
namespace irpass {
bool ExpandMetaFg::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) {
AnfNodePtr return_node = func_graph->get_return();
MS_EXCEPTION_IF_NULL(return_node);
std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(return_node);
// The expanding of meta fg may change the number of outer layer meta fgs.
// So, find all kinds of candidate meta fgs together and then expands them.
for (auto expand_meta_fg_element : expand_meta_fg_list_) {
expand_meta_fg_element->GetMetaFGPrim(all_nodes);
}
bool ret = false;
for (auto expand_meta_fg_element : expand_meta_fg_list_) {
auto prim_nodes = expand_meta_fg_element->prim_nodes();
if (prim_nodes.size() != 0) {
ret = ret || (*expand_meta_fg_element)(func_graph, optimizer);
}
}
return ret;
}
} // namespace irpass
} // namespace opt
} // namespace mindspore

+ 42
- 0
mindspore/ccsrc/frontend/optimizer/irpass/meta_fg_eliminate.h View File

@@ -0,0 +1,42 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_META_FG_ELIMINATE_H_
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_META_FG_ELIMINATE_H_

#include <vector>
#include <memory>
#include "base/core_ops.h"
#include "frontend/optimizer/irpass/gradient_eliminate.h"
#include "frontend/optimizer/irpass/meta_fg_prim_eliminate.h"

namespace mindspore {
namespace opt {
namespace irpass {
class ExpandMetaFg {
public:
ExpandMetaFg() { (void)expand_meta_fg_list_.emplace_back(std::make_shared<ExpandJPrim>()); }
virtual ~ExpandMetaFg() = default;
bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer);

private:
std::vector<ExpandMetaFGPrimPtr> expand_meta_fg_list_;
};
} // namespace irpass
} // namespace opt
} // namespace mindspore

#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_META_FG_ELIMINATE_H_

+ 47
- 0
mindspore/ccsrc/frontend/optimizer/irpass/meta_fg_prim_eliminate.cc View File

@@ -0,0 +1,47 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "frontend/optimizer/irpass/meta_fg_prim_eliminate.h"

namespace mindspore {
namespace opt {
namespace irpass {
bool ExpandMetaFGPrim::CheckIfEmbedMetaFGPrim(const CNodePtr &node) const {
auto &value_node = node->input(1);
if (IsValueNode<Primitive>(value_node)) {
return false;
}
auto func_graph = GetValueNode<FuncGraphPtr>(value_node);
if (func_graph == nullptr) {
MS_LOG(EXCEPTION) << "Unexpected meta function graph node:" << node->DebugString();
}
auto func_graph_manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(func_graph_manager);
return func_graph_manager->func_graph_j_total(func_graph);
}

void ExpandMetaFGPrim::GetMetaFGPrim(const std::vector<AnfNodePtr> &all_nodes) {
MS_EXCEPTION_IF_NULL(prim_);
prim_nodes_.clear();
for (auto &node : all_nodes) {
if (IsPrimitiveCNode(node, prim_) && !CheckIfEmbedMetaFGPrim(node->cast<CNodePtr>())) {
prim_nodes_.push_back(node->cast<CNodePtr>());
}
}
}
} // namespace irpass
} // namespace opt
} // namespace mindspore

+ 52
- 0
mindspore/ccsrc/frontend/optimizer/irpass/meta_fg_prim_eliminate.h View File

@@ -0,0 +1,52 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_META_FG_PRIM_ELIMINATE_H_
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_META_FG_PRIM_ELIMINATE_H_

#include <vector>
#include <algorithm>
#include <memory>

#include "frontend/optimizer/optimizer.h"
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/anf_visitor.h"
#include "utils/ms_utils.h"
#include "frontend/operator/ops.h"
#include "frontend/optimizer/ad/grad.h"

namespace mindspore {
namespace opt {
namespace irpass {
// {prim::kPrimJ, C}
class ExpandMetaFGPrim {
public:
ExpandMetaFGPrim() = default;
virtual ~ExpandMetaFGPrim() = default;
bool CheckIfEmbedMetaFGPrim(const CNodePtr &node) const;
const std::vector<CNodePtr> &prim_nodes() const { return prim_nodes_; }
virtual bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) = 0;
void GetMetaFGPrim(const std::vector<AnfNodePtr> &all_nodes);

protected:
std::vector<CNodePtr> prim_nodes_;
PrimitivePtr prim_{nullptr};
};
using ExpandMetaFGPrimPtr = std::shared_ptr<ExpandMetaFGPrim>;
} // namespace irpass
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_META_FG_PRIM_ELIMINATE_H_

+ 3
- 3
mindspore/ccsrc/pipeline/jit/pass.cc View File

@@ -48,7 +48,7 @@
#include "pipeline/pynative/pynative_execute.h"
#include "pipeline/jit/static_analysis/auto_monad.h"
#include "frontend/optimizer/irpass/branch_culling.h"
#include "frontend/optimizer/irpass/gradient_eliminate.h"
#include "frontend/optimizer/irpass/meta_fg_eliminate.h"
#include "frontend/optimizer/irpass/parameter_eliminate.h"
#include "frontend/optimizer/irpass/updatestate_eliminate.h"
#if ((defined ENABLE_CPU) && (!defined _WIN32))
@@ -232,7 +232,7 @@ bool parallel_mode() {
void AddParallelRenormalize(OptPassGroupMap *map_a) {
if (parallel_mode()) {
auto parallel_end_opt =
find_if(map_a->begin(), map_a->end(), [](auto opt_pair) { return opt_pair.first == "grad"; });
find_if(map_a->begin(), map_a->end(), [](auto opt_pair) { return opt_pair.first == "meta_fg_expand"; });
if (parallel_end_opt != map_a->end()) {
(void)map_a->insert(parallel_end_opt, {"parallel_renormalize", opt::OptPassConfig::Renormalize()});
}
@@ -357,7 +357,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
{"allreduce_fusion", opt::OptPassConfig(parallel::StepAllreduceFusion)},
{"virtual_dataset", virtual_dataset},
{"virtual_output", opt::OptPassConfig({irpass.virtual_output_eliminate_})},
{"grad", opt::OptPassConfig(opt::irpass::ExpandJPrim())},
{"meta_fg_expand", opt::OptPassConfig(opt::irpass::ExpandMetaFg())},
{"after_resolve", after_resolve_pass},
{"a_after_grad", a_after_grad},
{"renormalize", opt::OptPassConfig::Renormalize()},


+ 4
- 1
mindspore/core/base/core_ops.h View File

@@ -159,6 +159,9 @@ constexpr auto kCSRReduceSum = "CSRReduceSum";
constexpr auto kCSRMV = "CSRMV";
constexpr auto kCSRMul = "CSRMul";

// Meta Function Graph
constexpr auto kJ = "J";

// Others
constexpr auto kMakeTuple = "MakeTuple";
constexpr auto kAssign = "Assign";
@@ -822,7 +825,7 @@ MS_CORE_API inline const PrimitivePtr kPrimPyInterpret = std::make_shared<Primit

// Other primitive not used by backend but used in core;
MS_CORE_API inline const PrimitivePtr kPrimStateSetItem = std::make_shared<Primitive>("state_setitem");
MS_CORE_API inline const PrimitivePtr kPrimJ = std::make_shared<Primitive>("J", kSideEffectPropagate);
MS_CORE_API inline const PrimitivePtr kPrimJ = std::make_shared<Primitive>(kJ, kSideEffectPropagate);
MS_CORE_API inline const PrimitivePtr kPrimShard = std::make_shared<Primitive>("Shard", kSideEffectPropagate);

// Used to build graph which have keyword arguments


Loading…
Cancel
Save