Browse Source

add global batch normalization

tags/v0.2.0-alpha
zhaojichen 6 years ago
parent
commit
c5120e770c
1 changed files with 0 additions and 8 deletions
  1. +0
    -8
      mindspore/nn/layer/normalization.py

+ 0
- 8
mindspore/nn/layer/normalization.py View File

@@ -27,14 +27,6 @@ from mindspore.communication import management
from mindspore._checkparam import check_int_positive
from ..cell import Cell

class _GlobalBNHelper(Cell):
def __init__(self, group):
super(_GlobalBNHelper, self).__init__()
self.group = group
self.reduce = P.AllReduce(P.ReduceOp.SUM, group=self.group).add_prim_attr('fusion', 1)
def construct(self, x):
x = self.reduce(x)
return x

class _BatchNorm(Cell):
"""Batch Normalization base class."""


Loading…
Cancel
Save