|
|
|
@@ -21,7 +21,6 @@ from mindspore import context |
|
|
|
from mindspore.common.api import _executor |
|
|
|
from mindspore.ops import composite as C |
|
|
|
from mindspore.ops import operations as P |
|
|
|
from mindspore.ops.operations.comm_ops import _VirtualDataset |
|
|
|
from tests.ut.python.ops.test_math_ops import VirtualLoss |
|
|
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE) |
|
|
|
@@ -33,7 +32,6 @@ grad_all = C.GradOperation(get_all=True) |
|
|
|
class Net(nn.Cell): |
|
|
|
def __init__(self, strategy1, strategy2, num_segments): |
|
|
|
super(Net, self).__init__() |
|
|
|
self.virtual_dataset = _VirtualDataset() |
|
|
|
self.merge_op = P.UnsortedSegmentSum().shard((strategy1, strategy2)) |
|
|
|
self.num_segments = num_segments |
|
|
|
|
|
|
|
@@ -54,8 +52,8 @@ class GradWrap(nn.Cell): |
|
|
|
class NetWithLoss(nn.Cell): |
|
|
|
def __init__(self, network): |
|
|
|
super(NetWithLoss, self).__init__() |
|
|
|
self.loss = VirtualLoss() |
|
|
|
self.network = network |
|
|
|
self.loss = VirtualLoss() |
|
|
|
|
|
|
|
def construct(self, x, y): |
|
|
|
predict = self.network(x, y) |
|
|
|
@@ -63,13 +61,13 @@ class NetWithLoss(nn.Cell): |
|
|
|
|
|
|
|
|
|
|
|
def compile_graph(x, y, segments, strategy1, strategy2, auto=False): |
|
|
|
net = GradWrap(NetWithLoss(Net(strategy1, strategy2, segments))) |
|
|
|
net.set_auto_parallel() |
|
|
|
net.set_train() |
|
|
|
if auto: |
|
|
|
context.set_auto_parallel_context(parallel_mode="auto_parallel") |
|
|
|
else: |
|
|
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") |
|
|
|
net = GradWrap(NetWithLoss(Net(strategy1, strategy2, segments))) |
|
|
|
net.set_auto_parallel() |
|
|
|
net.set_train() |
|
|
|
_executor.compile(net, x, y) |
|
|
|
|
|
|
|
|
|
|
|
@@ -151,3 +149,13 @@ def test_unsortedsegmentsum_model_parallel_index_vector_slice_3d(): |
|
|
|
strategy1 = (2, 1, 2) |
|
|
|
strategy2 = (2, 1) |
|
|
|
compile_graph(x, y, num_segments, strategy1, strategy2) |
|
|
|
|
|
|
|
|
|
|
|
def test_unsortedsegmentsum_model_parallel_repeat_caculate(): |
|
|
|
context.set_auto_parallel_context(device_num=4, global_rank=0) |
|
|
|
x = Tensor(np.ones((4, 4, 8)), ms.float32) |
|
|
|
y = Tensor(np.ones((4, 4)), ms.int32) |
|
|
|
num_segments = 16 |
|
|
|
strategy1 = (1, 1, 1) |
|
|
|
strategy2 = (1, 1) |
|
|
|
compile_graph(x, y, num_segments, strategy1, strategy2) |