From: @cjh9368 Reviewed-by: @hangangqiang,@zhang_xue_tong Signed-off-by: @hangangqiangtags/v1.2.0-rc1
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -90,9 +90,7 @@ int Select::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outp | |||||
| output->set_shape(input->shape()); | output->set_shape(input->shape()); | ||||
| output->set_format(input->format()); | output->set_format(input->format()); | ||||
| auto data_type = input->data_type(); | auto data_type = input->data_type(); | ||||
| if (data_type != kObjectTypeTensorType) { | |||||
| continue; | |||||
| } else { | |||||
| if (data_type == kObjectTypeTensorType) { | |||||
| auto input_tensorlist = reinterpret_cast<TensorList *>(input); | auto input_tensorlist = reinterpret_cast<TensorList *>(input); | ||||
| auto output_tensorlist = reinterpret_cast<TensorList *>(output); | auto output_tensorlist = reinterpret_cast<TensorList *>(output); | ||||
| output_tensorlist->set_element_shape(input_tensorlist->element_shape()); | output_tensorlist->set_element_shape(input_tensorlist->element_shape()); | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -22,18 +22,41 @@ namespace mindspore { | |||||
| namespace lite { | namespace lite { | ||||
| PrimitiveC *CaffeReduceParser::ParseLitePrimitive(const caffe::LayerParameter &proto, | PrimitiveC *CaffeReduceParser::ParseLitePrimitive(const caffe::LayerParameter &proto, | ||||
| const caffe::LayerParameter &weight) { | const caffe::LayerParameter &weight) { | ||||
| std::unique_ptr<schema::PReLUT> attr = std::make_unique<schema::PReLUT>(); | |||||
| auto attr = std::make_unique<schema::ReduceT>(); | |||||
| if (attr == nullptr) { | if (attr == nullptr) { | ||||
| MS_LOG(ERROR) << "new op failed"; | MS_LOG(ERROR) << "new op failed"; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| const caffe::PReLUParameter &pReluParam = proto.prelu_param(); | |||||
| if (pReluParam.has_channel_shared()) { | |||||
| attr->channelShared = pReluParam.channel_shared(); | |||||
| attr->keepDims = false; | |||||
| const caffe::ReductionParameter &reduce_param = proto.reduction_param(); | |||||
| if (reduce_param.has_operation()) { | |||||
| if (reduce_param.operation() == caffe::ReductionParameter_ReductionOp_MEAN) { | |||||
| attr->mode = schema::ReduceMode_ReduceMean; | |||||
| } else if (reduce_param.operation() == caffe::ReductionParameter_ReductionOp_SUM) { | |||||
| attr->mode = schema::ReduceMode_ReduceSum; | |||||
| } else if (reduce_param.operation() == caffe::ReductionParameter_ReductionOp_SUMSQ) { | |||||
| attr->mode = schema::ReduceMode_ReduceSumSquare; | |||||
| } else if (reduce_param.operation() == caffe::ReductionParameter_ReductionOp_ASUM) { | |||||
| attr->mode = schema::ReduceMode_ReduceASum; | |||||
| } else { | |||||
| MS_LOG(ERROR) << "nsupported reduce mode: " << reduce_param.operation(); | |||||
| return nullptr; | |||||
| } | |||||
| } else { | |||||
| attr->mode = schema::ReduceMode_ReduceSum; | |||||
| } | |||||
| std::vector<int32_t> axes; | |||||
| if (reduce_param.has_axis()) { | |||||
| axes.push_back(1); | |||||
| axes.push_back(reduce_param.axis()); | |||||
| } else { | } else { | ||||
| attr->channelShared = false; | |||||
| axes.push_back(1); | |||||
| axes.push_back(0); | |||||
| } | } | ||||
| attr->axes = axes; | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | auto primitive = std::make_unique<schema::PrimitiveT>(); | ||||
| primitive->value.type = schema::PrimitiveType_Reduce; | primitive->value.type = schema::PrimitiveType_Reduce; | ||||