Merge pull request !4999 from yeyunpeng2020/primitivetags/v0.7.0-beta
| @@ -19,22 +19,22 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| long Crop::GetAxis() const { return this->primitive_->value.AsCrop()->axis; } | |||||
| std::vector<long> Crop::GetOffsets() const { return this->primitive_->value.AsCrop()->offsets; } | |||||
| int64_t Crop::GetAxis() const { return this->primitive_->value.AsCrop()->axis; } | |||||
| std::vector<int64_t> Crop::GetOffsets() const { return this->primitive_->value.AsCrop()->offsets; } | |||||
| void Crop::SetAxis(long axis) { this->primitive_->value.AsCrop()->axis = axis; } | |||||
| void Crop::SetOffsets(const std::vector<long> &offsets) { this->primitive_->value.AsCrop()->offsets = offsets; } | |||||
| void Crop::SetAxis(int64_t axis) { this->primitive_->value.AsCrop()->axis = axis; } | |||||
| void Crop::SetOffsets(const std::vector<int64_t> &offsets) { this->primitive_->value.AsCrop()->offsets = offsets; } | |||||
| #else | #else | ||||
| long Crop::GetAxis() const { return this->primitive_->value_as_Crop()->axis(); } | |||||
| std::vector<long> Crop::GetOffsets() const { | |||||
| int64_t Crop::GetAxis() const { return this->primitive_->value_as_Crop()->axis(); } | |||||
| std::vector<int64_t> Crop::GetOffsets() const { | |||||
| auto fb_vector = this->primitive_->value_as_Crop()->offsets(); | auto fb_vector = this->primitive_->value_as_Crop()->offsets(); | ||||
| return std::vector<long>(fb_vector->begin(), fb_vector->end()); | |||||
| return std::vector<int64_t>(fb_vector->begin(), fb_vector->end()); | |||||
| } | } | ||||
| void Crop::SetAxis(long axis) {} | |||||
| void Crop::SetOffsets(const std::vector<long> &offsets) {} | |||||
| void Crop::SetAxis(int64_t axis) {} | |||||
| void Crop::SetOffsets(const std::vector<int64_t> &offsets) {} | |||||
| #endif | #endif | ||||
| namespace { | namespace { | ||||
| constexpr int kCropOutputNum = 1; | constexpr int kCropOutputNum = 1; | ||||
| @@ -34,10 +34,10 @@ class Crop : public PrimitiveC { | |||||
| explicit Crop(schema::Primitive *primitive) : PrimitiveC(primitive) {} | explicit Crop(schema::Primitive *primitive) : PrimitiveC(primitive) {} | ||||
| #endif | #endif | ||||
| int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override; | int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override; | ||||
| long GetAxis() const; | |||||
| std::vector<long> GetOffsets() const; | |||||
| void SetAxis(long axis); | |||||
| void SetOffsets(const std::vector<long> &offsets); | |||||
| int64_t GetAxis() const; | |||||
| std::vector<int64_t> GetOffsets() const; | |||||
| void SetAxis(int64_t axis); | |||||
| void SetOffsets(const std::vector<int64_t> &offsets); | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -31,16 +31,16 @@ float DetectionPostProcess::GetNmsIouThreshold() const { | |||||
| float DetectionPostProcess::GetNmsScoreThreshold() const { | float DetectionPostProcess::GetNmsScoreThreshold() const { | ||||
| return this->primitive_->value.AsDetectionPostProcess()->NmsScoreThreshold; | return this->primitive_->value.AsDetectionPostProcess()->NmsScoreThreshold; | ||||
| } | } | ||||
| long DetectionPostProcess::GetMaxDetections() const { | |||||
| int64_t DetectionPostProcess::GetMaxDetections() const { | |||||
| return this->primitive_->value.AsDetectionPostProcess()->MaxDetections; | return this->primitive_->value.AsDetectionPostProcess()->MaxDetections; | ||||
| } | } | ||||
| long DetectionPostProcess::GetDetectionsPreClass() const { | |||||
| int64_t DetectionPostProcess::GetDetectionsPreClass() const { | |||||
| return this->primitive_->value.AsDetectionPostProcess()->DetectionsPreClass; | return this->primitive_->value.AsDetectionPostProcess()->DetectionsPreClass; | ||||
| } | } | ||||
| long DetectionPostProcess::GetMaxClassesPreDetection() const { | |||||
| int64_t DetectionPostProcess::GetMaxClassesPreDetection() const { | |||||
| return this->primitive_->value.AsDetectionPostProcess()->MaxClassesPreDetection; | return this->primitive_->value.AsDetectionPostProcess()->MaxClassesPreDetection; | ||||
| } | } | ||||
| long DetectionPostProcess::GetNumClasses() const { | |||||
| int64_t DetectionPostProcess::GetNumClasses() const { | |||||
| return this->primitive_->value.AsDetectionPostProcess()->NumClasses; | return this->primitive_->value.AsDetectionPostProcess()->NumClasses; | ||||
| } | } | ||||
| bool DetectionPostProcess::GetUseRegularNms() const { | bool DetectionPostProcess::GetUseRegularNms() const { | ||||
| @@ -71,16 +71,16 @@ void DetectionPostProcess::SetNmsIouThreshold(float nms_iou_threshold) { | |||||
| void DetectionPostProcess::SetNmsScoreThreshold(float nms_score_threshold) { | void DetectionPostProcess::SetNmsScoreThreshold(float nms_score_threshold) { | ||||
| this->primitive_->value.AsDetectionPostProcess()->NmsScoreThreshold = nms_score_threshold; | this->primitive_->value.AsDetectionPostProcess()->NmsScoreThreshold = nms_score_threshold; | ||||
| } | } | ||||
| void DetectionPostProcess::SetMaxDetections(long max_detections) { | |||||
| void DetectionPostProcess::SetMaxDetections(int64_t max_detections) { | |||||
| this->primitive_->value.AsDetectionPostProcess()->MaxClassesPreDetection = max_detections; | this->primitive_->value.AsDetectionPostProcess()->MaxClassesPreDetection = max_detections; | ||||
| } | } | ||||
| void DetectionPostProcess::SetDetectionsPreClass(long detections_pre_class) { | |||||
| void DetectionPostProcess::SetDetectionsPreClass(int64_t detections_pre_class) { | |||||
| this->primitive_->value.AsDetectionPostProcess()->DetectionsPreClass = detections_pre_class; | this->primitive_->value.AsDetectionPostProcess()->DetectionsPreClass = detections_pre_class; | ||||
| } | } | ||||
| void DetectionPostProcess::SetMaxClassesPreDetection(long max_classes_pre_detection) { | |||||
| void DetectionPostProcess::SetMaxClassesPreDetection(int64_t max_classes_pre_detection) { | |||||
| this->primitive_->value.AsDetectionPostProcess()->MaxClassesPreDetection = max_classes_pre_detection; | this->primitive_->value.AsDetectionPostProcess()->MaxClassesPreDetection = max_classes_pre_detection; | ||||
| } | } | ||||
| void DetectionPostProcess::SetNumClasses(long num_classes) { | |||||
| void DetectionPostProcess::SetNumClasses(int64_t num_classes) { | |||||
| this->primitive_->value.AsDetectionPostProcess()->NumClasses = num_classes; | this->primitive_->value.AsDetectionPostProcess()->NumClasses = num_classes; | ||||
| } | } | ||||
| void DetectionPostProcess::SetUseRegularNms(bool use_regular_nms) { | void DetectionPostProcess::SetUseRegularNms(bool use_regular_nms) { | ||||
| @@ -103,16 +103,16 @@ float DetectionPostProcess::GetNmsIouThreshold() const { | |||||
| float DetectionPostProcess::GetNmsScoreThreshold() const { | float DetectionPostProcess::GetNmsScoreThreshold() const { | ||||
| return this->primitive_->value_as_DetectionPostProcess()->NmsScoreThreshold(); | return this->primitive_->value_as_DetectionPostProcess()->NmsScoreThreshold(); | ||||
| } | } | ||||
| long DetectionPostProcess::GetMaxDetections() const { | |||||
| int64_t DetectionPostProcess::GetMaxDetections() const { | |||||
| return this->primitive_->value_as_DetectionPostProcess()->MaxDetections(); | return this->primitive_->value_as_DetectionPostProcess()->MaxDetections(); | ||||
| } | } | ||||
| long DetectionPostProcess::GetDetectionsPreClass() const { | |||||
| int64_t DetectionPostProcess::GetDetectionsPreClass() const { | |||||
| return this->primitive_->value_as_DetectionPostProcess()->DetectionsPreClass(); | return this->primitive_->value_as_DetectionPostProcess()->DetectionsPreClass(); | ||||
| } | } | ||||
| long DetectionPostProcess::GetMaxClassesPreDetection() const { | |||||
| int64_t DetectionPostProcess::GetMaxClassesPreDetection() const { | |||||
| return this->primitive_->value_as_DetectionPostProcess()->MaxClassesPreDetection(); | return this->primitive_->value_as_DetectionPostProcess()->MaxClassesPreDetection(); | ||||
| } | } | ||||
| long DetectionPostProcess::GetNumClasses() const { | |||||
| int64_t DetectionPostProcess::GetNumClasses() const { | |||||
| return this->primitive_->value_as_DetectionPostProcess()->NumClasses(); | return this->primitive_->value_as_DetectionPostProcess()->NumClasses(); | ||||
| } | } | ||||
| bool DetectionPostProcess::GetUseRegularNms() const { | bool DetectionPostProcess::GetUseRegularNms() const { | ||||
| @@ -127,10 +127,10 @@ void DetectionPostProcess::SetXScale(float x_scale) {} | |||||
| void DetectionPostProcess::SetYScale(float y_scale) {} | void DetectionPostProcess::SetYScale(float y_scale) {} | ||||
| void DetectionPostProcess::SetNmsIouThreshold(float nms_iou_threshold) {} | void DetectionPostProcess::SetNmsIouThreshold(float nms_iou_threshold) {} | ||||
| void DetectionPostProcess::SetNmsScoreThreshold(float nms_score_threshold) {} | void DetectionPostProcess::SetNmsScoreThreshold(float nms_score_threshold) {} | ||||
| void DetectionPostProcess::SetMaxDetections(long max_detections) {} | |||||
| void DetectionPostProcess::SetDetectionsPreClass(long detections_pre_class) {} | |||||
| void DetectionPostProcess::SetMaxClassesPreDetection(long max_classes_pre_detection) {} | |||||
| void DetectionPostProcess::SetNumClasses(long num_classes) {} | |||||
| void DetectionPostProcess::SetMaxDetections(int64_t max_detections) {} | |||||
| void DetectionPostProcess::SetDetectionsPreClass(int64_t detections_pre_class) {} | |||||
| void DetectionPostProcess::SetMaxClassesPreDetection(int64_t max_classes_pre_detection) {} | |||||
| void DetectionPostProcess::SetNumClasses(int64_t num_classes) {} | |||||
| void DetectionPostProcess::SetUseRegularNms(bool use_regular_nms) {} | void DetectionPostProcess::SetUseRegularNms(bool use_regular_nms) {} | ||||
| #endif | #endif | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -41,10 +41,10 @@ class DetectionPostProcess : public PrimitiveC { | |||||
| float GetYScale() const; | float GetYScale() const; | ||||
| float GetNmsIouThreshold() const; | float GetNmsIouThreshold() const; | ||||
| float GetNmsScoreThreshold() const; | float GetNmsScoreThreshold() const; | ||||
| long GetMaxDetections() const; | |||||
| long GetDetectionsPreClass() const; | |||||
| long GetMaxClassesPreDetection() const; | |||||
| long GetNumClasses() const; | |||||
| int64_t GetMaxDetections() const; | |||||
| int64_t GetDetectionsPreClass() const; | |||||
| int64_t GetMaxClassesPreDetection() const; | |||||
| int64_t GetNumClasses() const; | |||||
| bool GetUseRegularNms() const; | bool GetUseRegularNms() const; | ||||
| void SetFormat(int format); | void SetFormat(int format); | ||||
| void SetInputSize(int input_size); | void SetInputSize(int input_size); | ||||
| @@ -54,10 +54,10 @@ class DetectionPostProcess : public PrimitiveC { | |||||
| void SetYScale(float y_scale); | void SetYScale(float y_scale); | ||||
| void SetNmsIouThreshold(float nms_iou_threshold); | void SetNmsIouThreshold(float nms_iou_threshold); | ||||
| void SetNmsScoreThreshold(float nms_score_threshold); | void SetNmsScoreThreshold(float nms_score_threshold); | ||||
| void SetMaxDetections(long max_detections); | |||||
| void SetDetectionsPreClass(long detections_pre_class); | |||||
| void SetMaxClassesPreDetection(long max_classes_pre_detection); | |||||
| void SetNumClasses(long num_classes); | |||||
| void SetMaxDetections(int64_t max_detections); | |||||
| void SetDetectionsPreClass(int64_t detections_pre_class); | |||||
| void SetMaxClassesPreDetection(int64_t max_classes_pre_detection); | |||||
| void SetNumClasses(int64_t num_classes); | |||||
| void SetUseRegularNms(bool use_regular_nms); | void SetUseRegularNms(bool use_regular_nms); | ||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -19,18 +19,18 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| std::vector<long> Permute::GetOrder() const { return this->primitive_->value.AsPermute()->order; } | |||||
| std::vector<int64_t> Permute::GetOrder() const { return this->primitive_->value.AsPermute()->order; } | |||||
| void Permute::SetOrder(const std::vector<long> &order) { this->primitive_->value.AsPermute()->order = order; } | |||||
| void Permute::SetOrder(const std::vector<int64_t> &order) { this->primitive_->value.AsPermute()->order = order; } | |||||
| #else | #else | ||||
| std::vector<long> Permute::GetOrder() const { | |||||
| std::vector<int64_t> Permute::GetOrder() const { | |||||
| auto fb_vector = this->primitive_->value_as_Permute()->order(); | auto fb_vector = this->primitive_->value_as_Permute()->order(); | ||||
| return std::vector<long>(fb_vector->begin(), fb_vector->end()); | |||||
| return std::vector<int64_t>(fb_vector->begin(), fb_vector->end()); | |||||
| } | } | ||||
| void Permute::SetOrder(const std::vector<long> &order) {} | |||||
| void Permute::SetOrder(const std::vector<int64_t> &order) {} | |||||
| #endif | #endif | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -33,8 +33,8 @@ class Permute : public PrimitiveC { | |||||
| #else | #else | ||||
| explicit Permute(schema::Primitive *primitive) : PrimitiveC(primitive) {} | explicit Permute(schema::Primitive *primitive) : PrimitiveC(primitive) {} | ||||
| #endif | #endif | ||||
| std::vector<long> GetOrder() const; | |||||
| void SetOrder(const std::vector<long> &order); | |||||
| std::vector<int64_t> GetOrder() const; | |||||
| void SetOrder(const std::vector<int64_t> &order); | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -410,11 +410,32 @@ PrimitiveC *PrimitiveC::UnPackFromSchemaPrimitiveT(mindspore::schema::PrimitiveT | |||||
| return new Shape(primitive); | return new Shape(primitive); | ||||
| case schema::PrimitiveType_Unsqueeze: | case schema::PrimitiveType_Unsqueeze: | ||||
| return new Unsqueeze(primitive); | return new Unsqueeze(primitive); | ||||
| case schema::PrimitiveType_BatchToSpace: | |||||
| return new BatchToSpace(primitive); | |||||
| case schema::PrimitiveType_SpaceToBatch: | |||||
| return new SpaceToBatch(primitive); | |||||
| case schema::PrimitiveType_BroadcastTo: | |||||
| return new BroadcastTo(primitive); | |||||
| case schema::PrimitiveType_DepthToSpace: | |||||
| return new DepthToSpace(primitive); | |||||
| case schema::PrimitiveType_Lstm: | |||||
| return new Lstm(primitive); | |||||
| case schema::PrimitiveType_ZerosLike: | |||||
| return new ZerosLike(primitive); | |||||
| case schema::PrimitiveType_MakeTuple: | |||||
| return new MakeTuple(primitive); | |||||
| case schema::PrimitiveType_Where: | |||||
| return new Where(primitive); | |||||
| case schema::PrimitiveType_ScatterND: | |||||
| return new ScatterND(primitive); | |||||
| case schema::PrimitiveType_ConstantOfShape: | |||||
| return new ConstantOfShape(primitive); | |||||
| default: | default: | ||||
| MS_LOG(ERROR) << "Unsupported primitive type in UnPackFromSchemaPrimitiveT : " | MS_LOG(ERROR) << "Unsupported primitive type in UnPackFromSchemaPrimitiveT : " | ||||
| << schema::EnumNamePrimitiveType(op_type); | << schema::EnumNamePrimitiveType(op_type); | ||||
| return nullptr; | |||||
| break; | |||||
| } | } | ||||
| return nullptr; | |||||
| } | } | ||||
| #else | #else | ||||
| PrimitiveC *PrimitiveC::UnPackFromSchemaPrimitive(mindspore::schema::Primitive *primitive) { | PrimitiveC *PrimitiveC::UnPackFromSchemaPrimitive(mindspore::schema::Primitive *primitive) { | ||||
| @@ -433,6 +454,8 @@ PrimitiveC *PrimitiveC::UnPackFromSchemaPrimitive(mindspore::schema::Primitive * | |||||
| return new Reduce(const_cast<schema::Primitive *>(primitive)); | return new Reduce(const_cast<schema::Primitive *>(primitive)); | ||||
| case schema::PrimitiveType_Pooling: | case schema::PrimitiveType_Pooling: | ||||
| return new Pooling(const_cast<schema::Primitive *>(primitive)); | return new Pooling(const_cast<schema::Primitive *>(primitive)); | ||||
| case schema::PrimitiveType_ROIPooling: | |||||
| return new ROIPooling(const_cast<schema::Primitive *>(primitive)); | |||||
| case schema::PrimitiveType_DepthwiseConv2D: | case schema::PrimitiveType_DepthwiseConv2D: | ||||
| return new DepthwiseConv2D(const_cast<schema::Primitive *>(primitive)); | return new DepthwiseConv2D(const_cast<schema::Primitive *>(primitive)); | ||||
| case schema::PrimitiveType_FusedBatchNorm: | case schema::PrimitiveType_FusedBatchNorm: | ||||
| @@ -443,6 +466,8 @@ PrimitiveC *PrimitiveC::UnPackFromSchemaPrimitive(mindspore::schema::Primitive * | |||||
| return new FullConnection(const_cast<schema::Primitive *>(primitive)); | return new FullConnection(const_cast<schema::Primitive *>(primitive)); | ||||
| case schema::PrimitiveType_Power: | case schema::PrimitiveType_Power: | ||||
| return new Power(const_cast<schema::Primitive *>(primitive)); | return new Power(const_cast<schema::Primitive *>(primitive)); | ||||
| case schema::PrimitiveType_Pad: | |||||
| return new Pad(const_cast<schema::Primitive *>(primitive)); | |||||
| case schema::PrimitiveType_Range: | case schema::PrimitiveType_Range: | ||||
| return new Range(const_cast<schema::Primitive *>(primitive)); | return new Range(const_cast<schema::Primitive *>(primitive)); | ||||
| case schema::PrimitiveType_Mul: | case schema::PrimitiveType_Mul: | ||||
| @@ -469,20 +494,22 @@ PrimitiveC *PrimitiveC::UnPackFromSchemaPrimitive(mindspore::schema::Primitive * | |||||
| return new Scale(const_cast<schema::Primitive *>(primitive)); | return new Scale(const_cast<schema::Primitive *>(primitive)); | ||||
| case schema::PrimitiveType_Eltwise: | case schema::PrimitiveType_Eltwise: | ||||
| return new Eltwise(const_cast<schema::Primitive *>(primitive)); | return new Eltwise(const_cast<schema::Primitive *>(primitive)); | ||||
| case schema::PrimitiveType_Ceil: | |||||
| return new Ceil(const_cast<schema::Primitive *>(primitive)); | |||||
| case schema::PrimitiveType_Concat: | case schema::PrimitiveType_Concat: | ||||
| return new Concat(const_cast<schema::Primitive *>(primitive)); | return new Concat(const_cast<schema::Primitive *>(primitive)); | ||||
| case schema::PrimitiveType_Fill: | case schema::PrimitiveType_Fill: | ||||
| return new Fill(const_cast<schema::Primitive *>(primitive)); | return new Fill(const_cast<schema::Primitive *>(primitive)); | ||||
| case schema::PrimitiveType_Nhwc2Nchw: | |||||
| return new Nhwc2Nchw(const_cast<schema::Primitive *>(primitive)); | |||||
| case schema::PrimitiveType_Nchw2Nhwc: | |||||
| return new Nchw2Nhwc(const_cast<schema::Primitive *>(primitive)); | |||||
| case schema::PrimitiveType_Transpose: | case schema::PrimitiveType_Transpose: | ||||
| return new Transpose(const_cast<schema::Primitive *>(primitive)); | return new Transpose(const_cast<schema::Primitive *>(primitive)); | ||||
| case schema::PrimitiveType_Slice: | case schema::PrimitiveType_Slice: | ||||
| return new Slice(const_cast<schema::Primitive *>(primitive)); | return new Slice(const_cast<schema::Primitive *>(primitive)); | ||||
| case schema::PrimitiveType_Squeeze: | case schema::PrimitiveType_Squeeze: | ||||
| return new Squeeze(const_cast<schema::Primitive *>(primitive)); | return new Squeeze(const_cast<schema::Primitive *>(primitive)); | ||||
| case schema::PrimitiveType_Nchw2Nhwc: | |||||
| return new Nchw2Nhwc(const_cast<schema::Primitive *>(primitive)); | |||||
| case schema::PrimitiveType_Nhwc2Nchw: | |||||
| return new Nhwc2Nchw(const_cast<schema::Primitive *>(primitive)); | |||||
| case schema::PrimitiveType_Flatten: | case schema::PrimitiveType_Flatten: | ||||
| return new Flatten(const_cast<schema::Primitive *>(primitive)); | return new Flatten(const_cast<schema::Primitive *>(primitive)); | ||||
| case schema::PrimitiveType_Mean: | case schema::PrimitiveType_Mean: | ||||
| @@ -521,8 +548,6 @@ PrimitiveC *PrimitiveC::UnPackFromSchemaPrimitive(mindspore::schema::Primitive * | |||||
| return new Maximum(const_cast<schema::Primitive *>(primitive)); | return new Maximum(const_cast<schema::Primitive *>(primitive)); | ||||
| case schema::PrimitiveType_Minimum: | case schema::PrimitiveType_Minimum: | ||||
| return new Minimum(const_cast<schema::Primitive *>(primitive)); | return new Minimum(const_cast<schema::Primitive *>(primitive)); | ||||
| case schema::PrimitiveType_Pad: | |||||
| return new Pad(const_cast<schema::Primitive *>(primitive)); | |||||
| case schema::PrimitiveType_StridedSlice: | case schema::PrimitiveType_StridedSlice: | ||||
| return new StridedSlice(const_cast<schema::Primitive *>(primitive)); | return new StridedSlice(const_cast<schema::Primitive *>(primitive)); | ||||
| case schema::PrimitiveType_Prelu: | case schema::PrimitiveType_Prelu: | ||||
| @@ -559,12 +584,12 @@ PrimitiveC *PrimitiveC::UnPackFromSchemaPrimitive(mindspore::schema::Primitive * | |||||
| return new GreaterEqual(const_cast<schema::Primitive *>(primitive)); | return new GreaterEqual(const_cast<schema::Primitive *>(primitive)); | ||||
| case schema::PrimitiveType_Floor: | case schema::PrimitiveType_Floor: | ||||
| return new Floor(const_cast<schema::Primitive *>(primitive)); | return new Floor(const_cast<schema::Primitive *>(primitive)); | ||||
| case schema::PrimitiveType_Ceil: | |||||
| return new Ceil(const_cast<schema::Primitive *>(primitive)); | |||||
| case schema::PrimitiveType_Split: | case schema::PrimitiveType_Split: | ||||
| return new Split(const_cast<schema::Primitive *>(primitive)); | return new Split(const_cast<schema::Primitive *>(primitive)); | ||||
| case schema::PrimitiveType_OneHot: | case schema::PrimitiveType_OneHot: | ||||
| return new OneHot(const_cast<schema::Primitive *>(primitive)); | return new OneHot(const_cast<schema::Primitive *>(primitive)); | ||||
| case schema::PrimitiveType_PriorBox: | |||||
| return new PriorBox(const_cast<schema::Primitive *>(primitive)); | |||||
| case schema::PrimitiveType_SpaceToDepth: | case schema::PrimitiveType_SpaceToDepth: | ||||
| return new SpaceToDepth(const_cast<schema::Primitive *>(primitive)); | return new SpaceToDepth(const_cast<schema::Primitive *>(primitive)); | ||||
| case schema::PrimitiveType_Tile: | case schema::PrimitiveType_Tile: | ||||
| @@ -591,7 +616,29 @@ PrimitiveC *PrimitiveC::UnPackFromSchemaPrimitive(mindspore::schema::Primitive * | |||||
| return new Shape(const_cast<schema::Primitive *>(primitive)); | return new Shape(const_cast<schema::Primitive *>(primitive)); | ||||
| case schema::PrimitiveType_Unsqueeze: | case schema::PrimitiveType_Unsqueeze: | ||||
| return new Unsqueeze(const_cast<schema::Primitive *>(primitive)); | return new Unsqueeze(const_cast<schema::Primitive *>(primitive)); | ||||
| case schema::PrimitiveType_BatchToSpace: | |||||
| return new BatchToSpace(const_cast<schema::Primitive *>(primitive)); | |||||
| case schema::PrimitiveType_SpaceToBatch: | |||||
| return new SpaceToBatch(const_cast<schema::Primitive *>(primitive)); | |||||
| case schema::PrimitiveType_BroadcastTo: | |||||
| return new BroadcastTo(const_cast<schema::Primitive *>(primitive)); | |||||
| case schema::PrimitiveType_DepthToSpace: | |||||
| return new DepthToSpace(const_cast<schema::Primitive *>(primitive)); | |||||
| case schema::PrimitiveType_Lstm: | |||||
| return new Lstm(const_cast<schema::Primitive *>(primitive)); | |||||
| case schema::PrimitiveType_ZerosLike: | |||||
| return new ZerosLike(const_cast<schema::Primitive *>(primitive)); | |||||
| case schema::PrimitiveType_MakeTuple: | |||||
| return new MakeTuple(const_cast<schema::Primitive *>(primitive)); | |||||
| case schema::PrimitiveType_Where: | |||||
| return new Where(const_cast<schema::Primitive *>(primitive)); | |||||
| case schema::PrimitiveType_ScatterND: | |||||
| return new ScatterND(const_cast<schema::Primitive *>(primitive)); | |||||
| case schema::PrimitiveType_ConstantOfShape: | |||||
| return new ConstantOfShape(const_cast<schema::Primitive *>(primitive)); | |||||
| default: | default: | ||||
| MS_LOG(ERROR) << "Unsupported primitive type in UnPackFromSchemaPrimitive : " | |||||
| << schema::EnumNamePrimitiveType(op_type); | |||||
| break; | break; | ||||
| } | } | ||||
| return nullptr; | return nullptr; | ||||
| @@ -25,10 +25,10 @@ namespace mindspore { | |||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| int Reshape::GetFormat() const { return this->primitive_->value.AsReshape()->format; } | int Reshape::GetFormat() const { return this->primitive_->value.AsReshape()->format; } | ||||
| std::vector<long> Reshape::GetShape() const { return this->primitive_->value.AsReshape()->shape; } | |||||
| std::vector<int64_t> Reshape::GetShape() const { return this->primitive_->value.AsReshape()->shape; } | |||||
| void Reshape::SetFormat(int format) { this->primitive_->value.AsReshape()->format = (schema::Format)format; } | void Reshape::SetFormat(int format) { this->primitive_->value.AsReshape()->format = (schema::Format)format; } | ||||
| void Reshape::SetShape(const std::vector<long> &shape) { this->primitive_->value.AsReshape()->shape = shape; } | |||||
| void Reshape::SetShape(const std::vector<int64_t> &shape) { this->primitive_->value.AsReshape()->shape = shape; } | |||||
| int Reshape::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | int Reshape::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | ||||
| this->primitive_ = new (schema::PrimitiveT); | this->primitive_ = new (schema::PrimitiveT); | ||||
| auto attr = std::make_unique<schema::ReshapeT>(); | auto attr = std::make_unique<schema::ReshapeT>(); | ||||
| @@ -59,13 +59,13 @@ int Reshape::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &in | |||||
| #else | #else | ||||
| int Reshape::GetFormat() const { return this->primitive_->value_as_Reshape()->format(); } | int Reshape::GetFormat() const { return this->primitive_->value_as_Reshape()->format(); } | ||||
| std::vector<long> Reshape::GetShape() const { | |||||
| std::vector<int64_t> Reshape::GetShape() const { | |||||
| auto fb_vector = this->primitive_->value_as_Reshape()->shape(); | auto fb_vector = this->primitive_->value_as_Reshape()->shape(); | ||||
| return std::vector<long>(fb_vector->begin(), fb_vector->end()); | |||||
| return std::vector<int64_t>(fb_vector->begin(), fb_vector->end()); | |||||
| } | } | ||||
| void Reshape::SetFormat(int format) {} | void Reshape::SetFormat(int format) {} | ||||
| void Reshape::SetShape(const std::vector<long> &shape) {} | |||||
| void Reshape::SetShape(const std::vector<int64_t> &shape) {} | |||||
| #endif | #endif | ||||
| int Reshape::CalNewShape(const tensor::Tensor *in_tensor, std::vector<int> *out_shape) const { | int Reshape::CalNewShape(const tensor::Tensor *in_tensor, std::vector<int> *out_shape) const { | ||||
| @@ -36,9 +36,9 @@ class Reshape : public PrimitiveC { | |||||
| #endif | #endif | ||||
| int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override; | int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override; | ||||
| int GetFormat() const; | int GetFormat() const; | ||||
| std::vector<long> GetShape() const; | |||||
| std::vector<int64_t> GetShape() const; | |||||
| void SetFormat(int format); | void SetFormat(int format); | ||||
| void SetShape(const std::vector<long> &shape); | |||||
| void SetShape(const std::vector<int64_t> &shape); | |||||
| private: | private: | ||||
| int CalNewShape(const lite::tensor::Tensor *in_tensor, std::vector<int> *out_shape) const; | int CalNewShape(const lite::tensor::Tensor *in_tensor, std::vector<int> *out_shape) const; | ||||
| @@ -21,15 +21,15 @@ namespace lite { | |||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| int Resize::GetFormat() const { return this->primitive_->value.AsResize()->format; } | int Resize::GetFormat() const { return this->primitive_->value.AsResize()->format; } | ||||
| int Resize::GetMethod() const { return this->primitive_->value.AsResize()->method; } | int Resize::GetMethod() const { return this->primitive_->value.AsResize()->method; } | ||||
| long Resize::GetNewHeight() const { return this->primitive_->value.AsResize()->newHeight; } | |||||
| long Resize::GetNewWidth() const { return this->primitive_->value.AsResize()->newWidth; } | |||||
| int64_t Resize::GetNewHeight() const { return this->primitive_->value.AsResize()->newHeight; } | |||||
| int64_t Resize::GetNewWidth() const { return this->primitive_->value.AsResize()->newWidth; } | |||||
| bool Resize::GetAlignCorners() const { return this->primitive_->value.AsResize()->alignCorners; } | bool Resize::GetAlignCorners() const { return this->primitive_->value.AsResize()->alignCorners; } | ||||
| bool Resize::GetPreserveAspectRatio() const { return this->primitive_->value.AsResize()->preserveAspectRatio; } | bool Resize::GetPreserveAspectRatio() const { return this->primitive_->value.AsResize()->preserveAspectRatio; } | ||||
| void Resize::SetFormat(int format) { this->primitive_->value.AsResize()->format = (schema::Format)format; } | void Resize::SetFormat(int format) { this->primitive_->value.AsResize()->format = (schema::Format)format; } | ||||
| void Resize::SetMethod(int method) { this->primitive_->value.AsResize()->method = (schema::ResizeMethod)method; } | void Resize::SetMethod(int method) { this->primitive_->value.AsResize()->method = (schema::ResizeMethod)method; } | ||||
| void Resize::SetNewHeight(long new_height) { this->primitive_->value.AsResize()->newHeight = new_height; } | |||||
| void Resize::SetNewWidth(long new_width) { this->primitive_->value.AsResize()->newWidth = new_width; } | |||||
| void Resize::SetNewHeight(int64_t new_height) { this->primitive_->value.AsResize()->newHeight = new_height; } | |||||
| void Resize::SetNewWidth(int64_t new_width) { this->primitive_->value.AsResize()->newWidth = new_width; } | |||||
| void Resize::SetAlignCorners(bool align_corners) { this->primitive_->value.AsResize()->alignCorners = align_corners; } | void Resize::SetAlignCorners(bool align_corners) { this->primitive_->value.AsResize()->alignCorners = align_corners; } | ||||
| void Resize::SetPreserveAspectRatio(bool preserve_aspect_ratio) { | void Resize::SetPreserveAspectRatio(bool preserve_aspect_ratio) { | ||||
| this->primitive_->value.AsResize()->preserveAspectRatio = preserve_aspect_ratio; | this->primitive_->value.AsResize()->preserveAspectRatio = preserve_aspect_ratio; | ||||
| @@ -39,15 +39,15 @@ void Resize::SetPreserveAspectRatio(bool preserve_aspect_ratio) { | |||||
| int Resize::GetFormat() const { return this->primitive_->value_as_Resize()->format(); } | int Resize::GetFormat() const { return this->primitive_->value_as_Resize()->format(); } | ||||
| int Resize::GetMethod() const { return this->primitive_->value_as_Resize()->method(); } | int Resize::GetMethod() const { return this->primitive_->value_as_Resize()->method(); } | ||||
| long Resize::GetNewHeight() const { return this->primitive_->value_as_Resize()->newHeight(); } | |||||
| long Resize::GetNewWidth() const { return this->primitive_->value_as_Resize()->newWidth(); } | |||||
| int64_t Resize::GetNewHeight() const { return this->primitive_->value_as_Resize()->newHeight(); } | |||||
| int64_t Resize::GetNewWidth() const { return this->primitive_->value_as_Resize()->newWidth(); } | |||||
| bool Resize::GetAlignCorners() const { return this->primitive_->value_as_Resize()->alignCorners(); } | bool Resize::GetAlignCorners() const { return this->primitive_->value_as_Resize()->alignCorners(); } | ||||
| bool Resize::GetPreserveAspectRatio() const { return this->primitive_->value_as_Resize()->preserveAspectRatio(); } | bool Resize::GetPreserveAspectRatio() const { return this->primitive_->value_as_Resize()->preserveAspectRatio(); } | ||||
| void Resize::SetFormat(int format) {} | void Resize::SetFormat(int format) {} | ||||
| void Resize::SetMethod(int method) {} | void Resize::SetMethod(int method) {} | ||||
| void Resize::SetNewHeight(long new_height) {} | |||||
| void Resize::SetNewWidth(long new_width) {} | |||||
| void Resize::SetNewHeight(int64_t new_height) {} | |||||
| void Resize::SetNewWidth(int64_t new_width) {} | |||||
| void Resize::SetAlignCorners(bool align_corners) {} | void Resize::SetAlignCorners(bool align_corners) {} | ||||
| void Resize::SetPreserveAspectRatio(bool preserve_aspect_ratio) {} | void Resize::SetPreserveAspectRatio(bool preserve_aspect_ratio) {} | ||||
| #endif | #endif | ||||
| @@ -36,14 +36,14 @@ class Resize : public PrimitiveC { | |||||
| int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override; | int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override; | ||||
| int GetFormat() const; | int GetFormat() const; | ||||
| int GetMethod() const; | int GetMethod() const; | ||||
| long GetNewHeight() const; | |||||
| long GetNewWidth() const; | |||||
| int64_t GetNewHeight() const; | |||||
| int64_t GetNewWidth() const; | |||||
| bool GetAlignCorners() const; | bool GetAlignCorners() const; | ||||
| bool GetPreserveAspectRatio() const; | bool GetPreserveAspectRatio() const; | ||||
| void SetFormat(int format); | void SetFormat(int format); | ||||
| void SetMethod(int method); | void SetMethod(int method); | ||||
| void SetNewHeight(long new_height); | |||||
| void SetNewWidth(long new_width); | |||||
| void SetNewHeight(int64_t new_height); | |||||
| void SetNewWidth(int64_t new_width); | |||||
| void SetAlignCorners(bool align_corners); | void SetAlignCorners(bool align_corners); | ||||
| void SetPreserveAspectRatio(bool preserve_aspect_ratio); | void SetPreserveAspectRatio(bool preserve_aspect_ratio); | ||||
| }; | }; | ||||