Browse Source

change_pipeline_key_word

tags/v1.2.0-rc1
lichenever 4 years ago
parent
commit
a2b2727ba8
3 changed files with 5 additions and 2 deletions
  1. +1
    -1
      mindspore/ccsrc/pipeline/jit/parse/parse_base.h
  2. +3
    -0
      mindspore/ops/operations/_inner_ops.py
  3. +1
    -1
      tests/ut/python/parallel/test_pipeline_split.py

+ 1
- 1
mindspore/ccsrc/pipeline/jit/parse/parse_base.h View File

@@ -136,7 +136,7 @@ const char PYTHON_EXTERN_MINDSPORE_FLAG[] = "_mindspore_flags";
// define the parse constant
const int64_t MAX_COMPARISON_OPS_SUPPORTED = 1;
const char CUSTOM_BPROP_NAME[] = "bprop";
const char STAGE_NAME[] = "stage";
const char STAGE_NAME[] = "pipeline_stage";

// define the Namespace name
const char RESOLVE_NAMESPACE_NAME_AST[] = "Ast"; // for ast type namespace


+ 3
- 0
mindspore/ops/operations/_inner_ops.py View File

@@ -471,6 +471,9 @@ class Receive(PrimitiveWithInfer):
self.shape = shape
self.dtype = dtype
self.group = group
valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8]
args = {"dtype": dtype}
validator.check_scalar_or_tensor_types_same(args, valid_type, self.name)

def infer_shape(self, x_shape=None):
return self.shape


+ 1
- 1
tests/ut/python/parallel/test_pipeline_split.py View File

@@ -77,7 +77,7 @@ class Net(nn.Cell):
self.block = nn.CellList()
for i in range(2):
cell = MatMulCell(strategy1, strategy2, param)
cell.stage = i
cell.pipeline_stage = i
self.block.append(cell)

def construct(self, x):


Loading…
Cancel
Save