Browse Source

!9428 [GraphKernel] Modify cast to accelerate amp performance when active graph kernel

From: @tronzhang
Reviewed-by: @ryanww,@gaoxiong1
Signed-off-by: @gaoxiong1
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
a9e586dbe7
4 changed files with 28 additions and 13 deletions
  1. +7
    -4
      mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_cse.cc
  2. +13
    -5
      mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_cse.h
  3. +1
    -1
      mindspore/ccsrc/backend/optimizer/graph_kernel/shape_ops_splitter.cc
  4. +7
    -3
      mindspore/ccsrc/backend/session/gpu_session.cc

+ 7
- 4
mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_cse.cc View File

@@ -16,6 +16,7 @@


#include "backend/optimizer/graph_kernel/graph_kernel_cse.h" #include "backend/optimizer/graph_kernel/graph_kernel_cse.h"


#include <algorithm>
#include <memory> #include <memory>
#include <string> #include <string>
#include <utility> #include <utility>
@@ -26,13 +27,15 @@
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
namespace { namespace {
bool IsCNodePrimitveEqual(const CNodePtr &main, const CNodePtr &node) {
bool IsCNodePrimitveEqual(const CNodePtr &main, const CNodePtr &node, const std::vector<PrimitivePtr> &black_list) {
auto main_primitive = AnfAlgo::GetCNodePrimitive(main); auto main_primitive = AnfAlgo::GetCNodePrimitive(main);
auto node_primitive = AnfAlgo::GetCNodePrimitive(node); auto node_primitive = AnfAlgo::GetCNodePrimitive(node);
if (main_primitive != nullptr && node_primitive != nullptr) { if (main_primitive != nullptr && node_primitive != nullptr) {
// Some ops such as Reshape is not real op, cse these type will not get gain. And for ops fusion, keep these op // Some ops such as Reshape is not real op, cse these type will not get gain. And for ops fusion, keep these op
// alone can prevent some redundant output case (input -> reshape -> output). // alone can prevent some redundant output case (input -> reshape -> output).
if (main_primitive->name() != node_primitive->name() || IsPrimitiveCNode(node, prim::kPrimReshape)) {
if (main_primitive->name() != node_primitive->name() ||
std::any_of(black_list.begin(), black_list.end(),
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); })) {
return false; return false;
} }


@@ -125,12 +128,12 @@ bool GraphKernelBackendCSE::CheckEqualCnodeInputs(const AnfNodePtr &main, const
return false; return false;
} }
} }
return IsCNodePrimitveEqual(c_main, c_node);
return IsCNodePrimitveEqual(c_main, c_node, black_list_);
} }


bool GraphKernelCSE::Run(const FuncGraphPtr &func_graph) { bool GraphKernelCSE::Run(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
auto graphkernel_backend_cse = std::make_shared<GraphKernelBackendCSE>();
auto graphkernel_backend_cse = std::make_shared<GraphKernelBackendCSE>(black_list_);
return graphkernel_backend_cse->Cse(func_graph, func_graph->manager()); return graphkernel_backend_cse->Cse(func_graph, func_graph->manager());
} }
} // namespace opt } // namespace opt


+ 13
- 5
mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_cse.h View File

@@ -13,27 +13,35 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CSE_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CSE_H_
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_CSE_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_CSE_H_


#include <vector>
#include "backend/optimizer/pass/common_subexpression_elimination.h" #include "backend/optimizer/pass/common_subexpression_elimination.h"


namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
class GraphKernelCSE : public Pass { class GraphKernelCSE : public Pass {
public: public:
GraphKernelCSE() : Pass("graph_kernel_cse") {}
explicit GraphKernelCSE(const std::vector<PrimitivePtr> &black_list = {})
: Pass("graph_kernel_cse"), black_list_(black_list) {}
~GraphKernelCSE() override = default; ~GraphKernelCSE() override = default;
bool Run(const FuncGraphPtr &func_graph) override; bool Run(const FuncGraphPtr &func_graph) override;

private:
std::vector<PrimitivePtr> black_list_;
}; };


class GraphKernelBackendCSE : public BackendCSE { class GraphKernelBackendCSE : public BackendCSE {
public: public:
GraphKernelBackendCSE() = default;
explicit GraphKernelBackendCSE(const std::vector<PrimitivePtr> &black_list = {}) : black_list_(black_list) {}
~GraphKernelBackendCSE() override = default; ~GraphKernelBackendCSE() override = default;
bool CheckEqualKernelBuildInfo(const AnfNodePtr &main, const AnfNodePtr &node) const override; bool CheckEqualKernelBuildInfo(const AnfNodePtr &main, const AnfNodePtr &node) const override;
bool CheckEqualCnodeInputs(const AnfNodePtr &main, const AnfNodePtr &node) const override; bool CheckEqualCnodeInputs(const AnfNodePtr &main, const AnfNodePtr &node) const override;

private:
std::vector<PrimitivePtr> black_list_;
}; };
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CSE_H_
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_CSE_H_

+ 1
- 1
mindspore/ccsrc/backend/optimizer/graph_kernel/shape_ops_splitter.cc View File

@@ -34,7 +34,7 @@ namespace mindspore {
namespace opt { namespace opt {
namespace { namespace {
bool IsMultiUserShapeOps(AnfNodePtr node, const FuncGraphManagerPtr &mng) { bool IsMultiUserShapeOps(AnfNodePtr node, const FuncGraphManagerPtr &mng) {
std::vector<PrimitivePtr> shape_ops = {prim::kPrimReshape};
std::vector<PrimitivePtr> shape_ops = {prim::kPrimReshape, prim::kPrimCast};
auto &users = mng->node_users(); auto &users = mng->node_users();
return std::any_of(shape_ops.begin(), shape_ops.end(), return std::any_of(shape_ops.begin(), shape_ops.end(),
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); }) && [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); }) &&


+ 7
- 3
mindspore/ccsrc/backend/session/gpu_session.cc View File

@@ -120,7 +120,9 @@ void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
pm->AddPass(std::make_shared<opt::AdamFusion>()); pm->AddPass(std::make_shared<opt::AdamFusion>());
pm->AddPass(std::make_shared<opt::ApplyMomentumWeightDecayScaleFusion>()); pm->AddPass(std::make_shared<opt::ApplyMomentumWeightDecayScaleFusion>());
pm->AddPass(std::make_shared<opt::ApplyMomentumScaleFusion>()); pm->AddPass(std::make_shared<opt::ApplyMomentumScaleFusion>());
pm->AddPass(std::make_shared<opt::CastAllFusion>("cast_all"));
if (!(context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL))) {
pm->AddPass(std::make_shared<opt::CastAllFusion>("cast_all"));
}
pm->AddPass(std::make_shared<opt::CombineMomentumFusion>("combine_momentum")); pm->AddPass(std::make_shared<opt::CombineMomentumFusion>("combine_momentum"));
pm->AddPass(std::make_shared<opt::ReplaceMomentumCastFusion>()); pm->AddPass(std::make_shared<opt::ReplaceMomentumCastFusion>());
pm->AddPass(std::make_shared<opt::ReplaceAddNFusion>()); pm->AddPass(std::make_shared<opt::ReplaceAddNFusion>());
@@ -165,15 +167,17 @@ void GPUSession::GraphKernelOptimize(const std::shared_ptr<KernelGraph> &kernel_
} }
auto optimizer = std::make_shared<opt::GraphOptimizer>(); auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>("graph_kernel_pm"); auto pm = std::make_shared<opt::PassManager>("graph_kernel_pm");
std::vector<PrimitivePtr> black_list = {prim::kPrimReshape, prim::kPrimCast};
pm->AddPass(std::make_shared<opt::GraphKernelExpander>()); pm->AddPass(std::make_shared<opt::GraphKernelExpander>());
pm->AddPass(std::make_shared<opt::ShapeOpsSplitter>()); pm->AddPass(std::make_shared<opt::ShapeOpsSplitter>());
pm->AddPass(std::make_shared<opt::BasicOpsFusion>()); pm->AddPass(std::make_shared<opt::BasicOpsFusion>());
pm->AddPass(std::make_shared<opt::CompositeOpsFusion>()); pm->AddPass(std::make_shared<opt::CompositeOpsFusion>());
pm->AddPass(std::make_shared<opt::GraphKernelCSE>());
pm->AddPass(std::make_shared<opt::GraphKernelCSE>(black_list));
pm->AddPass(std::make_shared<opt::ArithmeticSimplify>()); pm->AddPass(std::make_shared<opt::ArithmeticSimplify>());
pm->AddPass(std::make_shared<opt::GraphKernelCSE>());
pm->AddPass(std::make_shared<opt::GraphKernelCSE>(black_list));
pm->AddPass(std::make_shared<opt::TensorPromotion>()); pm->AddPass(std::make_shared<opt::TensorPromotion>());
pm->AddPass(std::make_shared<opt::GraphKernelSplitter>()); pm->AddPass(std::make_shared<opt::GraphKernelSplitter>());
pm->AddPass(std::make_shared<opt::GraphKernelCSE>());
// After Simplify and Splitter, a lot of redundant getitem/maketuple // After Simplify and Splitter, a lot of redundant getitem/maketuple
// will be exposed, use GetitemTuple Pass to delete them. // will be exposed, use GetitemTuple Pass to delete them.
pm->AddPass(std::make_shared<opt::GetitemTuple>()); pm->AddPass(std::make_shared<opt::GetitemTuple>());


Loading…
Cancel
Save