Merge pull request !14 from zheng-huanhuan/2_mastertags/v0.2.0-alpha
| @@ -0,0 +1,118 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| import sys | |||
| import time | |||
| import numpy as np | |||
| import pytest | |||
| from scipy.special import softmax | |||
| from mindspore import Model | |||
| from mindspore import Tensor | |||
| from mindspore import context | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from mindarmour.attacks.iterative_gradient_method import MomentumDiverseInputIterativeMethod | |||
| from mindarmour.utils.logger import LogUtil | |||
| from mindarmour.evaluations.attack_evaluation import AttackEvaluate | |||
| from lenet5_net import LeNet5 | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
| sys.path.append("..") | |||
| from data_processing import generate_mnist_dataset | |||
| LOGGER = LogUtil.get_instance() | |||
| TAG = 'M_DI2_FGSM_Test' | |||
| LOGGER.set_level('INFO') | |||
| @pytest.mark.level1 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.env_card | |||
| @pytest.mark.component_mindarmour | |||
| def test_momentum_diverse_input_iterative_method(): | |||
| """ | |||
| M-DI2-FGSM Attack Test | |||
| """ | |||
| # upload trained network | |||
| ckpt_name = './trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' | |||
| net = LeNet5() | |||
| load_dict = load_checkpoint(ckpt_name) | |||
| load_param_into_net(net, load_dict) | |||
| # get test data | |||
| data_list = "./MNIST_unzip/test" | |||
| batch_size = 32 | |||
| ds = generate_mnist_dataset(data_list, batch_size, sparse=False) | |||
| # prediction accuracy before attack | |||
| model = Model(net) | |||
| batch_num = 32 # the number of batches of attacking samples | |||
| test_images = [] | |||
| test_labels = [] | |||
| predict_labels = [] | |||
| i = 0 | |||
| for data in ds.create_tuple_iterator(): | |||
| i += 1 | |||
| images = data[0].astype(np.float32) | |||
| labels = data[1] | |||
| test_images.append(images) | |||
| test_labels.append(labels) | |||
| pred_labels = np.argmax(model.predict(Tensor(images)).asnumpy(), | |||
| axis=1) | |||
| predict_labels.append(pred_labels) | |||
| if i >= batch_num: | |||
| break | |||
| predict_labels = np.concatenate(predict_labels) | |||
| true_labels = np.argmax(np.concatenate(test_labels), axis=1) | |||
| accuracy = np.mean(np.equal(predict_labels, true_labels)) | |||
| LOGGER.info(TAG, "prediction accuracy before attacking is : %s", accuracy) | |||
| # attacking | |||
| attack = MomentumDiverseInputIterativeMethod(net) | |||
| start_time = time.clock() | |||
| adv_data = attack.batch_generate(np.concatenate(test_images), | |||
| np.concatenate(test_labels), batch_size=32) | |||
| stop_time = time.clock() | |||
| pred_logits_adv = model.predict(Tensor(adv_data)).asnumpy() | |||
| # rescale predict confidences into (0, 1). | |||
| pred_logits_adv = softmax(pred_logits_adv, axis=1) | |||
| pred_labels_adv = np.argmax(pred_logits_adv, axis=1) | |||
| accuracy_adv = np.mean(np.equal(pred_labels_adv, true_labels)) | |||
| LOGGER.info(TAG, "prediction accuracy after attacking is : %s", accuracy_adv) | |||
| attack_evaluate = AttackEvaluate(np.concatenate(test_images).transpose(0, 2, 3, 1), | |||
| np.concatenate(test_labels), | |||
| adv_data.transpose(0, 2, 3, 1), | |||
| pred_logits_adv) | |||
| LOGGER.info(TAG, 'mis-classification rate of adversaries is : %s', | |||
| attack_evaluate.mis_classification_rate()) | |||
| LOGGER.info(TAG, 'The average confidence of adversarial class is : %s', | |||
| attack_evaluate.avg_conf_adv_class()) | |||
| LOGGER.info(TAG, 'The average confidence of true class is : %s', | |||
| attack_evaluate.avg_conf_true_class()) | |||
| LOGGER.info(TAG, 'The average distance (l0, l2, linf) between original ' | |||
| 'samples and adversarial samples are: %s', | |||
| attack_evaluate.avg_lp_distance()) | |||
| LOGGER.info(TAG, 'The average structural similarity between original ' | |||
| 'samples and adversarial samples are: %s', | |||
| attack_evaluate.avg_ssim()) | |||
| LOGGER.info(TAG, 'The average costing time is %s', | |||
| (stop_time - start_time)/(batch_num*batch_size)) | |||
| if __name__ == '__main__': | |||
| test_momentum_diverse_input_iterative_method() | |||
| @@ -26,6 +26,8 @@ __all__ = ['FastGradientMethod', | |||
| 'BasicIterativeMethod', | |||
| 'MomentumIterativeMethod', | |||
| 'ProjectedGradientDescent', | |||
| 'DiverseInputIterativeMethod', | |||
| 'MomentumDiverseInputIterativeMethod', | |||
| 'DeepFool', | |||
| 'CarliniWagnerL2Attack', | |||
| 'JSMAAttack', | |||
| @@ -46,7 +46,7 @@ class GradientMethod(Attack): | |||
| Default: None. | |||
| bounds (tuple): Upper and lower bounds of data, indicating the data range. | |||
| In form of (clip_min, clip_max). Default: None. | |||
| loss_fn (Loss): Loss function for optimization. | |||
| loss_fn (Loss): Loss function for optimization. Default: None. | |||
| """ | |||
| def __init__(self, network, eps=0.07, alpha=None, bounds=None, | |||
| @@ -151,7 +151,7 @@ class FastGradientMethod(GradientMethod): | |||
| Possible values: np.inf, 1 or 2. Default: 2. | |||
| is_targeted (bool): If True, targeted attack. If False, untargeted | |||
| attack. Default: False. | |||
| loss_fn (Loss): Loss function for optimization. | |||
| loss_fn (Loss): Loss function for optimization. Default: None. | |||
| Examples: | |||
| >>> attack = FastGradientMethod(network) | |||
| @@ -214,7 +214,7 @@ class RandomFastGradientMethod(FastGradientMethod): | |||
| Possible values: np.inf, 1 or 2. Default: 2. | |||
| is_targeted (bool): If True, targeted attack. If False, untargeted | |||
| attack. Default: False. | |||
| loss_fn (Loss): Loss function for optimization. | |||
| loss_fn (Loss): Loss function for optimization. Default: None. | |||
| Raises: | |||
| ValueError: eps is smaller than alpha! | |||
| @@ -255,7 +255,7 @@ class FastGradientSignMethod(GradientMethod): | |||
| In form of (clip_min, clip_max). Default: (0.0, 1.0). | |||
| is_targeted (bool): If True, targeted attack. If False, untargeted | |||
| attack. Default: False. | |||
| loss_fn (Loss): Loss function for optimization. | |||
| loss_fn (Loss): Loss function for optimization. Default: None. | |||
| Examples: | |||
| >>> attack = FastGradientSignMethod(network) | |||
| @@ -314,7 +314,7 @@ class RandomFastGradientSignMethod(FastGradientSignMethod): | |||
| In form of (clip_min, clip_max). Default: (0.0, 1.0). | |||
| is_targeted (bool): True: targeted attack. False: untargeted attack. | |||
| Default: False. | |||
| loss_fn (Loss): Loss function for optimization. | |||
| loss_fn (Loss): Loss function for optimization. Default: None. | |||
| Raises: | |||
| ValueError: eps is smaller than alpha! | |||
| @@ -350,7 +350,7 @@ class LeastLikelyClassMethod(FastGradientSignMethod): | |||
| Default: None. | |||
| bounds (tuple): Upper and lower bounds of data, indicating the data range. | |||
| In form of (clip_min, clip_max). Default: (0.0, 1.0). | |||
| loss_fn (Loss): Loss function for optimization. | |||
| loss_fn (Loss): Loss function for optimization. Default: None. | |||
| Examples: | |||
| >>> attack = LeastLikelyClassMethod(network) | |||
| @@ -15,6 +15,7 @@ | |||
| from abc import abstractmethod | |||
| import numpy as np | |||
| from PIL import Image, ImageOps | |||
| from mindspore.nn import SoftmaxCrossEntropyWithLogits | |||
| from mindspore import Tensor | |||
| @@ -115,7 +116,7 @@ class IterativeGradientMethod(Attack): | |||
| bounds (tuple): Upper and lower bounds of data, indicating the data range. | |||
| In form of (clip_min, clip_max). Default: (0.0, 1.0). | |||
| nb_iter (int): Number of iteration. Default: 5. | |||
| loss_fn (Loss): Loss function for optimization. | |||
| loss_fn (Loss): Loss function for optimization. Default: None. | |||
| """ | |||
| def __init__(self, network, eps=0.3, eps_iter=0.1, bounds=(0.0, 1.0), nb_iter=5, | |||
| loss_fn=None): | |||
| @@ -178,14 +179,13 @@ class BasicIterativeMethod(IterativeGradientMethod): | |||
| is_targeted (bool): If True, targeted attack. If False, untargeted | |||
| attack. Default: False. | |||
| nb_iter (int): Number of iteration. Default: 5. | |||
| loss_fn (Loss): Loss function for optimization. | |||
| loss_fn (Loss): Loss function for optimization. Default: None. | |||
| attack (class): The single step gradient method of each iteration. In | |||
| this class, FGSM is used. | |||
| Examples: | |||
| >>> attack = BasicIterativeMethod(network) | |||
| """ | |||
| def __init__(self, network, eps=0.3, eps_iter=0.1, bounds=(0.0, 1.0), | |||
| is_targeted=False, nb_iter=5, loss_fn=None): | |||
| super(BasicIterativeMethod, self).__init__(network, | |||
| @@ -227,14 +227,22 @@ class BasicIterativeMethod(IterativeGradientMethod): | |||
| clip_min, clip_max = self._bounds | |||
| clip_diff = clip_max - clip_min | |||
| for _ in range(self._nb_iter): | |||
| adv_x = self._attack.generate(inputs, labels) | |||
| if 'self.prob' in globals(): | |||
| d_inputs = _transform_inputs(inputs, self.prob) | |||
| else: | |||
| d_inputs = inputs | |||
| adv_x = self._attack.generate(d_inputs, labels) | |||
| perturs = np.clip(adv_x - arr_x, (0 - self._eps)*clip_diff, | |||
| self._eps*clip_diff) | |||
| adv_x = arr_x + perturs | |||
| inputs = adv_x | |||
| else: | |||
| for _ in range(self._nb_iter): | |||
| adv_x = self._attack.generate(inputs, labels) | |||
| if 'self.prob' in globals(): | |||
| d_inputs = _transform_inputs(inputs, self.prob) | |||
| else: | |||
| d_inputs = inputs | |||
| adv_x = self._attack.generate(d_inputs, labels) | |||
| adv_x = np.clip(adv_x, arr_x - self._eps, arr_x + self._eps) | |||
| inputs = adv_x | |||
| return adv_x | |||
| @@ -261,7 +269,7 @@ class MomentumIterativeMethod(IterativeGradientMethod): | |||
| decay_factor (float): Decay factor in iterations. Default: 1.0. | |||
| norm_level (Union[int, numpy.inf]): Order of the norm. Possible values: | |||
| np.inf, 1 or 2. Default: 'inf'. | |||
| loss_fn (Loss): Loss function for optimization. | |||
| loss_fn (Loss): Loss function for optimization. Default: None. | |||
| """ | |||
| def __init__(self, network, eps=0.3, eps_iter=0.1, bounds=(0.0, 1.0), | |||
| @@ -303,9 +311,13 @@ class MomentumIterativeMethod(IterativeGradientMethod): | |||
| clip_min, clip_max = self._bounds | |||
| clip_diff = clip_max - clip_min | |||
| for _ in range(self._nb_iter): | |||
| gradient = self._gradient(inputs, labels) | |||
| if 'self.prob' in globals(): | |||
| d_inputs = _transform_inputs(inputs, self.prob) | |||
| else: | |||
| d_inputs = inputs | |||
| gradient = self._gradient(d_inputs, labels) | |||
| momentum = self._decay_factor*momentum + gradient | |||
| adv_x = inputs + self._eps_iter*np.sign(momentum) | |||
| adv_x = d_inputs + self._eps_iter*np.sign(momentum) | |||
| perturs = np.clip(adv_x - arr_x, (0 - self._eps)*clip_diff, | |||
| self._eps*clip_diff) | |||
| adv_x = arr_x + perturs | |||
| @@ -313,12 +325,15 @@ class MomentumIterativeMethod(IterativeGradientMethod): | |||
| inputs = adv_x | |||
| else: | |||
| for _ in range(self._nb_iter): | |||
| gradient = self._gradient(inputs, labels) | |||
| if 'self.prob' in globals(): | |||
| d_inputs = _transform_inputs(inputs, self.prob) | |||
| else: | |||
| d_inputs = inputs | |||
| gradient = self._gradient(d_inputs, labels) | |||
| momentum = self._decay_factor*momentum + gradient | |||
| adv_x = inputs + self._eps_iter*np.sign(momentum) | |||
| adv_x = d_inputs + self._eps_iter*np.sign(momentum) | |||
| adv_x = np.clip(adv_x, arr_x - self._eps, arr_x + self._eps) | |||
| inputs = adv_x | |||
| return adv_x | |||
| def _gradient(self, inputs, labels): | |||
| @@ -372,7 +387,7 @@ class ProjectedGradientDescent(BasicIterativeMethod): | |||
| nb_iter (int): Number of iteration. Default: 5. | |||
| norm_level (Union[int, numpy.inf]): Order of the norm. Possible values: | |||
| np.inf, 1 or 2. Default: 'inf'. | |||
| loss_fn (Loss): Loss function for optimization. | |||
| loss_fn (Loss): Loss function for optimization. Default: None. | |||
| """ | |||
| def __init__(self, network, eps=0.3, eps_iter=0.1, bounds=(0.0, 1.0), | |||
| @@ -430,3 +445,114 @@ class ProjectedGradientDescent(BasicIterativeMethod): | |||
| adv_x = np.clip(adv_x, arr_x - self._eps, arr_x + self._eps) | |||
| inputs = adv_x | |||
| return adv_x | |||
| class DiverseInputIterativeMethod(BasicIterativeMethod): | |||
| """ | |||
| The Diverse Input Iterative Method attack. | |||
| References: `Xie, Cihang and Zhang, et al., "Improving Transferability of | |||
| Adversarial Examples With Input Diversity," in CVPR, 2019 <https://arxiv.org/abs/1803.06978>`_ | |||
| Args: | |||
| network (Cell): Target model. | |||
| eps (float): Proportion of adversarial perturbation generated by the | |||
| attack to data range. Default: 0.3. | |||
| bounds (tuple): Upper and lower bounds of data, indicating the data range. | |||
| In form of (clip_min, clip_max). Default: (0.0, 1.0). | |||
| is_targeted (bool): If True, targeted attack. If False, untargeted | |||
| attack. Default: False. | |||
| prob (float): Transformation probability. Default: 0.5. | |||
| loss_fn (Loss): Loss function for optimization. Default: None. | |||
| """ | |||
| def __init__(self, network, eps=0.3, bounds=(0.0, 1.0), | |||
| is_targeted=False, prob=0.5, loss_fn=None): | |||
| # reference to paper hyper parameters setting. | |||
| eps_iter = 16*2/255 | |||
| nb_iter = int(min(eps*255 + 4, 1.25*255*eps)) | |||
| super(DiverseInputIterativeMethod, self).__init__(network, | |||
| eps=eps, | |||
| eps_iter=eps_iter, | |||
| bounds=bounds, | |||
| is_targeted=is_targeted, | |||
| nb_iter=nb_iter, | |||
| loss_fn=loss_fn) | |||
| # FGSM default alpha is None equal alpha=1 | |||
| self.prob = check_param_type('prob', prob, float) | |||
| class MomentumDiverseInputIterativeMethod(MomentumIterativeMethod): | |||
| """ | |||
| The Momentum Diverse Input Iterative Method attack. | |||
| References: `Xie, Cihang and Zhang, et al., "Improving Transferability of | |||
| Adversarial Examples With Input Diversity," in CVPR, 2019 <https://arxiv.org/abs/1803.06978>`_ | |||
| Args: | |||
| network (Cell): Target model. | |||
| eps (float): Proportion of adversarial perturbation generated by the | |||
| attack to data range. Default: 0.3. | |||
| bounds (tuple): Upper and lower bounds of data, indicating the data range. | |||
| In form of (clip_min, clip_max). Default: (0.0, 1.0). | |||
| is_targeted (bool): If True, targeted attack. If False, untargeted | |||
| attack. Default: False. | |||
| norm_level (Union[int, numpy.inf]): Order of the norm. Possible values: | |||
| np.inf, 1 or 2. Default: 'l1'. | |||
| prob (float): Transformation probability. Default: 0.5. | |||
| loss_fn (Loss): Loss function for optimization. Default: None. | |||
| """ | |||
| def __init__(self, network, eps=0.3, bounds=(0.0, 1.0), | |||
| is_targeted=False, norm_level='l1', prob=0.5, loss_fn=None): | |||
| eps_iter = 16*2 / 255 | |||
| nb_iter = int(min(eps*255 + 4, 1.25*255*eps)) | |||
| super(MomentumDiverseInputIterativeMethod, self).__init__(network=network, | |||
| eps=eps, | |||
| eps_iter=eps_iter, | |||
| bounds=bounds, | |||
| nb_iter=nb_iter, | |||
| is_targeted=is_targeted, | |||
| norm_level=norm_level, | |||
| loss_fn=loss_fn) | |||
| self.prob = check_param_type('prob', prob, float) | |||
| def _transform_inputs(inputs, prob, low=29, high=33, full_aug=False): | |||
| """ | |||
| Inputs data augmentation. | |||
| Args: | |||
| inputs (Union[np.int8, np.float]): Inputs. | |||
| prob (float): The probability of augmentation. | |||
| low (int): Lower bound of resize image width. Default: 29. | |||
| high (int): Upper bound of resize image height. Default: 33. | |||
| full_aug (bool): type of augmentation method, use interpolation and padding | |||
| as default. Default: False. | |||
| Returns: | |||
| numpy.ndarray, the augmentation data. | |||
| """ | |||
| raw_shape = inputs[0].shape | |||
| tran_mask = np.random.uniform(0, 1, size=inputs.shape[0]) < prob | |||
| tran_inputs = inputs[tran_mask] | |||
| raw_inputs = inputs[tran_mask == 0] | |||
| tran_outputs = [] | |||
| for sample in tran_inputs: | |||
| width = np.random.choice(np.arange(low, high)) | |||
| # resize | |||
| sample = (sample*255).astype(np.uint8) | |||
| d_image = Image.fromarray(sample, mode='L').resize((width, width), Image.NEAREST) | |||
| # pad | |||
| left_pad = (raw_shape[0] - width) // 2 | |||
| right_pad = raw_shape[0] - width - left_pad | |||
| top_pad = (raw_shape[1] - width) // 2 | |||
| bottom_pad = raw_shape[1] - width - top_pad | |||
| p_sample = ImageOps.expand(d_image, | |||
| border=(left_pad, top_pad, right_pad, bottom_pad)) | |||
| tran_outputs.append(np.array(p_sample).astype(np.float) / 255) | |||
| if full_aug: | |||
| # gaussian noise | |||
| tran_outputs = np.random.normal(tran_outputs.shape) + tran_outputs | |||
| tran_outputs.extend(raw_inputs) | |||
| if not np.any(tran_outputs-raw_inputs): | |||
| LOGGER.error(TAG, 'the transform function does not take effect.') | |||
| return tran_outputs | |||
| @@ -242,7 +242,7 @@ def normalize_value(value, norm_level): | |||
| Raises: | |||
| NotImplementedError: If norm_level is not in [1, 2 , np.inf, '1', '2', | |||
| 'inf] | |||
| 'inf', 'l1', 'l2'] | |||
| """ | |||
| norm_level = check_norm_level(norm_level) | |||
| ori_shape = value.shape | |||
| @@ -1,6 +1,7 @@ | |||
| numpy >= 1.17.0 | |||
| scipy >= 1.3.3 | |||
| matplotlib >= 3.1.3 | |||
| Pillow >= 2.0.0 | |||
| pytest >= 4.3.1 | |||
| wheel >= 0.32.0 | |||
| setuptools >= 40.8.0 | |||
| @@ -95,7 +95,8 @@ setup( | |||
| install_requires=[ | |||
| 'scipy >= 1.3.3', | |||
| 'numpy >= 1.17.0', | |||
| 'matplotlib >= 3.1.3' | |||
| 'matplotlib >= 3.1.3', | |||
| 'Pillow >= 2.0.0' | |||
| ], | |||
| ) | |||
| print(find_packages()) | |||
| @@ -25,6 +25,8 @@ from mindarmour.attacks import BasicIterativeMethod | |||
| from mindarmour.attacks import MomentumIterativeMethod | |||
| from mindarmour.attacks import ProjectedGradientDescent | |||
| from mindarmour.attacks import IterativeGradientMethod | |||
| from mindarmour.attacks import DiverseInputIterativeMethod | |||
| from mindarmour.attacks import MomentumDiverseInputIterativeMethod | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
| @@ -91,7 +93,7 @@ def test_momentum_iterative_method(): | |||
| for i in range(5): | |||
| attack = MomentumIterativeMethod(Net(), nb_iter=i + 1) | |||
| ms_adv_x = attack.generate(input_np, label) | |||
| assert np.any(ms_adv_x != input_np), 'Basic iterative method: generate' \ | |||
| assert np.any(ms_adv_x != input_np), 'Momentum iterative method: generate' \ | |||
| ' value must not be equal to' \ | |||
| ' original value.' | |||
| @@ -119,6 +121,48 @@ def test_projected_gradient_descent_method(): | |||
| ' original value.' | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.env_card | |||
| @pytest.mark.component_mindarmour | |||
| def test_diverse_input_iterative_method(): | |||
| """ | |||
| Diverse input iterative method unit test. | |||
| """ | |||
| input_np = np.asarray([[0.1, 0.2, 0.7]], np.float32) | |||
| label = np.asarray([2], np.int32) | |||
| label = np.eye(3)[label].astype(np.float32) | |||
| for i in range(5): | |||
| attack = DiverseInputIterativeMethod(Net()) | |||
| ms_adv_x = attack.generate(input_np, label) | |||
| assert np.any(ms_adv_x != input_np), 'Diverse input iterative method: generate' \ | |||
| ' value must not be equal to' \ | |||
| ' original value.' | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.env_card | |||
| @pytest.mark.component_mindarmour | |||
| def test_momentum_diverse_input_iterative_method(): | |||
| """ | |||
| Momentum diverse input iterative method unit test. | |||
| """ | |||
| input_np = np.asarray([[0.1, 0.2, 0.7]], np.float32) | |||
| label = np.asarray([2], np.int32) | |||
| label = np.eye(3)[label].astype(np.float32) | |||
| for i in range(5): | |||
| attack = MomentumDiverseInputIterativeMethod(Net()) | |||
| ms_adv_x = attack.generate(input_np, label) | |||
| assert np.any(ms_adv_x != input_np), 'Momentum diverse input iterative method: ' \ | |||
| 'generate value must not be equal to' \ | |||
| ' original value.' | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||