Browse Source

fix bn cast

tags/v1.0.0
VectorSL 5 years ago
parent
commit
50dc89332c
4 changed files with 9 additions and 5 deletions
  1. +1
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_gpu_kernel.h
  2. +1
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_grad_gpu_kernel.h
  3. +4
    -1
      mindspore/ccsrc/backend/optimizer/gpu/replace_bn_cast_fusion.cc
  4. +3
    -2
      mindspore/ccsrc/backend/optimizer/gpu/replace_bn_grad_cast_fusion.cc

+ 1
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_gpu_kernel.h View File

@@ -144,7 +144,7 @@ class PoolingGpuFwdKernel : public GpuKernel {
void SetPoolingMode(const CNodePtr &kernel_node) {
mode_ = AnfAlgo::GetCNodeName(kernel_node);
if (mode_ == "AvgPool") {
pooling_mode_ = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
pooling_mode_ = CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING;
pad_value_ = 0.0;
} else {
pooling_mode_ = CUDNN_POOLING_MAX;


+ 1
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_grad_gpu_kernel.h View File

@@ -207,7 +207,7 @@ class PoolingGradGpuKernel : public GpuKernel {
cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
mode_ = AnfAlgo::GetCNodeName(kernel_node);
if (mode_ == "AvgPoolGradGpu") {
pooling_mode_ = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
pooling_mode_ = CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING;
pad_value_ = 0.0;
} else {
pooling_mode_ = CUDNN_POOLING_MAX;


+ 4
- 1
mindspore/ccsrc/backend/optimizer/gpu/replace_bn_cast_fusion.cc View File

@@ -37,13 +37,16 @@ const AnfNodePtr ReplaceBNCastFusion::Process(const FuncGraphPtr &graph, const A
const EquivPtr &equiv) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(equiv);
auto fbn2 = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0);
auto x_after = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2), 0);
auto x_before = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(x_after), 0);
MS_EXCEPTION_IF_NULL(fbn2);
MS_EXCEPTION_IF_NULL(x_after);
MS_EXCEPTION_IF_NULL(x_before);
// only deal with x_after with fp32: x 16->32->bn->16->32
if (AnfAlgo::GetOutputInferDataType(x_after, 0) == kNumberTypeFloat16) {
return nullptr;
}
std::vector<TypeId> outputs_type;
std::vector<std::vector<size_t>> outputs_shape;
auto manager = graph->manager();


+ 3
- 2
mindspore/ccsrc/backend/optimizer/gpu/replace_bn_grad_cast_fusion.cc View File

@@ -68,8 +68,9 @@ const AnfNodePtr ReplaceBNGradCastFusion::Process(const FuncGraphPtr &graph, con
auto dy_before = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(dy_after), 0);
auto x_ = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2g), 1);
MS_EXCEPTION_IF_NULL(x_);
// if x_type is fp32, the cast is necessary.
if (AnfAlgo::GetOutputInferDataType(x_, 0) == kNumberTypeFloat32) {
// if x_type is fp32, the cast is necessary or dy_afer is fp32: dy 16->32->bng->16->32.
if (AnfAlgo::GetOutputInferDataType(x_, 0) == kNumberTypeFloat32 ||
AnfAlgo::GetOutputInferDataType(dy_after, 0) == kNumberTypeFloat16) {
return nullptr;
}
MS_EXCEPTION_IF_NULL(fbn2g);


Loading…
Cancel
Save