|
|
|
@@ -17,7 +17,7 @@ |
|
|
|
import numpy as np |
|
|
|
import pytest |
|
|
|
|
|
|
|
from mindspore import context, Tensor, Parameter, ParameterTuple |
|
|
|
from mindspore import context, Tensor, Parameter, ParameterTuple, nn |
|
|
|
from mindspore._checkparam import _check_str_by_regular |
|
|
|
from mindspore.common import dtype as mstype |
|
|
|
from mindspore.common.initializer import initializer |
|
|
|
@@ -229,3 +229,25 @@ def test_parameter_lazy_init(): |
|
|
|
para.set_parameter_data(initializer('ones', [1, 2], mstype.float32), slice_shape=True) |
|
|
|
assert np.array_equal(para.default_input.asnumpy(), np.ones((1, 2))) |
|
|
|
context.reset_auto_parallel_context() |
|
|
|
|
|
|
|
|
|
|
|
def test_parameter_as_output(): |
|
|
|
context.reset_auto_parallel_context() |
|
|
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") |
|
|
|
initial_input = initializer('One', shape=(2,), dtype=mstype.int32) |
|
|
|
updated_input = Tensor([2, 2], mstype.int32) |
|
|
|
class Net(nn.Cell): |
|
|
|
def __init__(self, initial, updated): |
|
|
|
super().__init__() |
|
|
|
self.initial = initial |
|
|
|
self.updated = updated |
|
|
|
self.p = Parameter(self.initial, name="weight") |
|
|
|
self.new_p = self.p.init_data() |
|
|
|
self.new_p.set_parameter_data(self.updated) |
|
|
|
def construct(self): |
|
|
|
return self.new_p |
|
|
|
|
|
|
|
net = Net(initial_input, updated_input) |
|
|
|
output = net() |
|
|
|
assert np.array_equal(output.asnumpy(), np.array([2, 2], np.int32)) |
|
|
|
context.reset_auto_parallel_context() |