From: @lianliguang Reviewed-by: @zh_qh,@ginfung Signed-off-by: @zh_qhpull/15283/MERGE
| @@ -53,22 +53,10 @@ AbstractBasePtr InferImplReluGrad(const AnalysisEnginePtr &, const PrimitivePtr | |||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplConv2DBackpropInput(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| AbstractBasePtr InferImplConv2DBackpropFilter(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| AbstractBasePtr InferImplBiasAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplBiasAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplBiasAddGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplBiasAddGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplGeLU(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| AbstractBasePtr InferImplGeLUGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| AbstractBasePtr InferImplFastGeLU(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| AbstractBasePtr InferImplFastGeLUGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| AbstractBasePtr InferImplRelu(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplRelu(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplZerosLike(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplZerosLike(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| @@ -153,10 +141,6 @@ AbstractBasePtr InferImplListLen(const AnalysisEnginePtr &, const PrimitivePtr & | |||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplArrayLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplArrayLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplGenShapeIndex(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| AbstractBasePtr InferImplGenInverseIndex(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| AbstractBasePtr InferImplIdentity(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplIdentity(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| @@ -174,8 +158,6 @@ AbstractBasePtr InferImplGetRefKey(const AnalysisEnginePtr &, const PrimitivePtr | |||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplGetRefValue(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplGetRefValue(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplGetRefOrigin(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| AbstractBasePtr InferImplStateSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplStateSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| @@ -302,8 +284,6 @@ AbstractBasePtr InferImplLess(const AnalysisEnginePtr &, const PrimitivePtr &pri | |||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplArgMaxWithValue(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplArgMaxWithValue(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplSparseSoftmaxCrossEntropyWithLogits(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| AbstractBasePtr InferImplLoad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplLoad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplAssign(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplAssign(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| @@ -140,23 +140,6 @@ void FusedBatchNormCheckDim(const PrimitivePtr &primitive, const AbstractBasePtr | |||||
| } | } | ||||
| } | } | ||||
| AbstractBasePtr InferImplSparseSoftmaxCrossEntropyWithLogits(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| auto is_grad = GetValue<bool>(primitive->GetAttr("is_grad")); | |||||
| CheckArgsSize(primitive->name(), args_spec_list, 2); | |||||
| std::shared_ptr<BaseShape> shape = std::make_shared<abstract::Shape>(std::vector<int64_t>{}); | |||||
| MS_EXCEPTION_IF_NULL(args_spec_list[0]); | |||||
| if (is_grad) { | |||||
| shape = args_spec_list[0]->BuildShape(); | |||||
| } | |||||
| auto type = args_spec_list[0]->BuildType(); | |||||
| MS_EXCEPTION_IF_NULL(type); | |||||
| auto type_tensor = type->cast<TensorTypePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(type_tensor); | |||||
| return std::make_shared<abstract::AbstractTensor>(type_tensor->element(), shape); | |||||
| } | |||||
| AbstractBasePtr InferImplBatchNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplBatchNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list) { | const AbstractBasePtrList &args_spec_list) { | ||||
| // Inputs: five tensors(x, gamma, beta, mean, variance). | // Inputs: five tensors(x, gamma, beta, mean, variance). | ||||
| @@ -30,9 +30,7 @@ namespace ops { | |||||
| namespace { | namespace { | ||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto abs_prim = primitive->cast<PrimAbsPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(abs_prim); | |||||
| auto prim_name = abs_prim->name(); | |||||
| auto prim_name = primitive->name(); | |||||
| for (const auto &item : input_args) { | for (const auto &item : input_args) { | ||||
| MS_EXCEPTION_IF_NULL(item); | MS_EXCEPTION_IF_NULL(item); | ||||
| } | } | ||||
| @@ -23,9 +23,7 @@ namespace ops { | |||||
| namespace { | namespace { | ||||
| abstract::AbstractBasePtr AdamInfer(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::AbstractBasePtr AdamInfer(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto Adam_prim = primitive->cast<PrimAdamPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(Adam_prim); | |||||
| auto prim_name = Adam_prim->name(); | |||||
| auto prim_name = primitive->name(); | |||||
| // infer shape | // infer shape | ||||
| auto var_shape = CheckAndConvertUtils::ConvertShapePtrToShape("var_shape", input_args[0]->GetShapeTrack(), prim_name); | auto var_shape = CheckAndConvertUtils::ConvertShapePtrToShape("var_shape", input_args[0]->GetShapeTrack(), prim_name); | ||||
| @@ -27,9 +27,7 @@ namespace ops { | |||||
| namespace { | namespace { | ||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto add_prim = primitive->cast<PrimAddPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(add_prim); | |||||
| auto prim_name = add_prim->name(); | |||||
| auto prim_name = primitive->name(); | |||||
| return BroadCastInferShape(prim_name, input_args); | return BroadCastInferShape(prim_name, input_args); | ||||
| } | } | ||||
| @@ -22,9 +22,7 @@ namespace { | |||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto prim_name = primitive->name(); | auto prim_name = primitive->name(); | ||||
| auto prim = primitive->cast<PrimArgMaxPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| auto axis = prim->get_axis(); | |||||
| auto axis = GetValue<int64_t>(primitive->GetAttr(kAxis)); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | ||||
| auto x_rank = SizeToLong(x_shape.size()); | auto x_rank = SizeToLong(x_shape.size()); | ||||
| CheckAndConvertUtils::CheckInRange<int64_t>("argmax axis", axis, kIncludeLeft, {-x_rank, x_rank}, prim_name); | CheckAndConvertUtils::CheckInRange<int64_t>("argmax axis", axis, kIncludeLeft, {-x_rank, x_rank}, prim_name); | ||||
| @@ -27,10 +27,7 @@ void ArgMin::Init(const int64_t axis, const TypeId output_type) { | |||||
| void ArgMin::set_axis(const int64_t axis) { this->AddAttr(kAxis, MakeValue(axis)); } | void ArgMin::set_axis(const int64_t axis) { this->AddAttr(kAxis, MakeValue(axis)); } | ||||
| void ArgMin::set_output_type(const TypeId output_type) { this->AddAttr(kOutputType, TypeIdToType(output_type)); } | void ArgMin::set_output_type(const TypeId output_type) { this->AddAttr(kOutputType, TypeIdToType(output_type)); } | ||||
| int64_t ArgMin::get_axis() const { | |||||
| auto value_ptr = GetAttr(kAxis); | |||||
| return GetValue<int64_t>(value_ptr); | |||||
| } | |||||
| int64_t ArgMin::get_axis() const { return GetValue<int64_t>(GetAttr(kAxis)); } | |||||
| TypeId ArgMin::get_output_type() const { | TypeId ArgMin::get_output_type() const { | ||||
| auto type_ptr = GetAttr(kOutputType)->cast<TensorTypePtr>()->element(); | auto type_ptr = GetAttr(kOutputType)->cast<TensorTypePtr>()->element(); | ||||
| @@ -40,13 +37,11 @@ TypeId ArgMin::get_output_type() const { | |||||
| AbstractBasePtr ArgMinInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr ArgMinInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto argmin_prim = primitive->cast<PrimArgMin>(); | |||||
| MS_EXCEPTION_IF_NULL(argmin_prim); | |||||
| auto prim_name = argmin_prim->name(); | |||||
| auto prim_name = primitive->name(); | |||||
| CheckAndConvertUtils::CheckInteger("arg_min_infer", input_args.size(), kEqual, 1, prim_name); | CheckAndConvertUtils::CheckInteger("arg_min_infer", input_args.size(), kEqual, 1, prim_name); | ||||
| // Infer shape | // Infer shape | ||||
| auto axis = argmin_prim->get_axis(); | |||||
| auto axis = GetValue<int64_t>(primitive->GetAttr(kAxis)); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | ||||
| auto x_rank = SizeToLong(x_shape.size()); | auto x_rank = SizeToLong(x_shape.size()); | ||||
| CheckAndConvertUtils::CheckInRange<int64_t>("axis", axis, kIncludeLeft, {-x_rank, x_rank}, prim_name); | CheckAndConvertUtils::CheckInRange<int64_t>("axis", axis, kIncludeLeft, {-x_rank, x_rank}, prim_name); | ||||
| @@ -25,9 +25,7 @@ namespace ops { | |||||
| AbstractBasePtr AsinInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr AsinInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto asin_prim = primitive->cast<PrimAsinPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(asin_prim); | |||||
| auto prim_name = asin_prim->name(); | |||||
| auto prim_name = primitive->name(); | |||||
| CheckAndConvertUtils::CheckInteger("Asin_infer", input_args.size(), kEqual, 1, prim_name); | CheckAndConvertUtils::CheckInteger("Asin_infer", input_args.size(), kEqual, 1, prim_name); | ||||
| // Infer Shape | // Infer Shape | ||||
| @@ -37,9 +37,7 @@ int64_t Assert::get_summarize() const { | |||||
| AbstractBasePtr AssertInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr AssertInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto Assert_prim = primitive->cast<PrimAssertPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(Assert_prim); | |||||
| auto op_name = Assert_prim->name(); | |||||
| auto op_name = primitive->name(); | |||||
| TypePtr condition; | TypePtr condition; | ||||
| if (!(input_args[0]->BuildType()->type_id() == kObjectTypeTensorType)) { | if (!(input_args[0]->BuildType()->type_id() == kObjectTypeTensorType)) { | ||||
| auto condition_value = GetValue<std::vector<bool>>(input_args[0]->BuildValue()); | auto condition_value = GetValue<std::vector<bool>>(input_args[0]->BuildValue()); | ||||
| @@ -25,9 +25,7 @@ namespace ops { | |||||
| namespace { | namespace { | ||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto assignadd_prim = primitive->cast<PrimAssignAddPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(assignadd_prim); | |||||
| auto prim_name = assignadd_prim->name(); | |||||
| auto prim_name = primitive->name(); | |||||
| auto value_shape = | auto value_shape = | ||||
| CheckAndConvertUtils::ConvertShapePtrToShape("value_shape", input_args[1]->BuildShape(), prim_name); | CheckAndConvertUtils::ConvertShapePtrToShape("value_shape", input_args[1]->BuildShape(), prim_name); | ||||
| return std::make_shared<abstract::Shape>(value_shape); | return std::make_shared<abstract::Shape>(value_shape); | ||||
| @@ -23,9 +23,7 @@ namespace ops { | |||||
| AbstractBasePtr AtanInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr AtanInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto atan_prim = primitive->cast<PrimAtanPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(atan_prim); | |||||
| auto prim_name = atan_prim->name(); | |||||
| auto prim_name = primitive->name(); | |||||
| CheckAndConvertUtils::CheckInteger("Atan_infer", input_args.size(), kEqual, 1, prim_name); | CheckAndConvertUtils::CheckInteger("Atan_infer", input_args.size(), kEqual, 1, prim_name); | ||||
| // Infer Shape | // Infer Shape | ||||
| @@ -30,25 +30,25 @@ namespace { | |||||
| abstract::ShapePtr AudioSpectrogramInferShape(const PrimitivePtr &primitive, | abstract::ShapePtr AudioSpectrogramInferShape(const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto audio_spectrogram_prim = primitive->cast<PrimAudioSpectrogramPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(audio_spectrogram_prim); | |||||
| auto prim_name = audio_spectrogram_prim->name(); | |||||
| auto prim_name = primitive->name(); | |||||
| auto input_shape = | auto input_shape = | ||||
| CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name); | CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name); | ||||
| if (input_shape.size() != 2) { | if (input_shape.size() != 2) { | ||||
| MS_LOG(ERROR) << "input shape is error, which need to be 2 dimensions"; | MS_LOG(ERROR) << "input shape is error, which need to be 2 dimensions"; | ||||
| } | } | ||||
| if (audio_spectrogram_prim->get_window_size() < 2) { | |||||
| MS_LOG(ERROR) << "window size is too short, now is " << audio_spectrogram_prim->get_window_size(); | |||||
| auto window_size = GetValue<int64_t>(primitive->GetAttr(kWindowSize)); | |||||
| if (window_size < 2) { | |||||
| MS_LOG(ERROR) << "window size is too short, now is " << window_size; | |||||
| } | } | ||||
| if (audio_spectrogram_prim->get_stride() < 1) { | |||||
| MS_LOG(ERROR) << "stride must be positive, now is " << audio_spectrogram_prim->get_stride(); | |||||
| auto stride_size = GetValue<int64_t>(primitive->GetAttr(kStride)); | |||||
| if (stride_size < 1) { | |||||
| MS_LOG(ERROR) << "stride must be positive, now is " << stride_size; | |||||
| } | } | ||||
| std::vector<int64_t> infer_shape; | std::vector<int64_t> infer_shape; | ||||
| infer_shape.push_back(input_shape[1]); | infer_shape.push_back(input_shape[1]); | ||||
| int64_t sample_sub_window = input_shape[0] - audio_spectrogram_prim->get_window_size(); | |||||
| infer_shape.push_back(sample_sub_window < 0 ? 0 : 1 + sample_sub_window / audio_spectrogram_prim->get_stride()); | |||||
| int64_t fft_length = audio_spectrogram_prim->GetFftLength(audio_spectrogram_prim->get_window_size()); | |||||
| int64_t sample_sub_window = input_shape[0] - window_size; | |||||
| infer_shape.push_back(sample_sub_window < 0 ? 0 : 1 + sample_sub_window / stride_size); | |||||
| int64_t fft_length = GetFftLength(window_size); | |||||
| infer_shape.push_back(fft_length / 2 + 1); | infer_shape.push_back(fft_length / 2 + 1); | ||||
| MS_LOG(ERROR) << infer_shape; | MS_LOG(ERROR) << infer_shape; | ||||
| return std::make_shared<abstract::Shape>(infer_shape); | return std::make_shared<abstract::Shape>(infer_shape); | ||||
| @@ -81,7 +81,7 @@ int64_t AudioSpectrogram::get_stride() const { | |||||
| return GetValue<int64_t>(value_ptr); | return GetValue<int64_t>(value_ptr); | ||||
| } | } | ||||
| int64_t AudioSpectrogram::Log2Ceil(int64_t length) { | |||||
| int64_t Log2Ceil(int64_t length) { | |||||
| if (length == 0) { | if (length == 0) { | ||||
| return -1; | return -1; | ||||
| } | } | ||||
| @@ -97,7 +97,7 @@ int64_t AudioSpectrogram::Log2Ceil(int64_t length) { | |||||
| return length == (length & ~(unsigned int)(length - 1)) ? floor : floor + 1; | return length == (length & ~(unsigned int)(length - 1)) ? floor : floor + 1; | ||||
| } | } | ||||
| int64_t AudioSpectrogram::GetFftLength(int64_t length) { | |||||
| int64_t GetFftLength(int64_t length) { | |||||
| int64_t shift = Log2Ceil(length); | int64_t shift = Log2Ceil(length); | ||||
| return 1 << (unsigned int)shift; | return 1 << (unsigned int)shift; | ||||
| } | } | ||||
| @@ -27,6 +27,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ops { | namespace ops { | ||||
| constexpr auto kNameAudioSpectrogram = "AudioSpectrogram"; | constexpr auto kNameAudioSpectrogram = "AudioSpectrogram"; | ||||
| int64_t Log2Ceil(int64_t length); | |||||
| int64_t GetFftLength(int64_t length); | |||||
| class AudioSpectrogram : public PrimitiveC { | class AudioSpectrogram : public PrimitiveC { | ||||
| public: | public: | ||||
| AudioSpectrogram() : PrimitiveC(kNameAudioSpectrogram) {} | AudioSpectrogram() : PrimitiveC(kNameAudioSpectrogram) {} | ||||
| @@ -39,8 +41,6 @@ class AudioSpectrogram : public PrimitiveC { | |||||
| int64_t get_window_size() const; | int64_t get_window_size() const; | ||||
| int64_t get_stride() const; | int64_t get_stride() const; | ||||
| bool get_mag_square() const; | bool get_mag_square() const; | ||||
| int64_t Log2Ceil(int64_t length); | |||||
| int64_t GetFftLength(int64_t length); | |||||
| }; | }; | ||||
| AbstractBasePtr AudioSpectrogramInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr AudioSpectrogramInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args); | const std::vector<AbstractBasePtr> &input_args); | ||||
| @@ -31,37 +31,25 @@ void AvgPool::set_pad_mode(const PadMode &pad_mode) { | |||||
| this->AddAttr(kPadMode, MakeValue(swi)); | this->AddAttr(kPadMode, MakeValue(swi)); | ||||
| } | } | ||||
| PadMode AvgPool::get_pad_mode() const { | |||||
| auto value_ptr = GetAttr(kPadMode); | |||||
| return PadMode(GetValue<int64_t>(value_ptr)); | |||||
| } | |||||
| PadMode AvgPool::get_pad_mode() const { return PadMode(GetValue<int64_t>(GetAttr(kPadMode))); } | |||||
| void AvgPool::set_kernel_size(const std::vector<int64_t> &kernel_size) { | void AvgPool::set_kernel_size(const std::vector<int64_t> &kernel_size) { | ||||
| this->AddAttr(kKernelSize, | this->AddAttr(kKernelSize, | ||||
| MakeValue(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, this->name()))); | MakeValue(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, this->name()))); | ||||
| } | } | ||||
| std::vector<int64_t> AvgPool::get_kernel_size() const { | |||||
| auto value_ptr = GetAttr(kKernelSize); | |||||
| return GetValue<std::vector<int64_t>>(value_ptr); | |||||
| } | |||||
| std::vector<int64_t> AvgPool::get_kernel_size() const { return GetValue<std::vector<int64_t>>(GetAttr(kKernelSize)); } | |||||
| void AvgPool::set_strides(const std::vector<int64_t> &strides) { | void AvgPool::set_strides(const std::vector<int64_t> &strides) { | ||||
| this->AddAttr(kStrides, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStrides, strides, this->name()))); | this->AddAttr(kStrides, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStrides, strides, this->name()))); | ||||
| } | } | ||||
| std::vector<int64_t> AvgPool::get_strides() const { | |||||
| auto value_ptr = GetAttr(kStrides); | |||||
| return GetValue<std::vector<int64_t>>(value_ptr); | |||||
| } | |||||
| std::vector<int64_t> AvgPool::get_strides() const { return GetValue<std::vector<int64_t>>(GetAttr(kStrides)); } | |||||
| void AvgPool::set_format(const Format &format) { | void AvgPool::set_format(const Format &format) { | ||||
| int64_t f = format; | int64_t f = format; | ||||
| this->AddAttr(kFormat, MakeValue(f)); | this->AddAttr(kFormat, MakeValue(f)); | ||||
| } | } | ||||
| Format AvgPool::get_format() const { | |||||
| auto value_ptr = GetAttr(kFormat); | |||||
| return Format(GetValue<int64_t>(value_ptr)); | |||||
| } | |||||
| Format AvgPool::get_format() const { return Format(GetValue<int64_t>(GetAttr(kFormat))); } | |||||
| void AvgPool::set_pad(const std::vector<int64_t> &pad) { this->AddAttr(kPad, MakeValue(pad)); } | void AvgPool::set_pad(const std::vector<int64_t> &pad) { this->AddAttr(kPad, MakeValue(pad)); } | ||||
| @@ -93,22 +81,20 @@ void AvgPool::Init(const std::vector<int64_t> &kernel_size, const std::vector<in | |||||
| namespace { | namespace { | ||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto pool_prim = primitive->cast<PrimAvgPoolPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(pool_prim); | |||||
| auto op_name = pool_prim->name(); | |||||
| auto op_name = primitive->name(); | |||||
| auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), op_name); | auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), op_name); | ||||
| if (pool_prim->get_format() == NHWC) { | |||||
| auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat))); | |||||
| if (format == NHWC) { | |||||
| in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]}; | in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]}; | ||||
| } | } | ||||
| CheckAndConvertUtils::CheckInteger("x_rank", in_shape.size(), kEqual, 4, op_name); | CheckAndConvertUtils::CheckInteger("x_rank", in_shape.size(), kEqual, 4, op_name); | ||||
| auto kernel_size = pool_prim->get_kernel_size(); | |||||
| auto pad_mode = pool_prim->get_pad_mode(); | |||||
| auto kernel_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize)); | |||||
| auto pad_mode = PadMode(GetValue<int64_t>(primitive->GetAttr(kPadMode))); | |||||
| auto batch = in_shape[0]; | auto batch = in_shape[0]; | ||||
| auto channel = in_shape[1]; | auto channel = in_shape[1]; | ||||
| auto in_h = in_shape[2]; | auto in_h = in_shape[2]; | ||||
| auto in_w = in_shape[3]; | auto in_w = in_shape[3]; | ||||
| auto strides = pool_prim->get_strides(); | |||||
| auto strides = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStrides)); | |||||
| auto kernel_h = kernel_size[2]; | auto kernel_h = kernel_size[2]; | ||||
| auto kernel_w = kernel_size[3]; | auto kernel_w = kernel_size[3]; | ||||
| auto stride_h = strides[2]; | auto stride_h = strides[2]; | ||||
| @@ -123,7 +109,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A | |||||
| out_w = ceil(in_w / stride_w); | out_w = ceil(in_w / stride_w); | ||||
| } | } | ||||
| std::vector<int64_t> out_shape = {batch, channel, out_h, out_w}; | std::vector<int64_t> out_shape = {batch, channel, out_h, out_w}; | ||||
| if (pool_prim->get_format() == NHWC) { | |||||
| if (format == NHWC) { | |||||
| out_shape = {batch, out_h, out_w, channel}; | out_shape = {batch, out_h, out_w, channel}; | ||||
| } | } | ||||
| if (std::any_of(out_shape.begin(), out_shape.end(), [](int64_t a) { return a <= 0; })) { | if (std::any_of(out_shape.begin(), out_shape.end(), [](int64_t a) { return a <= 0; })) { | ||||
| @@ -72,13 +72,12 @@ AbstractBasePtr BatchNormInfer(const abstract::AnalysisEnginePtr &, const Primit | |||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| // Infer shape | // Infer shape | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto batch_prim = primitive->cast<PrimBatchNormPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(batch_prim); | |||||
| auto prim_name = batch_prim->name(); | |||||
| auto prim_name = primitive->name(); | |||||
| CheckAndConvertUtils::CheckInteger("batch_norm_infer", input_args.size(), kEqual, 5, prim_name); | CheckAndConvertUtils::CheckInteger("batch_norm_infer", input_args.size(), kEqual, 5, prim_name); | ||||
| auto input_x = CheckAndConvertUtils::ConvertShapePtrToShape("input_x", input_args[0]->BuildShape(), prim_name); | auto input_x = CheckAndConvertUtils::ConvertShapePtrToShape("input_x", input_args[0]->BuildShape(), prim_name); | ||||
| if (batch_prim->get_format() == NHWC) { | |||||
| auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat))); | |||||
| if (format == NHWC) { | |||||
| input_x = {input_x[0], input_x[3], input_x[1], input_x[2]}; | input_x = {input_x[0], input_x[3], input_x[1], input_x[2]}; | ||||
| } | } | ||||
| auto scale = CheckAndConvertUtils::ConvertShapePtrToShape("scale", input_args[1]->BuildShape(), prim_name); | auto scale = CheckAndConvertUtils::ConvertShapePtrToShape("scale", input_args[1]->BuildShape(), prim_name); | ||||
| @@ -87,7 +86,7 @@ AbstractBasePtr BatchNormInfer(const abstract::AnalysisEnginePtr &, const Primit | |||||
| auto variance = CheckAndConvertUtils::ConvertShapePtrToShape("variance", input_args[4]->BuildShape(), prim_name); | auto variance = CheckAndConvertUtils::ConvertShapePtrToShape("variance", input_args[4]->BuildShape(), prim_name); | ||||
| std::vector<int64_t> input_shape_norm; | std::vector<int64_t> input_shape_norm; | ||||
| if (batch_prim->get_format() == NCHW) { | |||||
| if (format == NCHW) { | |||||
| input_shape_norm = | input_shape_norm = | ||||
| CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), prim_name); | CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), prim_name); | ||||
| } else { | } else { | ||||
| @@ -100,7 +99,8 @@ AbstractBasePtr BatchNormInfer(const abstract::AnalysisEnginePtr &, const Primit | |||||
| CheckAndConvertUtils::Check("scale shape", scale, kEqual, "bias shape", bias, prim_name, TypeError); | CheckAndConvertUtils::Check("scale shape", scale, kEqual, "bias shape", bias, prim_name, TypeError); | ||||
| CheckAndConvertUtils::Check("scale shape[0]", scale[0], kEqual, "input_x channel", input_shape_norm[1], prim_name, | CheckAndConvertUtils::Check("scale shape[0]", scale[0], kEqual, "input_x channel", input_shape_norm[1], prim_name, | ||||
| TypeError); | TypeError); | ||||
| if (!batch_prim->get_is_training()) { | |||||
| if (!GetValue<bool>(primitive->GetAttr(kIsTraining))) { | |||||
| CheckAndConvertUtils::CheckInteger("mean rank", mean.size(), kEqual, 1, prim_name); | CheckAndConvertUtils::CheckInteger("mean rank", mean.size(), kEqual, 1, prim_name); | ||||
| CheckAndConvertUtils::Check("mean shape", mean, kEqual, "variance shape", variance, prim_name, TypeError); | CheckAndConvertUtils::Check("mean shape", mean, kEqual, "variance shape", variance, prim_name, TypeError); | ||||
| CheckAndConvertUtils::Check("mean shape", mean, kEqual, "scale shape", scale, prim_name, TypeError); | CheckAndConvertUtils::Check("mean shape", mean, kEqual, "scale shape", scale, prim_name, TypeError); | ||||
| @@ -126,7 +126,7 @@ AbstractBasePtr BatchNormInfer(const abstract::AnalysisEnginePtr &, const Primit | |||||
| auto output1 = std::make_shared<abstract::AbstractTensor>(scale_type, scale); | auto output1 = std::make_shared<abstract::AbstractTensor>(scale_type, scale); | ||||
| auto output2 = std::make_shared<abstract::AbstractTensor>(bias_type, scale); | auto output2 = std::make_shared<abstract::AbstractTensor>(bias_type, scale); | ||||
| auto output3 = std::make_shared<abstract::AbstractTensor>(input_x_type, scale); | auto output3 = std::make_shared<abstract::AbstractTensor>(input_x_type, scale); | ||||
| if (batch_prim->get_format() == NHWC) { | |||||
| if (format == NHWC) { | |||||
| output2 = std::make_shared<abstract::AbstractTensor>(scale_type, scale); | output2 = std::make_shared<abstract::AbstractTensor>(scale_type, scale); | ||||
| output3 = std::make_shared<abstract::AbstractTensor>(bias_type, scale); | output3 = std::make_shared<abstract::AbstractTensor>(bias_type, scale); | ||||
| output1 = std::make_shared<abstract::AbstractTensor>(input_x_type, scale); | output1 = std::make_shared<abstract::AbstractTensor>(input_x_type, scale); | ||||
| @@ -67,9 +67,7 @@ int64_t BatchNormFold::get_freeze_bn() const { | |||||
| AbstractBasePtr BatchNormFoldInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr BatchNormFoldInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto BatchNormFold_prim = primitive->cast<PrimBatchNormFoldPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(BatchNormFold_prim); | |||||
| auto op_name = BatchNormFold_prim->name(); | |||||
| auto op_name = primitive->name(); | |||||
| auto mean_shape = CheckAndConvertUtils::ConvertShapePtrToShape("mean_shape", input_args[1]->BuildShape(), op_name); | auto mean_shape = CheckAndConvertUtils::ConvertShapePtrToShape("mean_shape", input_args[1]->BuildShape(), op_name); | ||||
| auto variance_shape = | auto variance_shape = | ||||
| CheckAndConvertUtils::ConvertShapePtrToShape("variance_shape", input_args[2]->BuildShape(), op_name); | CheckAndConvertUtils::ConvertShapePtrToShape("variance_shape", input_args[2]->BuildShape(), op_name); | ||||
| @@ -47,9 +47,7 @@ std::vector<std::vector<int64_t>> BatchToSpace::get_crops() const { | |||||
| AbstractBasePtr BatchToSpaceInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr BatchToSpaceInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto prim = primitive->cast<PrimBatchToSpacePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| auto prim_name = prim->name(); | |||||
| auto prim_name = primitive->name(); | |||||
| CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, prim_name); | CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, prim_name); | ||||
| for (const auto &item : input_args) { | for (const auto &item : input_args) { | ||||
| MS_EXCEPTION_IF_NULL(item); | MS_EXCEPTION_IF_NULL(item); | ||||
| @@ -59,8 +57,8 @@ AbstractBasePtr BatchToSpaceInfer(const abstract::AnalysisEnginePtr &, const Pri | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | ||||
| CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name); | CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name); | ||||
| auto block_size = prim->get_block_size(); | |||||
| auto crops = prim->get_crops(); | |||||
| auto block_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kBlockSize)); | |||||
| auto crops = GetValue<std::vector<std::vector<int64_t>>>(primitive->GetAttr(kCrops)); | |||||
| auto out_shape = x_shape; | auto out_shape = x_shape; | ||||
| for (size_t i = 0; i < 2; ++i) { | for (size_t i = 0; i < 2; ++i) { | ||||
| auto x_block_prod = out_shape[i + 2] * block_size[i]; | auto x_block_prod = out_shape[i + 2] * block_size[i]; | ||||
| @@ -28,16 +28,14 @@ namespace ops { | |||||
| namespace { | namespace { | ||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto batch_prim = primitive->cast<PrimBatchToSpaceNDPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(batch_prim); | |||||
| auto prim_name = batch_prim->name(); | |||||
| auto prim_name = primitive->name(); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | ||||
| CheckAndConvertUtils::CheckInteger("input_x rank", x_shape.size(), kEqual, 4, prim_name); | CheckAndConvertUtils::CheckInteger("input_x rank", x_shape.size(), kEqual, 4, prim_name); | ||||
| auto out_shape = x_shape; | auto out_shape = x_shape; | ||||
| int64_t block_shape_prod = 1; | int64_t block_shape_prod = 1; | ||||
| int64_t offset = 2; | int64_t offset = 2; | ||||
| auto block_shape = batch_prim->get_block_shape(); | |||||
| auto crops = batch_prim->get_crops(); | |||||
| auto block_shape = GetValue<std::vector<int64_t>>(primitive->GetAttr(kBlockShape)); | |||||
| auto crops = GetValue<std::vector<std::vector<int64_t>>>(primitive->GetAttr(kCrops)); | |||||
| int64_t size = block_shape.size(); | int64_t size = block_shape.size(); | ||||
| for (int64_t i = 0; i < size; i++) { | for (int64_t i = 0; i < size; i++) { | ||||
| block_shape_prod = block_shape_prod * block_shape[i]; | block_shape_prod = block_shape_prod * block_shape[i]; | ||||
| @@ -32,9 +32,7 @@ namespace { | |||||
| abstract::ShapePtr BinaryCrossEntroyInferShape(const PrimitivePtr &primitive, | abstract::ShapePtr BinaryCrossEntroyInferShape(const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto binary_cross_entropy_prim = primitive->cast<PrimBinaryCrossEntropyPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(binary_cross_entropy_prim); | |||||
| auto prim_name = binary_cross_entropy_prim->name(); | |||||
| auto prim_name = primitive->name(); | |||||
| CheckAndConvertUtils::CheckInRange("binary_cross_entropy_infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name); | CheckAndConvertUtils::CheckInRange("binary_cross_entropy_infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name); | ||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | ||||
| auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShape("y_shape", input_args[1]->BuildShape(), prim_name); | auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShape("y_shape", input_args[1]->BuildShape(), prim_name); | ||||
| @@ -45,8 +43,8 @@ abstract::ShapePtr BinaryCrossEntroyInferShape(const PrimitivePtr &primitive, | |||||
| if (weight_shape.size() < 1) { | if (weight_shape.size() < 1) { | ||||
| CheckAndConvertUtils::Check("x shape", y_shape, kEqual, "weight shape", weight_shape, prim_name); | CheckAndConvertUtils::Check("x shape", y_shape, kEqual, "weight shape", weight_shape, prim_name); | ||||
| } | } | ||||
| if (binary_cross_entropy_prim->get_reduction() != REDUCTION_SUM && | |||||
| binary_cross_entropy_prim->get_reduction() != MEAN) { | |||||
| auto reduction = Reduction(GetValue<int64_t>(primitive->GetAttr(kReduction))); | |||||
| if (reduction != REDUCTION_SUM && reduction != MEAN) { | |||||
| infer_shape = {x_shape.begin(), infer_shape.end()}; | infer_shape = {x_shape.begin(), infer_shape.end()}; | ||||
| } | } | ||||
| return std::make_shared<abstract::Shape>(infer_shape); | return std::make_shared<abstract::Shape>(infer_shape); | ||||
| @@ -45,9 +45,7 @@ std::string Broadcast::get_group() const { | |||||
| AbstractBasePtr BroadcastInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr BroadcastInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto broadcast_prim = primitive->cast<PrimBroadcast>(); | |||||
| MS_EXCEPTION_IF_NULL(broadcast_prim); | |||||
| auto prim_name = broadcast_prim->name(); | |||||
| auto prim_name = primitive->name(); | |||||
| for (const auto &item : input_args) { | for (const auto &item : input_args) { | ||||
| MS_EXCEPTION_IF_NULL(item); | MS_EXCEPTION_IF_NULL(item); | ||||
| } | } | ||||
| @@ -32,9 +32,7 @@ void Concat::set_axis(const int64_t axis) { this->AddAttr(kAxis, MakeValue(axis) | |||||
| AbstractBasePtr ConcatInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr ConcatInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto prim = primitive->cast<PrimConcatPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| auto prim_name = prim->name(); | |||||
| auto prim_name = primitive->name(); | |||||
| CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, prim_name); | CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, prim_name); | ||||
| for (const auto &item : input_args) { | for (const auto &item : input_args) { | ||||
| MS_EXCEPTION_IF_NULL(item); | MS_EXCEPTION_IF_NULL(item); | ||||
| @@ -48,7 +46,7 @@ AbstractBasePtr ConcatInfer(const abstract::AnalysisEnginePtr &, const Primitive | |||||
| auto element0_shape = | auto element0_shape = | ||||
| CheckAndConvertUtils::ConvertShapePtrToShape("element0 shape", element0->BuildShape(), prim_name); | CheckAndConvertUtils::ConvertShapePtrToShape("element0 shape", element0->BuildShape(), prim_name); | ||||
| auto element0_rank = SizeToLong(element0_shape.size()); | auto element0_rank = SizeToLong(element0_shape.size()); | ||||
| auto axis = prim->get_axis(); | |||||
| auto axis = GetValue<int64_t>(primitive->GetAttr(kAxis)); | |||||
| CheckAndConvertUtils::CheckInRange<int64_t>("Concat axis", axis, kIncludeBoth, {-element0_rank - 1, element0_rank}, | CheckAndConvertUtils::CheckInRange<int64_t>("Concat axis", axis, kIncludeBoth, {-element0_rank - 1, element0_rank}, | ||||
| prim_name); | prim_name); | ||||
| axis = axis < 0 ? axis + element0_rank : axis; | axis = axis < 0 ? axis + element0_rank : axis; | ||||
| @@ -31,9 +31,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A | |||||
| TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto constant_prim = primitive->cast<PrimConstantOfShapePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(constant_prim); | |||||
| auto data_type = TypeId(constant_prim->get_data_type()); | |||||
| auto data_type = TypeId(GetValue<int64_t>(primitive->GetAttr(kDataType))); | |||||
| return TypeIdToType(data_type); | return TypeIdToType(data_type); | ||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| @@ -28,9 +28,7 @@ namespace { | |||||
| abstract::ShapePtr Conv2dTransposeInferShape(const PrimitivePtr &primitive, | abstract::ShapePtr Conv2dTransposeInferShape(const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto conv2d_transpose_prim = primitive->cast<PrimConv2dTransposePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(conv2d_transpose_prim); | |||||
| auto prim_name = conv2d_transpose_prim->name(); | |||||
| auto prim_name = primitive->name(); | |||||
| auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[3]->BuildShape(), prim_name); | auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[3]->BuildShape(), prim_name); | ||||
| return std::make_shared<abstract::Shape>(input_shape); | return std::make_shared<abstract::Shape>(input_shape); | ||||
| } | } | ||||
| @@ -24,9 +24,7 @@ namespace ops { | |||||
| namespace { | namespace { | ||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto cos_prim = primitive->cast<PrimCos>(); | |||||
| MS_EXCEPTION_IF_NULL(cos_prim); | |||||
| auto prim_name = cos_prim->name(); | |||||
| auto prim_name = primitive->name(); | |||||
| for (const auto &item : input_args) { | for (const auto &item : input_args) { | ||||
| MS_EXCEPTION_IF_NULL(item); | MS_EXCEPTION_IF_NULL(item); | ||||
| } | } | ||||
| @@ -43,9 +43,7 @@ std::vector<int64_t> Crop::get_offsets() const { | |||||
| AbstractBasePtr CropInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr CropInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto crop_prim = primitive->cast<PrimCrop>(); | |||||
| MS_EXCEPTION_IF_NULL(crop_prim); | |||||
| auto prim_name = crop_prim->name(); | |||||
| auto prim_name = primitive->name(); | |||||
| CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 2, prim_name); | CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 2, prim_name); | ||||
| for (const auto &item : input_args) { | for (const auto &item : input_args) { | ||||
| MS_EXCEPTION_IF_NULL(item); | MS_EXCEPTION_IF_NULL(item); | ||||
| @@ -24,9 +24,7 @@ namespace ops { | |||||
| AbstractBasePtr CustomExtractFeaturesInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr CustomExtractFeaturesInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto extract_prim = primitive->cast<PrimCustomExtractFeaturesPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(extract_prim); | |||||
| auto prim_name = extract_prim->name(); | |||||
| auto prim_name = primitive->name(); | |||||
| MS_EXCEPTION_IF_NULL(input_args[0]); | MS_EXCEPTION_IF_NULL(input_args[0]); | ||||
| // auto input = input_args[0]; | // auto input = input_args[0]; | ||||
| @@ -24,13 +24,8 @@ namespace { | |||||
| abstract::ShapePtr CustomNormalizeInferShape(const PrimitivePtr &primitive, | abstract::ShapePtr CustomNormalizeInferShape(const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto custom_normalize_prim = primitive->cast<PrimCustomNormalizePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(custom_normalize_prim); | |||||
| // auto prim_name = custom_normalize_prim->name(); | |||||
| MS_EXCEPTION_IF_NULL(input_args[0]); | MS_EXCEPTION_IF_NULL(input_args[0]); | ||||
| MS_EXCEPTION_IF_NULL(input_args[0]->BuildShape()); | MS_EXCEPTION_IF_NULL(input_args[0]->BuildShape()); | ||||
| // auto input_shape = | |||||
| // CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name); | |||||
| if (input_args[0]->BuildValue()->cast<tensor::TensorPtr>()->data_c() == nullptr) { | if (input_args[0]->BuildValue()->cast<tensor::TensorPtr>()->data_c() == nullptr) { | ||||
| MS_LOG(ERROR) << "Do infer shape in runtime."; | MS_LOG(ERROR) << "Do infer shape in runtime."; | ||||
| } | } | ||||
| @@ -45,13 +45,11 @@ float CustomPredict::get_weight_threshold() const { | |||||
| AbstractBasePtr CustomPredictInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr CustomPredictInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto CustomPredict_prim = primitive->cast<PrimCustomPredictPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(CustomPredict_prim); | |||||
| for (const auto &input : input_args) { | for (const auto &input : input_args) { | ||||
| MS_EXCEPTION_IF_NULL(input); | MS_EXCEPTION_IF_NULL(input); | ||||
| } | } | ||||
| std::vector<int64_t> shape; | std::vector<int64_t> shape; | ||||
| shape.push_back(CustomPredict_prim->get_output_num()); | |||||
| shape.push_back(GetValue<int64_t>(primitive->GetAttr(kOutputNum))); | |||||
| auto output0 = std::make_shared<abstract::AbstractTensor>(kInt32, shape); | auto output0 = std::make_shared<abstract::AbstractTensor>(kInt32, shape); | ||||
| auto output1 = std::make_shared<abstract::AbstractTensor>(kFloat32, shape); | auto output1 = std::make_shared<abstract::AbstractTensor>(kFloat32, shape); | ||||
| @@ -30,19 +30,13 @@ void DepthToSpace::set_block_size(const int64_t block_size) { | |||||
| this->AddAttr(kBlockSize, MakeValue(block_size)); | this->AddAttr(kBlockSize, MakeValue(block_size)); | ||||
| } | } | ||||
| int64_t DepthToSpace::get_block_size() const { | |||||
| auto value_ptr = GetAttr(kBlockSize); | |||||
| return GetValue<int64_t>(value_ptr); | |||||
| } | |||||
| int64_t DepthToSpace::get_block_size() const { return GetValue<int64_t>(GetAttr(kBlockSize)); } | |||||
| void DepthToSpace::set_format(const Format &format) { | void DepthToSpace::set_format(const Format &format) { | ||||
| int64_t f = format; | int64_t f = format; | ||||
| this->AddAttr(kFormat, MakeValue(f)); | this->AddAttr(kFormat, MakeValue(f)); | ||||
| } | } | ||||
| Format DepthToSpace::get_format() const { | |||||
| auto value_ptr = GetAttr(kFormat); | |||||
| return Format(GetValue<int64_t>(value_ptr)); | |||||
| } | |||||
| Format DepthToSpace::get_format() const { return Format(GetValue<int64_t>(GetAttr(kFormat))); } | |||||
| void DepthToSpace::Init(const int64_t block_size, const Format &format) { | void DepthToSpace::Init(const int64_t block_size, const Format &format) { | ||||
| this->set_block_size(block_size); | this->set_block_size(block_size); | ||||
| @@ -52,9 +46,7 @@ void DepthToSpace::Init(const int64_t block_size, const Format &format) { | |||||
| AbstractBasePtr DepthToSpaceInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr DepthToSpaceInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto prim = primitive->cast<PrimDepthToSpacePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| auto prim_name = prim->name(); | |||||
| auto prim_name = primitive->name(); | |||||
| CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, prim_name); | CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, prim_name); | ||||
| for (const auto &item : input_args) { | for (const auto &item : input_args) { | ||||
| MS_EXCEPTION_IF_NULL(item); | MS_EXCEPTION_IF_NULL(item); | ||||
| @@ -63,18 +55,19 @@ AbstractBasePtr DepthToSpaceInfer(const abstract::AnalysisEnginePtr &, const Pri | |||||
| MS_EXCEPTION_IF_NULL(input_x); | MS_EXCEPTION_IF_NULL(input_x); | ||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | ||||
| if (prim->get_format() == NHWC) { | |||||
| auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat))); | |||||
| if (format == NHWC) { | |||||
| x_shape = {x_shape[0], x_shape[3], x_shape[1], x_shape[2]}; | x_shape = {x_shape[0], x_shape[3], x_shape[1], x_shape[2]}; | ||||
| } | } | ||||
| CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name); | CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name); | ||||
| int64_t block_size = prim->get_block_size(); | |||||
| int64_t block_size = GetValue<int64_t>(primitive->GetAttr(kBlockSize)); | |||||
| CheckAndConvertUtils::CheckInteger("x_shape[1] % (block_size*block_size)", x_shape[1] % (block_size * block_size), | CheckAndConvertUtils::CheckInteger("x_shape[1] % (block_size*block_size)", x_shape[1] % (block_size * block_size), | ||||
| kEqual, 0, prim_name); | kEqual, 0, prim_name); | ||||
| auto out_shape = x_shape; | auto out_shape = x_shape; | ||||
| out_shape[1] /= block_size * block_size; | out_shape[1] /= block_size * block_size; | ||||
| out_shape[2] *= block_size; | out_shape[2] *= block_size; | ||||
| out_shape[3] *= block_size; | out_shape[3] *= block_size; | ||||
| if (prim->get_format() == NHWC) { | |||||
| if (format == NHWC) { | |||||
| out_shape = {out_shape[0], out_shape[2], out_shape[3], out_shape[1]}; | out_shape = {out_shape[0], out_shape[2], out_shape[3], out_shape[1]}; | ||||
| } | } | ||||
| auto ret = input_x->Broaden(); | auto ret = input_x->Broaden(); | ||||
| @@ -65,25 +65,14 @@ void DepthWiseConv2D::Init(const int64_t channel_multiplier, const std::vector<i | |||||
| } | } | ||||
| std::vector<int64_t> DepthWiseConv2D::get_kernel_size() const { | std::vector<int64_t> DepthWiseConv2D::get_kernel_size() const { | ||||
| auto value_ptr = GetAttr(kKernelSize); | |||||
| return GetValue<std::vector<int64_t>>(value_ptr); | |||||
| } | |||||
| std::vector<int64_t> DepthWiseConv2D::get_stride() const { | |||||
| auto value_ptr = GetAttr(kStride); | |||||
| return GetValue<std::vector<int64_t>>(value_ptr); | |||||
| return GetValue<std::vector<int64_t>>(GetAttr(kKernelSize)); | |||||
| } | } | ||||
| std::vector<int64_t> DepthWiseConv2D::get_stride() const { return GetValue<std::vector<int64_t>>(GetAttr(kStride)); } | |||||
| std::vector<int64_t> DepthWiseConv2D::get_dilation() const { | std::vector<int64_t> DepthWiseConv2D::get_dilation() const { | ||||
| auto value_ptr = GetAttr(kDilation); | |||||
| return GetValue<std::vector<int64_t>>(value_ptr); | |||||
| } | |||||
| PadMode DepthWiseConv2D::get_pad_mode() const { | |||||
| auto value_ptr = this->GetAttr(kPadMode); | |||||
| return PadMode(GetValue<int64_t>(value_ptr)); | |||||
| } | |||||
| std::vector<int64_t> DepthWiseConv2D::get_pad() const { | |||||
| auto value_ptr = this->GetAttr(kPad); | |||||
| return GetValue<std::vector<int64_t>>(value_ptr); | |||||
| return GetValue<std::vector<int64_t>>(GetAttr(kDilation)); | |||||
| } | } | ||||
| PadMode DepthWiseConv2D::get_pad_mode() const { return PadMode(GetValue<int64_t>(GetAttr(kPadMode))); } | |||||
| std::vector<int64_t> DepthWiseConv2D::get_pad() const { return GetValue<std::vector<int64_t>>(GetAttr(kPad)); } | |||||
| std::vector<int64_t> DepthWiseConv2D::get_pads() const { | std::vector<int64_t> DepthWiseConv2D::get_pads() const { | ||||
| auto value_ptr = this->GetAttr(kPads); | auto value_ptr = this->GetAttr(kPads); | ||||
| @@ -99,10 +88,7 @@ int64_t DepthWiseConv2D::get_group() const { | |||||
| auto value_ptr = this->GetAttr(kGroup); | auto value_ptr = this->GetAttr(kGroup); | ||||
| return GetValue<int64_t>(value_ptr); | return GetValue<int64_t>(value_ptr); | ||||
| } | } | ||||
| int64_t DepthWiseConv2D::get_out_channel() const { | |||||
| auto value_ptr = this->GetAttr(kOutChannel); | |||||
| return GetValue<int64_t>(value_ptr); | |||||
| } | |||||
| int64_t DepthWiseConv2D::get_out_channel() const { return GetValue<int64_t>(GetAttr(kOutChannel)); } | |||||
| void DepthWiseConv2D::set_kernel_size(const std::vector<int64_t> &kernel_size) { | void DepthWiseConv2D::set_kernel_size(const std::vector<int64_t> &kernel_size) { | ||||
| this->AddAttr(kKernelSize, MakeValue(kernel_size)); | this->AddAttr(kKernelSize, MakeValue(kernel_size)); | ||||
| @@ -126,33 +112,29 @@ void DepthWiseConv2D::set_format(const Format &format) { | |||||
| this->AddAttr(kFormat, MakeValue(f)); | this->AddAttr(kFormat, MakeValue(f)); | ||||
| } | } | ||||
| Format DepthWiseConv2D::get_format() const { | |||||
| auto value_ptr = GetAttr(kFormat); | |||||
| return Format(GetValue<int64_t>(value_ptr)); | |||||
| } | |||||
| Format DepthWiseConv2D::get_format() const { return Format(GetValue<int64_t>(GetAttr(kFormat))); } | |||||
| abstract::ShapePtr DepthWiseConv2DInferShape(const PrimitivePtr &primitive, | abstract::ShapePtr DepthWiseConv2DInferShape(const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto conv_prim = primitive->cast<PrimDepthWiseConv2DPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(conv_prim); | |||||
| auto prim_name = conv_prim->name(); | |||||
| auto prim_name = primitive->name(); | |||||
| CheckAndConvertUtils::CheckInRange<size_t>("conv2d_Infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name); | CheckAndConvertUtils::CheckInRange<size_t>("conv2d_Infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name); | ||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), prim_name); | auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), prim_name); | ||||
| auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[1]->GetShapeTrack(), prim_name); | auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[1]->GetShapeTrack(), prim_name); | ||||
| if (conv_prim->get_format() == NHWC) { | |||||
| auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat))); | |||||
| if (format == NHWC) { | |||||
| x_shape = {x_shape[0], x_shape[3], x_shape[1], x_shape[2]}; | x_shape = {x_shape[0], x_shape[3], x_shape[1], x_shape[2]}; | ||||
| w_shape = {w_shape[0], w_shape[3], w_shape[1], w_shape[2]}; | w_shape = {w_shape[0], w_shape[3], w_shape[1], w_shape[2]}; | ||||
| } | } | ||||
| CheckAndConvertUtils::CheckInteger("weight_rank", w_shape.size(), kEqual, 4, prim_name); | CheckAndConvertUtils::CheckInteger("weight_rank", w_shape.size(), kEqual, 4, prim_name); | ||||
| CheckAndConvertUtils::CheckInteger("x_rank", x_shape.size(), kEqual, 4, prim_name); | CheckAndConvertUtils::CheckInteger("x_rank", x_shape.size(), kEqual, 4, prim_name); | ||||
| CheckAndConvertUtils::Check("x_shape[1]", x_shape[1], kEqual, "w_shape[1]", w_shape[1], conv_prim->name()); | |||||
| auto out_channel = conv_prim->get_out_channel(); | |||||
| CheckAndConvertUtils::Check("x_shape[1]", x_shape[1], kEqual, "w_shape[1]", w_shape[1], prim_name); | |||||
| auto out_channel = GetValue<int64_t>(primitive->GetAttr(kOutChannel)); | |||||
| std::vector<int64_t> temp_w; | std::vector<int64_t> temp_w; | ||||
| std::copy(w_shape.begin() + 2, w_shape.end(), std::back_inserter(temp_w)); | std::copy(w_shape.begin() + 2, w_shape.end(), std::back_inserter(temp_w)); | ||||
| CheckAndConvertUtils::Check("kernel_size", conv_prim->get_kernel_size(), kEqual, "w_shape[2:4]", temp_w, | |||||
| conv_prim->name()); | |||||
| CheckAndConvertUtils::Check("kernel_size", GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize)), kEqual, | |||||
| "w_shape[2:4]", temp_w, prim_name); | |||||
| auto kernel_size_n = w_shape[0]; | auto kernel_size_n = w_shape[0]; | ||||
| if (kernel_size_n != 1) { | if (kernel_size_n != 1) { | ||||
| @@ -160,8 +142,8 @@ abstract::ShapePtr DepthWiseConv2DInferShape(const PrimitivePtr &primitive, | |||||
| } | } | ||||
| auto kernel_size_h = w_shape[2]; | auto kernel_size_h = w_shape[2]; | ||||
| auto kernel_size_w = w_shape[3]; | auto kernel_size_w = w_shape[3]; | ||||
| auto stride = conv_prim->get_stride(); | |||||
| auto dilation = conv_prim->get_dilation(); | |||||
| auto stride = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStride)); | |||||
| auto dilation = GetValue<std::vector<int64_t>>(primitive->GetAttr(kDilation)); | |||||
| auto stride_h = stride[2]; | auto stride_h = stride[2]; | ||||
| auto stride_w = stride[3]; | auto stride_w = stride[3]; | ||||
| auto dilation_h = dilation[2]; | auto dilation_h = dilation[2]; | ||||
| @@ -169,7 +151,7 @@ abstract::ShapePtr DepthWiseConv2DInferShape(const PrimitivePtr &primitive, | |||||
| int64_t h_out = -1; | int64_t h_out = -1; | ||||
| int64_t w_out = -1; | int64_t w_out = -1; | ||||
| std::vector<int64_t> pad_list(4, 0); | std::vector<int64_t> pad_list(4, 0); | ||||
| auto pad_mode = conv_prim->get_pad_mode(); | |||||
| auto pad_mode = PadMode(GetValue<int64_t>(primitive->GetAttr(kPadMode))); | |||||
| if (pad_mode == VALID) { | if (pad_mode == VALID) { | ||||
| h_out = ceil((x_shape[2] - dilation_h * (kernel_size_h - 1)) / stride_h); | h_out = ceil((x_shape[2] - dilation_h * (kernel_size_h - 1)) / stride_h); | ||||
| w_out = ceil((x_shape[3] - dilation_w * (kernel_size_w - 1)) / stride_w); | w_out = ceil((x_shape[3] - dilation_w * (kernel_size_w - 1)) / stride_w); | ||||
| @@ -187,20 +169,21 @@ abstract::ShapePtr DepthWiseConv2DInferShape(const PrimitivePtr &primitive, | |||||
| pad_list.emplace_back(pad_left); | pad_list.emplace_back(pad_left); | ||||
| pad_list.emplace_back(pad_needed_h - pad_left); | pad_list.emplace_back(pad_needed_h - pad_left); | ||||
| } else if (pad_mode == PAD) { | } else if (pad_mode == PAD) { | ||||
| std::copy(conv_prim->get_pad().begin(), conv_prim->get_pad().end(), std::back_inserter(pad_list)); | |||||
| auto pad_top = conv_prim->get_pad()[0]; | |||||
| auto pad_bottom = conv_prim->get_pad()[1]; | |||||
| auto pad_right = conv_prim->get_pad()[2]; | |||||
| auto pad_left = conv_prim->get_pad()[3]; | |||||
| auto pads = GetValue<std::vector<int64_t>>(primitive->GetAttr(kPad)); | |||||
| std::copy(pads.begin(), pads.end(), std::back_inserter(pad_list)); | |||||
| auto pad_top = pads[0]; | |||||
| auto pad_bottom = pads[1]; | |||||
| auto pad_right = pads[2]; | |||||
| auto pad_left = pads[3]; | |||||
| h_out = 1 + (x_shape[2] + pad_top + pad_bottom - kernel_size_h - (kernel_size_h - 1) * (dilation_h - 1)) / stride_h; | h_out = 1 + (x_shape[2] + pad_top + pad_bottom - kernel_size_h - (kernel_size_h - 1) * (dilation_h - 1)) / stride_h; | ||||
| w_out = 1 + (x_shape[3] + pad_left + pad_right - kernel_size_w - (kernel_size_w - 1) * (dilation_w - 1)) / stride_w; | w_out = 1 + (x_shape[3] + pad_left + pad_right - kernel_size_w - (kernel_size_w - 1) * (dilation_w - 1)) / stride_w; | ||||
| h_out = floor(h_out); | h_out = floor(h_out); | ||||
| w_out = floor(w_out); | w_out = floor(w_out); | ||||
| } | } | ||||
| conv_prim->set_pads(pad_list); | |||||
| primitive->AddAttr(kPads, MakeValue(pad_list)); | |||||
| std::vector<int64_t> out_shape = {x_shape[0], out_channel * x_shape[1], h_out, w_out}; | std::vector<int64_t> out_shape = {x_shape[0], out_channel * x_shape[1], h_out, w_out}; | ||||
| if (conv_prim->get_format() == NHWC) { | |||||
| if (format == NHWC) { | |||||
| out_shape = {x_shape[0], h_out, w_out, out_channel * x_shape[1]}; | out_shape = {x_shape[0], h_out, w_out, out_channel * x_shape[1]}; | ||||
| } | } | ||||
| return std::make_shared<abstract::Shape>(out_shape); | return std::make_shared<abstract::Shape>(out_shape); | ||||
| @@ -68,10 +68,7 @@ float DetectionPostProcess::get_nms_score_threshold() const { | |||||
| void DetectionPostProcess::set_max_detections(const int64_t MaxDetections) { | void DetectionPostProcess::set_max_detections(const int64_t MaxDetections) { | ||||
| this->AddAttr(kMaxDetections, MakeValue(MaxDetections)); | this->AddAttr(kMaxDetections, MakeValue(MaxDetections)); | ||||
| } | } | ||||
| int64_t DetectionPostProcess::get_max_detections() const { | |||||
| auto value_ptr = this->GetAttr(kMaxDetections); | |||||
| return GetValue<int64_t>(value_ptr); | |||||
| } | |||||
| int64_t DetectionPostProcess::get_max_detections() const { return GetValue<int64_t>(GetAttr(kMaxDetections)); } | |||||
| void DetectionPostProcess::set_detections_per_class(const int64_t DetectionsPerClass) { | void DetectionPostProcess::set_detections_per_class(const int64_t DetectionsPerClass) { | ||||
| this->AddAttr(kDetectionsPerClass, MakeValue(DetectionsPerClass)); | this->AddAttr(kDetectionsPerClass, MakeValue(DetectionsPerClass)); | ||||
| @@ -85,17 +82,13 @@ void DetectionPostProcess::set_max_classes_per_detection(const int64_t MaxClasse | |||||
| this->AddAttr(kMaxClassesPerDetection, MakeValue(MaxClassesPerDetection)); | this->AddAttr(kMaxClassesPerDetection, MakeValue(MaxClassesPerDetection)); | ||||
| } | } | ||||
| int64_t DetectionPostProcess::get_max_classes_per_detection() const { | int64_t DetectionPostProcess::get_max_classes_per_detection() const { | ||||
| auto value_ptr = this->GetAttr(kMaxClassesPerDetection); | |||||
| return GetValue<int64_t>(value_ptr); | |||||
| return GetValue<int64_t>(GetAttr(kMaxClassesPerDetection)); | |||||
| } | } | ||||
| void DetectionPostProcess::set_num_classes(const int64_t NumClasses) { | void DetectionPostProcess::set_num_classes(const int64_t NumClasses) { | ||||
| this->AddAttr(kNumClasses, MakeValue(NumClasses)); | this->AddAttr(kNumClasses, MakeValue(NumClasses)); | ||||
| } | } | ||||
| int64_t DetectionPostProcess::get_num_classes() const { | |||||
| auto value_ptr = this->GetAttr(kNumClasses); | |||||
| return GetValue<int64_t>(value_ptr); | |||||
| } | |||||
| int64_t DetectionPostProcess::get_num_classes() const { return GetValue<int64_t>(GetAttr(kNumClasses)); } | |||||
| void DetectionPostProcess::set_use_regular_nms(const bool UseRegularNms) { | void DetectionPostProcess::set_use_regular_nms(const bool UseRegularNms) { | ||||
| this->AddAttr(kUseRegularNms, MakeValue(UseRegularNms)); | this->AddAttr(kUseRegularNms, MakeValue(UseRegularNms)); | ||||
| } | } | ||||
| @@ -115,16 +108,11 @@ void DetectionPostProcess::set_format(const Format &format) { | |||||
| int64_t f = format; | int64_t f = format; | ||||
| this->AddAttr(kFormat, MakeValue(f)); | this->AddAttr(kFormat, MakeValue(f)); | ||||
| } | } | ||||
| Format DetectionPostProcess::get_format() const { | |||||
| auto value_ptr = this->GetAttr(kFormat); | |||||
| return Format(GetValue<int64_t>(value_ptr)); | |||||
| } | |||||
| Format DetectionPostProcess::get_format() const { return Format(GetValue<int64_t>(GetAttr(kFormat))); } | |||||
| AbstractBasePtr DetectionPostProcessInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr DetectionPostProcessInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto detection_prim = primitive->cast<PrimDetectionPostProcessPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(detection_prim); | |||||
| auto prim_name = detection_prim->name(); | |||||
| auto prim_name = primitive->name(); | |||||
| CheckAndConvertUtils::CheckInteger("detection_post_process_infer", input_args.size(), kEqual, 3, prim_name); | CheckAndConvertUtils::CheckInteger("detection_post_process_infer", input_args.size(), kEqual, 3, prim_name); | ||||
| MS_EXCEPTION_IF_NULL(input_args[0]); | MS_EXCEPTION_IF_NULL(input_args[0]); | ||||
| MS_EXCEPTION_IF_NULL(input_args[1]); | MS_EXCEPTION_IF_NULL(input_args[1]); | ||||
| @@ -135,12 +123,13 @@ AbstractBasePtr DetectionPostProcessInfer(const abstract::AnalysisEnginePtr &, c | |||||
| auto boxes_shape = CheckAndConvertUtils::ConvertShapePtrToShape("boxes_shape", boxes->BuildShape(), prim_name); | auto boxes_shape = CheckAndConvertUtils::ConvertShapePtrToShape("boxes_shape", boxes->BuildShape(), prim_name); | ||||
| auto scores_shape = CheckAndConvertUtils::ConvertShapePtrToShape("scores_shape", scores->BuildShape(), prim_name); | auto scores_shape = CheckAndConvertUtils::ConvertShapePtrToShape("scores_shape", scores->BuildShape(), prim_name); | ||||
| auto anchors_shape = CheckAndConvertUtils::ConvertShapePtrToShape("anchors_shape", anchors->BuildShape(), prim_name); | auto anchors_shape = CheckAndConvertUtils::ConvertShapePtrToShape("anchors_shape", anchors->BuildShape(), prim_name); | ||||
| if (detection_prim->get_format() == NHWC) { | |||||
| auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat))); | |||||
| if (format == NHWC) { | |||||
| boxes_shape = {boxes_shape[0], boxes_shape[3], boxes_shape[1], boxes_shape[2]}; | boxes_shape = {boxes_shape[0], boxes_shape[3], boxes_shape[1], boxes_shape[2]}; | ||||
| scores_shape = {scores_shape[0], scores_shape[3], scores_shape[1], scores_shape[2]}; | scores_shape = {scores_shape[0], scores_shape[3], scores_shape[1], scores_shape[2]}; | ||||
| anchors_shape = {anchors_shape[0], anchors_shape[3], anchors_shape[1], anchors_shape[2]}; | anchors_shape = {anchors_shape[0], anchors_shape[3], anchors_shape[1], anchors_shape[2]}; | ||||
| } | } | ||||
| auto num_classes = detection_prim->get_num_classes(); | |||||
| auto num_classes = GetValue<int64_t>(primitive->GetAttr(kNumClasses)); | |||||
| CheckAndConvertUtils::CheckInRange("scores_shape[2]", scores_shape[2], kIncludeBoth, {num_classes, num_classes + 1}, | CheckAndConvertUtils::CheckInRange("scores_shape[2]", scores_shape[2], kIncludeBoth, {num_classes, num_classes + 1}, | ||||
| prim_name); | prim_name); | ||||
| CheckAndConvertUtils::Check("boxes_shape[1]", boxes_shape[1], kEqual, "scores_shape[1]", scores_shape[1], prim_name, | CheckAndConvertUtils::Check("boxes_shape[1]", boxes_shape[1], kEqual, "scores_shape[1]", scores_shape[1], prim_name, | ||||
| @@ -149,8 +138,8 @@ AbstractBasePtr DetectionPostProcessInfer(const abstract::AnalysisEnginePtr &, c | |||||
| ValueError); | ValueError); | ||||
| // Infer shape | // Infer shape | ||||
| auto max_detections = detection_prim->get_max_detections(); | |||||
| auto max_classes_per_detection = detection_prim->get_max_classes_per_detection(); | |||||
| auto max_detections = GetValue<int64_t>(primitive->GetAttr(kMaxDetections)); | |||||
| auto max_classes_per_detection = GetValue<int64_t>(primitive->GetAttr(kMaxClassesPerDetection)); | |||||
| auto num_detected_boxes = max_detections * max_classes_per_detection; | auto num_detected_boxes = max_detections * max_classes_per_detection; | ||||
| std::vector<int64_t> output_boxes_shape = {1, num_detected_boxes, 4}; | std::vector<int64_t> output_boxes_shape = {1, num_detected_boxes, 4}; | ||||
| std::vector<int64_t> output_class_shape = {1, num_detected_boxes}; | std::vector<int64_t> output_class_shape = {1, num_detected_boxes}; | ||||
| @@ -163,7 +152,7 @@ AbstractBasePtr DetectionPostProcessInfer(const abstract::AnalysisEnginePtr &, c | |||||
| auto output1 = std::make_shared<abstract::AbstractTensor>(output_type, output_class_shape); | auto output1 = std::make_shared<abstract::AbstractTensor>(output_type, output_class_shape); | ||||
| auto output2 = std::make_shared<abstract::AbstractTensor>(output_type, output_num_shape); | auto output2 = std::make_shared<abstract::AbstractTensor>(output_type, output_num_shape); | ||||
| AbstractBasePtrList output = {output0, output1, output1, output2}; | AbstractBasePtrList output = {output0, output1, output1, output2}; | ||||
| if (detection_prim->get_format() == NHWC) { | |||||
| if (format == NHWC) { | |||||
| output = {output0, output1, output2, output1}; | output = {output0, output1, output2, output1}; | ||||
| } | } | ||||
| return std::make_shared<abstract::AbstractTuple>(output); | return std::make_shared<abstract::AbstractTuple>(output); | ||||
| @@ -28,9 +28,7 @@ namespace ops { | |||||
| namespace { | namespace { | ||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto div_prim = primitive->cast<PrimDivPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(div_prim); | |||||
| auto prim_name = div_prim->name(); | |||||
| auto prim_name = primitive->name(); | |||||
| return BroadCastInferShape(prim_name, input_args); | return BroadCastInferShape(prim_name, input_args); | ||||
| } | } | ||||
| @@ -39,9 +39,7 @@ float Dropout::get_keep_prob() const { | |||||
| AbstractBasePtr DropoutInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr DropoutInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto dropout_prim = primitive->cast<PrimDropoutPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(dropout_prim); | |||||
| auto prim_name = dropout_prim->name(); | |||||
| auto prim_name = primitive->name(); | |||||
| CheckAndConvertUtils::CheckInteger("dropout_infer", input_args.size(), kEqual, 1, prim_name); | CheckAndConvertUtils::CheckInteger("dropout_infer", input_args.size(), kEqual, 1, prim_name); | ||||
| // Infer shape | // Infer shape | ||||
| @@ -31,9 +31,7 @@ namespace ops { | |||||
| namespace { | namespace { | ||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto elu_prim = primitive->cast<PrimElu>(); | |||||
| MS_EXCEPTION_IF_NULL(elu_prim); | |||||
| auto op_name = elu_prim->name(); | |||||
| auto op_name = primitive->name(); | |||||
| for (const auto &item : input_args) { | for (const auto &item : input_args) { | ||||
| MS_EXCEPTION_IF_NULL(item); | MS_EXCEPTION_IF_NULL(item); | ||||
| } | } | ||||
| @@ -29,9 +29,7 @@ namespace ops { | |||||
| namespace { | namespace { | ||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto equal_prim = primitive->cast<PrimEqualPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(equal_prim); | |||||
| auto op_name = equal_prim->name(); | |||||
| auto op_name = primitive->name(); | |||||
| return BroadCastInferShape(op_name, input_args); | return BroadCastInferShape(op_name, input_args); | ||||
| } | } | ||||
| @@ -30,9 +30,7 @@ namespace ops { | |||||
| AbstractBasePtr ExpandDimsInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr ExpandDimsInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto expand_dims_prim = primitive->cast<PrimExpandDims>(); | |||||
| MS_EXCEPTION_IF_NULL(expand_dims_prim); | |||||
| auto prim_name = expand_dims_prim->name(); | |||||
| auto prim_name = primitive->name(); | |||||
| CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kEqual, 2, prim_name); | CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kEqual, 2, prim_name); | ||||
| for (const auto &item : input_args) { | for (const auto &item : input_args) { | ||||
| MS_EXCEPTION_IF_NULL(item); | MS_EXCEPTION_IF_NULL(item); | ||||
| @@ -28,9 +28,7 @@ namespace ops { | |||||
| namespace { | namespace { | ||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto fake_prim = primitive->cast<PrimFakeQuantWithMinMaxVarsPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(fake_prim); | |||||
| auto prim_name = fake_prim->name(); | |||||
| auto prim_name = primitive->name(); | |||||
| auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | ||||
| auto min_shape = CheckAndConvertUtils::ConvertShapePtrToShape("min_shape", input_args[1]->BuildShape(), prim_name); | auto min_shape = CheckAndConvertUtils::ConvertShapePtrToShape("min_shape", input_args[1]->BuildShape(), prim_name); | ||||
| auto max_shape = CheckAndConvertUtils::ConvertShapePtrToShape("max_shape", input_args[2]->BuildShape(), prim_name); | auto max_shape = CheckAndConvertUtils::ConvertShapePtrToShape("max_shape", input_args[2]->BuildShape(), prim_name); | ||||
| @@ -43,9 +43,7 @@ AbstractBasePtr FakeQuantWithMinMaxVarsPerChannelInfer(const abstract::AnalysisE | |||||
| const PrimitivePtr &primitive, | const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto FakeQuantWithMinMaxVarsPerChannel_prim = primitive->cast<PrimFakeQuantWithMinMaxVarsPerChannelPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(FakeQuantWithMinMaxVarsPerChannel_prim); | |||||
| auto op_name = FakeQuantWithMinMaxVarsPerChannel_prim->name(); | |||||
| auto op_name = primitive->name(); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), op_name); | auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), op_name); | ||||
| auto min_shape = CheckAndConvertUtils::ConvertShapePtrToShape("min_shape", input_args[1]->BuildShape(), op_name); | auto min_shape = CheckAndConvertUtils::ConvertShapePtrToShape("min_shape", input_args[1]->BuildShape(), op_name); | ||||
| auto max_shape = CheckAndConvertUtils::ConvertShapePtrToShape("max_shape", input_args[2]->BuildShape(), op_name); | auto max_shape = CheckAndConvertUtils::ConvertShapePtrToShape("max_shape", input_args[2]->BuildShape(), op_name); | ||||
| @@ -24,9 +24,7 @@ namespace ops { | |||||
| namespace { | namespace { | ||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto FftImag_prim = primitive->cast<PrimFftImagPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(FftImag_prim); | |||||
| auto prim_name = FftImag_prim->name(); | |||||
| auto prim_name = primitive->name(); | |||||
| auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("in_shape", input_args[0]->BuildShape(), prim_name); | auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("in_shape", input_args[0]->BuildShape(), prim_name); | ||||
| in_shape.pop_back(); | in_shape.pop_back(); | ||||
| return std::make_shared<abstract::Shape>(in_shape); | return std::make_shared<abstract::Shape>(in_shape); | ||||
| @@ -23,9 +23,7 @@ namespace ops { | |||||
| namespace { | namespace { | ||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto flatten_prim = primitive->cast<PrimFlattenPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(flatten_prim); | |||||
| auto prim_name = flatten_prim->name(); | |||||
| auto prim_name = primitive->name(); | |||||
| CheckAndConvertUtils::CheckInteger("input args size", input_args.size(), kGreaterEqual, 1, prim_name); | CheckAndConvertUtils::CheckInteger("input args size", input_args.size(), kGreaterEqual, 1, prim_name); | ||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | ||||
| auto prod = 1; | auto prod = 1; | ||||
| @@ -28,9 +28,7 @@ namespace ops { | |||||
| namespace { | namespace { | ||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto floor_prim = primitive->cast<PrimFLoorPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(floor_prim); | |||||
| auto prim_name = floor_prim->name(); | |||||
| auto prim_name = primitive->name(); | |||||
| for (const auto &item : input_args) { | for (const auto &item : input_args) { | ||||
| MS_EXCEPTION_IF_NULL(item); | MS_EXCEPTION_IF_NULL(item); | ||||
| } | } | ||||
| @@ -39,9 +39,7 @@ void AddFusion::Init(const ActivationType activation_type) { this->set_activatio | |||||
| namespace { | namespace { | ||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto add_prim = primitive->cast<PrimAddFusionPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(add_prim); | |||||
| auto op_name = add_prim->name(); | |||||
| auto op_name = primitive->name(); | |||||
| return BroadCastInferShape(op_name, input_args); | return BroadCastInferShape(op_name, input_args); | ||||
| } | } | ||||
| @@ -52,22 +52,21 @@ ActivationType AvgPoolFusion::get_activation_type() const { | |||||
| namespace { | namespace { | ||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto pool_prim = primitive->cast<PrimAvgPoolFusionPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(pool_prim); | |||||
| auto op_name = pool_prim->name(); | |||||
| auto op_name = primitive->name(); | |||||
| auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), op_name); | auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), op_name); | ||||
| if (pool_prim->get_format() == NHWC) { | |||||
| auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat))); | |||||
| if (format == NHWC) { | |||||
| in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]}; | in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]}; | ||||
| } | } | ||||
| CheckAndConvertUtils::CheckInteger("x_rank", in_shape.size(), kEqual, 4, op_name); | CheckAndConvertUtils::CheckInteger("x_rank", in_shape.size(), kEqual, 4, op_name); | ||||
| auto kernel_size = pool_prim->get_kernel_size(); | |||||
| auto pad_mode = pool_prim->get_pad_mode(); | |||||
| auto kernel_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize)); | |||||
| auto pad_mode = PadMode(GetValue<int64_t>(primitive->GetAttr(kPadMode))); | |||||
| auto batch = in_shape[0]; | auto batch = in_shape[0]; | ||||
| auto channel = in_shape[1]; | auto channel = in_shape[1]; | ||||
| auto in_h = in_shape[2]; | auto in_h = in_shape[2]; | ||||
| auto in_w = in_shape[3]; | auto in_w = in_shape[3]; | ||||
| auto strides = pool_prim->get_strides(); | |||||
| auto strides = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStrides)); | |||||
| auto kernel_h = kernel_size[2]; | auto kernel_h = kernel_size[2]; | ||||
| auto kernel_w = kernel_size[3]; | auto kernel_w = kernel_size[3]; | ||||
| auto stride_h = strides[2]; | auto stride_h = strides[2]; | ||||
| @@ -82,7 +81,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A | |||||
| out_w = ceil(in_w / stride_w); | out_w = ceil(in_w / stride_w); | ||||
| } | } | ||||
| std::vector<int64_t> out_shape = {batch, channel, out_h, out_w}; | std::vector<int64_t> out_shape = {batch, channel, out_h, out_w}; | ||||
| if (pool_prim->get_format() == NHWC) { | |||||
| if (format == NHWC) { | |||||
| out_shape = {batch, out_h, out_w, channel}; | out_shape = {batch, out_h, out_w, channel}; | ||||
| } | } | ||||
| if (std::any_of(out_shape.begin(), out_shape.end(), [](int64_t a) { return a <= 0; })) { | if (std::any_of(out_shape.begin(), out_shape.end(), [](int64_t a) { return a <= 0; })) { | ||||
| @@ -21,22 +21,13 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ops { | namespace ops { | ||||
| void FullConnection::set_has_bias(const bool has_bias) { this->AddAttr(kHasBias, MakeValue(has_bias)); } | void FullConnection::set_has_bias(const bool has_bias) { this->AddAttr(kHasBias, MakeValue(has_bias)); } | ||||
| bool FullConnection::get_has_bias() const { | |||||
| auto value_ptr = GetAttr(kHasBias); | |||||
| return GetValue<bool>(value_ptr); | |||||
| } | |||||
| bool FullConnection::get_has_bias() const { return GetValue<bool>(GetAttr(kHasBias)); } | |||||
| void FullConnection::set_axis(const int64_t axis) { this->AddAttr(kAxis, MakeValue(axis)); } | void FullConnection::set_axis(const int64_t axis) { this->AddAttr(kAxis, MakeValue(axis)); } | ||||
| int64_t FullConnection::get_axis() const { | |||||
| auto value_ptr = GetAttr(kAxis); | |||||
| return GetValue<int64_t>(value_ptr); | |||||
| } | |||||
| int64_t FullConnection::get_axis() const { return GetValue<int64_t>(GetAttr(kAxis)); } | |||||
| void FullConnection::set_use_axis(const bool use_axis) { this->AddAttr(kUseAxis, MakeValue(use_axis)); } | void FullConnection::set_use_axis(const bool use_axis) { this->AddAttr(kUseAxis, MakeValue(use_axis)); } | ||||
| bool FullConnection::get_use_axis() const { | |||||
| auto value_ptr = GetAttr(kUseAxis); | |||||
| return GetValue<bool>(value_ptr); | |||||
| } | |||||
| bool FullConnection::get_use_axis() const { return GetValue<bool>(GetAttr(kUseAxis)); } | |||||
| void FullConnection::set_activation_type(const ActivationType &activation_type) { | void FullConnection::set_activation_type(const ActivationType &activation_type) { | ||||
| int64_t swi; | int64_t swi; | ||||
| @@ -57,26 +48,26 @@ void FullConnection::Init(const bool has_bias, const int64_t axis, const bool us | |||||
| AbstractBasePtr FullConnectionInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr FullConnectionInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto full_prim = primitive->cast<PrimFullConnectionPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(full_prim); | |||||
| auto prim_name = full_prim->name(); | |||||
| auto prim_name = primitive->name(); | |||||
| MS_EXCEPTION_IF_NULL(input_args[0]); | MS_EXCEPTION_IF_NULL(input_args[0]); | ||||
| MS_EXCEPTION_IF_NULL(input_args[1]); | MS_EXCEPTION_IF_NULL(input_args[1]); | ||||
| auto input0 = input_args[0]; | auto input0 = input_args[0]; | ||||
| auto input1 = input_args[1]; | auto input1 = input_args[1]; | ||||
| auto input0_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input0_shape", input0->BuildShape(), prim_name); | auto input0_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input0_shape", input0->BuildShape(), prim_name); | ||||
| auto input1_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input1_shape", input1->BuildShape(), prim_name); | auto input1_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input1_shape", input1->BuildShape(), prim_name); | ||||
| auto prim_axis = full_prim->get_axis(); | |||||
| if (full_prim->get_has_bias()) { | |||||
| auto prim_axis = GetValue<int64_t>(primitive->GetAttr(kAxis)); | |||||
| auto has_bias = GetValue<bool>(primitive->GetAttr(kHasBias)); | |||||
| if (has_bias) { | |||||
| CheckAndConvertUtils::CheckInteger("input_args.size()", input_args.size(), kEqual, 3, prim_name); | CheckAndConvertUtils::CheckInteger("input_args.size()", input_args.size(), kEqual, 3, prim_name); | ||||
| } else { | } else { | ||||
| CheckAndConvertUtils::CheckInteger("input_args.size()", input_args.size(), kEqual, 2, prim_name); | CheckAndConvertUtils::CheckInteger("input_args.size()", input_args.size(), kEqual, 2, prim_name); | ||||
| } | } | ||||
| if (full_prim->get_use_axis() && (prim_axis < 1 || prim_axis > (int64_t)input0_shape.size())) { | |||||
| auto use_axis = GetValue<bool>(primitive->GetAttr(kUseAxis)); | |||||
| if (use_axis && (prim_axis < 1 || prim_axis > (int64_t)input0_shape.size())) { | |||||
| MS_EXCEPTION(ValueError) << "Full Connection axis invalid"; | MS_EXCEPTION(ValueError) << "Full Connection axis invalid"; | ||||
| } | } | ||||
| int64_t new_k = 1; | int64_t new_k = 1; | ||||
| if (full_prim->get_use_axis()) { | |||||
| if (use_axis) { | |||||
| for (size_t t = prim_axis; t < input0_shape.size(); t++) { | for (size_t t = prim_axis; t < input0_shape.size(); t++) { | ||||
| new_k *= input0_shape[t]; | new_k *= input0_shape[t]; | ||||
| } | } | ||||
| @@ -86,7 +77,7 @@ AbstractBasePtr FullConnectionInfer(const abstract::AnalysisEnginePtr &, const P | |||||
| } else { | } else { | ||||
| new_k = input1_shape[1]; | new_k = input1_shape[1]; | ||||
| } | } | ||||
| if (full_prim->get_has_bias()) { | |||||
| if (has_bias) { | |||||
| auto input2_shape = | auto input2_shape = | ||||
| CheckAndConvertUtils::ConvertShapePtrToShape("input2_shape", input_args[2]->BuildShape(), prim_name); | CheckAndConvertUtils::ConvertShapePtrToShape("input2_shape", input_args[2]->BuildShape(), prim_name); | ||||
| if (input2_shape[0] != input1_shape[0]) { | if (input2_shape[0] != input1_shape[0]) { | ||||
| @@ -94,7 +85,7 @@ AbstractBasePtr FullConnectionInfer(const abstract::AnalysisEnginePtr &, const P | |||||
| } | } | ||||
| } | } | ||||
| std::vector<int64_t> out_shape = {(int64_t)input0_shape.size()}; | std::vector<int64_t> out_shape = {(int64_t)input0_shape.size()}; | ||||
| if (full_prim->get_use_axis()) { | |||||
| if (use_axis) { | |||||
| out_shape.resize(prim_axis + 1); | out_shape.resize(prim_axis + 1); | ||||
| out_shape[prim_axis] = input1_shape[0]; | out_shape[prim_axis] = input1_shape[0]; | ||||
| } else { | } else { | ||||
| @@ -52,22 +52,21 @@ ActivationType MaxPoolFusion::get_activation_type() const { | |||||
| namespace { | namespace { | ||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto pool_prim = primitive->cast<PrimMaxPoolFusionPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(pool_prim); | |||||
| auto op_name = pool_prim->name(); | |||||
| auto op_name = primitive->name(); | |||||
| auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), op_name); | auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), op_name); | ||||
| if (pool_prim->get_format() == NHWC) { | |||||
| auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat))); | |||||
| if (format == NHWC) { | |||||
| in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]}; | in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]}; | ||||
| } | } | ||||
| CheckAndConvertUtils::CheckInteger("x_rank", in_shape.size(), kEqual, 4, op_name); | CheckAndConvertUtils::CheckInteger("x_rank", in_shape.size(), kEqual, 4, op_name); | ||||
| auto kernel_size = pool_prim->get_kernel_size(); | |||||
| auto pad_mode = pool_prim->get_pad_mode(); | |||||
| auto kernel_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize)); | |||||
| auto pad_mode = PadMode(GetValue<int64_t>(primitive->GetAttr(kPadMode))); | |||||
| auto batch = in_shape[0]; | auto batch = in_shape[0]; | ||||
| auto channel = in_shape[1]; | auto channel = in_shape[1]; | ||||
| auto in_h = in_shape[2]; | auto in_h = in_shape[2]; | ||||
| auto in_w = in_shape[3]; | auto in_w = in_shape[3]; | ||||
| auto strides = pool_prim->get_strides(); | |||||
| auto strides = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStrides)); | |||||
| auto kernel_h = kernel_size[2]; | auto kernel_h = kernel_size[2]; | ||||
| auto kernel_w = kernel_size[3]; | auto kernel_w = kernel_size[3]; | ||||
| auto stride_h = strides[2]; | auto stride_h = strides[2]; | ||||
| @@ -82,7 +81,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A | |||||
| out_w = ceil(in_w / stride_w); | out_w = ceil(in_w / stride_w); | ||||
| } | } | ||||
| std::vector<int64_t> out_shape = {batch, channel, out_h, out_w}; | std::vector<int64_t> out_shape = {batch, channel, out_h, out_w}; | ||||
| if (pool_prim->get_format() == NHWC) { | |||||
| if (format == NHWC) { | |||||
| out_shape = {batch, out_h, out_w, channel}; | out_shape = {batch, out_h, out_w, channel}; | ||||
| } | } | ||||
| if (std::any_of(out_shape.begin(), out_shape.end(), [](int64_t a) { return a <= 0; })) { | if (std::any_of(out_shape.begin(), out_shape.end(), [](int64_t a) { return a <= 0; })) { | ||||
| @@ -37,9 +37,7 @@ float PowFusion::get_shift() const { return GetValue<float>(GetAttr(kShift)); } | |||||
| namespace { | namespace { | ||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto pow_prim = primitive->cast<PrimPowPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(pow_prim); | |||||
| auto op_name = pow_prim->name(); | |||||
| auto op_name = primitive->name(); | |||||
| return BroadCastInferShape(op_name, input_args); | return BroadCastInferShape(op_name, input_args); | ||||
| } | } | ||||
| @@ -32,9 +32,7 @@ std::vector<int64_t> SliceFusion::get_axes() const { | |||||
| AbstractBasePtr SliceFusionInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr SliceFusionInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto SliceFusion_prim = primitive->cast<PrimSliceFusionPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(SliceFusion_prim); | |||||
| auto op_name = SliceFusion_prim->name(); | |||||
| auto op_name = primitive->name(); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), op_name); | auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), op_name); | ||||
| auto x_shape_len = (int64_t)x_shape.size(); | auto x_shape_len = (int64_t)x_shape.size(); | ||||
| auto begin_v = input_args[1]->BuildValue(); | auto begin_v = input_args[1]->BuildValue(); | ||||
| @@ -27,9 +27,7 @@ namespace ops { | |||||
| namespace { | namespace { | ||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto gather_prim = primitive->cast<PrimGatherNd>(); | |||||
| MS_EXCEPTION_IF_NULL(gather_prim); | |||||
| auto prim_name = gather_prim->name(); | |||||
| auto prim_name = primitive->name(); | |||||
| CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 2, prim_name); | CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 2, prim_name); | ||||
| for (const auto &item : input_args) { | for (const auto &item : input_args) { | ||||
| MS_EXCEPTION_IF_NULL(item); | MS_EXCEPTION_IF_NULL(item); | ||||
| @@ -28,9 +28,7 @@ namespace ops { | |||||
| namespace { | namespace { | ||||
| abstract::ShapePtr GeLUInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr GeLUInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto gelu_prim = primitive->cast<PrimGeLUPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(gelu_prim); | |||||
| auto prim_name = gelu_prim->name(); | |||||
| auto prim_name = primitive->name(); | |||||
| auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_x", input_args[0]->BuildShape(), prim_name); | auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_x", input_args[0]->BuildShape(), prim_name); | ||||
| return std::make_shared<abstract::Shape>(input_shape); | return std::make_shared<abstract::Shape>(input_shape); | ||||
| } | } | ||||
| @@ -22,8 +22,6 @@ namespace ops { | |||||
| AbstractBasePtr AvgPoolGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr AvgPoolGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto AvgPoolGrad_prim = primitive->cast<PrimAvgPoolGradPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(AvgPoolGrad_prim); | |||||
| MS_EXCEPTION_IF_NULL(input_args[0]->BuildValue()); | MS_EXCEPTION_IF_NULL(input_args[0]->BuildValue()); | ||||
| auto origin_input_shape = GetValue<std::vector<int64_t>>(input_args[0]->BuildValue()); | auto origin_input_shape = GetValue<std::vector<int64_t>>(input_args[0]->BuildValue()); | ||||
| auto tensor_type = input_args[1]->BuildType()->cast<TensorTypePtr>(); | auto tensor_type = input_args[1]->BuildType()->cast<TensorTypePtr>(); | ||||
| @@ -47,9 +47,7 @@ bool BatchNormGrad::get_is_training() const { | |||||
| AbstractBasePtr BatchNormGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr BatchNormGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto BatchNormGrad_prim = primitive->cast<PrimBatchNormGradPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(BatchNormGrad_prim); | |||||
| auto op_name = BatchNormGrad_prim->name(); | |||||
| auto op_name = primitive->name(); | |||||
| MS_EXCEPTION_IF_NULL(input_args[1]); | MS_EXCEPTION_IF_NULL(input_args[1]); | ||||
| MS_EXCEPTION_IF_NULL(input_args[2]); | MS_EXCEPTION_IF_NULL(input_args[2]); | ||||
| MS_EXCEPTION_IF_NULL(input_args[3]); | MS_EXCEPTION_IF_NULL(input_args[3]); | ||||
| @@ -41,9 +41,7 @@ Format BiasAddGrad::get_format() const { | |||||
| AbstractBasePtr BiasAddGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr BiasAddGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto bias_prim = primitive->cast<PrimBiasAddGradPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(bias_prim); | |||||
| auto prim_name = bias_prim->name(); | |||||
| auto prim_name = primitive->name(); | |||||
| CheckAndConvertUtils::CheckInteger("bias_grad_infer", input_args.size(), kEqual, 1, prim_name); | CheckAndConvertUtils::CheckInteger("bias_grad_infer", input_args.size(), kEqual, 1, prim_name); | ||||
| MS_EXCEPTION_IF_NULL(input_args[0]); | MS_EXCEPTION_IF_NULL(input_args[0]); | ||||
| @@ -26,9 +26,7 @@ namespace { | |||||
| abstract::ShapePtr BinaryCrossEntroyGradInferShape(const PrimitivePtr &primitive, | abstract::ShapePtr BinaryCrossEntroyGradInferShape(const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto binary_cross_entropy_grad_prim = primitive->cast<PrimBinaryCrossEntropyGradPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(binary_cross_entropy_grad_prim); | |||||
| auto prim_name = binary_cross_entropy_grad_prim->name(); | |||||
| auto prim_name = primitive->name(); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | ||||
| auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShape("y_shape", input_args[1]->BuildShape(), prim_name); | auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShape("y_shape", input_args[1]->BuildShape(), prim_name); | ||||
| auto weight_shape = | auto weight_shape = | ||||
| @@ -35,18 +35,14 @@ namespace { | |||||
| abstract::ShapePtr DropoutGradInferShape(const PrimitivePtr &primitive, | abstract::ShapePtr DropoutGradInferShape(const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto DropoutGrad_prim = primitive->cast<PrimDropoutGradPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(DropoutGrad_prim); | |||||
| auto op_name = DropoutGrad_prim->name(); | |||||
| auto op_name = primitive->name(); | |||||
| auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), op_name); | auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), op_name); | ||||
| return std::make_shared<abstract::Shape>(in_shape); | return std::make_shared<abstract::Shape>(in_shape); | ||||
| } | } | ||||
| TypePtr DropoutGradInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | TypePtr DropoutGradInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(prim); | MS_EXCEPTION_IF_NULL(prim); | ||||
| auto DropoutGrad_prim = prim->cast<PrimDropoutGradPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(DropoutGrad_prim); | |||||
| auto op_name = DropoutGrad_prim->name(); | |||||
| auto op_name = prim->name(); | |||||
| auto mask_dtype = input_args[1]->BuildType(); | auto mask_dtype = input_args[1]->BuildType(); | ||||
| auto dy_dtype = input_args[0]->BuildType(); | auto dy_dtype = input_args[0]->BuildType(); | ||||
| CheckAndConvertUtils::CheckTensorTypeValid("mask", mask_dtype, {kTensorType}, op_name); | CheckAndConvertUtils::CheckTensorTypeValid("mask", mask_dtype, {kTensorType}, op_name); | ||||
| @@ -114,8 +114,7 @@ void GroupConv2DGradInput::set_input_shape(const std::vector<int64_t> &input_sha | |||||
| } | } | ||||
| std::vector<int64_t> GroupConv2DGradInput::get_input_shape() const { | std::vector<int64_t> GroupConv2DGradInput::get_input_shape() const { | ||||
| auto value_ptr = GetAttr(kInputShape); | |||||
| return GetValue<std::vector<int64_t>>(value_ptr); | |||||
| return GetValue<std::vector<int64_t>>(GetAttr(kInputShape)); | |||||
| } | } | ||||
| void GroupConv2DGradInput::set_format(const Format &format) { | void GroupConv2DGradInput::set_format(const Format &format) { | ||||
| @@ -147,14 +146,12 @@ bool GroupConv2DGradInput::get_has_bias() const { | |||||
| AbstractBasePtr GroupConv2DGradInputInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr GroupConv2DGradInputInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto group_prim = primitive->cast<PrimGroupConv2DGradInputPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(group_prim); | |||||
| auto prim_name = group_prim->name(); | |||||
| auto prim_name = primitive->name(); | |||||
| CheckAndConvertUtils::CheckInteger("group_conv_2D_infer", input_args.size(), kGreaterEqual, 2, prim_name); | CheckAndConvertUtils::CheckInteger("group_conv_2D_infer", input_args.size(), kGreaterEqual, 2, prim_name); | ||||
| MS_EXCEPTION_IF_NULL(input_args[0]); | MS_EXCEPTION_IF_NULL(input_args[0]); | ||||
| // Infer shape | // Infer shape | ||||
| auto shape = group_prim->get_input_shape(); | |||||
| auto shape = GetValue<std::vector<int64_t>>(primitive->GetAttr(kInputShape)); | |||||
| // Infer type | // Infer type | ||||
| auto type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element(); | auto type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element(); | ||||
| @@ -21,9 +21,8 @@ namespace mindspore { | |||||
| namespace ops { | namespace ops { | ||||
| AbstractBasePtr MaxPoolGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr MaxPoolGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| auto MaxPoolGrad_prim = primitive->cast<PrimMaxPoolGradPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(MaxPoolGrad_prim); | |||||
| auto op_name = MaxPoolGrad_prim->name(); | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| auto op_name = primitive->name(); | |||||
| MS_EXCEPTION_IF_NULL(input_args[0]->BuildValue()); | MS_EXCEPTION_IF_NULL(input_args[0]->BuildValue()); | ||||
| auto x1_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x1_shape", input_args[0]->BuildShape(), op_name); | auto x1_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x1_shape", input_args[0]->BuildShape(), op_name); | ||||
| auto tensor_type = input_args[0]->BuildType()->cast<TensorTypePtr>(); | auto tensor_type = input_args[0]->BuildType()->cast<TensorTypePtr>(); | ||||
| @@ -30,9 +30,7 @@ AbstractBasePtr SigmoidCrossEntropyWithLogitsGradInfer(const abstract::AnalysisE | |||||
| const PrimitivePtr &primitive, | const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto sigmoid_prim = primitive->cast<PrimSigmoidCrossEntropyWithLogitsGradPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(sigmoid_prim); | |||||
| auto prim_name = sigmoid_prim->name(); | |||||
| auto prim_name = primitive->name(); | |||||
| CheckAndConvertUtils::CheckInteger("sigmoid_cross_entropy_with_logits_grad_infer", input_args.size(), kEqual, 3, | CheckAndConvertUtils::CheckInteger("sigmoid_cross_entropy_with_logits_grad_infer", input_args.size(), kEqual, 3, | ||||
| prim_name); | prim_name); | ||||
| @@ -36,9 +36,7 @@ float SmoothL1LossGrad::get_beta() const { | |||||
| AbstractBasePtr SmoothL1LossGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr SmoothL1LossGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto smooth_prim = primitive->cast<PrimSmoothL1LossGradPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(smooth_prim); | |||||
| auto prim_name = smooth_prim->name(); | |||||
| auto prim_name = primitive->name(); | |||||
| CheckAndConvertUtils::CheckInteger("smooth_l1_loss_grad_infer", input_args.size(), kEqual, 3, prim_name); | CheckAndConvertUtils::CheckInteger("smooth_l1_loss_grad_infer", input_args.size(), kEqual, 3, prim_name); | ||||
| // Infer shape | // Infer shape | ||||
| @@ -24,12 +24,10 @@ namespace ops { | |||||
| AbstractBasePtr HashtableLookupInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr HashtableLookupInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto HashtableLookup_prim = primitive->cast<PrimHashtableLookupPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(HashtableLookup_prim); | |||||
| for (auto input : input_args) { | for (auto input : input_args) { | ||||
| MS_EXCEPTION_IF_NULL(input); | MS_EXCEPTION_IF_NULL(input); | ||||
| } | } | ||||
| auto op_name = HashtableLookup_prim->name(); | |||||
| auto op_name = primitive->name(); | |||||
| std::vector<int64_t> hits_shape; | std::vector<int64_t> hits_shape; | ||||
| auto input = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), op_name); | auto input = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), op_name); | ||||
| hits_shape.push_back(input[0]); | hits_shape.push_back(input[0]); | ||||
| @@ -29,10 +29,7 @@ void L2Normalize::set_axis(const std::vector<int64_t> &axis) { AddAttr(kAxis, Ma | |||||
| void L2Normalize::set_epsilon(const float epsilon) { AddAttr(kEpsilon, MakeValue(epsilon)); } | void L2Normalize::set_epsilon(const float epsilon) { AddAttr(kEpsilon, MakeValue(epsilon)); } | ||||
| std::vector<int64_t> L2Normalize::get_axis() const { | |||||
| auto value_ptr = GetAttr(kAxis); | |||||
| return GetValue<std::vector<int64_t>>(value_ptr); | |||||
| } | |||||
| std::vector<int64_t> L2Normalize::get_axis() const { return GetValue<std::vector<int64_t>>(GetAttr(kAxis)); } | |||||
| float L2Normalize::get_epsilon() const { | float L2Normalize::get_epsilon() const { | ||||
| auto value_ptr = GetAttr(kEpsilon); | auto value_ptr = GetAttr(kEpsilon); | ||||
| @@ -42,9 +39,7 @@ float L2Normalize::get_epsilon() const { | |||||
| AbstractBasePtr L2NormalizeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr L2NormalizeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto prim = primitive->cast<PrimL2NormalizePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| auto prim_name = prim->name(); | |||||
| auto prim_name = primitive->name(); | |||||
| CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, prim_name); | CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, prim_name); | ||||
| for (const auto &item : input_args) { | for (const auto &item : input_args) { | ||||
| MS_EXCEPTION_IF_NULL(item); | MS_EXCEPTION_IF_NULL(item); | ||||
| @@ -53,7 +48,7 @@ AbstractBasePtr L2NormalizeInfer(const abstract::AnalysisEnginePtr &, const Prim | |||||
| (void)CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_args[0]->BuildType(), valid_types, prim_name); | (void)CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_args[0]->BuildType(), valid_types, prim_name); | ||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | ||||
| auto x_rank = SizeToLong(x_shape.size()); | auto x_rank = SizeToLong(x_shape.size()); | ||||
| auto axiss = prim->get_axis(); | |||||
| auto axiss = GetValue<std::vector<int64_t>>(primitive->GetAttr(kAxis)); | |||||
| for (auto &axis : axiss) { | for (auto &axis : axiss) { | ||||
| CheckAndConvertUtils::CheckInRange<int64_t>("axis", axis, kIncludeLeft, {-x_rank, x_rank}, prim_name); | CheckAndConvertUtils::CheckInRange<int64_t>("axis", axis, kIncludeLeft, {-x_rank, x_rank}, prim_name); | ||||
| } | } | ||||
| @@ -27,9 +27,7 @@ namespace ops { | |||||
| namespace { | namespace { | ||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto less_prim = primitive->cast<PrimLessPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(less_prim); | |||||
| auto op_name = less_prim->name(); | |||||
| auto op_name = primitive->name(); | |||||
| return BroadCastInferShape(op_name, input_args); | return BroadCastInferShape(op_name, input_args); | ||||
| } | } | ||||
| @@ -28,9 +28,7 @@ namespace ops { | |||||
| namespace { | namespace { | ||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto equal_prim = primitive->cast<PrimLessEqualPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(equal_prim); | |||||
| auto op_name = equal_prim->name(); | |||||
| auto op_name = primitive->name(); | |||||
| return BroadCastInferShape(op_name, input_args); | return BroadCastInferShape(op_name, input_args); | ||||
| } | } | ||||
| @@ -28,9 +28,7 @@ namespace ops { | |||||
| namespace { | namespace { | ||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto logicaland_prim = primitive->cast<PrimLogicalAndPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(logicaland_prim); | |||||
| auto op_name = logicaland_prim->name(); | |||||
| auto op_name = primitive->name(); | |||||
| return BroadCastInferShape(op_name, input_args); | return BroadCastInferShape(op_name, input_args); | ||||
| } | } | ||||
| @@ -24,18 +24,14 @@ namespace ops { | |||||
| namespace { | namespace { | ||||
| abstract::ShapePtr LogicalNotInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr LogicalNotInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto LogicalNot_prim = primitive->cast<PrimLogicalNotPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(LogicalNot_prim); | |||||
| auto op_name = LogicalNot_prim->name(); | |||||
| auto op_name = primitive->name(); | |||||
| auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), op_name); | auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), op_name); | ||||
| return std::make_shared<abstract::Shape>(in_shape); | return std::make_shared<abstract::Shape>(in_shape); | ||||
| } | } | ||||
| TypePtr LogicalNotInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | TypePtr LogicalNotInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(prim); | MS_EXCEPTION_IF_NULL(prim); | ||||
| auto LogicalNot_prim = prim->cast<PrimLogicalNotPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(LogicalNot_prim); | |||||
| auto op_name = LogicalNot_prim->name(); | |||||
| auto op_name = prim->name(); | |||||
| auto infer_dtype = input_args[0]->BuildType(); | auto infer_dtype = input_args[0]->BuildType(); | ||||
| std::set<TypePtr> local_bool = {kBool}; | std::set<TypePtr> local_bool = {kBool}; | ||||
| return CheckAndConvertUtils::CheckTensorTypeValid("x", infer_dtype, local_bool, op_name); | return CheckAndConvertUtils::CheckTensorTypeValid("x", infer_dtype, local_bool, op_name); | ||||
| @@ -29,9 +29,7 @@ namespace ops { | |||||
| namespace { | namespace { | ||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto logicalor_prim = primitive->cast<PrimLogicalOrPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(logicalor_prim); | |||||
| auto op_name = logicalor_prim->name(); | |||||
| auto op_name = primitive->name(); | |||||
| return BroadCastInferShape(op_name, input_args); | return BroadCastInferShape(op_name, input_args); | ||||
| } | } | ||||
| @@ -77,9 +77,7 @@ void LRN::Init(const int64_t depth_radius, const float bias, const float alpha, | |||||
| namespace { | namespace { | ||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto lrn_prim = primitive->cast<PrimLrn>(); | |||||
| MS_EXCEPTION_IF_NULL(lrn_prim); | |||||
| auto prim_name = lrn_prim->name(); | |||||
| auto prim_name = primitive->name(); | |||||
| auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | ||||
| CheckAndConvertUtils::CheckInteger("input shape", in_shape.size(), kEqual, 4, prim_name); | CheckAndConvertUtils::CheckInteger("input shape", in_shape.size(), kEqual, 4, prim_name); | ||||
| return std::make_shared<abstract::Shape>(in_shape); | return std::make_shared<abstract::Shape>(in_shape); | ||||
| @@ -26,20 +26,12 @@ void LshProjection::set_type(const LshProjectionType &type) { | |||||
| AddAttr(kType, MakeValue(swi)); | AddAttr(kType, MakeValue(swi)); | ||||
| } | } | ||||
| LshProjectionType LshProjection::get_type() const { | |||||
| auto value_ptr = GetAttr(kType); | |||||
| return LshProjectionType(GetValue<int64_t>(value_ptr)); | |||||
| } | |||||
| LshProjectionType LshProjection::get_type() const { return LshProjectionType(GetValue<int64_t>(GetAttr(kType))); } | |||||
| AbstractBasePtr LshProjectionInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr LshProjectionInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto LshProjection_prim = primitive->cast<PrimLshProjectionPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(LshProjection_prim); | |||||
| // if (input_args.size() != 2 && input_args.size() != 3) { | |||||
| // MS_LOG(ERROR) << "inputs to LshProjection operator should be 2 or 3, but " << input_args.size() << " is given."; | |||||
| // } | |||||
| auto op_name = LshProjection_prim->name(); | |||||
| auto op_name = primitive->name(); | |||||
| auto input0 = CheckAndConvertUtils::ConvertShapePtrToShape("input0_shape", input_args[0]->BuildShape(), op_name); | auto input0 = CheckAndConvertUtils::ConvertShapePtrToShape("input0_shape", input_args[0]->BuildShape(), op_name); | ||||
| auto input1 = CheckAndConvertUtils::ConvertShapePtrToShape("input1_shape", input_args[1]->BuildShape(), op_name); | auto input1 = CheckAndConvertUtils::ConvertShapePtrToShape("input1_shape", input_args[1]->BuildShape(), op_name); | ||||
| CheckAndConvertUtils::CheckInteger("input0_shape", input0.size(), kEqual, 2, op_name); | CheckAndConvertUtils::CheckInteger("input0_shape", input0.size(), kEqual, 2, op_name); | ||||
| @@ -53,7 +45,7 @@ AbstractBasePtr LshProjectionInfer(const abstract::AnalysisEnginePtr &, const Pr | |||||
| } | } | ||||
| std::vector<int64_t> out_shape; | std::vector<int64_t> out_shape; | ||||
| switch ((int64_t)LshProjection_prim->get_type()) { | |||||
| switch ((int64_t)LshProjectionType(GetValue<int64_t>(primitive->GetAttr(kType)))) { | |||||
| case (int64_t)LshProjectionType::SPARSE: | case (int64_t)LshProjectionType::SPARSE: | ||||
| out_shape.push_back(input0[0]); | out_shape.push_back(input0[0]); | ||||
| break; | break; | ||||
| @@ -19,28 +19,34 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ops { | namespace ops { | ||||
| namespace { | namespace { | ||||
| int64_t get_good_ld(const int64_t dim, const int64_t type_size) { | |||||
| int64_t ld = ((dim + (64 / type_size) - 1) / (64 / type_size)) * (64 / type_size); | |||||
| if (ld * 256 == 0) { | |||||
| return ld + 64 / type_size; | |||||
| } | |||||
| return ld; | |||||
| } | |||||
| AbstractBasePtr LstmInfer(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | AbstractBasePtr LstmInfer(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| // infer shape | // infer shape | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto lstm_prim = primitive->cast<PrimLstmPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(lstm_prim); | |||||
| auto prim_name = lstm_prim->name(); | |||||
| auto prim_name = primitive->name(); | |||||
| CheckAndConvertUtils::CheckInteger("lstm_prim_infer", input_args.size(), kEqual, 4, prim_name); | CheckAndConvertUtils::CheckInteger("lstm_prim_infer", input_args.size(), kEqual, 4, prim_name); | ||||
| auto x_input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | auto x_input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | ||||
| auto h_input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("h_shape", input_args[1]->BuildShape(), prim_name); | auto h_input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("h_shape", input_args[1]->BuildShape(), prim_name); | ||||
| auto c_input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("c_shape", input_args[2]->BuildShape(), prim_name); | auto c_input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("c_shape", input_args[2]->BuildShape(), prim_name); | ||||
| int64_t input_x_size = lstm_prim->get_input_size(); | |||||
| int64_t input_x_size = GetValue<int64_t>(primitive->GetAttr(kInput_size)); | |||||
| CheckAndConvertUtils::CheckInteger("x_shape.size()", x_input_shape.size(), kEqual, 3, prim_name); | CheckAndConvertUtils::CheckInteger("x_shape.size()", x_input_shape.size(), kEqual, 3, prim_name); | ||||
| CheckAndConvertUtils::CheckInteger("x_shape[2]", x_input_shape[2], kEqual, input_x_size, prim_name); | CheckAndConvertUtils::CheckInteger("x_shape[2]", x_input_shape[2], kEqual, input_x_size, prim_name); | ||||
| CheckAndConvertUtils::CheckInteger("h_shape.size()", h_input_shape.size(), kEqual, 3, prim_name); | CheckAndConvertUtils::CheckInteger("h_shape.size()", h_input_shape.size(), kEqual, 3, prim_name); | ||||
| CheckAndConvertUtils::Check("h_shape", h_input_shape, kEqual, "c_shape", c_input_shape, lstm_prim->name()); | |||||
| CheckAndConvertUtils::Check("h_shape", h_input_shape, kEqual, "c_shape", c_input_shape, prim_name); | |||||
| int64_t num_layers = lstm_prim->get_num_layers(); | |||||
| int64_t num_directions = lstm_prim->get_num_directions(); | |||||
| int64_t hidden_size = lstm_prim->get_hidden_size(); | |||||
| int64_t input_size = lstm_prim->get_input_size(); | |||||
| int64_t num_layers = GetValue<int64_t>(primitive->GetAttr(kNumLayers)); | |||||
| int64_t num_directions = GetValue<int64_t>(primitive->GetAttr(kNumDirections)); | |||||
| int64_t hidden_size = GetValue<int64_t>(primitive->GetAttr(kHidden_size)); | |||||
| int64_t input_size = input_x_size; | |||||
| CheckAndConvertUtils::CheckInteger("h_shape[0]", h_input_shape[0], kEqual, num_layers * num_directions, prim_name); | CheckAndConvertUtils::CheckInteger("h_shape[0]", h_input_shape[0], kEqual, num_layers * num_directions, prim_name); | ||||
| CheckAndConvertUtils::CheckInteger("h_shape[1]", h_input_shape[1], kEqual, x_input_shape[1], prim_name); | CheckAndConvertUtils::CheckInteger("h_shape[1]", h_input_shape[1], kEqual, x_input_shape[1], prim_name); | ||||
| CheckAndConvertUtils::CheckInteger("h_shape[2]", h_input_shape[2], kEqual, hidden_size, prim_name); | CheckAndConvertUtils::CheckInteger("h_shape[2]", h_input_shape[2], kEqual, hidden_size, prim_name); | ||||
| @@ -48,8 +54,8 @@ AbstractBasePtr LstmInfer(const PrimitivePtr &primitive, const std::vector<Abstr | |||||
| std::vector<int64_t> y_shape = {x_input_shape[0], x_input_shape[1], hidden_size * num_directions}; | std::vector<int64_t> y_shape = {x_input_shape[0], x_input_shape[1], hidden_size * num_directions}; | ||||
| int64_t type_size = 4; | int64_t type_size = 4; | ||||
| int64_t gates_ws_ld = lstm_prim->get_good_ld(hidden_size * 4, type_size); | |||||
| int64_t states_ws_ld = lstm_prim->get_good_ld(std::max(hidden_size, input_size), type_size); | |||||
| int64_t gates_ws_ld = get_good_ld(hidden_size * 4, type_size); | |||||
| int64_t states_ws_ld = get_good_ld(std::max(hidden_size, input_size), type_size); | |||||
| int64_t ws_gates_size = num_layers * num_directions * x_input_shape[0] * x_input_shape[1] * gates_ws_ld * type_size; | int64_t ws_gates_size = num_layers * num_directions * x_input_shape[0] * x_input_shape[1] * gates_ws_ld * type_size; | ||||
| int64_t ws_states_size = | int64_t ws_states_size = | ||||
| (num_layers + 1) * num_directions * (x_input_shape[0] + 1) * x_input_shape[1] * states_ws_ld * type_size; | (num_layers + 1) * num_directions * (x_input_shape[0] + 1) * x_input_shape[1] * states_ws_ld * type_size; | ||||
| @@ -99,26 +105,17 @@ void LSTM::set_input_size(const int64_t input_size) { | |||||
| CheckAndConvertUtils::CheckInteger(kInput_size, input_size, kGreaterThan, 0, this->name()); | CheckAndConvertUtils::CheckInteger(kInput_size, input_size, kGreaterThan, 0, this->name()); | ||||
| AddAttr(kInput_size, MakeValue(input_size)); | AddAttr(kInput_size, MakeValue(input_size)); | ||||
| } | } | ||||
| int64_t LSTM::get_input_size() const { | |||||
| auto value_ptr = this->GetAttr(kInput_size); | |||||
| return GetValue<int64_t>(value_ptr); | |||||
| } | |||||
| int64_t LSTM::get_input_size() const { return GetValue<int64_t>(GetAttr(kInput_size)); } | |||||
| void LSTM::set_hidden_size(const int64_t hidden_size) { | void LSTM::set_hidden_size(const int64_t hidden_size) { | ||||
| CheckAndConvertUtils::CheckInteger(kHidden_size, hidden_size, kGreaterThan, 0, this->name()); | CheckAndConvertUtils::CheckInteger(kHidden_size, hidden_size, kGreaterThan, 0, this->name()); | ||||
| AddAttr(kHidden_size, MakeValue(hidden_size)); | AddAttr(kHidden_size, MakeValue(hidden_size)); | ||||
| } | } | ||||
| int64_t LSTM::get_hidden_size() const { | |||||
| auto value_ptr = this->GetAttr(kHidden_size); | |||||
| return GetValue<int64_t>(value_ptr); | |||||
| } | |||||
| int64_t LSTM::get_hidden_size() const { return GetValue<int64_t>(GetAttr(kHidden_size)); } | |||||
| void LSTM::set_num_layers(const int64_t num_layers) { | void LSTM::set_num_layers(const int64_t num_layers) { | ||||
| CheckAndConvertUtils::CheckInteger(kNumLayers, num_layers, kGreaterThan, 0, this->name()); | CheckAndConvertUtils::CheckInteger(kNumLayers, num_layers, kGreaterThan, 0, this->name()); | ||||
| AddAttr(kNumLayers, MakeValue(num_layers)); | AddAttr(kNumLayers, MakeValue(num_layers)); | ||||
| } | } | ||||
| int64_t LSTM::get_num_layers() const { | |||||
| auto value_ptr = this->GetAttr(kNumLayers); | |||||
| return GetValue<int64_t>(value_ptr); | |||||
| } | |||||
| int64_t LSTM::get_num_layers() const { return GetValue<int64_t>(GetAttr(kNumLayers)); } | |||||
| void LSTM::set_has_bias(const bool has_bias) { AddAttr(kHasBias, MakeValue(has_bias)); } | void LSTM::set_has_bias(const bool has_bias) { AddAttr(kHasBias, MakeValue(has_bias)); } | ||||
| bool LSTM::get_has_bias() const { | bool LSTM::get_has_bias() const { | ||||
| auto value_ptr = this->GetAttr(kHasBias); | auto value_ptr = this->GetAttr(kHasBias); | ||||
| @@ -138,10 +135,7 @@ bool LSTM::get_bidirectional() const { | |||||
| return GetValue<bool>(value_ptr); | return GetValue<bool>(value_ptr); | ||||
| } | } | ||||
| void LSTM::set_num_directions(const int64_t num_directions) { AddAttr(kNumDirections, MakeValue(num_directions)); } | void LSTM::set_num_directions(const int64_t num_directions) { AddAttr(kNumDirections, MakeValue(num_directions)); } | ||||
| int64_t LSTM::get_num_directions() const { | |||||
| auto value_ptr = this->GetAttr(kNumDirections); | |||||
| return GetValue<int64_t>(value_ptr); | |||||
| } | |||||
| int64_t LSTM::get_num_directions() const { return GetValue<int64_t>(GetAttr(kNumDirections)); } | |||||
| void LSTM::set_zoneout_cell(float zoneout_cell) { AddAttr(kZoneoutCell, MakeValue(zoneout_cell)); } | void LSTM::set_zoneout_cell(float zoneout_cell) { AddAttr(kZoneoutCell, MakeValue(zoneout_cell)); } | ||||
| float LSTM::get_zoneout_cell() const { return GetValue<float>(this->GetAttr(kZoneoutCell)); } | float LSTM::get_zoneout_cell() const { return GetValue<float>(this->GetAttr(kZoneoutCell)); } | ||||
| @@ -167,14 +161,6 @@ void LSTM::Init(const int64_t input_size, const int64_t hidden_size, const int64 | |||||
| this->set_zoneout_hidden(zoneout_hidden); | this->set_zoneout_hidden(zoneout_hidden); | ||||
| } | } | ||||
| int64_t LSTM::get_good_ld(const int64_t dim, const int64_t type_size) { | |||||
| int64_t ld = ((dim + (64 / type_size) - 1) / (64 / type_size)) * (64 / type_size); | |||||
| if (ld * 256 == 0) { | |||||
| return ld + 64 / type_size; | |||||
| } | |||||
| return ld; | |||||
| } | |||||
| AbstractBasePtr LstmInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr LstmInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| return std::make_shared<abstract::AbstractTensor>(LstmInfer(primitive, input_args)); | return std::make_shared<abstract::AbstractTensor>(LstmInfer(primitive, input_args)); | ||||
| @@ -29,9 +29,7 @@ namespace ops { | |||||
| namespace { | namespace { | ||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto matrixdiag_prim = primitive->cast<PrimMatrixDiagPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(matrixdiag_prim); | |||||
| auto prim_name = matrixdiag_prim->name(); | |||||
| auto prim_name = primitive->name(); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | ||||
| auto assist_shape = | auto assist_shape = | ||||
| CheckAndConvertUtils::ConvertShapePtrToShape("assist_shape", input_args[1]->BuildShape(), prim_name); | CheckAndConvertUtils::ConvertShapePtrToShape("assist_shape", input_args[1]->BuildShape(), prim_name); | ||||
| @@ -31,37 +31,25 @@ void MaxPool::set_pad_mode(const PadMode &pad_mode) { | |||||
| this->AddAttr(kPadMode, MakeValue(swi)); | this->AddAttr(kPadMode, MakeValue(swi)); | ||||
| } | } | ||||
| PadMode MaxPool::get_pad_mode() const { | |||||
| auto value_ptr = GetAttr(kPadMode); | |||||
| return PadMode(GetValue<int64_t>(value_ptr)); | |||||
| } | |||||
| PadMode MaxPool::get_pad_mode() const { return PadMode(GetValue<int64_t>(GetAttr(kPadMode))); } | |||||
| void MaxPool::set_kernel_size(const std::vector<int64_t> &kernel_size) { | void MaxPool::set_kernel_size(const std::vector<int64_t> &kernel_size) { | ||||
| this->AddAttr(kKernelSize, | this->AddAttr(kKernelSize, | ||||
| MakeValue(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, this->name()))); | MakeValue(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, this->name()))); | ||||
| } | } | ||||
| std::vector<int64_t> MaxPool::get_kernel_size() const { | |||||
| auto value_ptr = GetAttr(kKernelSize); | |||||
| return GetValue<std::vector<int64_t>>(value_ptr); | |||||
| } | |||||
| std::vector<int64_t> MaxPool::get_kernel_size() const { return GetValue<std::vector<int64_t>>(GetAttr(kKernelSize)); } | |||||
| void MaxPool::set_strides(const std::vector<int64_t> &strides) { | void MaxPool::set_strides(const std::vector<int64_t> &strides) { | ||||
| this->AddAttr(kStrides, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStrides, strides, this->name()))); | this->AddAttr(kStrides, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStrides, strides, this->name()))); | ||||
| } | } | ||||
| std::vector<int64_t> MaxPool::get_strides() const { | |||||
| auto value_ptr = GetAttr(kStrides); | |||||
| return GetValue<std::vector<int64_t>>(value_ptr); | |||||
| } | |||||
| std::vector<int64_t> MaxPool::get_strides() const { return GetValue<std::vector<int64_t>>(GetAttr(kStrides)); } | |||||
| void MaxPool::set_format(const Format &format) { | void MaxPool::set_format(const Format &format) { | ||||
| int64_t f = format; | int64_t f = format; | ||||
| this->AddAttr(kFormat, MakeValue(f)); | this->AddAttr(kFormat, MakeValue(f)); | ||||
| } | } | ||||
| Format MaxPool::get_format() const { | |||||
| auto value_ptr = GetAttr(kFormat); | |||||
| return Format(GetValue<int64_t>(value_ptr)); | |||||
| } | |||||
| Format MaxPool::get_format() const { return Format(GetValue<int64_t>(GetAttr(kFormat))); } | |||||
| void MaxPool::set_pad(const std::vector<int64_t> &pad) { this->AddAttr(kPad, MakeValue(pad)); } | void MaxPool::set_pad(const std::vector<int64_t> &pad) { this->AddAttr(kPad, MakeValue(pad)); } | ||||
| @@ -25,9 +25,7 @@ namespace ops { | |||||
| namespace { | namespace { | ||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto maximum_prim = primitive->cast<PrimMaximumPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(maximum_prim); | |||||
| auto op_name = maximum_prim->name(); | |||||
| auto op_name = primitive->name(); | |||||
| return BroadCastInferShape(op_name, input_args); | return BroadCastInferShape(op_name, input_args); | ||||
| } | } | ||||
| @@ -28,9 +28,7 @@ namespace ops { | |||||
| AbstractBasePtr MergeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr MergeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto Merge_prim = primitive->cast<PrimMergePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(Merge_prim); | |||||
| auto op_name = Merge_prim->name(); | |||||
| auto op_name = primitive->name(); | |||||
| auto inputs_type = input_args[0]->BuildType()->cast<TuplePtr>()->elements(); | auto inputs_type = input_args[0]->BuildType()->cast<TuplePtr>()->elements(); | ||||
| auto inputs_shape = input_args[0]->BuildShape()->cast<abstract::TupleShapePtr>()->shape(); | auto inputs_shape = input_args[0]->BuildShape()->cast<abstract::TupleShapePtr>()->shape(); | ||||
| std::map<std::string, TypePtr> args; | std::map<std::string, TypePtr> args; | ||||
| @@ -24,16 +24,15 @@ namespace ops { | |||||
| namespace { | namespace { | ||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto mfcc_prim = primitive->cast<PrimMfccPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(mfcc_prim); | |||||
| auto prim_name = mfcc_prim->name(); | |||||
| auto prim_name = primitive->name(); | |||||
| auto first_input_shape = | auto first_input_shape = | ||||
| CheckAndConvertUtils::ConvertShapePtrToShape("first_input_shape", input_args[0]->BuildShape(), prim_name); | CheckAndConvertUtils::ConvertShapePtrToShape("first_input_shape", input_args[0]->BuildShape(), prim_name); | ||||
| auto second_input_shape = | auto second_input_shape = | ||||
| CheckAndConvertUtils::ConvertShapePtrToShape("second_input_shape", input_args[1]->BuildShape(), prim_name); | CheckAndConvertUtils::ConvertShapePtrToShape("second_input_shape", input_args[1]->BuildShape(), prim_name); | ||||
| CheckAndConvertUtils::CheckInteger("first input rank", first_input_shape.size(), kEqual, 3, prim_name); | CheckAndConvertUtils::CheckInteger("first input rank", first_input_shape.size(), kEqual, 3, prim_name); | ||||
| CheckAndConvertUtils::CheckInteger("second input rank", second_input_shape.size(), kEqual, 1, prim_name); | CheckAndConvertUtils::CheckInteger("second input rank", second_input_shape.size(), kEqual, 1, prim_name); | ||||
| std::vector<int64_t> out_shape = {first_input_shape[0], first_input_shape[1], mfcc_prim->get_dct_coeff_num()}; | |||||
| std::vector<int64_t> out_shape = {first_input_shape[0], first_input_shape[1], | |||||
| GetValue<int64_t>(primitive->GetAttr(kDctCoeffNum))}; | |||||
| return std::make_shared<abstract::Shape>(out_shape); | return std::make_shared<abstract::Shape>(out_shape); | ||||
| } | } | ||||
| @@ -83,10 +82,7 @@ int64_t Mfcc::get_filter_bank_channel_num() const { | |||||
| void Mfcc::set_dct_coeff_num(const int64_t dct_coeff_num) { this->AddAttr(kDctCoeffNum, MakeValue(dct_coeff_num)); } | void Mfcc::set_dct_coeff_num(const int64_t dct_coeff_num) { this->AddAttr(kDctCoeffNum, MakeValue(dct_coeff_num)); } | ||||
| int64_t Mfcc::get_dct_coeff_num() const { | |||||
| auto value_ptr = this->GetAttr(kDctCoeffNum); | |||||
| return GetValue<int64_t>(value_ptr); | |||||
| } | |||||
| int64_t Mfcc::get_dct_coeff_num() const { return GetValue<int64_t>(GetAttr(kDctCoeffNum)); } | |||||
| AbstractBasePtr MfccInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr MfccInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| @@ -31,9 +31,6 @@ void NonMaxSuppression::Init(const int64_t center_point_box) { this->set_center_ | |||||
| AbstractBasePtr NonMaxSuppressionInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr NonMaxSuppressionInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| auto non_max_suppression_prim = primitive->cast<PrimNonMaxSuppressionPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(non_max_suppression_prim); | |||||
| MS_LOG(INFO) << "NonMaxSuppression infer shape in runtime."; | MS_LOG(INFO) << "NonMaxSuppression infer shape in runtime."; | ||||
| return std::make_shared<abstract::AbstractTensor>(kInt32, std::vector<int64_t>{}); | return std::make_shared<abstract::AbstractTensor>(kInt32, std::vector<int64_t>{}); | ||||
| } | } | ||||
| @@ -25,17 +25,12 @@ namespace ops { | |||||
| void OneHot::Init(const int64_t axis) { this->set_axis(axis); } | void OneHot::Init(const int64_t axis) { this->set_axis(axis); } | ||||
| void OneHot::set_axis(const int64_t axis) { this->AddAttr(kAxis, MakeValue(axis)); } | void OneHot::set_axis(const int64_t axis) { this->AddAttr(kAxis, MakeValue(axis)); } | ||||
| int64_t OneHot::get_axis() const { | |||||
| auto value_ptr = this->GetAttr(kAxis); | |||||
| return GetValue<int64_t>(value_ptr); | |||||
| } | |||||
| int64_t OneHot::get_axis() const { return GetValue<int64_t>(GetAttr(kAxis)); } | |||||
| namespace { | namespace { | ||||
| abstract::ShapePtr OneHotInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr OneHotInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto OneHot_prim = primitive->cast<PrimOneHotPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(OneHot_prim); | |||||
| auto op_name = OneHot_prim->name(); | |||||
| int64_t axis = OneHot_prim->get_axis(); | |||||
| auto op_name = primitive->name(); | |||||
| int64_t axis = GetValue<int64_t>(primitive->GetAttr(kAxis)); | |||||
| auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), op_name); | auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), op_name); | ||||
| CheckAndConvertUtils::CheckInRange<int64_t>("axis", axis, kIncludeBoth, {-1, SizeToLong(in_shape.size())}, op_name); | CheckAndConvertUtils::CheckInRange<int64_t>("axis", axis, kIncludeBoth, {-1, SizeToLong(in_shape.size())}, op_name); | ||||
| auto depth_val = GetValue<int64_t>(input_args[1]->BuildValue()); | auto depth_val = GetValue<int64_t>(input_args[1]->BuildValue()); | ||||
| @@ -50,9 +45,7 @@ abstract::ShapePtr OneHotInferShape(const PrimitivePtr &primitive, const std::ve | |||||
| TypePtr OneHotInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | TypePtr OneHotInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(prim); | MS_EXCEPTION_IF_NULL(prim); | ||||
| auto OneHot_prim = prim->cast<PrimOneHotPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(OneHot_prim); | |||||
| auto op_name = OneHot_prim->name(); | |||||
| auto op_name = prim->name(); | |||||
| CheckAndConvertUtils::CheckTensorTypeValid("indices", input_args[0]->BuildType(), {kInt32}, op_name); | CheckAndConvertUtils::CheckTensorTypeValid("indices", input_args[0]->BuildType(), {kInt32}, op_name); | ||||
| CheckAndConvertUtils::CheckTypeValid("depth", input_args[1]->BuildType(), {kInt8, kInt16, kInt32, kInt64}, op_name); | CheckAndConvertUtils::CheckTypeValid("depth", input_args[1]->BuildType(), {kInt8, kInt16, kInt32, kInt64}, op_name); | ||||
| std::map<std::string, TypePtr> args = {{"on_value", input_args[2]->BuildType()}, | std::map<std::string, TypePtr> args = {{"on_value", input_args[2]->BuildType()}, | ||||
| @@ -28,9 +28,7 @@ namespace ops { | |||||
| namespace { | namespace { | ||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto OnesLike_prim = primitive->cast<PrimOnesLikePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(OnesLike_prim); | |||||
| auto prim_name = OnesLike_prim->name(); | |||||
| auto prim_name = primitive->name(); | |||||
| auto input_shape = | auto input_shape = | ||||
| CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name); | CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name); | ||||
| return std::make_shared<abstract::Shape>(input_shape); | return std::make_shared<abstract::Shape>(input_shape); | ||||
| @@ -50,23 +50,18 @@ std::vector<int64_t> _get_pack_shape(std::vector<BaseShapePtr> x_shapes, std::ve | |||||
| void Pack::set_axis(const int64_t &axis) { AddAttr(kAxis, MakeValue(axis)); } | void Pack::set_axis(const int64_t &axis) { AddAttr(kAxis, MakeValue(axis)); } | ||||
| int64_t Pack::get_axis() const { | |||||
| auto value_ptr = this->GetAttr(kAxis); | |||||
| return GetValue<int64_t>(value_ptr); | |||||
| } | |||||
| int64_t Pack::get_axis() const { return GetValue<int64_t>(GetAttr(kAxis)); } | |||||
| void Pack::Init(const int64_t &axis) { this->set_axis(axis); } | void Pack::Init(const int64_t &axis) { this->set_axis(axis); } | ||||
| AbstractBasePtr PackInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr PackInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto pack_prim = primitive->cast<PrimPackPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(pack_prim); | |||||
| auto prim_name = pack_prim->name(); | |||||
| auto prim_name = primitive->name(); | |||||
| auto x_shapes = input_args[0]->BuildShape()->cast<abstract::TupleShapePtr>()->shape(); | auto x_shapes = input_args[0]->BuildShape()->cast<abstract::TupleShapePtr>()->shape(); | ||||
| auto x_types = input_args[0]->BuildType()->cast<TuplePtr>()->elements(); | auto x_types = input_args[0]->BuildType()->cast<TuplePtr>()->elements(); | ||||
| auto all_shape = _get_pack_shape(x_shapes, x_types, pack_prim->get_axis(), prim_name); | |||||
| auto all_shape = _get_pack_shape(x_shapes, x_types, GetValue<int64_t>(primitive->GetAttr(kAxis)), prim_name); | |||||
| auto tensor_type = x_types[0]->cast<TensorTypePtr>(); | auto tensor_type = x_types[0]->cast<TensorTypePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(tensor_type); | MS_EXCEPTION_IF_NULL(tensor_type); | ||||
| auto data_type = tensor_type->element(); | auto data_type = tensor_type->element(); | ||||
| @@ -23,10 +23,8 @@ namespace ops { | |||||
| namespace { | namespace { | ||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto pad_prim = primitive->cast<PrimPadPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(pad_prim); | |||||
| auto prim_name = pad_prim->name(); | |||||
| auto paddings_attr = pad_prim->get_paddings(); | |||||
| auto prim_name = primitive->name(); | |||||
| auto paddings_attr = GetValue<std::vector<std::vector<int64_t>>>(primitive->GetAttr(kPaddings)); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), "Pad"); | auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), "Pad"); | ||||
| CheckAndConvertUtils::CheckInteger("paddings_size", paddings_attr.size(), kEqual, int64_t(2 * x_shape.size()), | CheckAndConvertUtils::CheckInteger("paddings_size", paddings_attr.size(), kEqual, int64_t(2 * x_shape.size()), | ||||
| prim_name); | prim_name); | ||||
| @@ -59,8 +57,7 @@ void Pad::set_paddings(const std::vector<std::vector<int64_t>> &paddings) { | |||||
| this->AddAttr(kPaddings, MakeValue(paddings)); | this->AddAttr(kPaddings, MakeValue(paddings)); | ||||
| } | } | ||||
| std::vector<std::vector<int64_t>> Pad::get_paddings() const { | std::vector<std::vector<int64_t>> Pad::get_paddings() const { | ||||
| auto value_ptr = GetAttr(kPaddings); | |||||
| return GetValue<std::vector<std::vector<int64_t>>>(value_ptr); | |||||
| return GetValue<std::vector<std::vector<int64_t>>>(GetAttr(kPaddings)); | |||||
| } | } | ||||
| AbstractBasePtr PadInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr PadInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| @@ -24,9 +24,7 @@ namespace ops { | |||||
| namespace { | namespace { | ||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto pow_prim = primitive->cast<PrimPowPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(pow_prim); | |||||
| auto op_name = pow_prim->name(); | |||||
| auto op_name = primitive->name(); | |||||
| return BroadCastInferShape(op_name, input_args); | return BroadCastInferShape(op_name, input_args); | ||||
| } | } | ||||
| @@ -24,10 +24,7 @@ namespace mindspore { | |||||
| namespace ops { | namespace ops { | ||||
| void PriorBox::set_min_sizes(const std::vector<int64_t> &min_sizes) { this->AddAttr(kMinSizes, MakeValue(min_sizes)); } | void PriorBox::set_min_sizes(const std::vector<int64_t> &min_sizes) { this->AddAttr(kMinSizes, MakeValue(min_sizes)); } | ||||
| std::vector<int64_t> PriorBox::get_min_sizes() const { | |||||
| auto value_ptr = GetAttr(kMinSizes); | |||||
| return GetValue<std::vector<int64_t>>(value_ptr); | |||||
| } | |||||
| std::vector<int64_t> PriorBox::get_min_sizes() const { return GetValue<std::vector<int64_t>>(GetAttr(kMinSizes)); } | |||||
| void PriorBox::set_max_sizes(const std::vector<int64_t> &max_sizes) { this->AddAttr(kMaxSizes, MakeValue(max_sizes)); } | void PriorBox::set_max_sizes(const std::vector<int64_t> &max_sizes) { this->AddAttr(kMaxSizes, MakeValue(max_sizes)); } | ||||
| @@ -40,10 +37,7 @@ void PriorBox::set_aspect_ratios(const std::vector<float> &aspect_ratios) { | |||||
| this->AddAttr(kAspectRatios, MakeValue(aspect_ratios)); | this->AddAttr(kAspectRatios, MakeValue(aspect_ratios)); | ||||
| } | } | ||||
| std::vector<float> PriorBox::get_aspect_ratios() const { | |||||
| auto value_ptr = GetAttr(kAspectRatios); | |||||
| return GetValue<std::vector<float>>(value_ptr); | |||||
| } | |||||
| std::vector<float> PriorBox::get_aspect_ratios() const { return GetValue<std::vector<float>>(GetAttr(kAspectRatios)); } | |||||
| void PriorBox::set_variances(const std::vector<float> &variances) { this->AddAttr(kVariances, MakeValue(variances)); } | void PriorBox::set_variances(const std::vector<float> &variances) { this->AddAttr(kVariances, MakeValue(variances)); } | ||||
| @@ -89,10 +83,7 @@ bool PriorBox::get_clip() const { | |||||
| void PriorBox::set_flip(const bool flip) { this->AddAttr(kFlip, MakeValue(flip)); } | void PriorBox::set_flip(const bool flip) { this->AddAttr(kFlip, MakeValue(flip)); } | ||||
| bool PriorBox::get_flip() const { | |||||
| auto value_ptr = GetAttr(kFlip); | |||||
| return GetValue<bool>(value_ptr); | |||||
| } | |||||
| bool PriorBox::get_flip() const { return GetValue<bool>(GetAttr(kFlip)); } | |||||
| void PriorBox::set_offset(const float offset) { this->AddAttr(kOffset, MakeValue(offset)); } | void PriorBox::set_offset(const float offset) { this->AddAttr(kOffset, MakeValue(offset)); } | ||||
| @@ -121,25 +112,23 @@ void PriorBox::Init(const std::vector<int64_t> &min_sizes, const std::vector<int | |||||
| AbstractBasePtr PriorBoxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr PriorBoxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto PriorBox_prim = primitive->cast<PrimPriorBoxPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(PriorBox_prim); | |||||
| auto op_name = PriorBox_prim->name(); | |||||
| auto op_name = primitive->name(); | |||||
| MS_EXCEPTION_IF_NULL(input_args[0]); | MS_EXCEPTION_IF_NULL(input_args[0]); | ||||
| std::vector<float> different_aspect_ratios{1.0f}; | std::vector<float> different_aspect_ratios{1.0f}; | ||||
| auto aspect_ratios = PriorBox_prim->get_aspect_ratios(); | |||||
| auto aspect_ratios = GetValue<std::vector<float>>(primitive->GetAttr(kAspectRatios)); | |||||
| for (int64_t i = 0; i < (int64_t)aspect_ratios.size(); i++) { | for (int64_t i = 0; i < (int64_t)aspect_ratios.size(); i++) { | ||||
| float ratio = aspect_ratios[i]; | float ratio = aspect_ratios[i]; | ||||
| bool exist = std::any_of(different_aspect_ratios.begin(), different_aspect_ratios.end(), | bool exist = std::any_of(different_aspect_ratios.begin(), different_aspect_ratios.end(), | ||||
| [&](float v) { return abs(ratio - v) < 1e-6; }); | [&](float v) { return abs(ratio - v) < 1e-6; }); | ||||
| if (!exist) { | if (!exist) { | ||||
| different_aspect_ratios.emplace_back(ratio); | different_aspect_ratios.emplace_back(ratio); | ||||
| if (PriorBox_prim->get_flip()) { | |||||
| if (GetValue<bool>(primitive->GetAttr(kFlip))) { | |||||
| different_aspect_ratios.emplace_back(1.0f / ratio); | different_aspect_ratios.emplace_back(1.0f / ratio); | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| int64_t num_priors_box = | |||||
| PriorBox_prim->get_min_sizes().size() * different_aspect_ratios.size() + PriorBox_prim->get_max_sizes().size(); | |||||
| auto min_sizes = GetValue<std::vector<int64_t>>(primitive->GetAttr(kMinSizes)); | |||||
| int64_t num_priors_box = min_sizes.size() * different_aspect_ratios.size() + min_sizes.size(); | |||||
| auto input = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), op_name); | auto input = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), op_name); | ||||
| int64_t h = input[0] * input[1] * num_priors_box * 4; | int64_t h = input[0] * input[1] * num_priors_box * 4; | ||||
| std::vector<int64_t> output_shape{1, h, 1, 2}; | std::vector<int64_t> output_shape{1, h, 1, 2}; | ||||
| @@ -24,10 +24,7 @@ int64_t QuantDTypeCast::get_src_t() const { | |||||
| return GetValue<int64_t>(value_ptr); | return GetValue<int64_t>(value_ptr); | ||||
| } | } | ||||
| void QuantDTypeCast::set_dst_t(const int64_t dst_t) { AddAttr(kDstT, MakeValue(dst_t)); } | void QuantDTypeCast::set_dst_t(const int64_t dst_t) { AddAttr(kDstT, MakeValue(dst_t)); } | ||||
| int64_t QuantDTypeCast::get_dst_t() const { | |||||
| auto value_ptr = this->GetAttr(kDstT); | |||||
| return GetValue<int64_t>(value_ptr); | |||||
| } | |||||
| int64_t QuantDTypeCast::get_dst_t() const { return GetValue<int64_t>(GetAttr(kDstT)); } | |||||
| void QuantDTypeCast::Init(const int64_t src_t, const int64_t dst_t) { | void QuantDTypeCast::Init(const int64_t src_t, const int64_t dst_t) { | ||||
| this->set_src_t(src_t); | this->set_src_t(src_t); | ||||
| this->set_dst_t(dst_t); | this->set_dst_t(dst_t); | ||||
| @@ -35,16 +32,14 @@ void QuantDTypeCast::Init(const int64_t src_t, const int64_t dst_t) { | |||||
| AbstractBasePtr QuantDTypeCastInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr QuantDTypeCastInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto QuantDTypeCast_prim = primitive->cast<PrimQuantDTypeCastPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(QuantDTypeCast_prim); | |||||
| auto op_name = QuantDTypeCast_prim->name(); | |||||
| auto op_name = primitive->name(); | |||||
| MS_EXCEPTION_IF_NULL(input_args[0]); | MS_EXCEPTION_IF_NULL(input_args[0]); | ||||
| auto input_type = input_args[0]->BuildType()->cast<TensorTypePtr>(); | auto input_type = input_args[0]->BuildType()->cast<TensorTypePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(input_type); | MS_EXCEPTION_IF_NULL(input_type); | ||||
| MS_ASSERT(input_type->element() == TypeIdToType(TypeId(QuantDTypeCast_prim->get_dst_t()))); | |||||
| auto dst_type = GetValue<int64_t>(primitive->GetAttr(kDstT)); | |||||
| MS_ASSERT(input_type->element() == TypeIdToType(TypeId(dst_type))); | |||||
| auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), op_name); | auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), op_name); | ||||
| return std::make_shared<abstract::AbstractTensor>(TypeIdToType(TypeId(QuantDTypeCast_prim->get_dst_t())), | |||||
| input_shape); | |||||
| return std::make_shared<abstract::AbstractTensor>(TypeIdToType(TypeId(dst_type)), input_shape); | |||||
| } | } | ||||
| REGISTER_PRIMITIVE_C(kNameQuantDTypeCast, QuantDTypeCast); | REGISTER_PRIMITIVE_C(kNameQuantDTypeCast, QuantDTypeCast); | ||||
| } // namespace ops | } // namespace ops | ||||
| @@ -34,10 +34,7 @@ int64_t Range::get_d_type() const { | |||||
| void Range::set_start(const int64_t start) { this->AddAttr(kStart, MakeValue(start)); } | void Range::set_start(const int64_t start) { this->AddAttr(kStart, MakeValue(start)); } | ||||
| int64_t Range::get_start() const { | |||||
| auto value_ptr = GetAttr(kStart); | |||||
| return GetValue<int64_t>(value_ptr); | |||||
| } | |||||
| int64_t Range::get_start() const { return GetValue<int64_t>(GetAttr(kStart)); } | |||||
| void Range::set_limit(const int64_t limit) { this->AddAttr(kLimit, MakeValue(limit)); } | void Range::set_limit(const int64_t limit) { this->AddAttr(kLimit, MakeValue(limit)); } | ||||
| @@ -63,10 +60,7 @@ void Range::Init(const int64_t d_type, const int64_t start, const int64_t limit, | |||||
| AbstractBasePtr RangeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr RangeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto prim = primitive->cast<PrimRangePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| int64_t shape_size = 0; | int64_t shape_size = 0; | ||||
| TypeId dtype; | |||||
| if (input_args.size() == 3) { | if (input_args.size() == 3) { | ||||
| MS_EXCEPTION_IF_NULL(input_args[0]->BuildValue()); | MS_EXCEPTION_IF_NULL(input_args[0]->BuildValue()); | ||||
| MS_EXCEPTION_IF_NULL(input_args[1]->BuildValue()); | MS_EXCEPTION_IF_NULL(input_args[1]->BuildValue()); | ||||
| @@ -74,7 +68,7 @@ AbstractBasePtr RangeInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP | |||||
| auto start_tensor = input_args[0]->BuildValue()->cast<tensor::TensorPtr>(); | auto start_tensor = input_args[0]->BuildValue()->cast<tensor::TensorPtr>(); | ||||
| auto limit_tensor = input_args[1]->BuildValue()->cast<tensor::TensorPtr>(); | auto limit_tensor = input_args[1]->BuildValue()->cast<tensor::TensorPtr>(); | ||||
| auto delta_tensor = input_args[2]->BuildValue()->cast<tensor::TensorPtr>(); | auto delta_tensor = input_args[2]->BuildValue()->cast<tensor::TensorPtr>(); | ||||
| dtype = static_cast<TypeId>(start_tensor->data_type_c()); | |||||
| auto dtype = start_tensor->data_type(); | |||||
| switch (dtype) { | switch (dtype) { | ||||
| case kNumberTypeInt: | case kNumberTypeInt: | ||||
| case kNumberTypeInt32: { | case kNumberTypeInt32: { | ||||
| @@ -97,9 +91,9 @@ AbstractBasePtr RangeInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP | |||||
| } | } | ||||
| } | } | ||||
| } else { | } else { | ||||
| int64_t start = prim->get_start(); | |||||
| int64_t limit = prim->get_limit(); | |||||
| int64_t delta = prim->get_delta(); | |||||
| int64_t start = GetValue<int64_t>(primitive->GetAttr(kStart)); | |||||
| int64_t limit = GetValue<int64_t>(primitive->GetAttr(kLimit)); | |||||
| int64_t delta = GetValue<int64_t>(primitive->GetAttr(kDelta)); | |||||
| shape_size = | shape_size = | ||||
| std::max(static_cast<int64_t>(std::ceil(LongToDouble(limit - start) / delta)), static_cast<int64_t>(0)); | std::max(static_cast<int64_t>(std::ceil(LongToDouble(limit - start) / delta)), static_cast<int64_t>(0)); | ||||
| } | } | ||||
| @@ -21,9 +21,7 @@ namespace ops { | |||||
| namespace { | namespace { | ||||
| TypePtr RankInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | TypePtr RankInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(prim); | MS_EXCEPTION_IF_NULL(prim); | ||||
| auto Rank_prim = prim->cast<PrimRankPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(Rank_prim); | |||||
| auto op_name = Rank_prim->name(); | |||||
| auto op_name = prim->name(); | |||||
| auto infer_dtype = input_args[0]->BuildType(); | auto infer_dtype = input_args[0]->BuildType(); | ||||
| CheckAndConvertUtils::CheckTensorTypeValid("x", infer_dtype, {kTensorType}, op_name); | CheckAndConvertUtils::CheckTensorTypeValid("x", infer_dtype, {kTensorType}, op_name); | ||||
| return kTypeNone; | return kTypeNone; | ||||
| @@ -29,9 +29,7 @@ namespace ops { | |||||
| AbstractBasePtr ReciprocalInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr ReciprocalInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto reciprocal_prim = primitive->cast<PrimReciprocalPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(reciprocal_prim); | |||||
| auto prim_name = reciprocal_prim->name(); | |||||
| auto prim_name = primitive->name(); | |||||
| for (const auto &item : input_args) { | for (const auto &item : input_args) { | ||||
| MS_EXCEPTION_IF_NULL(item); | MS_EXCEPTION_IF_NULL(item); | ||||
| } | } | ||||
| @@ -70,13 +70,11 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A | |||||
| auto axis_value = input_args[1]->BuildValue(); | auto axis_value = input_args[1]->BuildValue(); | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto reduce_prim = primitive->cast<PrimReducePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(reduce_prim); | |||||
| auto prim_name = reduce_prim->name(); | |||||
| auto prim_name = primitive->name(); | |||||
| auto input_x_shape = | auto input_x_shape = | ||||
| CheckAndConvertUtils::ConvertShapePtrToShape("input_x_shape", input_args[0]->BuildShape(), prim_name); | CheckAndConvertUtils::ConvertShapePtrToShape("input_x_shape", input_args[0]->BuildShape(), prim_name); | ||||
| auto keep_dims = reduce_prim->get_keep_dims(); | |||||
| auto keep_dims = GetValue<bool>(primitive->GetAttr(kKeepDims)); | |||||
| auto out_shape = infer_shape_reduce(input_x_shape, axis_value, keep_dims, prim_name); | auto out_shape = infer_shape_reduce(input_x_shape, axis_value, keep_dims, prim_name); | ||||
| return std::make_shared<abstract::Shape>(out_shape); | return std::make_shared<abstract::Shape>(out_shape); | ||||
| @@ -93,10 +91,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> & | |||||
| void Reduce::set_keep_dims(const bool keep_dims) { this->AddAttr(kKeepDims, MakeValue(keep_dims)); } | void Reduce::set_keep_dims(const bool keep_dims) { this->AddAttr(kKeepDims, MakeValue(keep_dims)); } | ||||
| bool Reduce::get_keep_dims() const { | |||||
| auto value_ptr = GetAttr(kKeepDims); | |||||
| return GetValue<bool>(value_ptr); | |||||
| } | |||||
| bool Reduce::get_keep_dims() const { return GetValue<bool>(GetAttr(kKeepDims)); } | |||||
| void Reduce::Init(const bool keep_dims) { this->set_keep_dims(keep_dims); } | void Reduce::Init(const bool keep_dims) { this->set_keep_dims(keep_dims); } | ||||
| @@ -27,10 +27,7 @@ namespace mindspore { | |||||
| namespace ops { | namespace ops { | ||||
| void ResizeBilinear::set_size(const std::vector<int64_t> &size) { this->AddAttr(kSize, MakeValue(size)); } | void ResizeBilinear::set_size(const std::vector<int64_t> &size) { this->AddAttr(kSize, MakeValue(size)); } | ||||
| std::vector<int64_t> ResizeBilinear::get_size() const { | |||||
| auto value_ptr = GetAttr(kSize); | |||||
| return GetValue<std::vector<int64_t>>(value_ptr); | |||||
| } | |||||
| std::vector<int64_t> ResizeBilinear::get_size() const { return GetValue<std::vector<int64_t>>(GetAttr(kSize)); } | |||||
| void ResizeBilinear::set_align_corners(const bool align_corners) { | void ResizeBilinear::set_align_corners(const bool align_corners) { | ||||
| this->AddAttr(kAlignCorners, MakeValue(align_corners)); | this->AddAttr(kAlignCorners, MakeValue(align_corners)); | ||||
| @@ -48,9 +45,7 @@ void ResizeBilinear::Init(const std::vector<int64_t> &size, const bool align_cor | |||||
| AbstractBasePtr ResizeBilinearInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr ResizeBilinearInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto resize_prim = primitive->cast<PrimResizeBilinearPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(resize_prim); | |||||
| auto prim_name = resize_prim->name(); | |||||
| auto prim_name = primitive->name(); | |||||
| CheckAndConvertUtils::CheckInteger("resize_bilinear_infer", input_args.size(), kEqual, 1, prim_name); | CheckAndConvertUtils::CheckInteger("resize_bilinear_infer", input_args.size(), kEqual, 1, prim_name); | ||||
| // Infer shape | // Infer shape | ||||
| @@ -58,7 +53,7 @@ AbstractBasePtr ResizeBilinearInfer(const abstract::AnalysisEnginePtr &, const P | |||||
| CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name); | CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name); | ||||
| CheckAndConvertUtils::CheckInteger("input_shape_rank", input_shape.size(), kEqual, 4, prim_name); | CheckAndConvertUtils::CheckInteger("input_shape_rank", input_shape.size(), kEqual, 4, prim_name); | ||||
| std::vector<int64_t> out_shape = {input_shape[0], input_shape[1]}; | std::vector<int64_t> out_shape = {input_shape[0], input_shape[1]}; | ||||
| auto size = resize_prim->get_size(); | |||||
| auto size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kSize)); | |||||
| out_shape.insert(out_shape.end(), size.begin(), size.end()); | out_shape.insert(out_shape.end(), size.begin(), size.end()); | ||||
| // Infer type | // Infer type | ||||
| @@ -30,10 +30,7 @@ void ReverseSequence::Init(const int64_t seq_dim, const int64_t batch_dim) { | |||||
| void ReverseSequence::set_seq_dim(const int64_t seq_dim) { this->AddAttr(kSeqDim, MakeValue(seq_dim)); } | void ReverseSequence::set_seq_dim(const int64_t seq_dim) { this->AddAttr(kSeqDim, MakeValue(seq_dim)); } | ||||
| void ReverseSequence::set_batch_dim(const int64_t batch_dim) { this->AddAttr(kBatchDim, MakeValue(batch_dim)); } | void ReverseSequence::set_batch_dim(const int64_t batch_dim) { this->AddAttr(kBatchDim, MakeValue(batch_dim)); } | ||||
| int64_t ReverseSequence::get_seq_dim() const { | |||||
| auto value_ptr = this->GetAttr(kSeqDim); | |||||
| return GetValue<int64_t>(value_ptr); | |||||
| } | |||||
| int64_t ReverseSequence::get_seq_dim() const { return GetValue<int64_t>(GetAttr(kSeqDim)); } | |||||
| int64_t ReverseSequence::get_batch_dim() const { | int64_t ReverseSequence::get_batch_dim() const { | ||||
| auto value_ptr = this->GetAttr(kBatchDim); | auto value_ptr = this->GetAttr(kBatchDim); | ||||
| return GetValue<int64_t>(value_ptr); | return GetValue<int64_t>(value_ptr); | ||||
| @@ -41,9 +38,7 @@ int64_t ReverseSequence::get_batch_dim() const { | |||||
| AbstractBasePtr ReverseSequenceInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr ReverseSequenceInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto reverse_prim = primitive->cast<PrimReverseSequence>(); | |||||
| MS_EXCEPTION_IF_NULL(reverse_prim); | |||||
| auto prim_name = reverse_prim->name(); | |||||
| auto prim_name = primitive->name(); | |||||
| CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kEqual, 2, prim_name); | CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kEqual, 2, prim_name); | ||||
| for (const auto &item : input_args) { | for (const auto &item : input_args) { | ||||
| MS_EXCEPTION_IF_NULL(item); | MS_EXCEPTION_IF_NULL(item); | ||||
| @@ -53,8 +48,8 @@ AbstractBasePtr ReverseSequenceInfer(const abstract::AnalysisEnginePtr &, const | |||||
| CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name); | CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name); | ||||
| auto seq_lengths = | auto seq_lengths = | ||||
| CheckAndConvertUtils::ConvertShapePtrToShape("seq_lengths", input_args[1]->BuildShape(), prim_name); | CheckAndConvertUtils::ConvertShapePtrToShape("seq_lengths", input_args[1]->BuildShape(), prim_name); | ||||
| auto seq_dim = reverse_prim->get_seq_dim(); | |||||
| auto batch_dim = reverse_prim->get_batch_dim(); | |||||
| auto seq_dim = GetValue<int64_t>(primitive->GetAttr(kSeqDim)); | |||||
| auto batch_dim = GetValue<int64_t>(primitive->GetAttr(kBatchDim)); | |||||
| CheckAndConvertUtils::CheckInteger("seq_dim", seq_dim, kLessEqual, input_shape.size(), prim_name); | CheckAndConvertUtils::CheckInteger("seq_dim", seq_dim, kLessEqual, input_shape.size(), prim_name); | ||||
| CheckAndConvertUtils::CheckInteger("batch_dim", batch_dim, kLessEqual, input_shape.size(), prim_name); | CheckAndConvertUtils::CheckInteger("batch_dim", batch_dim, kLessEqual, input_shape.size(), prim_name); | ||||
| CheckAndConvertUtils::CheckInteger("batch_dim", batch_dim, kNotEqual, seq_dim, prim_name); | CheckAndConvertUtils::CheckInteger("batch_dim", batch_dim, kNotEqual, seq_dim, prim_name); | ||||
| @@ -24,9 +24,7 @@ namespace ops { | |||||
| namespace { | namespace { | ||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto reverseV2_prim = primitive->cast<PrimReverseV2Ptr>(); | |||||
| MS_EXCEPTION_IF_NULL(reverseV2_prim); | |||||
| auto prim_name = reverseV2_prim->name(); | |||||
| auto prim_name = primitive->name(); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | ||||
| return std::make_shared<abstract::Shape>(x_shape); | return std::make_shared<abstract::Shape>(x_shape); | ||||
| } | } | ||||
| @@ -24,13 +24,11 @@ namespace ops { | |||||
| namespace { | namespace { | ||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto rfft_prim = primitive->cast<PrimRfftPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(rfft_prim); | |||||
| auto prim_name = rfft_prim->name(); | |||||
| auto prim_name = primitive->name(); | |||||
| auto first_input_shape = | auto first_input_shape = | ||||
| CheckAndConvertUtils::ConvertShapePtrToShape("first_input_shape", input_args[0]->BuildShape(), prim_name); | CheckAndConvertUtils::ConvertShapePtrToShape("first_input_shape", input_args[0]->BuildShape(), prim_name); | ||||
| auto out_shape = first_input_shape; | auto out_shape = first_input_shape; | ||||
| out_shape[out_shape.size() - 1] = rfft_prim->get_fft_length() / 2 + 1; | |||||
| out_shape[out_shape.size() - 1] = GetValue<int64_t>(primitive->GetAttr(kFftLength)) / 2 + 1; | |||||
| out_shape.push_back(2); | out_shape.push_back(2); | ||||
| return std::make_shared<abstract::Shape>(out_shape); | return std::make_shared<abstract::Shape>(out_shape); | ||||
| } | } | ||||
| @@ -47,10 +45,7 @@ void Rfft::Init(const int64_t fft_length) { this->set_fft_length(fft_length); } | |||||
| void Rfft::set_fft_length(const int64_t fft_length) { this->AddAttr(kFftLength, MakeValue(fft_length)); } | void Rfft::set_fft_length(const int64_t fft_length) { this->AddAttr(kFftLength, MakeValue(fft_length)); } | ||||
| int64_t Rfft::get_fft_length() const { | |||||
| auto value_ptr = this->GetAttr(kFftLength); | |||||
| return GetValue<int64_t>(value_ptr); | |||||
| } | |||||
| int64_t Rfft::get_fft_length() const { return GetValue<int64_t>(GetAttr(kFftLength)); } | |||||
| AbstractBasePtr RfftInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr RfftInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| @@ -27,10 +27,7 @@ namespace mindspore { | |||||
| namespace ops { | namespace ops { | ||||
| void ROIPooling::set_pooled_h(const int64_t pooled_h) { this->AddAttr(kPooledH, MakeValue(pooled_h)); } | void ROIPooling::set_pooled_h(const int64_t pooled_h) { this->AddAttr(kPooledH, MakeValue(pooled_h)); } | ||||
| int64_t ROIPooling::get_pooled_h() const { | |||||
| auto value_ptr = GetAttr(kPooledH); | |||||
| return GetValue<int64_t>(value_ptr); | |||||
| } | |||||
| int64_t ROIPooling::get_pooled_h() const { return GetValue<int64_t>(GetAttr(kPooledH)); } | |||||
| void ROIPooling::set_pooled_w(const int64_t pooled_w) { this->AddAttr(kPooledW, MakeValue(pooled_w)); } | void ROIPooling::set_pooled_w(const int64_t pooled_w) { this->AddAttr(kPooledW, MakeValue(pooled_w)); } | ||||
| @@ -54,9 +51,7 @@ void ROIPooling::Init(const int64_t pooled_h, const int64_t pooled_w, const floa | |||||
| AbstractBasePtr ROIPoolingInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr ROIPoolingInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto roi_prim = primitive->cast<PrimROIPoolingPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(roi_prim); | |||||
| auto prim_name = roi_prim->name(); | |||||
| auto prim_name = primitive->name(); | |||||
| CheckAndConvertUtils::CheckInteger("roi_pooling_infer", input_args.size(), kEqual, 2, prim_name); | CheckAndConvertUtils::CheckInteger("roi_pooling_infer", input_args.size(), kEqual, 2, prim_name); | ||||
| MS_EXCEPTION_IF_NULL(input_args[0]); | MS_EXCEPTION_IF_NULL(input_args[0]); | ||||
| MS_EXCEPTION_IF_NULL(input_args[1]); | MS_EXCEPTION_IF_NULL(input_args[1]); | ||||
| @@ -65,8 +60,8 @@ AbstractBasePtr ROIPoolingInfer(const abstract::AnalysisEnginePtr &, const Primi | |||||
| auto output_data_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element(); | auto output_data_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element(); | ||||
| // Infer shape | // Infer shape | ||||
| auto new_h = roi_prim->get_pooled_h(); | |||||
| auto new_w = roi_prim->get_pooled_w(); | |||||
| auto new_h = GetValue<int64_t>(primitive->GetAttr(kPooledH)); | |||||
| auto new_w = GetValue<int64_t>(primitive->GetAttr(kPooledW)); | |||||
| auto input_shape = | auto input_shape = | ||||
| CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name); | CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name); | ||||
| auto roi_shape = CheckAndConvertUtils::ConvertShapePtrToShape("roi_shape", input_args[1]->BuildShape(), prim_name); | auto roi_shape = CheckAndConvertUtils::ConvertShapePtrToShape("roi_shape", input_args[1]->BuildShape(), prim_name); | ||||
| @@ -29,9 +29,7 @@ namespace ops { | |||||
| namespace { | namespace { | ||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto rsqrt_prim = primitive->cast<PrimRsqrtPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(rsqrt_prim); | |||||
| auto prim_name = rsqrt_prim->name(); | |||||
| auto prim_name = primitive->name(); | |||||
| auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("in_shape", input_args[0]->GetShapeTrack(), prim_name); | auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("in_shape", input_args[0]->GetShapeTrack(), prim_name); | ||||
| CheckAndConvertUtils::CheckInteger("input shape", in_shape.size(), kEqual, 1, prim_name); | CheckAndConvertUtils::CheckInteger("input shape", in_shape.size(), kEqual, 1, prim_name); | ||||
| return std::make_shared<abstract::Shape>(in_shape); | return std::make_shared<abstract::Shape>(in_shape); | ||||
| @@ -29,9 +29,7 @@ namespace ops { | |||||
| AbstractBasePtr SigmoidCrossEntropyWithLogitsInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr SigmoidCrossEntropyWithLogitsInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto sigmoid_prim = primitive->cast<PrimSigmoidCrossEntropyWithLogitsPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(sigmoid_prim); | |||||
| auto prim_name = sigmoid_prim->name(); | |||||
| auto prim_name = primitive->name(); | |||||
| CheckAndConvertUtils::CheckInteger("sigmoid_cross_extropy_with_logits_infer", input_args.size(), kEqual, 2, | CheckAndConvertUtils::CheckInteger("sigmoid_cross_extropy_with_logits_infer", input_args.size(), kEqual, 2, | ||||
| prim_name); | prim_name); | ||||
| @@ -23,9 +23,7 @@ namespace ops { | |||||
| namespace { | namespace { | ||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto SkipGram_prim = primitive->cast<PrimSkipGramPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(SkipGram_prim); | |||||
| auto prim_name = SkipGram_prim->name(); | |||||
| auto prim_name = primitive->name(); | |||||
| if (input_args.size() != 1) { | if (input_args.size() != 1) { | ||||
| MS_LOG(ERROR) << "Skip Gram should have one input"; | MS_LOG(ERROR) << "Skip Gram should have one input"; | ||||
| } | } | ||||
| @@ -36,9 +36,7 @@ float SmoothL1Loss::get_beta() const { | |||||
| AbstractBasePtr SmoothL1LossInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr SmoothL1LossInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto smooth_prim = primitive->cast<PrimSmoothL1LossPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(smooth_prim); | |||||
| auto prim_name = smooth_prim->name(); | |||||
| auto prim_name = primitive->name(); | |||||
| CheckAndConvertUtils::CheckInteger("smooth_l1_loss_infer", input_args.size(), kEqual, 2, prim_name); | CheckAndConvertUtils::CheckInteger("smooth_l1_loss_infer", input_args.size(), kEqual, 2, prim_name); | ||||
| // Infer shape | // Infer shape | ||||
| @@ -29,9 +29,7 @@ namespace ops { | |||||
| AbstractBasePtr SoftmaxCrossEntropyWithLogitsInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr SoftmaxCrossEntropyWithLogitsInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto softmax_prim = primitive->cast<PrimSoftmaxCrossEntropyWithLogitsPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(softmax_prim); | |||||
| auto prim_name = softmax_prim->name(); | |||||
| auto prim_name = primitive->name(); | |||||
| CheckAndConvertUtils::CheckInteger("softmax_cross_entropy_with_logics_infer", input_args.size(), kEqual, 2, | CheckAndConvertUtils::CheckInteger("softmax_cross_entropy_with_logics_infer", input_args.size(), kEqual, 2, | ||||
| prim_name); | prim_name); | ||||
| @@ -28,15 +28,13 @@ namespace ops { | |||||
| namespace { | namespace { | ||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto spacetobatch_prim = primitive->cast<PrimSpaceToBatchPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(spacetobatch_prim); | |||||
| auto prim_name = spacetobatch_prim->name(); | |||||
| auto prim_name = primitive->name(); | |||||
| auto input_shape = | auto input_shape = | ||||
| CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name); | CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name); | ||||
| CheckAndConvertUtils::CheckInteger("input shape", input_shape.size(), kEqual, 4, prim_name); | CheckAndConvertUtils::CheckInteger("input shape", input_shape.size(), kEqual, 4, prim_name); | ||||
| std::vector<int64_t> output_shape(input_shape.size()); | std::vector<int64_t> output_shape(input_shape.size()); | ||||
| auto block_shape_vector = spacetobatch_prim->get_block_size(); | |||||
| auto paddings = spacetobatch_prim->get_paddings(); | |||||
| auto block_shape_vector = GetValue<std::vector<int64_t>>(primitive->GetAttr(kBlockSize)); | |||||
| auto paddings = GetValue<std::vector<std::vector<int64_t>>>(primitive->GetAttr(kPaddings)); | |||||
| for (size_t i = 0; i < 2; i++) { | for (size_t i = 0; i < 2; i++) { | ||||
| auto padded = output_shape[i + 2] + paddings[i][0] + paddings[i][1]; | auto padded = output_shape[i + 2] + paddings[i][0] + paddings[i][1]; | ||||
| CheckAndConvertUtils::CheckInteger("padded shape", padded % block_shape_vector.size(), kEqual, 0, prim_name); | CheckAndConvertUtils::CheckInteger("padded shape", padded % block_shape_vector.size(), kEqual, 0, prim_name); | ||||
| @@ -77,8 +75,7 @@ void SpaceToBatch::set_block_size(const std::vector<int64_t> block_size) { | |||||
| } | } | ||||
| std::vector<int64_t> SpaceToBatch::get_block_size() const { | std::vector<int64_t> SpaceToBatch::get_block_size() const { | ||||
| auto value_ptr = GetAttr(kBlockSize); | |||||
| return GetValue<std::vector<int64_t>>(value_ptr); | |||||
| return GetValue<std::vector<int64_t>>(GetAttr(kBlockSize)); | |||||
| } | } | ||||
| void SpaceToBatch::Init(const std::vector<int64_t> block_size, const std::vector<std::vector<int64_t>> &paddings) { | void SpaceToBatch::Init(const std::vector<int64_t> block_size, const std::vector<std::vector<int64_t>> &paddings) { | ||||
| @@ -28,16 +28,14 @@ namespace ops { | |||||
| namespace { | namespace { | ||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto space_prim = primitive->cast<PrimSpaceToBatchNDPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(space_prim); | |||||
| auto prim_name = space_prim->name(); | |||||
| auto prim_name = primitive->name(); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | ||||
| CheckAndConvertUtils::CheckInteger("input_x rank", x_shape.size(), kEqual, 4, prim_name); | CheckAndConvertUtils::CheckInteger("input_x rank", x_shape.size(), kEqual, 4, prim_name); | ||||
| auto out_shape = x_shape; | auto out_shape = x_shape; | ||||
| int64_t block_shape_prod = 1; | int64_t block_shape_prod = 1; | ||||
| const int64_t offset = 2; | const int64_t offset = 2; | ||||
| auto block_shape = space_prim->get_block_shape(); | |||||
| auto padding = space_prim->get_paddings(); | |||||
| auto block_shape = GetValue<std::vector<int64_t>>(primitive->GetAttr(kBlockShape)); | |||||
| auto padding = GetValue<std::vector<std::vector<int64_t>>>(primitive->GetAttr(kPaddings)); | |||||
| int64_t size = block_shape.size(); | int64_t size = block_shape.size(); | ||||
| for (int64_t i = 0; i < size; i++) { | for (int64_t i = 0; i < size; i++) { | ||||
| int64_t padded = out_shape[i + offset] + padding[i][0] + padding[i][1]; | int64_t padded = out_shape[i + offset] + padding[i][0] + padding[i][1]; | ||||
| @@ -87,8 +85,7 @@ void SpaceToBatchND::set_block_shape(std::vector<int64_t> block_shape) { | |||||
| } | } | ||||
| std::vector<int64_t> SpaceToBatchND::get_block_shape() const { | std::vector<int64_t> SpaceToBatchND::get_block_shape() const { | ||||
| auto value_ptr = GetAttr(kBlockShape); | |||||
| return GetValue<std::vector<int64_t>>(value_ptr); | |||||
| return GetValue<std::vector<int64_t>>(GetAttr(kBlockShape)); | |||||
| } | } | ||||
| void SpaceToBatchND::Init(std::vector<int64_t> block_shape, std::vector<std::vector<int64_t>> paddings) { | void SpaceToBatchND::Init(std::vector<int64_t> block_shape, std::vector<std::vector<int64_t>> paddings) { | ||||
| @@ -31,18 +31,13 @@ void SparseSoftmaxCrossEntropyWithLogits::set_is_grad(const bool is_grad) { | |||||
| this->AddAttr(kIsGrad, MakeValue(is_grad)); | this->AddAttr(kIsGrad, MakeValue(is_grad)); | ||||
| } | } | ||||
| bool SparseSoftmaxCrossEntropyWithLogits::get_is_grad() const { | |||||
| auto value_ptr = GetAttr(kIsGrad); | |||||
| return GetValue<bool>(value_ptr); | |||||
| } | |||||
| bool SparseSoftmaxCrossEntropyWithLogits::get_is_grad() const { return GetValue<bool>(GetAttr(kIsGrad)); } | |||||
| AbstractBasePtr SparseSoftmaxCrossEntropyWithLogitsInfer(const abstract::AnalysisEnginePtr &, | AbstractBasePtr SparseSoftmaxCrossEntropyWithLogitsInfer(const abstract::AnalysisEnginePtr &, | ||||
| const PrimitivePtr &primitive, | const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto sparse_softmax_cross_entropy_prim = primitive->cast<PrimSparseSoftmaxCrossEntropyWithLogitsPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(sparse_softmax_cross_entropy_prim); | |||||
| auto prim_name = sparse_softmax_cross_entropy_prim->name(); | |||||
| auto prim_name = primitive->name(); | |||||
| CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kEqual, 2, prim_name); | CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kEqual, 2, prim_name); | ||||
| for (const auto &item : input_args) { | for (const auto &item : input_args) { | ||||
| MS_EXCEPTION_IF_NULL(item); | MS_EXCEPTION_IF_NULL(item); | ||||
| @@ -51,7 +46,7 @@ AbstractBasePtr SparseSoftmaxCrossEntropyWithLogitsInfer(const abstract::Analysi | |||||
| auto input_shape = | auto input_shape = | ||||
| CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name); | CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name); | ||||
| std::vector<int64_t> output_shape; | std::vector<int64_t> output_shape; | ||||
| if (sparse_softmax_cross_entropy_prim->get_is_grad() != 0) { | |||||
| if (GetValue<bool>(primitive->GetAttr(kIsGrad)) != 0) { | |||||
| output_shape = input_shape; | output_shape = input_shape; | ||||
| } else { | } else { | ||||
| output_shape.push_back(1); | output_shape.push_back(1); | ||||
| @@ -27,9 +27,7 @@ namespace ops { | |||||
| AbstractBasePtr SparseToDenseInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr SparseToDenseInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto spasetodense_prim = primitive->cast<PrimSparseToDensePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(spasetodense_prim); | |||||
| auto prim_name = spasetodense_prim->name(); | |||||
| auto prim_name = primitive->name(); | |||||
| CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kEqual, 3, prim_name); | CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kEqual, 3, prim_name); | ||||
| for (const auto &item : input_args) { | for (const auto &item : input_args) { | ||||
| MS_EXCEPTION_IF_NULL(item); | MS_EXCEPTION_IF_NULL(item); | ||||
| @@ -27,9 +27,7 @@ namespace ops { | |||||
| namespace { | namespace { | ||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto squared_prim = primitive->cast<PrimSquaredDifferencePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(squared_prim); | |||||
| auto op_name = squared_prim->name(); | |||||
| auto op_name = primitive->name(); | |||||
| return BroadCastInferShape(op_name, input_args); | return BroadCastInferShape(op_name, input_args); | ||||
| } | } | ||||