Browse Source

[ms][pynative][lenet]fix training bug

[ms][pynative][lenet]fix training bug

[ms][pynative][lenet]fix training bug

[ms][pynative][lenet]fix training bug

[ms][pynative][lenet]fix training bug

[ms][pynative][lenet]fix training bug

[ms][pynative][lenet]fix training bug

[ms][pynative][lenet]fix training bug

[ms][pynative][lenet]fix training bug

[ms][pynative][lenet]fix training bug

[ms][pynative][lenet]fix training bug

[ms][pynative][lenet]fix training bug

[ms][pynative][lenet]fix training bug

[ms][pynative][lenet]fix training bug
pull/1/head
xiongkun 4 years ago
parent
commit
c4f7aad9db
3 changed files with 51 additions and 8 deletions
  1. +1
    -1
      mindspore/python/mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py
  2. +15
    -0
      tests/st/quantization/lenet_quant/test_lenet_quant.py
  3. +35
    -7
      tests/st/quantization/mobilenetv2_quant/test_mobilenetv2_quant.py

+ 1
- 1
mindspore/python/mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py View File

@@ -144,7 +144,7 @@ def fake_quant_perchannel_grad_param(x, min_val, max_val, channel_axis,

shape_c = [1] * len(x_shape)
shape_c[channel_axis_] = min_val.get("ori_shape")[0]
if x_format == "NC1HWC0" and channel_axis_ == 1:
if shape_c[channel_axis_] != x_shape[channel_axis_]:
shape_c = min_val.get("shape")
return x_shape, shape_c, x_dtype



+ 15
- 0
tests/st/quantization/lenet_quant/test_lenet_quant.py View File

@@ -182,3 +182,18 @@ def test_lenet_quant_ascend():
train_lenet_quant(optim_option="LEARNED_SCALE")
eval_quant(optim_option="LEARNED_SCALE")
export_lenet(optim_option="LEARNED_SCALE", file_format="AIR")


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_lenet_quant_ascend_pynative():
"""
test_lenet_quant_ascend_pynative
Features: test_lenet_quant_ascend_pynative
Description: test_lenet_quant_ascend_pynative pynative mode
Expectation: None
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
train_lenet_quant(optim_option="QAT")

+ 35
- 7
tests/st/quantization/mobilenetv2_quant/test_mobilenetv2_quant.py View File

@@ -55,13 +55,8 @@ config_ascend_quant = ed({
dataset_path = "/home/workspace/mindspore_dataset/cifar-10-batches-bin/"


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_single
def test_mobilenetv2_quant():
set_seed(1)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
def train():
"""train"""
config = config_ascend_quant
print("training configure: {}".format(config))

@@ -121,5 +116,38 @@ def test_mobilenetv2_quant():
assert avg_step_loss < expect_avg_step_loss


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_single
def test_mobilenetv2_quant():
"""
test_mobilenetv2_quant
Features: test_mobilenetv2_quant
Description: test_mobilenetv2_quant graph mode
Expectation: None
"""
set_seed(1)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
train()


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_single
def test_mobilenetv2_quant_pynative():
"""
test_mobilenetv2_quant_pynative
Features: test_mobilenetv2_quant_pynative
Description: test_mobilenetv2_quant_pynative pynative mode
Expectation: None
"""
set_seed(1)
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
train()


if __name__ == '__main__':
test_mobilenetv2_quant()
test_mobilenetv2_quant_pynative()

Loading…
Cancel
Save