You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_auto_monad_gpu.py 22 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import os
  16. import re
  17. import subprocess
  18. import pytest
  19. import numpy as np
  20. import mindspore as ms
  21. import mindspore.ops.operations as P
  22. import mindspore.numpy as msnp
  23. from mindspore.nn import Cell
  24. from mindspore.nn import ReLU, BatchNorm2d, Conv2d, ParameterUpdate
  25. from mindspore.nn import Momentum
  26. from mindspore.nn import SoftmaxCrossEntropyWithLogits
  27. from mindspore import amp
  28. from mindspore import context, Tensor
  29. from mindspore.common import ParameterTuple
  30. from mindspore.common.parameter import Parameter
  31. from mindspore.ops.composite import GradOperation
  32. from tests.security_utils import security_off_wrap
  33. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  34. class _Grad(Cell):
  35. def __init__(self, grad, network, wrt_params=False, real_inputs_count=None):
  36. super().__init__()
  37. self.network = network
  38. self.grad = grad
  39. self.sens_param = self.grad.sens_param
  40. self.wrt_params = wrt_params
  41. self.real_inputs_count = real_inputs_count
  42. if self.wrt_params:
  43. self.params = ParameterTuple(self.network.trainable_params())
  44. def construct(self, *inputs):
  45. if self.real_inputs_count is None or self.sens_param is False:
  46. if self.wrt_params:
  47. return self.grad(self.network, self.params)(*inputs)
  48. return self.grad(self.network)(*inputs)
  49. real_inputs = inputs[:self.real_inputs_count]
  50. sense_param_inputs = inputs[self.real_inputs_count:]
  51. if self.wrt_params:
  52. return self.grad(self.network, self.params)(*real_inputs, sense_param_inputs)
  53. return self.grad(self.network)(*real_inputs, sense_param_inputs)
  54. class GradOfAllInputs(_Grad):
  55. '''
  56. get grads of all inputs
  57. '''
  58. def __init__(self, network, sens_param=True, real_inputs_count=None):
  59. super().__init__(grad=GradOperation(get_all=True, sens_param=sens_param),
  60. network=network, real_inputs_count=real_inputs_count)
  61. class GradOfAllInputsAndParams(_Grad):
  62. '''
  63. get grads of all inputs and params
  64. '''
  65. def __init__(self, network, sens_param=True, real_inputs_count=None):
  66. super().__init__(grad=GradOperation(get_all=True, get_by_list=True, sens_param=sens_param),
  67. network=network, wrt_params=True, real_inputs_count=real_inputs_count)
  68. def _count_unequal_element(data_expected, data_me, rtol, atol):
  69. assert data_expected.shape == data_me.shape
  70. total_count = len(data_expected.flatten())
  71. error = np.abs(data_expected - data_me)
  72. greater = np.greater(error, atol + np.abs(data_me)*rtol)
  73. loss_count = np.count_nonzero(greater)
  74. assert (loss_count/total_count) < rtol, \
  75. "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}".\
  76. format(data_expected[greater], data_me[greater], error[greater])
  77. def allclose_nparray(data_expected, data_me, rtol, atol, equal_nan=True):
  78. if np.any(np.isnan(data_expected)):
  79. assert np.allclose(data_expected, data_me, rtol,
  80. atol, equal_nan=equal_nan)
  81. elif not np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan):
  82. _count_unequal_element(data_expected, data_me, rtol, atol)
  83. else:
  84. assert True
  85. def clear_files():
  86. os.system("rm verbose_ir_files/*")
  87. def find_files(file, para):
  88. output = subprocess.check_output(
  89. ["grep '%s' verbose_ir_files/%s | wc -l" % (para, file)],
  90. shell=True)
  91. out = str(output, 'utf-8').strip()
  92. return out
  93. class SideEffectCastAll(Cell):
  94. def __init__(self):
  95. super().__init__()
  96. self.cast = P.Cast()
  97. self.dtype = ms.float16
  98. np.random.seed(5)
  99. inputs1 = np.random.randn(5, 5)
  100. inputs2 = np.random.randn(5, 5)
  101. self.parameter_a = Parameter(Tensor(inputs1, ms.float32), name="a")
  102. self.parameter_b = Parameter(Tensor(inputs2, ms.float32), name="b")
  103. self.assign = P.Assign()
  104. def construct(self, x, y):
  105. self.assign(self.parameter_a, x)
  106. self.assign(self.parameter_b, y)
  107. out_a = self.cast(self.parameter_a, self.dtype)
  108. out_b = self.cast(self.parameter_b, self.dtype)
  109. return out_a, out_b
  110. @security_off_wrap
  111. def test_side_effect_castall():
  112. clear_files()
  113. context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
  114. net = SideEffectCastAll()
  115. inputs1 = np.random.randn(5, 5)
  116. inputs2 = np.random.randn(5, 5)
  117. net(Tensor(inputs1, ms.float32), Tensor(inputs2, ms.float32))
  118. result = find_files('./hwopt*cast_all*.ir', 'CastAll')
  119. assert result == '2'
  120. class SideEffectControlFlowAssignDependWhileNet(Cell):
  121. def __init__(self):
  122. super().__init__()
  123. self.parameter1 = Parameter(
  124. Tensor([199.0], ms.float32), name="parameter1")
  125. self.assign = P.Assign()
  126. self.assignadd = P.AssignAdd()
  127. self.addn = P.AddN()
  128. def construct(self, x, y, z):
  129. self.assign(self.parameter1, x)
  130. while self.parameter1 < y:
  131. x = self.addn((x, x))
  132. self.assignadd(self.parameter1, z)
  133. return x
  134. def grad_mindspore_impl(self, params1, params2, params3, grad_ys):
  135. grad_net = GradOfAllInputsAndParams(self)
  136. grad_net.set_train()
  137. grad_out = grad_net(params1, params2, params3, grad_ys)
  138. return grad_out
  139. @pytest.mark.level0
  140. @pytest.mark.platform_x86_gpu_training
  141. @pytest.mark.env_onecard
  142. def test_side_effect_control_flow_assign_depend_while_net():
  143. net = SideEffectControlFlowAssignDependWhileNet()
  144. context.set_context(mode=context.GRAPH_MODE)
  145. out1 = net(Tensor([9.0], ms.float32), Tensor(
  146. [99.0], ms.float32), Tensor([1.0], ms.float32))
  147. net = SideEffectControlFlowAssignDependWhileNet()
  148. context.set_context(mode=context.PYNATIVE_MODE)
  149. out2 = net(Tensor([9.0], ms.float32), Tensor(
  150. [99.0], ms.float32), Tensor([1.0], ms.float32))
  151. allclose_nparray(out1.asnumpy(), out2.asnumpy(), 0.001, 0.001)
  152. class Addn(Cell):
  153. def __init__(self):
  154. super().__init__()
  155. self.parameter3 = Parameter(Tensor([1.0], ms.float32),
  156. name="parameter3")
  157. self.parameter4 = Parameter(Tensor([3.0], ms.float32),
  158. name="parameter4")
  159. self.addn = P.AddN()
  160. def construct(self, inputs):
  161. out = self.addn((inputs, self.parameter3, self.parameter4))
  162. return out
  163. class Relu(Cell):
  164. def __init__(self):
  165. super().__init__()
  166. self.relu = P.ReLU()
  167. def construct(self, inputs):
  168. out = self.relu(inputs)
  169. return out
  170. class SideEffectTwoAssignTwoAddnDependencyNet(Cell):
  171. def __init__(self):
  172. super().__init__()
  173. self.parameter1 = Parameter(Tensor([1.0], ms.float32),
  174. name="parameter1")
  175. self.parameter2 = Parameter(Tensor([3.0], ms.float32),
  176. name="parameter2")
  177. self.assign = P.Assign()
  178. self.addN = P.AddN()
  179. def construct(self, inputs):
  180. self.assign(self.parameter1, inputs)
  181. out = self.addN((inputs, self.parameter1, self.parameter2))
  182. self.assign(self.parameter2, inputs)
  183. out = self.addN((out, self.parameter1, self.parameter2))
  184. return out
  185. def grad_mindspore_impl(self, params, grad_ys):
  186. grad_net = GradOfAllInputsAndParams(self)
  187. grad_net.set_train()
  188. grad_out = grad_net(params, grad_ys)
  189. return grad_out
  190. @pytest.mark.level1
  191. @pytest.mark.platform_x86_gpu_training
  192. @pytest.mark.env_onecard
  193. def test_ctrl_while_by_while_and_if_in_first_while():
  194. class Net(Cell):
  195. def __init__(self):
  196. super().__init__()
  197. self.relu = P.ReLU()
  198. self.sigmoid = P.Sigmoid()
  199. self.tanh = P.Tanh()
  200. self.add = P.Add()
  201. a = np.full((1,), 5, dtype=np.float32)
  202. self.a = Parameter(Tensor(a), name="a")
  203. b = np.full((1,), 4, dtype=np.float32)
  204. self.b = Parameter(Tensor(b), name="b")
  205. c = np.full((1,), 7, dtype=np.float32)
  206. self.c = Parameter(Tensor(c), name="c")
  207. def construct(self, x):
  208. out = x
  209. while self.a < 7:
  210. if self.a < self.c:
  211. out = self.relu(x)
  212. self.a += 1
  213. while self.c > 5:
  214. out = self.add(out, out)
  215. self.c -= 1
  216. return out
  217. context.set_context(mode=context.GRAPH_MODE)
  218. input_np_a = np.random.randn(2, 3, 4, 5).astype(np.float32)
  219. input_me_a = Tensor(input_np_a)
  220. net = Net()
  221. net(input_me_a)
  222. @pytest.mark.level1
  223. @pytest.mark.platform_x86_gpu_training
  224. @pytest.mark.env_onecard
  225. def test_ctrl_while_by_while_and_while_in_first_while():
  226. class Net(Cell):
  227. def __init__(self):
  228. super().__init__()
  229. self.relu = P.ReLU()
  230. self.sigmoid = P.Sigmoid()
  231. self.tanh = P.Tanh()
  232. self.add = P.Add()
  233. a = np.full((1,), 5, dtype=np.float32)
  234. self.a = Parameter(Tensor(a), name="a")
  235. b = np.full((1,), 4, dtype=np.float32)
  236. self.b = Parameter(Tensor(b), name="b")
  237. c = np.full((1,), 7, dtype=np.float32)
  238. self.c = Parameter(Tensor(c), name="c")
  239. def construct(self, x):
  240. out = x
  241. while self.a < self.c:
  242. out = self.relu(x)
  243. while self.b > 1:
  244. self.b -= 1
  245. self.a += 1
  246. while self.c > 5:
  247. out = self.add(out, out)
  248. self.c -= 1
  249. return out
  250. context.set_context(mode=context.GRAPH_MODE)
  251. input_np_a = np.random.randn(2, 3, 4, 5).astype(np.float32)
  252. input_me_a = Tensor(input_np_a)
  253. net = Net()
  254. net(input_me_a)
  255. class InplaceNet(Cell):
  256. def __init__(self):
  257. super().__init__()
  258. self.bn1 = BatchNorm2d(num_features=4, eps=1e-4,
  259. momentum=0.9, gamma_init=1, beta_init=0,
  260. moving_mean_init=0, moving_var_init=1, data_format="NHWC")
  261. self.bn2 = BatchNorm2d(num_features=4, eps=1e-4,
  262. momentum=0.9, gamma_init=1, beta_init=0,
  263. moving_mean_init=0, moving_var_init=1, data_format="NHWC")
  264. self.add = P.Add()
  265. self.relu = ReLU()
  266. self.conv2d1 = Conv2d(in_channels=4, out_channels=4,
  267. kernel_size=2, data_format="NHWC")
  268. self.conv2d2 = Conv2d(in_channels=4, out_channels=4,
  269. kernel_size=2, data_format="NHWC")
  270. self.conv2d3 = Conv2d(in_channels=4, out_channels=4,
  271. kernel_size=2, data_format="NHWC")
  272. self.conv2d4 = Conv2d(in_channels=4, out_channels=4,
  273. kernel_size=2, data_format="NHWC")
  274. def construct(self, input_x):
  275. tmp_c1 = self.conv2d1(input_x)
  276. tmp_c2 = self.conv2d2(input_x)
  277. tmp_x = self.bn1(tmp_c1)
  278. tmp_y = self.bn2(tmp_c2)
  279. tmp_w = self.add(tmp_x, tmp_y)
  280. tmp_w = self.relu(tmp_w)
  281. tmp_c1 = self.conv2d3(tmp_w)
  282. tmp_c2 = self.conv2d4(tmp_w)
  283. output = self.add(tmp_c1, tmp_c2)
  284. return output
  285. @security_off_wrap
  286. def test_ir_fusion_inplace_bn_conv_conv():
  287. clear_files()
  288. context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
  289. input_np = np.random.uniform(0.0, 255.0,
  290. size=[4, 4, 4, 4]).astype(np.float32)
  291. label = np.ones([4, 4, 4, 4]).astype(np.float32)
  292. net = InplaceNet()
  293. loss = SoftmaxCrossEntropyWithLogits(sparse=False)
  294. opt = Momentum(learning_rate=0.01, momentum=0.9,
  295. params=filter(lambda x: x.requires_grad, net.get_parameters()))
  296. net = amp.build_train_network(net, opt, loss, level="O2",
  297. keep_batchnorm_fp32=False)
  298. net.set_train()
  299. net(Tensor(input_np), Tensor(label))
  300. find_accum = find_files("./hwopt*cudnn_inplace*ir",
  301. "inplace_algo: accumulation")
  302. find_cover = find_files("./hwopt*cudnn_inplace*ir",
  303. "inplace_algo: cover")
  304. assert find_accum == '1'
  305. assert find_cover == '1'
  306. def clean_all_ir_files(folder_path):
  307. if os.path.exists(folder_path):
  308. for file_name in os.listdir(folder_path):
  309. if file_name.endswith('.ir') or file_name.endswith('.dot') or \
  310. file_name.endswith('.dat'):
  311. os.remove(os.path.join(folder_path, file_name))
  312. def find_newest_validateir_file(folder_path):
  313. ckpt_files = map(lambda f: os.path.join(folder_path, f),
  314. filter(lambda f: re.match(r'\d+_validate_\d+.ir', f),
  315. os.listdir(folder_path)))
  316. return max(ckpt_files, key=os.path.getctime)
  317. def read_file():
  318. filename = find_newest_validateir_file('./')
  319. with open((os.path.join(filename)), 'r') as f:
  320. content = f.read()
  321. clean_all_ir_files('./')
  322. return content
  323. class Add(Cell):
  324. def __init__(self):
  325. super().__init__()
  326. self.add = P.Add()
  327. def construct(self, x, y):
  328. return self.add(x, y)
  329. class MixControlNet(Cell):
  330. def __init__(self, in_channel, x):
  331. super().__init__()
  332. #self._save_graphs(save_graph_flag=True, save_graph_path=".")
  333. self.biasadd = P.BiasAdd()
  334. self.equal = P.Equal()
  335. self.addn = P.AddN()
  336. self.conv = Conv2d(in_channels=in_channel, out_channels=in_channel,
  337. kernel_size=1, stride=1, has_bias=False,
  338. weight_init='ones', pad_mode='same')
  339. self.bn = BatchNorm2d(num_features=in_channel)
  340. self.assignadd = P.AssignAdd()
  341. self.assign = P.Assign()
  342. self.relu = ReLU()
  343. self.mean = P.ReduceMean(keep_dims=False)
  344. self.bias = Parameter(
  345. Tensor(np.random.randint(2, size=(3,)).astype((np.float32))),
  346. name="bias")
  347. self.bias2 = Parameter(Tensor(np.ones([3]).astype(np.float32)),
  348. name="bias2")
  349. self.parameterupdate = ParameterUpdate(self.bias)
  350. self.value = Tensor(np.random.randn(*(3,)), ms.float32)
  351. self.x = x
  352. def construct(self, input_x):
  353. x = self.x
  354. z = self.x
  355. out = self.biasadd(input_x, self.bias)
  356. while x < 20:
  357. update = self.parameterupdate(self.bias2)
  358. out = self.biasadd(out, update)
  359. if x < 10:
  360. out = self.addn((input_x, out))
  361. while z < 20:
  362. out = self.conv(out)
  363. z = z + 1
  364. if x < 20:
  365. out = self.biasadd(out, self.bias)
  366. if x % 2 == 0:
  367. self.assignadd(self.bias, self.value)
  368. out = self.biasadd(out, self.bias)
  369. out = self.bn(out)
  370. else:
  371. out = self.conv(out)
  372. x = x + 1
  373. out = self.addn((out, out))
  374. out = self.mean(out, (2, 3))
  375. return out
  376. def use_build_train_network_controlflow_check_cast_num(network, level, input_x,
  377. label, cast_num,
  378. sparse=False,
  379. loss_flag=True,
  380. **kwargs):
  381. opt = Momentum(learning_rate=0.0001, momentum=0.009,
  382. params=network.trainable_params())
  383. loss = None
  384. if loss_flag:
  385. loss = SoftmaxCrossEntropyWithLogits(sparse=sparse, reduction='mean')
  386. train_network = ms.amp.build_train_network(network, opt, loss, level=level,
  387. **kwargs)
  388. out_me = train_network(input_x, label)
  389. if context.get_context("mode") == 0:
  390. content = read_file()
  391. castnum = re.findall('Cast', content)
  392. assert len(castnum) == cast_num
  393. return out_me
  394. @security_off_wrap
  395. def test_auto_mixed_precision_controlflow_auto():
  396. context.set_context(mode=context.PYNATIVE_MODE, save_graphs=True)
  397. net = MixControlNet(3, 5)
  398. input_x = Tensor(
  399. np.random.randint(2, size=(1, 3, 2, 2)).astype((np.float32)))
  400. label = Tensor(np.zeros([1, 3]).astype(np.float32))
  401. if ms.context.get_context("device_target") == "Ascend":
  402. cast_num = 77
  403. if ms.context.get_context("device_target") == "GPU":
  404. cast_num = 73
  405. use_build_train_network_controlflow_check_cast_num(net, "auto", input_x,
  406. label, cast_num)
  407. @security_off_wrap
  408. def test_updatestate_between_assigns():
  409. class UpdateState_Assigns(Cell):
  410. def __init__(self):
  411. super().__init__()
  412. self.para1 = Parameter(Tensor(1, dtype=ms.int32), name='para1')
  413. self.para2 = Parameter(Tensor(3, dtype=ms.int32), name='para2')
  414. def construct(self, value1, value2):
  415. self.para1 = value1
  416. self.para2 = value2
  417. return self.para2
  418. context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
  419. input_x = Tensor(10, dtype=ms.int32)
  420. input_y = Tensor(30, dtype=ms.int32)
  421. expect = Tensor(30, dtype=ms.int32)
  422. net = UpdateState_Assigns()
  423. out = net(input_x, input_y)
  424. np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy())
  425. if ms.context.get_context('mode') == 0:
  426. content = read_file()
  427. updatestate_num = re.findall('UpdateState', content)
  428. assert len(updatestate_num) == 1
  429. @security_off_wrap
  430. def test_updatestate_between_maketuple_assign():
  431. class UpdateState_MakeTuple_Assign(Cell):
  432. def __init__(self):
  433. super().__init__()
  434. self.para1 = Parameter(Tensor(1, dtype=ms.int32), name='para1')
  435. self.para2 = Parameter(Tensor(3, dtype=ms.int32), name='para2')
  436. self.para3 = Parameter(Tensor(5, dtype=ms.int32), name='para3')
  437. def construct(self, value1, value2, value3):
  438. (self.para1, self.para2) = (value1, value2)
  439. self.para3 = value3
  440. return self.para3
  441. context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
  442. input_x = Tensor(10, dtype=ms.int32)
  443. input_y = Tensor(30, dtype=ms.int32)
  444. input_z = Tensor(50, dtype=ms.int32)
  445. expect = Tensor(50, dtype=ms.int32)
  446. net = UpdateState_MakeTuple_Assign()
  447. out = net(input_x, input_y, input_z)
  448. np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy())
  449. if ms.context.get_context('mode') == 0:
  450. content = read_file()
  451. updatestate_num = re.findall('UpdateState', content)
  452. assert len(updatestate_num) == 1
  453. @security_off_wrap
  454. def test_updatestate_between_assign_maketuple():
  455. class UpdateState_Assign_MakeTuple(Cell):
  456. def __init__(self):
  457. super().__init__()
  458. self.para1 = Parameter(Tensor(1, dtype=ms.int32), name='para1')
  459. self.para2 = Parameter(Tensor(3, dtype=ms.int32), name='para2')
  460. self.para3 = Parameter(Tensor(5, dtype=ms.int32), name='para3')
  461. def construct(self, value1, value2, value3):
  462. self.para1 = value1
  463. (self.para2, self.para3) = (value2, value3)
  464. return self.para3
  465. context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
  466. input_x = Tensor(10, dtype=ms.int32)
  467. input_y = Tensor(30, dtype=ms.int32)
  468. input_z = Tensor(50, dtype=ms.int32)
  469. expect = Tensor(50, dtype=ms.int32)
  470. net = UpdateState_Assign_MakeTuple()
  471. out = net(input_x, input_y, input_z)
  472. np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy())
  473. if ms.context.get_context('mode') == 0:
  474. content = read_file()
  475. updatestate_num = re.findall('UpdateState', content)
  476. assert len(updatestate_num) == 1
  477. @pytest.mark.level1
  478. @pytest.mark.platform_x86_gpu_training
  479. @pytest.mark.env_onecard
  480. def test_cycle_parameter_binding():
  481. """
  482. Feature: Auto-monad side-effect finder.
  483. Description: Auto-monad should work properly when cycle parameter binding existed.
  484. Expectation: Normal output, no core dump.
  485. """
  486. class MyActor(Cell):
  487. def construct(self, inputs):
  488. return inputs
  489. class MyCell(Cell):
  490. def __init__(self, actor_list):
  491. super().__init__()
  492. self.zero = Tensor(0, ms.int32)
  493. self.actor_list = actor_list
  494. def construct(self, state):
  495. duration = self.zero
  496. while duration < 2:
  497. for n in msnp.arange(3):
  498. samples = (state[n])
  499. x = self.actor_list[n](samples)
  500. print(x)
  501. duration += 1
  502. return duration
  503. actor_list = [MyActor(), MyActor(), MyActor()]
  504. net = MyCell(actor_list)
  505. state = Tensor(np.ones((3, 3)), ms.float32)
  506. out = net(state)
  507. np.testing.assert_allclose(out.asnumpy(), 2)