|
|
|
@@ -22,6 +22,7 @@ from mindspore.ops import operations as P, functional as F |
|
|
|
from mindspore.common.initializer import initializer |
|
|
|
import mindspore.common.dtype as mstype |
|
|
|
from mindspore.common.api import _cell_graph_executor |
|
|
|
from mindspore.parallel._cost_model_context import _set_algo_single_loop |
|
|
|
from tests.dataset_mock import MindData |
|
|
|
|
|
|
|
|
|
|
|
@@ -126,6 +127,7 @@ _w1 = Tensor(np.ones([512, 128, 1]), dtype=ms.float32) |
|
|
|
|
|
|
|
def test_auto_parallel(): |
|
|
|
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=16, global_rank=0) |
|
|
|
_set_algo_single_loop(True) |
|
|
|
net = Full(_w1, 3) |
|
|
|
net.set_auto_parallel() |
|
|
|
net.set_train() |
|
|
|
|