|
|
|
@@ -32,8 +32,6 @@ class TestStepParallel : public UT::Common { |
|
|
|
void TearDown() {} |
|
|
|
}; |
|
|
|
|
|
|
|
void TestStepParallel::SetUp() { UT::InitPythonPath(); } |
|
|
|
|
|
|
|
void Init_Device_Manager() { |
|
|
|
RankList dev_list; |
|
|
|
|
|
|
|
@@ -52,6 +50,11 @@ void Init_Device_Manager() { |
|
|
|
g_device_manager->Init(dev_list, local_dev, stage_map, "hccl"); |
|
|
|
} |
|
|
|
|
|
|
|
void TestStepParallel::SetUp() { |
|
|
|
UT::InitPythonPath(); |
|
|
|
Init_Device_Manager(); |
|
|
|
} |
|
|
|
|
|
|
|
CNodePtr Make_Node(Shape x, Shape y, Shape out, int64_t condition = 0) { |
|
|
|
FuncGraphPtr func_graph = std::make_shared<FuncGraph>(); |
|
|
|
ParameterPtr param1 = func_graph->add_parameter(); |
|
|
|
@@ -345,7 +348,6 @@ TEST_F(TestStepParallel, CreatOpInstance1) { |
|
|
|
} |
|
|
|
|
|
|
|
TEST_F(TestStepParallel, OperatorInstance) { |
|
|
|
Init_Device_Manager(); |
|
|
|
// creat attrs and prim |
|
|
|
PrimitivePtr prim = NewValueNode(prim::kPrimMatMul)->value()->cast<PrimitivePtr>(); |
|
|
|
ValuePtr transpose_a = MakeValue(false); |
|
|
|
@@ -369,7 +371,6 @@ TEST_F(TestStepParallel, OperatorInstance) { |
|
|
|
} |
|
|
|
|
|
|
|
TEST_F(TestStepParallel, ExtractInformation) { |
|
|
|
Init_Device_Manager(); |
|
|
|
FuncGraphManagerPtr manager = Make_Manager(); |
|
|
|
FuncGraphSet graphs = manager->func_graphs(); |
|
|
|
FuncGraphPtr graph = *graphs.begin(); |
|
|
|
@@ -379,7 +380,6 @@ TEST_F(TestStepParallel, ExtractInformation) { |
|
|
|
} |
|
|
|
|
|
|
|
TEST_F(TestStepParallel, ExtractInformation2) { |
|
|
|
Init_Device_Manager(); |
|
|
|
FuncGraphManagerPtr manager = Make_Manager(2); |
|
|
|
FuncGraphSet graphs = manager->func_graphs(); |
|
|
|
FuncGraphPtr graph = *graphs.begin(); |
|
|
|
@@ -389,7 +389,6 @@ TEST_F(TestStepParallel, ExtractInformation2) { |
|
|
|
} |
|
|
|
|
|
|
|
TEST_F(TestStepParallel, ExtractInformation3) { |
|
|
|
Init_Device_Manager(); |
|
|
|
FuncGraphManagerPtr manager = Make_Manager(3); |
|
|
|
FuncGraphSet graphs = manager->func_graphs(); |
|
|
|
FuncGraphPtr graph = *graphs.begin(); |
|
|
|
@@ -399,7 +398,6 @@ TEST_F(TestStepParallel, ExtractInformation3) { |
|
|
|
} |
|
|
|
|
|
|
|
TEST_F(TestStepParallel, ForwardCommunication1) { |
|
|
|
Init_Device_Manager(); |
|
|
|
ValuePtr attr0_value = MakeValue(REDUCE_OP_SUM); |
|
|
|
ValuePtr attr1_value = MakeValue("0-1-2"); |
|
|
|
Attr attr0 = std::make_pair("op", attr0_value); |
|
|
|
@@ -499,7 +497,6 @@ TEST_F(TestStepParallel, ForwardCommunication3) { |
|
|
|
} |
|
|
|
|
|
|
|
TEST_F(TestStepParallel, GetTensorInLayout) { |
|
|
|
Init_Device_Manager(); |
|
|
|
// creat attrs and prim |
|
|
|
FuncGraphPtr func_graph = std::make_shared<FuncGraph>(); |
|
|
|
Shape inputs_x_dims = {64, 32}; |
|
|
|
|