Browse Source

add global batch normalization

tags/v0.2.0-alpha
zhaojichen 6 years ago
parent
commit
97e250d4f1
2 changed files with 0 additions and 16 deletions
  1. +0
    -13
      tests/ut/python/hccl_test/manage/api.py
  2. +0
    -3
      tests/ut/python/nn/test_batchnorm.py

+ 0
- 13
tests/ut/python/hccl_test/manage/api.py View File

@@ -21,7 +21,6 @@ class Hccl():
_instance = None
_rank_id = 0
_rank_size = 1
_group_size = 4

def __init__(self):
pass
@@ -48,10 +47,6 @@ class Hccl():
def rank_size(self):
return self._rank_size

@property
def group_size(self):
return self._group_size

@rank_size.setter
def rank_size(self, size):
self._rank_size = size
@@ -70,14 +65,6 @@ def get_rank_size(group=None):
return int(group.split("-")[0])
raise ValueError

def get_group_size(group=None):
hccl = Hccl()
if group is None:
return hccl.group_size
if isinstance(group, str):
return int(group.split("-")[0])
raise ValueError

# pylint: disable=unused-argument
def get_world_rank_from_group_rank(group, group_rank_id):
return group_rank_id


+ 0
- 3
tests/ut/python/nn/test_batchnorm.py View File

@@ -19,9 +19,6 @@ import pytest
import mindspore.nn as nn
from mindspore.common.api import _executor
from mindspore import Tensor, Parameter
from mindspore.communication.management import init
from mindspore import context
from mindspore import ParallelMode


def test_bn_pars_valid1():


Loading…
Cancel
Save