diff --git a/mindspore/ops/primitive.py b/mindspore/ops/primitive.py index 618181fe7b..759cebb804 100644 --- a/mindspore/ops/primitive.py +++ b/mindspore/ops/primitive.py @@ -142,10 +142,10 @@ class Primitive(Primitive_): Args: strategy (tuple): Strategy describes the distributed parallel mode of the current primitive. """ - if context.get_auto_parallel_context("parallel_mode") not in [context.ParallelMode.AUTO_PARALLEL, - context.ParallelMode.SEMI_AUTO_PARALLEL]: - logger.warning("Shard strategy is not valid in ", context.get_auto_parallel_context("parallel_mode"), - " mode. Please use semi auto or auto parallel mode.") + mode = context.get_auto_parallel_context("parallel_mode") + if mode not in [context.ParallelMode.AUTO_PARALLEL, context.ParallelMode.SEMI_AUTO_PARALLEL]: + logger.warning(f"The shard strategy {strategy} of {self.name} is not valid in {mode}. " + f"Please use semi auto or auto parallel mode.") self.add_prim_attr("strategy", strategy) return self diff --git a/mindspore/train/model.py b/mindspore/train/model.py index 3182cf62c0..b99dffadce 100755 --- a/mindspore/train/model.py +++ b/mindspore/train/model.py @@ -815,11 +815,12 @@ class Model: >>> import mindspore as ms >>> from mindspore import Model, context, Tensor >>> from mindspore.context import ParallelMode + >>> from mindspore.communication import init >>> >>> context.set_context(mode=context.GRAPH_MODE) >>> init() >>> context.set_auto_parallel_context(full_batch=True, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL) - >>> input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), ms.float32) + >>> input_data = Tensor(np.random.randint(0, 255, [1, 1, 32, 32]), ms.float32) >>> model = Model(Net()) >>> model.infer_predict_layout(input_data) """