Browse Source

!6532 GPU fix BnCast

Merge pull request !6532 from VectorSL/bncast
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
149285b6f2
4 changed files with 9 additions and 5 deletions
  1. +1
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_gpu_kernel.h
  2. +1
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_grad_gpu_kernel.h
  3. +4
    -1
      mindspore/ccsrc/backend/optimizer/gpu/replace_bn_cast_fusion.cc
  4. +3
    -2
      mindspore/ccsrc/backend/optimizer/gpu/replace_bn_grad_cast_fusion.cc

+ 1
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_gpu_kernel.h View File

@@ -144,7 +144,7 @@ class PoolingGpuFwdKernel : public GpuKernel {
void SetPoolingMode(const CNodePtr &kernel_node) {
mode_ = AnfAlgo::GetCNodeName(kernel_node);
if (mode_ == "AvgPool") {
pooling_mode_ = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
pooling_mode_ = CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING;
pad_value_ = 0.0;
} else {
pooling_mode_ = CUDNN_POOLING_MAX;


+ 1
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_grad_gpu_kernel.h View File

@@ -207,7 +207,7 @@ class PoolingGradGpuKernel : public GpuKernel {
void SetPoolingMode(const CNodePtr &kernel_node) {
mode_ = AnfAlgo::GetCNodeName(kernel_node);
if (mode_ == "AvgPoolGradGpu") {
pooling_mode_ = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
pooling_mode_ = CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING;
pad_value_ = 0.0;
} else {
pooling_mode_ = CUDNN_POOLING_MAX;


+ 4
- 1
mindspore/ccsrc/backend/optimizer/gpu/replace_bn_cast_fusion.cc View File

@@ -37,13 +37,16 @@ const AnfNodePtr ReplaceBNCastFusion::Process(const FuncGraphPtr &graph, const A
const EquivPtr &equiv) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(equiv);
auto fbn2 = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0);
auto x_after = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2), 0);
auto x_before = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(x_after), 0);
MS_EXCEPTION_IF_NULL(fbn2);
MS_EXCEPTION_IF_NULL(x_after);
MS_EXCEPTION_IF_NULL(x_before);
// only deal with x_after with fp32: x 16->32->bn->16->32
if (AnfAlgo::GetOutputInferDataType(x_after, 0) == kNumberTypeFloat16) {
return nullptr;
}
std::vector<TypeId> outputs_type;
std::vector<std::vector<size_t>> outputs_shape;
auto manager = graph->manager();


+ 3
- 2
mindspore/ccsrc/backend/optimizer/gpu/replace_bn_grad_cast_fusion.cc View File

@@ -68,8 +68,9 @@ const AnfNodePtr ReplaceBNGradCastFusion::Process(const FuncGraphPtr &graph, con
auto dy_before = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(dy_after), 0);
auto x_ = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2g), 1);
MS_EXCEPTION_IF_NULL(x_);
// if x_type is fp32, the cast is necessary.
if (AnfAlgo::GetOutputInferDataType(x_, 0) == kNumberTypeFloat32) {
// if x_type is fp32, the cast is necessary or dy_afer is fp32: dy 16->32->bng->16->32.
if (AnfAlgo::GetOutputInferDataType(x_, 0) == kNumberTypeFloat32 ||
AnfAlgo::GetOutputInferDataType(dy_after, 0) == kNumberTypeFloat16) {
return nullptr;
}
MS_EXCEPTION_IF_NULL(fbn2g);


Loading…
Cancel
Save