/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "ops/squeeze.h" namespace mindspore { namespace ops { void Squeeze::Init(const std::vector &axis) { set_axis(axis); } void Squeeze::set_axis(const std::vector &axis) { (void)AddAttr(kAxis, MakeValue(axis)); } std::vector Squeeze::get_axis() const { return GetValue>(GetAttr(kAxis)); } namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto op_name = primitive->name(); auto axis = GetValue>(primitive->GetAttr(kAxis)); std::vector infer_shape; auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape]; auto len = SizeToLong(in_shape.size()); if (axis.empty()) { (void)std::copy_if(in_shape.begin(), in_shape.end(), std::back_inserter(infer_shape), [](int64_t value) { return value != 1; }); } else { for (auto &item : axis) { CheckAndConvertUtils::CheckInRange("axis_or_elememt", item, kIncludeBoth, {-len, len + 1}, op_name); auto idx = item >= 0 ? item : len + item; if (in_shape[LongToSize(idx)] != 1L) { MS_EXCEPTION(ValueError) << "Cannot select an axis to squeeze out which has size not equal to one."; } } for (int64_t i = 0; i < len; i++) { auto it = std::find(axis.begin(), axis.end(), i); auto it2 = std::find(axis.begin(), axis.end(), i - len); if (!(it != axis.end() || it2 != axis.end())) { infer_shape.push_back(in_shape[LongToSize(i)]); } } } return std::make_shared(infer_shape); } TypePtr InferType(const PrimitivePtr &prim, const std::vector &input_args) { if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr arg) { return arg == nullptr; })) { MS_LOG(EXCEPTION) << "nullptr"; } return input_args[0]->BuildType(); } } // namespace AbstractBasePtr SqueezeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { return std::make_shared(InferType(primitive, input_args), InferShape(primitive, input_args)->shape()); } REGISTER_PRIMITIVE_C(kNameSqueeze, Squeeze); } // namespace ops } // namespace mindspore