Merge pull request !4784 from yeyunpeng2020/master_cops_4tags/v0.7.0-beta
| @@ -14,10 +14,10 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/ops/constant_of_shape.h" | |||
| #include "include/errorcode.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "src/ir/tensor.h" | |||
| #include "src/ops/constant_of_shape.h" | |||
| namespace mindspore::lite { | |||
| namespace { | |||
| @@ -25,9 +25,9 @@ constexpr int kShapeInputNum = 1; | |||
| constexpr int kShapeOutputNum = 1; | |||
| } // namespace | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| int ConstantOfShape::GetValue() const { return this->primitive->value.AsConstantOfShape()->Value; } | |||
| float ConstantOfShape::GetValue() const { return this->primitive->value.AsConstantOfShape()->value; } | |||
| void ConstantOfShape::SetValue(float value) { this->primitive->value.AsConstantOfShape()->Value = value; } | |||
| void ConstantOfShape::SetValue(float value) { this->primitive->value.AsConstantOfShape()->value = value; } | |||
| #else | |||
| @@ -104,19 +104,18 @@ void Conv2D::SetActivationType(int activation_type) {} | |||
| #endif | |||
| void Conv2D::ConvInferShape(int input_h, int input_w, int *output_h, int *output_w) { | |||
| MS_ASSERT(this->primitive != nullptr); | |||
| auto conv2DPrim = this->primitive->value_as_Conv2D(); | |||
| int kernel_w = conv2DPrim->kernelW(); | |||
| int kernel_h = conv2DPrim->kernelH(); | |||
| int stride_w = conv2DPrim->strideW(); | |||
| int stride_h = conv2DPrim->strideH(); | |||
| int dilate_w = conv2DPrim->dilateW(); | |||
| int dilate_h = conv2DPrim->dilateH(); | |||
| pad_l_ = conv2DPrim->padLeft(); | |||
| pad_u_ = conv2DPrim->padUp(); | |||
| pad_d_ = conv2DPrim->padDown(); | |||
| pad_r_ = conv2DPrim->padRight(); | |||
| int kernel_w = GetKernelW(); | |||
| int kernel_h = GetKernelH(); | |||
| int stride_w = GetStrideW(); | |||
| int stride_h = GetStrideH(); | |||
| int dilate_w = GetDilateW(); | |||
| int dilate_h = GetDilateH(); | |||
| pad_l_ = GetPadLeft(); | |||
| pad_u_ = GetPadUp(); | |||
| pad_d_ = GetPadDown(); | |||
| pad_r_ = GetPadRight(); | |||
| if (conv2DPrim->padMode() == schema::PadMode_SAME) { | |||
| if (GetPadMode() == schema::PadMode_SAME) { | |||
| *output_w = std::ceil(static_cast<float>(input_w) / static_cast<float>(stride_w)); | |||
| *output_h = std::ceil(static_cast<float>(input_h) / static_cast<float>(stride_h)); | |||
| auto pad_h_all = ((*output_h - 1) * stride_h + (kernel_h - 1) * dilate_h + 1 - input_h); | |||
| @@ -23,7 +23,7 @@ int DepthToSpace::GetBlockSize() const { return this->primitive->value.AsDepthTo | |||
| int DepthToSpace::GetFormat() const { return this->primitive->value.AsDepthToSpace()->format; } | |||
| void DepthToSpace::SetBlockSize(int block_size) { this->primitive->value.AsDepthToSpace()->blockSize = block_size; } | |||
| void DepthToSpace::SetFormat(int format) { this->primitive->value.AsDepthToSpace()->format = format; } | |||
| void DepthToSpace::SetFormat(int format) { this->primitive->value.AsDepthToSpace()->format = (schema::Format)format; } | |||
| #else | |||
| @@ -50,13 +50,12 @@ int Fill::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:: | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| auto fill_prim = this->primitive->value_as_Fill(); | |||
| if (fill_prim == nullptr) { | |||
| MS_LOG(ERROR) << "Fill primitive is null!"; | |||
| return RET_ERROR; | |||
| } | |||
| std::vector<int> output_shape; | |||
| (void)output_shape.insert(output_shape.begin(), fill_prim->dims()->begin(), fill_prim->dims()->end()); | |||
| for (int i = 0; i < GetDims().size(); i++) { | |||
| output_shape.push_back(GetDims()[i]); | |||
| } | |||
| // (void)output_shape.insert(output_shape.begin(), GetDims().begin(), GetDims().end()); | |||
| output->set_shape(output_shape); | |||
| return RET_OK; | |||
| } | |||
| @@ -22,13 +22,13 @@ namespace lite { | |||
| bool FullConnection::GetHasBias() const { return this->primitive->value.AsFullConnection()->hasBias; } | |||
| int FullConnection::GetAxis() const { return this->primitive->value.AsFullConnection()->axis; } | |||
| bool FullConnection::GetUseAxis() const { return this->primitive->value.AsFullConnection()->useAxis; } | |||
| int FullConnection::GetActivationType() const { return this->primitive->value.AsFullConnection()->activationType(); } | |||
| int FullConnection::GetActivationType() const { return this->primitive->value.AsFullConnection()->activationType; } | |||
| void FullConnection::SetHasBias(bool has_bias) { this->primitive->value.AsFullConnection()->hasBias = has_bias; } | |||
| void FullConnection::SetAxis(int axis) { this->primitive->value.AsFullConnection()->axis = axis; } | |||
| void FullConnection::SetUseAxis(bool use_axis) { this->primitive->value.AsFullConnection()->useAxis = use_axis; } | |||
| void FullConnection::SetActivationType(int activationType) { | |||
| his->primitive->value.AsFullConnection()->activationType = (schema::ActivationType)activationType; | |||
| this->primitive->value.AsFullConnection()->activationType = (schema::ActivationType) activationType; | |||
| } | |||
| #else | |||
| @@ -21,7 +21,9 @@ namespace lite { | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| int Mul::GetActivationType() const { return this->primitive->value.AsMul()->activationType; } | |||
| void Mul::SetActivationType(int activation_type) { this->primitive->value.AsMul()->activationType = activation_type; } | |||
| void Mul::SetActivationType(int activation_type) { | |||
| this->primitive->value.AsMul()->activationType = (schema::ActivationType) activation_type; | |||
| } | |||
| #else | |||
| @@ -24,7 +24,9 @@ int Pad::GetPaddingMode() const { return this->primitive->value.AsPad()->padding | |||
| float Pad::GetConstantValue() const { return this->primitive->value.AsPad()->constantValue; } | |||
| void Pad::SetPaddings(const std::vector<int> &paddings) { this->primitive->value.AsPad()->paddings = paddings; } | |||
| void Pad::SetPaddingMode(int padding_mode) { this->primitive->value.AsPad()->paddingMode = padding_mode; } | |||
| void Pad::SetPaddingMode(int padding_mode) { | |||
| this->primitive->value.AsPad()->paddingMode = (schema::PaddingMode) padding_mode; | |||
| } | |||
| void Pad::SetConstantValue(float constant_value) { this->primitive->value.AsPad()->constantValue = constant_value; } | |||
| #else | |||
| @@ -34,22 +34,22 @@ int Pooling::GetPadLeft() const { return this->primitive->value.AsPooling()->pad | |||
| int Pooling::GetPadRight() const { return this->primitive->value.AsPooling()->padRight; } | |||
| int Pooling::GetRoundMode() const { return this->primitive->value.AsPooling()->roundMode; } | |||
| void Pooling::SetFormat(int format) { this->primitive->value.AsPooling()->format = (schema::Format)format; } | |||
| void Pooling::SetFormat(int format) { this->primitive->value.AsPooling()->format = (schema::Format) format; } | |||
| void Pooling::SetPoolingMode(int pooling_mode) { | |||
| this->primitive->value.AsPooling()->poolingMode = (schema::PoolMode)pooling_mode; | |||
| this->primitive->value.AsPooling()->poolingMode = (schema::PoolMode) pooling_mode; | |||
| } | |||
| void Pooling::SetGlobal(bool global) { this->primitive->value.AsPooling()->global = global; } | |||
| void Pooling::SetWindowW(int window_w) { this->primitive->value.AsPooling()->windowW = window_w; } | |||
| void Pooling::SetWindowH(int window_h) { this->primitive->value.AsPooling()->windowH = window_h; } | |||
| void Pooling::SetStrideW(int stride_w) { this->primitive->value.AsPooling()->strideW = stride_w; } | |||
| void Pooling::SetStrideH(int stride_h) { this->primitive->value.AsPooling()->strideH = stride_h; } | |||
| void Pooling::SetPadMode(int pad_mode) { this->primitive->value.AsPooling()->padMode = (schema::PadMode)pad_mode; } | |||
| void Pooling::SetPadMode(int pad_mode) { this->primitive->value.AsPooling()->padMode = (schema::PadMode) pad_mode; } | |||
| void Pooling::SetPadUp(int pad_up) { this->primitive->value.AsPooling()->padUp = pad_up; } | |||
| void Pooling::SetPadDown(int pad_down) { this->primitive->value.AsPooling()->padDown = pad_down; } | |||
| void Pooling::SetPadLeft(int pad_left) { this->primitive->value.AsPooling()->padLeft = pad_left; } | |||
| void Pooling::SetPadRight(int pad_right) { this->primitive->value.AsPooling()->padRight = pad_right; } | |||
| void Pooling::SetRoundMode(int round_mode) { | |||
| this->primitive->value.AsPooling()->roundMode = (schema::RoundMode)round_mode; | |||
| this->primitive->value.AsPooling()->roundMode = (schema::RoundMode) round_mode; | |||
| } | |||
| #else | |||
| @@ -82,13 +82,13 @@ void Pooling::SetPadLeft(int pad_left) {} | |||
| void Pooling::SetPadRight(int pad_right) {} | |||
| void Pooling::SetRoundMode(int round_mode) {} | |||
| #endif | |||
| int Pooling::PadUp() const { return this->pad_u_; } | |||
| int Pooling::PadDown() const { return this->pad_d_; } | |||
| int Pooling::PadLeft() const { return this->pad_l_; } | |||
| int Pooling::PadRight() const { return this->pad_r_; } | |||
| #endif | |||
| int Pooling::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) { | |||
| MS_ASSERT(this->primitive != nullptr); | |||
| auto input = inputs_.front(); | |||
| @@ -102,37 +102,37 @@ int Pooling::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso | |||
| } | |||
| int input_h = input->shape().at(1); | |||
| int input_w = input->shape().at(2); | |||
| auto pooling_prim = this->primitive->value_as_Pooling(); | |||
| MS_ASSERT(pooling_prim != nullptr); | |||
| auto window_h = pooling_prim->windowH(); | |||
| auto window_w = pooling_prim->windowW(); | |||
| if (pooling_prim->global()) { | |||
| auto window_h = GetWindowH(); | |||
| auto window_w = GetWindowW(); | |||
| if (GetGlobal()) { | |||
| window_h = input_h; | |||
| window_w = input_w; | |||
| } | |||
| int output_h = 0; | |||
| int output_w = 0; | |||
| pad_l_ = pooling_prim->padLeft(); | |||
| pad_u_ = pooling_prim->padUp(); | |||
| pad_d_ = pooling_prim->padDown(); | |||
| pad_r_ = pooling_prim->padRight(); | |||
| if (pooling_prim->padMode() == schema::PadMode_SAME) { | |||
| output_w = std::ceil(static_cast<float>(input_w) / static_cast<float>(pooling_prim->strideW())); | |||
| output_h = std::ceil(static_cast<float>(input_h) / static_cast<float>(pooling_prim->strideH())); | |||
| auto pad_h_all = ((output_h - 1) * pooling_prim->strideH() + (window_h - 1) + 1 - input_h); | |||
| auto pad_w_all = ((output_w - 1) * pooling_prim->strideW() + (window_w - 1) + 1 - input_w); | |||
| pad_l_ = GetPadLeft(); | |||
| pad_u_ = GetPadUp(); | |||
| pad_d_ = GetPadDown(); | |||
| pad_r_ = GetPadRight(); | |||
| if (GetPadMode() == schema::PadMode_SAME) { | |||
| output_w = std::ceil(static_cast<float>(input_w) / static_cast<float>(GetStrideW())); | |||
| output_h = std::ceil(static_cast<float>(input_h) / static_cast<float>(GetStrideH())); | |||
| auto pad_h_all = ((output_h - 1) * GetStrideH() + (window_h - 1) + 1 - input_h); | |||
| auto pad_w_all = ((output_w - 1) * GetStrideW() + (window_w - 1) + 1 - input_w); | |||
| pad_u_ = pad_h_all / 2; | |||
| pad_d_ = pad_h_all - pad_u_; | |||
| pad_l_ = pad_w_all / 2; | |||
| pad_r_ = pad_w_all - pad_l_; | |||
| } else { | |||
| auto round_mode = pooling_prim->roundMode(); | |||
| auto round_mode = (schema::RoundMode) GetRoundMode(); | |||
| if (round_mode == schema::RoundMode_FLOOR) { | |||
| output_h = std::floor(static_cast<float>(input_h + pad_u_ + pad_d_ - window_h) / pooling_prim->strideH()) + 1; | |||
| output_w = std::floor(static_cast<float>(input_w + pad_l_ + pad_r_ - window_w) / pooling_prim->strideW()) + 1; | |||
| output_h = std::floor(static_cast<float>(input_h + pad_u_ + pad_d_ - window_h) / GetStrideH()) + 1; | |||
| output_w = std::floor(static_cast<float>(input_w + pad_l_ + pad_r_ - window_w) / GetStrideW()) + 1; | |||
| } else if (round_mode == schema::RoundMode_CEIL) { | |||
| output_h = std::ceil(static_cast<float>(input_h + pad_u_ + pad_d_ - window_h) / pooling_prim->strideH()) + 1; | |||
| output_w = std::ceil(static_cast<float>(input_w + pad_l_ + pad_r_ - window_w) / pooling_prim->strideW()) + 1; | |||
| output_h = std::ceil(static_cast<float>(input_h + pad_u_ + pad_d_ - window_h) / GetStrideH()) + 1; | |||
| output_w = std::ceil(static_cast<float>(input_w + pad_l_ + pad_r_ - window_w) / GetStrideW()) + 1; | |||
| } else { | |||
| MS_LOG(ERROR) << "unsupported round mode."; | |||
| } | |||
| @@ -26,7 +26,7 @@ namespace lite { | |||
| int Reshape::GetFormat() const { return this->primitive->value.AsReshape()->format; } | |||
| std::vector<long> Reshape::GetShape() const { return this->primitive->value.AsReshape()->shape; } | |||
| void Reshape::SetFormat(int format) { this->primitive->value.AsReshape()->format = format; } | |||
| void Reshape::SetFormat(int format) { this->primitive->value.AsReshape()->format = (schema::Format) format; } | |||
| void Reshape::SetShape(const std::vector<long> &shape) { this->primitive->value.AsReshape()->shape = shape; } | |||
| #else | |||
| @@ -75,7 +75,7 @@ int Reshape::CalNewShape(const tensor::Tensor *in_tensor, std::vector<int> *out_ | |||
| } | |||
| return RET_OK; | |||
| } | |||
| template <typename T> | |||
| template<typename T> | |||
| void CalShape(const T *data, const std::vector<tensor::Tensor *> &inputs, std::vector<int> *out_shape, int shape_size) { | |||
| int input_count = inputs[0]->ElementsNum(); | |||
| int index = 0; | |||
| @@ -103,7 +103,7 @@ int Reshape::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| auto reshape_prim = this->primitive->value_as_Reshape(); | |||
| MS_ASSERT(reshape_prim != nullptr); | |||
| std::vector<int> out_shape; | |||
| if (inputs_.size() == kDoubleNum) { | |||
| @@ -117,30 +117,38 @@ int Reshape::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso | |||
| case kNumberTypeInt8: { | |||
| auto data = reinterpret_cast<int8_t *>(shape_tensor->Data()); | |||
| CalShape<int8_t>(data, inputs_, &out_shape, shape_size); | |||
| } break; | |||
| } | |||
| break; | |||
| case kNumberTypeInt32: { | |||
| auto data = reinterpret_cast<int32_t *>(shape_tensor->Data()); | |||
| CalShape<int32_t>(data, inputs_, &out_shape, shape_size); | |||
| } break; | |||
| } | |||
| break; | |||
| case kNumberTypeInt64: { | |||
| auto data = reinterpret_cast<int64_t *>(shape_tensor->Data()); | |||
| CalShape<int64_t>(data, inputs_, &out_shape, shape_size); | |||
| } break; | |||
| } | |||
| break; | |||
| case kNumberTypeFloat: { | |||
| auto data = reinterpret_cast<float *>(shape_tensor->Data()); | |||
| CalShape<float>(data, inputs_, &out_shape, shape_size); | |||
| } break; | |||
| } | |||
| break; | |||
| case kNumberTypeUInt32: { | |||
| auto data = reinterpret_cast<uint32_t *>(shape_tensor->Data()); | |||
| CalShape<uint32_t>(data, inputs_, &out_shape, shape_size); | |||
| } break; | |||
| } | |||
| break; | |||
| default: { | |||
| MS_LOG(ERROR) << "Reshape weight tensor has unsupported dataType: " << shape_tensor->data_type(); | |||
| return RET_INFER_ERR; | |||
| } | |||
| } | |||
| } else if (inputs_.size() == kSingleNum) { | |||
| std::copy(reshape_prim->shape()->begin(), reshape_prim->shape()->end(), std::back_inserter(out_shape)); | |||
| for (int i = 0; i < GetShape().size(); ++i) { | |||
| out_shape.push_back(GetShape()[i]); | |||
| } | |||
| // std::copy(GetShape().begin(), GetShape().end(), std::back_inserter(out_shape)); | |||
| } else { | |||
| MS_LOG(ERROR) << "inputs tensor size invalid."; | |||
| return RET_INFER_ERR; | |||
| @@ -30,7 +30,7 @@ int SliceOp::GetFormat() const { return this->primitive->value.AsSlice()->format | |||
| std::vector<int> SliceOp::GetBegin() const { return this->primitive->value.AsSlice()->begin; } | |||
| std::vector<int> SliceOp::GetSize() const { return this->primitive->value.AsSlice()->size; } | |||
| void SliceOp::SetFormat(int format) { this->primitive->value.AsSlice()->format = format; } | |||
| void SliceOp::SetFormat(int format) { this->primitive->value.AsSlice()->format = (schema::Format)format; } | |||
| void SliceOp::SetBegin(const std::vector<int> &begin) { this->primitive->value.AsSlice()->begin = begin; } | |||
| void SliceOp::SetSize(const std::vector<int> &size) { this->primitive->value.AsSlice()->size = size; } | |||
| @@ -24,7 +24,7 @@ int SpaceToDepth::GetBlockSize() const { return this->primitive->value.AsSpaceTo | |||
| int SpaceToDepth::GetFormat() const { return this->primitive->value.AsSpaceToDepth()->format; } | |||
| void SpaceToDepth::SetBlockSize(int block_size) { this->primitive->value.AsSpaceToDepth()->blockSize = block_size; } | |||
| void SpaceToDepth::SetFormat(int format) { this->primitive->value.AsSpaceToDepth()->format = format; } | |||
| void SpaceToDepth::SetFormat(int format) { this->primitive->value.AsSpaceToDepth()->format = (schema::Format)format; } | |||
| #else | |||
| @@ -50,7 +50,6 @@ int Split::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor: | |||
| MS_ASSERT(this->primitive != nullptr); | |||
| auto input = inputs_.front(); | |||
| MS_ASSERT(input != nullptr); | |||
| auto spilt_prim = this->primitive->value_as_Split(); | |||
| MS_ASSERT(spilt_prim != nullptr); | |||
| if (inputs_.size() != kSplitInputNum) { | |||
| MS_LOG(ERROR) << "inputs number is not equal to " << kSplitInputNum; | |||
| @@ -61,7 +60,7 @@ int Split::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor: | |||
| MS_LOG(ERROR) << "output null pointer dereferencing."; | |||
| return RET_ERROR; | |||
| } | |||
| int number_split = spilt_prim->numberSplit(); | |||
| int number_split = GetNumberSplit(); | |||
| if (static_cast<int>(outputs_.size()) != number_split) { | |||
| MS_LOG(ERROR) << "outputs number is not equal to " << number_split; | |||
| return RET_ERROR; | |||
| @@ -73,10 +72,12 @@ int Split::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor: | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| int split_dim = spilt_prim->splitDim(); | |||
| int split_dim = GetSplitDim(); | |||
| std::vector<int> input_shape = input->shape(); | |||
| std::vector<int> size_split; | |||
| size_split.insert(size_split.begin(), spilt_prim->sizeSplits()->begin(), spilt_prim->sizeSplits()->end()); | |||
| for (int i = 0; i < GetSizeSplits().size(); ++i) { | |||
| size_split.push_back(GetSizeSplits()[i]); | |||
| } | |||
| for (int i = 0; i < number_split; ++i) { | |||
| std::vector<int> output_shape; | |||
| output_shape.insert(output_shape.begin(), input_shape.begin(), input_shape.end()); | |||
| @@ -24,6 +24,10 @@ std::vector<int> Tile::GetMultiples() const { return this->primitive->value.AsTi | |||
| void Tile::SetMultiples(const std::vector<int> &multiples) { this->primitive->value.AsTile()->multiples = multiples; } | |||
| std::vector<int> Tile::GetDims() const { return this->primitive->value.AsTile()->multiples; } | |||
| void Tile::SetDims(const std::vector<int> &dims) { this->primitive->value.AsTile()->dims = dims; } | |||
| #else | |||
| std::vector<int> Tile::GetMultiples() const { | |||
| @@ -32,6 +36,13 @@ std::vector<int> Tile::GetMultiples() const { | |||
| } | |||
| void Tile::SetMultiples(const std::vector<int> &multiples) {} | |||
| std::vector<int> Tile::GetDims() const { | |||
| auto fb_vector = this->primitive->value_as_Tile()->dims(); | |||
| return std::vector<int>(fb_vector->begin(), fb_vector->end()); | |||
| } | |||
| void Tile::SetDims(const std::vector<int> &dims) {} | |||
| #endif | |||
| int Tile::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) { | |||
| @@ -45,11 +56,14 @@ int Tile::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:: | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| auto tile_prim = this->primitive->value_as_Tile(); | |||
| MS_ASSERT(tile_prim != nullptr); | |||
| std::vector<int> out_shape; | |||
| std::vector<int> multiples; | |||
| std::copy(tile_prim->multiples()->begin(), tile_prim->multiples()->end(), std::back_inserter(multiples)); | |||
| for (int i = 0; i < GetMultiples().size(); ++i) { | |||
| multiples.push_back(GetMultiples()[i]); | |||
| } | |||
| // std::copy(GetMultiples().begin(), GetMultiples().end(), std::back_inserter(multiples)); | |||
| for (size_t i = 0; i < input->shape().size(); ++i) { | |||
| int tmp = input->shape()[i] * multiples[i]; | |||
| out_shape.push_back(tmp); | |||
| @@ -37,6 +37,8 @@ class Tile : public PrimitiveC { | |||
| int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override; | |||
| std::vector<int> GetMultiples() const; | |||
| void SetMultiples(const std::vector<int> &multiples); | |||
| std::vector<int> GetDims() const; | |||
| void SetDims(const std::vector<int> &dims); | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -52,14 +52,17 @@ int Transpose::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<ten | |||
| } | |||
| MS_ASSERT(inputs_.size() == kSingleNum); | |||
| MS_ASSERT(outputs_.size() == kSingleNum); | |||
| auto transpore_prim = this->primitive->value_as_Transpose(); | |||
| int conjugate = transpore_prim->conjugate(); | |||
| int conjugate = GetConjugate(); | |||
| if (conjugate) { | |||
| MS_LOG(ERROR) << "Transpose conjugate is not support currently"; | |||
| return RET_ERROR; | |||
| } | |||
| std::vector<int> perm; | |||
| perm.insert(perm.begin(), transpore_prim->perm()->begin(), transpore_prim->perm()->end()); | |||
| for (int i = 0; i < GetPerm().size(); i++) { | |||
| perm.push_back(GetPerm()[i]); | |||
| } | |||
| // perm.insert(perm.begin(), GetPerm().begin(), GetPerm().end()); | |||
| std::vector<int> in_shape = input->shape(); | |||
| std::vector<int> out_shape; | |||
| out_shape.resize(perm.size()); | |||
| @@ -988,7 +988,7 @@ OpParameter *PopulateSliceParameter(const mindspore::lite::PrimitiveC *primitive | |||
| } | |||
| slice_param->param_length_ = static_cast<int32_t>(param_begin.size()); | |||
| for (int32_t i = 0; i < slice_param->param_length_; ++i) { | |||
| slice_param->begin_[i] = param_begin[1]; | |||
| slice_param->begin_[i] = param_begin[i]; | |||
| slice_param->size_[i] = param_size[i]; | |||
| } | |||
| return reinterpret_cast<OpParameter *>(slice_param); | |||