You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pack.cc 3.4 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "ops/pack.h"
  17. namespace mindspore {
  18. namespace ops {
  19. namespace {
  20. std::vector<int64_t> _get_pack_shape(std::vector<BaseShapePtr> x_shapes, std::vector<TypePtr> x_types, int64_t axis,
  21. std::string name) {
  22. CheckAndConvertUtils::CheckInteger("len of input_x", (int64_t)x_shapes.size(), kGreaterEqual, 1, name);
  23. CheckAndConvertUtils::CheckSubClass("input_x[0]", x_types[0], {TypeIdToType(kObjectTypeTensorType)}, name);
  24. auto output_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape[0]", x_shapes[0], name);
  25. int64_t rank_base = output_shape.size();
  26. int64_t N = x_shapes.size();
  27. // CheckAndConvertUtils::CheckInRange("axis", axis, kIncludeBoth, {-rank_base-1, rank_base}, name);
  28. if (axis < 0) {
  29. axis = axis + rank_base + 1;
  30. }
  31. for (int64_t i = 1; i < N; i++) {
  32. auto type = x_types[i]->cast<TensorTypePtr>()->element();
  33. MS_EXCEPTION_IF_NULL(type);
  34. auto type0 = x_types[0]->cast<TensorTypePtr>()->element();
  35. MS_EXCEPTION_IF_NULL(type0);
  36. CheckAndConvertUtils::Check("x_type[" + std::to_string(i) + "]", type->type_id(), kEqual, "base", type0->type_id(),
  37. name);
  38. auto shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape" + std::to_string(i), x_shapes[i], name);
  39. if (shape != output_shape) {
  40. MS_EXCEPTION(ValueError) << "For '" + name + "' element " + std::to_string(i) +
  41. "shape in input can't pack with first element.";
  42. }
  43. }
  44. output_shape.insert(output_shape.begin() + axis, N);
  45. return output_shape;
  46. }
  47. } // namespace
  48. void Pack::set_axis(const int64_t &axis) { AddAttr(kAxis, MakeValue(axis)); }
  49. int64_t Pack::get_axis() const {
  50. auto value_ptr = this->GetAttr(kAxis);
  51. return GetValue<int64_t>(value_ptr);
  52. }
  53. void Pack::Init(const int64_t &axis) { this->set_axis(axis); }
  54. AbstractBasePtr PackInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
  55. const std::vector<AbstractBasePtr> &input_args) {
  56. MS_EXCEPTION_IF_NULL(primitive);
  57. auto pack_prim = primitive->cast<PrimPackPtr>();
  58. MS_EXCEPTION_IF_NULL(pack_prim);
  59. auto prim_name = pack_prim->name();
  60. auto x_shapes = input_args[0]->BuildShape()->cast<abstract::TupleShapePtr>()->shape();
  61. auto x_types = input_args[0]->BuildType()->cast<TuplePtr>()->elements();
  62. auto all_shape = _get_pack_shape(x_shapes, x_types, pack_prim->get_axis(), prim_name);
  63. auto tensor_type = x_types[0]->cast<TensorTypePtr>();
  64. MS_EXCEPTION_IF_NULL(tensor_type);
  65. auto data_type = tensor_type->element();
  66. MS_EXCEPTION_IF_NULL(data_type);
  67. return std::make_shared<abstract::AbstractTensor>(data_type, all_shape);
  68. }
  69. REGISTER_PRIMITIVE_C(kNamePack, Pack);
  70. } // namespace ops
  71. } // namespace mindspore