Browse Source

add mobilenet ascend st

tags/v1.1.0
xiaoyisd 5 years ago
parent
commit
c5ba1d49d9
2 changed files with 9 additions and 5 deletions
  1. +6
    -3
      tests/st/quantization/mobilenetv2_quant/test_mobilenetv2_quant.py
  2. +3
    -2
      tests/st/quantization/mobilenetv2_quant/utils.py

+ 6
- 3
tests/st/quantization/mobilenetv2_quant/test_mobilenetv2_quant.py View File

@@ -55,10 +55,10 @@ config_ascend_quant = ed({
dataset_path = "/home/workspace/mindspore_dataset/cifar-10-batches-bin/"


@pytest.mark.level1
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.env_single
def test_mobilenetv2_quant():
set_seed(1)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
@@ -111,9 +111,12 @@ def test_mobilenetv2_quant():
dataset_sink_mode=False)
print("============== End Training ==============")

export_time_used = 700
train_time = monitor.step_mseconds
print('train_time_used:{}'.format(train_time))
assert train_time < export_time_used
expect_avg_step_loss = 2.32
avg_step_loss = np.mean(np.array(monitor.losses))

print("average step loss:{}".format(avg_step_loss))
assert avg_step_loss < expect_avg_step_loss



+ 3
- 2
tests/st/quantization/mobilenetv2_quant/utils.py View File

@@ -45,7 +45,7 @@ class Monitor(Callback):
self.lr_init = lr_init
self.lr_init_len = len(lr_init)
self.step_threshold = step_threshold
self.step_mseconds = 0
self.step_mseconds = 50000

def epoch_begin(self, run_context):
self.losses = []
@@ -66,7 +66,8 @@ class Monitor(Callback):

def step_end(self, run_context):
cb_params = run_context.original_args()
self.step_mseconds = (time.time() - self.step_time) * 1000
step_mseconds = (time.time() - self.step_time) * 1000
self.step_mseconds = min(self.step_mseconds, step_mseconds)
step_loss = cb_params.net_outputs

if isinstance(step_loss, (tuple, list)) and isinstance(step_loss[0], Tensor):


Loading…
Cancel
Save