| @@ -136,7 +136,7 @@ const char PYTHON_EXTERN_MINDSPORE_FLAG[] = "_mindspore_flags"; | |||||
| // define the parse constant | // define the parse constant | ||||
| const int64_t MAX_COMPARISON_OPS_SUPPORTED = 1; | const int64_t MAX_COMPARISON_OPS_SUPPORTED = 1; | ||||
| const char CUSTOM_BPROP_NAME[] = "bprop"; | const char CUSTOM_BPROP_NAME[] = "bprop"; | ||||
| const char STAGE_NAME[] = "stage"; | |||||
| const char STAGE_NAME[] = "pipeline_stage"; | |||||
| // define the Namespace name | // define the Namespace name | ||||
| const char RESOLVE_NAMESPACE_NAME_AST[] = "Ast"; // for ast type namespace | const char RESOLVE_NAMESPACE_NAME_AST[] = "Ast"; // for ast type namespace | ||||
| @@ -471,6 +471,9 @@ class Receive(PrimitiveWithInfer): | |||||
| self.shape = shape | self.shape = shape | ||||
| self.dtype = dtype | self.dtype = dtype | ||||
| self.group = group | self.group = group | ||||
| valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8] | |||||
| args = {"dtype": dtype} | |||||
| validator.check_scalar_or_tensor_types_same(args, valid_type, self.name) | |||||
| def infer_shape(self, x_shape=None): | def infer_shape(self, x_shape=None): | ||||
| return self.shape | return self.shape | ||||
| @@ -77,7 +77,7 @@ class Net(nn.Cell): | |||||
| self.block = nn.CellList() | self.block = nn.CellList() | ||||
| for i in range(2): | for i in range(2): | ||||
| cell = MatMulCell(strategy1, strategy2, param) | cell = MatMulCell(strategy1, strategy2, param) | ||||
| cell.stage = i | |||||
| cell.pipeline_stage = i | |||||
| self.block.append(cell) | self.block.append(cell) | ||||
| def construct(self, x): | def construct(self, x): | ||||