You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

prim_nn.cc 19 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "pipeline/static_analysis/prim.h"
  17. #include "operator/ops.h"
  18. #include "pipeline/static_analysis/utils.h"
  19. #include "pipeline/static_analysis/param_validator.h"
  20. namespace mindspore {
  21. namespace abstract {
  22. AbstractBasePtr InferImplPooling(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  23. const AbstractBasePtrList &args_spec_list) {
  24. // Inputs: a tensor.
  25. const std::string op_name = primitive->name();
  26. CheckArgsSize(op_name, args_spec_list, 1);
  27. AbstractTensorPtr input_tensor = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
  28. (void)CheckTensorDType(input_tensor, {kFloat16, kFloat32}, "Input 0 of Pooling should be %s");
  29. ShapePtr input_shape = dyn_cast<Shape>(input_tensor->GetShapeTrack()); // NCHW
  30. MS_EXCEPTION_IF_NULL(input_shape);
  31. if (input_shape->shape().size() != 4) {
  32. MS_LOG(EXCEPTION) << "Pooling input should be a 4-D tensor.";
  33. }
  34. int h_input = input_shape->shape()[2];
  35. int w_input = input_shape->shape()[3];
  36. int window = primitive->GetAttr("window")->cast<Int32ImmPtr>()->value();
  37. int stride = primitive->GetAttr("stride")->cast<Int32ImmPtr>()->value();
  38. int padding = primitive->GetAttr("pad")->cast<Int32ImmPtr>()->value();
  39. int nan_opt = primitive->GetAttr("nan_opt")->cast<Int32ImmPtr>()->value();
  40. int data_mode = primitive->GetAttr("data_mode")->cast<Int32ImmPtr>()->value();
  41. int ceil_mode = primitive->GetAttr("ceil_mode")->cast<Int32ImmPtr>()->value();
  42. if (stride <= 0) {
  43. MS_LOG(EXCEPTION) << "Invalid stride value: " << stride << ", should greater then 0";
  44. }
  45. if (nan_opt != 0) {
  46. MS_LOG(EXCEPTION) << "Invalid nan_opt value: " << nan_opt << ", should be 0";
  47. }
  48. if (data_mode != 1) {
  49. MS_LOG(EXCEPTION) << "Invalid data_mode value: " << data_mode << ", should be 1";
  50. }
  51. if (ceil_mode != 0) {
  52. MS_LOG(EXCEPTION) << "Invalid ceil_mode value: " << ceil_mode << ", should be 0";
  53. }
  54. std::set<std::string> available_pad_mode{"pad", "same", "valid"};
  55. auto pad_mode_ptr = primitive->GetAttr("pad_mode");
  56. if ((pad_mode_ptr != nullptr) && pad_mode_ptr->isa<StringImm>()) {
  57. auto pad_mode = pad_mode_ptr->cast<StringImmPtr>()->value();
  58. if (available_pad_mode.find(pad_mode) == available_pad_mode.end()) {
  59. MS_LOG(EXCEPTION) << "Unsupported pad mode: " << pad_mode << ". use pad, same, valid";
  60. }
  61. if (pad_mode == "valid") {
  62. padding = 0;
  63. } else if (pad_mode == "same") {
  64. padding = (window - 1) / 2;
  65. }
  66. }
  67. std::set<std::string> available_mode{"max", "avg"};
  68. auto mode_ptr = primitive->GetAttr("mode");
  69. if ((mode_ptr != nullptr) && mode_ptr->isa<StringImm>()) {
  70. auto mode = mode_ptr->cast<StringImmPtr>()->value();
  71. if (available_mode.find(mode) == available_mode.end()) {
  72. MS_LOG(EXCEPTION) << "Unsupported pooling mode: " << mode << ".";
  73. }
  74. }
  75. int h_out = ((h_input + 2 * padding - (window - 1) - 1) / stride) + 1;
  76. int w_out = ((w_input + 2 * padding - (window - 1) - 1) / stride) + 1;
  77. std::vector<int> shape_out = {input_shape->shape()[0], input_shape->shape()[1], h_out, w_out};
  78. AbstractBasePtr ret = input_tensor->Broaden();
  79. ret->set_shape(std::make_shared<Shape>(shape_out));
  80. return ret;
  81. }
  82. AbstractBasePtr InferImplPoolingGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  83. const AbstractBasePtrList &args_spec_list) {
  84. // Inputs: three tensors(y, dy, x).
  85. const std::string op_name = primitive->name();
  86. CheckArgsSize(op_name, args_spec_list, 3);
  87. auto out_y = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
  88. auto d_out = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
  89. auto input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 2);
  90. (void)CheckTensorsDTypeSame({out_y, d_out, input_x}, {kInt, kUInt, kFloat},
  91. op_name + "evaluator three inputs should be %s");
  92. AbstractBasePtr ret = d_out->Broaden();
  93. auto x_shape = dyn_cast<Shape>(args_spec_list[2]->GetShapeTrack());
  94. MS_EXCEPTION_IF_NULL(x_shape);
  95. ret->set_shape(x_shape);
  96. return ret;
  97. }
  98. void FusedBatchNormCheckDim(const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) {
  99. // check dimension, x > 1, others equal 1
  100. const std::string op_name = primitive->name();
  101. for (std::size_t i = 0; i < args_spec_list.size(); ++i) {
  102. AbstractTensorPtr arg = CheckArg<AbstractTensor>(op_name, args_spec_list, i);
  103. ShapePtr arg_shape = dyn_cast<Shape>(arg->GetShapeTrack());
  104. if (arg_shape == nullptr) {
  105. MS_LOG(EXCEPTION) << op_name << " type of args[" << i << "] should be Shape, but " << arg->ToString();
  106. }
  107. if (i == 0) {
  108. if (arg_shape->shape().size() < 2) {
  109. MS_LOG(EXCEPTION) << op_name << " shape of args[" << i
  110. << "] should be TensorShape with dimension greater than 1, but shape: "
  111. << arg_shape->ToString();
  112. }
  113. continue;
  114. }
  115. if (arg_shape->shape().size() != 1) {
  116. MS_LOG(EXCEPTION) << op_name << " shape of args[" << i
  117. << "] should be TensorShape with dimension: 1, but shape: " << arg_shape->ToString();
  118. }
  119. }
  120. }
  121. AbstractBasePtr InferImplFusedBatchNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  122. const AbstractBasePtrList &args_spec_list) {
  123. // Inputs: five tensors(x, gamma, beta, mean, variance).
  124. const std::string op_name = primitive->name();
  125. CheckArgsSize(op_name, args_spec_list, 5);
  126. MS_EXCEPTION_IF_NULL(args_spec_list[0]);
  127. MS_LOG(DEBUG) << "InferImplFusedBatchNorm args0:" << args_spec_list[0]->ToString()
  128. << ", arg1:" << args_spec_list[1]->ToString();
  129. FusedBatchNormCheckDim(primitive, args_spec_list);
  130. auto input = args_spec_list[0];
  131. auto input_shape = dyn_cast<Shape>(input->GetShapeTrack());
  132. MS_EXCEPTION_IF_NULL(input_shape);
  133. const auto &input_shape_list = input_shape->shape();
  134. if (input_shape_list.size() < 2) {
  135. MS_LOG(EXCEPTION) << "Input shape size should >= 2.";
  136. }
  137. for (size_t i = 1; i < args_spec_list.size(); ++i) {
  138. auto arg_shape = dyn_cast<Shape>(args_spec_list[i]->GetShapeTrack());
  139. MS_EXCEPTION_IF_NULL(arg_shape);
  140. const auto &arg_shape_list = arg_shape->shape();
  141. if (arg_shape_list.size() < 1) {
  142. MS_LOG(EXCEPTION) << "Arg shape size should >= 1.";
  143. }
  144. if (arg_shape_list[0] != input_shape_list[1]) {
  145. MS_LOG(EXCEPTION) << op_name << " size of tensor param[" << i << "](which is " << arg_shape_list[0]
  146. << ") should match the second dimension of tensor"
  147. " param[0](which is "
  148. << input_shape_list[1] << ").";
  149. }
  150. }
  151. auto input_tensor = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
  152. (void)CheckTensorDType(input_tensor, {kFloat16, kFloat32}, "param 0 of FusedBatchNorm should be %s");
  153. AbstractTensorPtrList tensorPtrList = std::vector<AbstractTensorPtr>();
  154. for (size_t i = 1; i < args_spec_list.size(); ++i) {
  155. auto param = CheckArg<AbstractTensor>(op_name, args_spec_list, i);
  156. tensorPtrList.push_back(param);
  157. }
  158. (void)CheckTensorsDTypeSame(tensorPtrList, {kFloat16, kFloat32}, "param 1 to 4 of FusedBatchNorm should be %s");
  159. // check validity;
  160. auto epsilon_value = primitive->GetAttr("epsilon");
  161. auto momentum_value = primitive->GetAttr("momentum");
  162. MS_EXCEPTION_IF_NULL(epsilon_value);
  163. MS_EXCEPTION_IF_NULL(momentum_value);
  164. if (!epsilon_value->isa<FP32Imm>() || !momentum_value->isa<FP32Imm>()) {
  165. MS_LOG(EXCEPTION) << "expect epsilon and momentum be float, but: epsilon: " << epsilon_value->ToString()
  166. << ", momentum: " << momentum_value->ToString();
  167. }
  168. auto epsilon = epsilon_value->cast<FP32ImmPtr>()->value();
  169. auto momentum = momentum_value->cast<FP32ImmPtr>()->value();
  170. if (epsilon > 1.0f || epsilon <= 0.0f) {
  171. MS_LOG(EXCEPTION) << "expect epsilon is greater than 0 and less or equal than 1, but epsilon: " << epsilon;
  172. }
  173. if (momentum > 1.0f || momentum < 0.0f) {
  174. MS_LOG(EXCEPTION) << "expect momentum is great or equal than 0 and less or equal than 1, but epsilon: " << momentum;
  175. }
  176. // Outputs: y, running_mean, running_variance, save_mean, save_inv_variance.
  177. AbstractBasePtr y = input->Broaden();
  178. AbstractBasePtr other = args_spec_list[1]->Broaden();
  179. MS_LOG(DEBUG) << "output y: " << y->ToString() << ", other: " << other->ToString();
  180. AbstractBasePtrList elements = {y, other, other, other, other};
  181. return std::make_shared<AbstractTuple>(elements);
  182. }
  183. AbstractBasePtr InferImplFusedBatchNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  184. const AbstractBasePtrList &args_spec_list) {
  185. // Inputs: five tensors(y_backprop, x, scale, save_mean, save_inv_variance).
  186. MS_EXCEPTION_IF_NULL(args_spec_list[1]);
  187. MS_EXCEPTION_IF_NULL(args_spec_list[2]);
  188. MS_EXCEPTION_IF_NULL(args_spec_list[3]);
  189. CheckArgsSize(primitive->name(), args_spec_list, 5);
  190. auto dx = args_spec_list[1]->Broaden();
  191. auto dscale = args_spec_list[2]->Broaden();
  192. auto dbias = args_spec_list[3]->Broaden();
  193. AbstractBasePtrList rets = {dx, dscale, dbias};
  194. return std::make_shared<AbstractTuple>(rets);
  195. }
  196. AbstractBasePtr InferImplReluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  197. const AbstractBasePtrList &args_spec_list) {
  198. // Inputs: two tensors(y_backprop, x).
  199. CheckArgsSize(primitive->name(), args_spec_list, 2);
  200. return args_spec_list[1]->Broaden();
  201. }
  202. AbstractBasePtr InferImplConv2DBackpropInput(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  203. const AbstractBasePtrList &args_spec_list) {
  204. // Inputs: three tensors(doutput, input, filters).
  205. CheckArgsSize(primitive->name(), args_spec_list, 3);
  206. return args_spec_list[1]->Broaden();
  207. }
  208. AbstractBasePtr InferImplConv2DBackpropFilter(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  209. const AbstractBasePtrList &args_spec_list) {
  210. // Inputs: three tensors(inputs, filter, doutput).
  211. CheckArgsSize(primitive->name(), args_spec_list, 3);
  212. return args_spec_list[2]->Broaden();
  213. }
  214. AbstractBasePtr InferImplBiasAddGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  215. const AbstractBasePtrList &args_spec_list) {
  216. // Inputs: at least one tensor(y_backprop)
  217. // Outputs: dbias
  218. if (args_spec_list.empty()) {
  219. MS_LOG(EXCEPTION) << primitive->name() << " evaluator at least has 1 parameters, while the input size is "
  220. << args_spec_list.size() << ".";
  221. }
  222. MS_EXCEPTION_IF_NULL(args_spec_list[0]);
  223. ShapePtr shape_y = dyn_cast<Shape>(args_spec_list[0]->GetShapeTrack());
  224. MS_EXCEPTION_IF_NULL(shape_y);
  225. std::vector<int> y_dims = shape_y->shape();
  226. if (y_dims.size() < 2) {
  227. MS_LOG(EXCEPTION) << primitive->name() << " input y backprop, dim should >= 2, while " << y_dims.size() << ".";
  228. }
  229. std::vector<int> bias_dims = {y_dims[1]};
  230. ShapePtr ret_shape = std::make_shared<Shape>(bias_dims);
  231. AbstractBasePtr ret = args_spec_list[0]->Broaden();
  232. ret->set_shape(ret_shape);
  233. return ret;
  234. }
  235. AbstractBasePtr InferImplRelu(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  236. const AbstractBasePtrList &args_spec_list) {
  237. // Inputs: a tensor.
  238. CheckArgsSize(primitive->name(), args_spec_list, 1);
  239. return args_spec_list[0]->Broaden();
  240. }
  241. AbstractBasePtr InferImplZerosLike(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  242. const AbstractBasePtrList &args_spec_list) {
  243. // Inputs: a tensor.
  244. CheckArgsSize(primitive->name(), args_spec_list, 1);
  245. return args_spec_list[0]->Broaden();
  246. }
  247. AbstractBasePtr InferImplFakeBprop(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  248. const AbstractBasePtrList &args_spec_list) {
  249. // Inputs: a tensor.
  250. CheckArgsSize(primitive->name(), args_spec_list, 1);
  251. return args_spec_list[0]->Broaden();
  252. }
  253. AbstractBasePtr InferImplBpropCut(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  254. const AbstractBasePtrList &args_spec_list) {
  255. // Inputs: a tensor.
  256. AbstractBasePtrList args_list;
  257. for (size_t i = 0; i < args_spec_list.size() - 2; i++) {
  258. args_list.push_back(args_spec_list[i]->Broaden());
  259. }
  260. return std::make_shared<AbstractTuple>(args_list);
  261. }
  262. AbstractBasePtr InferImplLayerNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  263. const AbstractBasePtrList &args_spec_list) {
  264. // Inputs: three tensors(x, gamma, beta).
  265. // outputs: y, mean, variance
  266. const std::string op_name = primitive->name();
  267. CheckArgsSize(op_name, args_spec_list, 3);
  268. auto input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
  269. auto input_shape = input_x->shape();
  270. auto const &input_shape_list = input_shape->shape();
  271. const size_t input_rank = input_shape_list.size();
  272. if (input_rank == 0) {
  273. MS_LOG(EXCEPTION) << "input_rank should not be zero";
  274. }
  275. // begin_norm_axis and begin_params_axis should be smaller than the size of input_x and >= -1
  276. ValuePtr bna_ptr = primitive->GetAttr("begin_norm_axis");
  277. int begin_norm_axis = CheckAxis(op_name, bna_ptr, -1, SizeToInt(input_rank) - 1);
  278. ValuePtr bpa_ptr = primitive->GetAttr("begin_params_axis");
  279. int begin_params_axis = CheckAxis(op_name, bpa_ptr, -1, SizeToInt(input_rank) - 1);
  280. begin_params_axis = GetPositiveAxis(begin_params_axis, input_rank);
  281. // the beta and gama shape should be x_shape[begin_params_axis:]
  282. auto tensor = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
  283. auto gamma = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
  284. auto beta = CheckArg<AbstractTensor>(op_name, args_spec_list, 2);
  285. (void)CheckTensorDType(tensor, {kFloat16, kFloat32}, "input 0 of LayerNorm should be %s");
  286. (void)CheckTensorDType(gamma, {kFloat16, kFloat32}, "input 1 of LayerNorm should be %s");
  287. (void)CheckTensorDType(beta, {kFloat16, kFloat32}, "input 2 of LayerNorm should be %s");
  288. auto gamma_shape = dyn_cast<Shape>(gamma->BuildShape());
  289. auto beta_shape = dyn_cast<Shape>(beta->BuildShape());
  290. MS_EXCEPTION_IF_NULL(gamma_shape);
  291. MS_EXCEPTION_IF_NULL(beta_shape);
  292. auto const &gamma_shape_list = gamma_shape->shape();
  293. auto const &beta_shape_list = beta_shape->shape();
  294. if (gamma_shape_list.empty() || beta_shape_list.empty()) {
  295. MS_LOG(EXCEPTION) << "LayerNorm evaluator gamma or beta is a AbstractScalar that is not support.";
  296. }
  297. size_t begin_params_axis_u = IntToSize(begin_params_axis);
  298. if ((begin_params_axis_u > input_shape_list.size()) ||
  299. (gamma_shape_list.size() + begin_params_axis_u < input_shape_list.size()) ||
  300. (beta_shape_list.size() + begin_params_axis_u < input_shape_list.size())) {
  301. MS_LOG(EXCEPTION) << "Gamma and beta shape get wrong size.";
  302. }
  303. for (size_t i = begin_params_axis_u; i < input_shape_list.size(); ++i) {
  304. size_t gamma_beta_shape_dim = i - begin_params_axis_u;
  305. if ((gamma_shape_list[gamma_beta_shape_dim] != input_shape_list[i]) ||
  306. (beta_shape_list[gamma_beta_shape_dim] != input_shape_list[i])) {
  307. MS_LOG(EXCEPTION) << "Gamma or beta shape not match input shape, input_shape=" << input_shape->ToString()
  308. << ", gamma_shape=" << gamma_shape->ToString() << ", beta_shape=" << beta_shape->ToString();
  309. }
  310. }
  311. auto mean_var_shape_value = input_shape->shape();
  312. if (begin_norm_axis == -1) {
  313. mean_var_shape_value[input_rank - 1] = 1;
  314. } else {
  315. for (size_t i = begin_norm_axis; i < input_rank; ++i) {
  316. mean_var_shape_value[i] = 1;
  317. }
  318. }
  319. auto mean = input_x->Broaden();
  320. mean->set_shape(std::make_shared<Shape>(mean_var_shape_value));
  321. auto var = input_x->Broaden();
  322. var->set_shape(std::make_shared<Shape>(mean_var_shape_value));
  323. AbstractBasePtrList args_list({input_x->Broaden(), mean, var});
  324. return std::make_shared<AbstractTuple>(args_list);
  325. }
  326. AbstractBasePtr InferImplLayerNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  327. const AbstractBasePtrList &args_spec_list) {
  328. // Inputs: five tensors(y_backprob, x, variance, mean, gamma).
  329. // Outputs: x_backprob, gamma_backprob, beta_backprob
  330. CheckArgsSize(primitive->name(), args_spec_list, 5);
  331. auto x_backprob = args_spec_list[0]->Broaden();
  332. auto gamma_backprob = args_spec_list[4]->Broaden();
  333. auto beta_backprob = args_spec_list[4]->Broaden();
  334. AbstractBasePtrList args_list({x_backprob, gamma_backprob, beta_backprob});
  335. return std::make_shared<AbstractTuple>(args_list);
  336. }
  337. AbstractBasePtr InferImplDropoutGenMask(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  338. const AbstractBasePtrList &args_spec_list) {
  339. // Inputs: a tuple and a tensor.
  340. // Outputs: mask.
  341. const std::string op_name = primitive->name();
  342. CheckArgsSize(op_name, args_spec_list, 2);
  343. AbstractTuplePtr x_shape = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
  344. AbstractTensorPtr keep_prob = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
  345. TypePtr prob_type = keep_prob->element()->BuildType();
  346. if ((prob_type->type_id() != kNumberTypeFloat16) && (prob_type->type_id() != kNumberTypeFloat32)) {
  347. MS_LOG(EXCEPTION) << op_name << " keep_prob type should be float16 or float32, but " << prob_type->ToString()
  348. << ".";
  349. }
  350. auto x_shape_data = x_shape->elements();
  351. int count = 1;
  352. for (std::size_t i = 0; i < x_shape->size(); ++i) {
  353. auto value_track = x_shape_data[i]->GetValueTrack();
  354. MS_EXCEPTION_IF_NULL(value_track);
  355. if (!value_track->isa<Int32Imm>()) {
  356. MS_LOG(EXCEPTION) << "DropOutGenMask input x_shape elements is not int32, but " << value_track->ToString() << ".";
  357. }
  358. int e_value = GetValue<int>(value_track);
  359. if (e_value <= 0) {
  360. MS_LOG(EXCEPTION) << "DropOutGenMask product of x_shape should be > 0";
  361. }
  362. if (std::numeric_limits<int>::max() / count / e_value < 1) {
  363. MS_LOG(EXCEPTION) << "integer multiply integer overflow";
  364. }
  365. count = count * e_value;
  366. }
  367. // convert to bytes(8 bits) mask, using round up
  368. int n128s = count / 128;
  369. if ((count % 128) != 0) {
  370. n128s++;
  371. }
  372. int bytes_count = n128s * 16;
  373. std::vector<int> shape_y{bytes_count};
  374. primitive->set_attr("T", kInt32);
  375. return std::make_shared<AbstractTensor>(std::make_shared<AbstractScalar>(kAnyValue, kUInt8),
  376. std::make_shared<Shape>(std::vector<int>{shape_y}));
  377. }
  378. } // namespace abstract
  379. } // namespace mindspore