|
|
|
@@ -13,11 +13,12 @@ |
|
|
|
# limitations under the License. |
|
|
|
# ============================================================================ |
|
|
|
import numpy as np |
|
|
|
|
|
|
|
import mindspore as ms |
|
|
|
from mindspore import context, Tensor, Parameter |
|
|
|
from mindspore.nn import Cell, TrainOneStepCell, Momentum |
|
|
|
from mindspore.ops import operations as P |
|
|
|
from mindspore.common.api import _executor |
|
|
|
from mindspore.nn import Cell |
|
|
|
from mindspore.ops import operations as P |
|
|
|
|
|
|
|
|
|
|
|
class Net(Cell): |
|
|
|
@@ -42,7 +43,7 @@ class EvalNet(Cell): |
|
|
|
def construct(self, x, b): |
|
|
|
out = self.network(x, b) |
|
|
|
out = self.relu(out) |
|
|
|
return out |
|
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
_x = Tensor(np.ones([8, 8]), dtype=ms.float32) |
|
|
|
@@ -54,15 +55,15 @@ def test_train_and_eval(): |
|
|
|
context.set_context(save_graphs=True, mode=0) |
|
|
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16) |
|
|
|
strategy1 = ((4, 4), (4, 4)) |
|
|
|
strategy2 = ((4, 4), ) |
|
|
|
strategy2 = ((4, 4),) |
|
|
|
net = Net(_w1, strategy1, strategy2) |
|
|
|
eval_net = EvalNet(net, strategy2=strategy2) |
|
|
|
net.set_train() |
|
|
|
net.set_auto_parallel() |
|
|
|
_executor.compile(net, _x, _b, phase='train', auto_parallel_mode=True) |
|
|
|
_executor.compile(net, _x, _b, phase='train', auto_parallel_mode=True) |
|
|
|
|
|
|
|
eval_net.set_train(mode=False) |
|
|
|
eval_net.set_auto_parallel() |
|
|
|
_executor.compile(eval_net, _x, _b, phase='eval', auto_parallel_mode=True) |
|
|
|
_executor.compile(eval_net, _x, _b, phase='eval', auto_parallel_mode=True) |
|
|
|
|
|
|
|
context.reset_auto_parallel_context() |
|
|
|
context.reset_auto_parallel_context() |