Browse Source

dynamic shape check

tags/v1.1.0
wilfChen 5 years ago
parent
commit
2291b7f2e6
3 changed files with 27 additions and 15 deletions
  1. +0
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel.h
  2. +26
    -13
      mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_grad_fusion.cc
  3. +1
    -1
      mindspore/ccsrc/backend/session/session_basic.cc

+ 0
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel.h View File

@@ -251,7 +251,6 @@ class GpuKernel : public KernelMod {
device::DynamicKernelPtr dynamic_kernel_; device::DynamicKernelPtr dynamic_kernel_;
}; };
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore


+ 26
- 13
mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_grad_fusion.cc View File

@@ -110,44 +110,57 @@ void ReplaceOutput(const FuncGraphPtr &graph, const AnfNodePtr &bn_grad, const A
manager->Replace(relu_grad, bn_add_relu_grad_output[kBNAddReluGradOutputNum - 1]); manager->Replace(relu_grad, bn_add_relu_grad_output[kBNAddReluGradOutputNum - 1]);
return; return;
} }
} // namespace
const BaseRef BatchNormAddReluGradFusion::DefinePattern() const {
VectorRef relu_grad = VectorRef({prim::kPrimReluGrad, dy_, y_});
VectorRef batch_norm_grad =
VectorRef({prim::kPrimFusedBatchNormGradEx, relu_grad, x_, scale_, save_mean_, save_var_, reserve_});
return batch_norm_grad;
}
const AnfNodePtr BatchNormAddReluGradFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
const EquivPtr &) const {
bool PatternCheck(const FuncGraphPtr &graph, const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
auto format_attr = AnfAlgo::GetCNodePrimitive(node)->GetAttr("data_format"); auto format_attr = AnfAlgo::GetCNodePrimitive(node)->GetAttr("data_format");
MS_EXCEPTION_IF_NULL(format_attr); MS_EXCEPTION_IF_NULL(format_attr);
auto format = GetValue<std::string>(format_attr); auto format = GetValue<std::string>(format_attr);
if (AnfAlgo::GetInputFormat(node, 0) != kOpFormat_NHWC && format != "NHWC") { if (AnfAlgo::GetInputFormat(node, 0) != kOpFormat_NHWC && format != "NHWC") {
return nullptr;
return false;
} }
auto relu_grad = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0); auto relu_grad = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0);
MS_EXCEPTION_IF_NULL(relu_grad); MS_EXCEPTION_IF_NULL(relu_grad);
auto relu_users = GetRealNodeUsedList(graph, relu_grad); auto relu_users = GetRealNodeUsedList(graph, relu_grad);
if (relu_users->size() != 2) { if (relu_users->size() != 2) {
return nullptr;
return false;
} }
// process pattern as Relu(TensorAdd(BN#0, BN#1)) // process pattern as Relu(TensorAdd(BN#0, BN#1))
auto tuple_getitem = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 5); auto tuple_getitem = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 5);
MS_EXCEPTION_IF_NULL(tuple_getitem); MS_EXCEPTION_IF_NULL(tuple_getitem);
if (!utils::isa<CNodePtr>(tuple_getitem) || AnfAlgo::GetCNodeName(tuple_getitem) != prim::kPrimTupleGetItem->name()) { if (!utils::isa<CNodePtr>(tuple_getitem) || AnfAlgo::GetCNodeName(tuple_getitem) != prim::kPrimTupleGetItem->name()) {
return nullptr;
return false;
} }
auto forward_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tuple_getitem), 0); auto forward_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tuple_getitem), 0);
if (AnfAlgo::GetCNodeName(forward_node) != kFusedBatchNormExWithAddAndActivation) { if (AnfAlgo::GetCNodeName(forward_node) != kFusedBatchNormExWithAddAndActivation) {
return false;
}
return true;
}
} // namespace
const BaseRef BatchNormAddReluGradFusion::DefinePattern() const {
VectorRef relu_grad = VectorRef({prim::kPrimReluGrad, dy_, y_});
VectorRef batch_norm_grad =
VectorRef({prim::kPrimFusedBatchNormGradEx, relu_grad, x_, scale_, save_mean_, save_var_, reserve_});
return batch_norm_grad;
}
const AnfNodePtr BatchNormAddReluGradFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
if (!PatternCheck(graph, node)) {
return nullptr; return nullptr;
} }
auto relu_grad = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0);
MS_EXCEPTION_IF_NULL(relu_grad);
auto dy = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(relu_grad), 0); auto dy = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(relu_grad), 0);
MS_EXCEPTION_IF_NULL(dy); MS_EXCEPTION_IF_NULL(dy);
auto y = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(relu_grad), 1); auto y = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(relu_grad), 1);


+ 1
- 1
mindspore/ccsrc/backend/session/session_basic.cc View File

@@ -1432,7 +1432,7 @@ void SessionBasic::RunGraphAsync(const GraphId &graph_id, const std::vector<tens
} }


bool IsDynamicShape(const NotNull<abstract::ShapePtr> &shape) { bool IsDynamicShape(const NotNull<abstract::ShapePtr> &shape) {
return std::any_of(shape->shape().begin(), shape->shape().end(), [](int s) { return s < 0; });
return std::any_of(shape->shape().begin(), shape->shape().end(), [](int64_t s) { return s < 0; });
} }


bool IsNodeOutputDynamicShape(const CNodePtr &anf_node_ptr) { bool IsNodeOutputDynamicShape(const CNodePtr &anf_node_ptr) {


Loading…
Cancel
Save