| @@ -304,8 +304,6 @@ AbstractBasePtr InferImplArgMaxWithValue(const AnalysisEnginePtr &, const Primit | |||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplSparseSoftmaxCrossEntropyWithLogits(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplSparseSoftmaxCrossEntropyWithLogits(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplDType(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| AbstractBasePtr InferImplLoad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplLoad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplAssign(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplAssign(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| @@ -545,22 +545,6 @@ AbstractBasePtr InferImplGpuConvertToDynamicShape(const AnalysisEnginePtr &, con | |||||
| return std::make_shared<AbstractTensor>(input->element(), shape); | return std::make_shared<AbstractTensor>(input->element(), shape); | ||||
| } | } | ||||
| AbstractBasePtr InferImplDType(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| auto op_name = primitive->name(); | |||||
| CheckArgsSize(op_name, args_spec_list, 1); | |||||
| MS_EXCEPTION_IF_NULL(args_spec_list[0]); | |||||
| auto type = args_spec_list[0]->BuildType(); | |||||
| MS_EXCEPTION_IF_NULL(type); | |||||
| auto tensor_type = type->cast<TensorTypePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(tensor_type); | |||||
| auto value = tensor_type->element(); | |||||
| auto abstract = std::make_shared<abstract::AbstractType>(value); | |||||
| abstract->set_value(value); | |||||
| return abstract; | |||||
| } | |||||
| AbstractBasePtr InferImplLoad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplLoad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list) { | const AbstractBasePtrList &args_spec_list) { | ||||
| // Inputs: Ref/Tensor, universal | // Inputs: Ref/Tensor, universal | ||||
| @@ -34,22 +34,33 @@ void LogSoftmax::Init(const int64_t axis) { this->set_axis(axis); } | |||||
| abstract::ShapePtr LogSoftmaxInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr LogSoftmaxInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto LogSoftmax_prim = primitive->cast<PrimLogSoftmaxPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(LogSoftmax_prim); | |||||
| auto op_name = LogSoftmax_prim->name(); | |||||
| auto axis = LogSoftmax_prim->get_axis(); | |||||
| auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->GetShapeTrack(), op_name); | |||||
| auto op_name = primitive->name(); | |||||
| auto axis = GetValue<int64_t>(primitive->GetAttr(kAxis)); | |||||
| CheckAndConvertUtils::CheckInteger("log_softmax infer", input_args.size(), kEqual, 1, op_name); | |||||
| MS_EXCEPTION_IF_NULL(input_args[0]); | |||||
| auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape()); | |||||
| if (shape_map.empty()) { | |||||
| // Scalar input, has no shape | |||||
| return std::make_shared<abstract::Shape>(std::vector<int64_t>()); | |||||
| } | |||||
| auto in_shape = shape_map[kShape]; | |||||
| auto min_shape = shape_map[kMinShape]; | |||||
| auto max_shape = shape_map[kMaxShape]; | |||||
| auto rank = SizeToLong(in_shape.size()); | auto rank = SizeToLong(in_shape.size()); | ||||
| CheckAndConvertUtils::CheckInRange<int64_t>("axis", axis, kIncludeLeft, {-rank, rank}, op_name); | CheckAndConvertUtils::CheckInRange<int64_t>("axis", axis, kIncludeLeft, {-rank, rank}, op_name); | ||||
| if (min_shape.size() != 0 && max_shape.size() != 0) { | |||||
| return std::make_shared<abstract::Shape>(in_shape, min_shape, max_shape); | |||||
| } | |||||
| return std::make_shared<abstract::Shape>(in_shape); | return std::make_shared<abstract::Shape>(in_shape); | ||||
| } | } | ||||
| TypePtr LogSoftmaxInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | TypePtr LogSoftmaxInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | ||||
| if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr &a) { return a == nullptr; })) { | |||||
| MS_LOG(EXCEPTION) << "nullptr"; | |||||
| } | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| auto op_name = prim->name(); | |||||
| CheckAndConvertUtils::CheckInteger("log_softmax infer", input_args.size(), kEqual, 1, op_name); | |||||
| MS_EXCEPTION_IF_NULL(input_args[0]); | |||||
| const std::set<TypePtr> valid_types = {kFloat16, kFloat32}; | const std::set<TypePtr> valid_types = {kFloat16, kFloat32}; | ||||
| return CheckAndConvertUtils::CheckTensorTypeValid("x", input_args[0]->BuildType(), valid_types, prim->name()); | |||||
| return CheckAndConvertUtils::CheckTensorTypeValid("x", input_args[0]->BuildType(), valid_types, op_name); | |||||
| } | } | ||||
| AbstractBasePtr LogSoftmaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr LogSoftmaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| @@ -57,6 +68,6 @@ AbstractBasePtr LogSoftmaxInfer(const abstract::AnalysisEnginePtr &, const Primi | |||||
| return std::make_shared<abstract::AbstractTensor>(LogSoftmaxInferType(primitive, input_args), | return std::make_shared<abstract::AbstractTensor>(LogSoftmaxInferType(primitive, input_args), | ||||
| LogSoftmaxInferShape(primitive, input_args)->shape()); | LogSoftmaxInferShape(primitive, input_args)->shape()); | ||||
| } | } | ||||
| REGISTER_PRIMITIVE_C(kNameLogSoftmax, LogSoftmax); | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(LogSoftmax, prim::kPrimLogSoftmax, LogSoftmaxInfer, nullptr, true); | |||||
| } // namespace ops | } // namespace ops | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -37,9 +37,6 @@ class LogSoftmax : public PrimitiveC { | |||||
| void set_axis(const int64_t axis); | void set_axis(const int64_t axis); | ||||
| int64_t get_axis() const; | int64_t get_axis() const; | ||||
| }; | }; | ||||
| AbstractBasePtr LogSoftmaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args); | |||||
| using PrimLogSoftmaxPtr = std::shared_ptr<LogSoftmax>; | using PrimLogSoftmaxPtr = std::shared_ptr<LogSoftmax>; | ||||
| } // namespace ops | } // namespace ops | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -178,7 +178,7 @@ class Softmax(Primitive): | |||||
| validator.check_value_type("item of axis", item, [int], self.name) | validator.check_value_type("item of axis", item, [int], self.name) | ||||
| class LogSoftmax(PrimitiveWithInfer): | |||||
| class LogSoftmax(Primitive): | |||||
| r""" | r""" | ||||
| Log Softmax activation function. | Log Softmax activation function. | ||||
| @@ -220,15 +220,6 @@ class LogSoftmax(PrimitiveWithInfer): | |||||
| def __init__(self, axis=-1): | def __init__(self, axis=-1): | ||||
| validator.check_value_type("axis", axis, [int], self.name) | validator.check_value_type("axis", axis, [int], self.name) | ||||
| def infer_shape(self, logits): | |||||
| rank = len(logits) | |||||
| validator.check_int_range(self.axis, -rank, rank, Rel.INC_LEFT, 'axis', self.name) | |||||
| return logits | |||||
| def infer_dtype(self, logits): | |||||
| validator.check_tensor_dtype_valid("logits", logits, (mstype.float16, mstype.float32), self.name) | |||||
| return logits | |||||
| class Softplus(Primitive): | class Softplus(Primitive): | ||||
| r""" | r""" | ||||