Browse Source

update parallel validation and example

pull/14370/head
Ziyan 4 years ago
parent
commit
df726d9611
2 changed files with 6 additions and 5 deletions
  1. +4
    -4
      mindspore/ops/primitive.py
  2. +2
    -1
      mindspore/train/model.py

+ 4
- 4
mindspore/ops/primitive.py View File

@@ -142,10 +142,10 @@ class Primitive(Primitive_):
Args:
strategy (tuple): Strategy describes the distributed parallel mode of the current primitive.
"""
if context.get_auto_parallel_context("parallel_mode") not in [context.ParallelMode.AUTO_PARALLEL,
context.ParallelMode.SEMI_AUTO_PARALLEL]:
logger.warning("Shard strategy is not valid in ", context.get_auto_parallel_context("parallel_mode"),
" mode. Please use semi auto or auto parallel mode.")
mode = context.get_auto_parallel_context("parallel_mode")
if mode not in [context.ParallelMode.AUTO_PARALLEL, context.ParallelMode.SEMI_AUTO_PARALLEL]:
logger.warning(f"The shard strategy {strategy} of {self.name} is not valid in {mode}. "
f"Please use semi auto or auto parallel mode.")
self.add_prim_attr("strategy", strategy)
return self



+ 2
- 1
mindspore/train/model.py View File

@@ -815,11 +815,12 @@ class Model:
>>> import mindspore as ms
>>> from mindspore import Model, context, Tensor
>>> from mindspore.context import ParallelMode
>>> from mindspore.communication import init
>>>
>>> context.set_context(mode=context.GRAPH_MODE)
>>> init()
>>> context.set_auto_parallel_context(full_batch=True, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
>>> input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), ms.float32)
>>> input_data = Tensor(np.random.randint(0, 255, [1, 1, 32, 32]), ms.float32)
>>> model = Model(Net())
>>> model.infer_predict_layout(input_data)
"""


Loading…
Cancel
Save