|
|
|
@@ -97,11 +97,29 @@ int Resize::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers: |
|
|
|
namespace { |
|
|
|
constexpr int kInputRank = 4; |
|
|
|
} // namespace |
|
|
|
template <typename T> |
|
|
|
void CalShape(const T *data, const std::vector<Tensor *> &inputs, std::vector<int> *out_shape, int shape_size) { |
|
|
|
int input_count = inputs[0]->ElementsNum(); |
|
|
|
int index = 0; |
|
|
|
int size = 1; |
|
|
|
for (int i = 0; i < shape_size; i++) { |
|
|
|
if (static_cast<int>(data[i]) == -1) { |
|
|
|
index = i; |
|
|
|
} else { |
|
|
|
size *= data[i]; |
|
|
|
} |
|
|
|
out_shape->push_back(data[i]); |
|
|
|
} |
|
|
|
if (static_cast<int>(data[index]) == -1) { |
|
|
|
(*out_shape)[index] = input_count / size; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
int Resize::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) { |
|
|
|
MS_ASSERT(this->primitive_ != nullptr); |
|
|
|
auto input = inputs_.front(); |
|
|
|
if (input == nullptr) { |
|
|
|
return 1; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
if (input->shape().size() != kInputRank) { |
|
|
|
MS_LOG(ERROR) << "Size of input shape is wrong."; |
|
|
|
@@ -110,20 +128,58 @@ int Resize::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Te |
|
|
|
|
|
|
|
auto output = outputs_.front(); |
|
|
|
if (output == nullptr) { |
|
|
|
return 1; |
|
|
|
return RET_NULL_PTR; |
|
|
|
} |
|
|
|
output->set_data_type(input->data_type()); |
|
|
|
output->SetFormat(input->GetFormat()); |
|
|
|
if (!GetInferFlag()) { |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
auto new_height = GetNewHeight(); |
|
|
|
auto new_width = GetNewWidth(); |
|
|
|
|
|
|
|
std::vector<int> output_shape; |
|
|
|
output_shape.push_back(input->Batch()); |
|
|
|
output_shape.push_back(new_height); |
|
|
|
output_shape.push_back(new_width); |
|
|
|
if (inputs_.size() == kDoubleNum) { |
|
|
|
auto shape_tensor = inputs_.at(1); |
|
|
|
if (shape_tensor->data_c() == nullptr) { |
|
|
|
MS_LOG(INFO) << "Do infer shape in runtime."; |
|
|
|
return RET_INFER_INVALID; |
|
|
|
} |
|
|
|
size_t shape_size = shape_tensor->ElementsNum(); |
|
|
|
switch (shape_tensor->data_type()) { |
|
|
|
case kNumberTypeInt8: { |
|
|
|
auto data = reinterpret_cast<int8_t *>(shape_tensor->MutableData()); |
|
|
|
CalShape<int8_t>(data, inputs_, &output_shape, shape_size); |
|
|
|
} break; |
|
|
|
case kNumberTypeInt32: { |
|
|
|
auto data = reinterpret_cast<int32_t *>(shape_tensor->MutableData()); |
|
|
|
CalShape<int32_t>(data, inputs_, &output_shape, shape_size); |
|
|
|
} break; |
|
|
|
case kNumberTypeInt64: { |
|
|
|
auto data = reinterpret_cast<int64_t *>(shape_tensor->MutableData()); |
|
|
|
CalShape<int64_t>(data, inputs_, &output_shape, shape_size); |
|
|
|
} break; |
|
|
|
case kNumberTypeFloat: { |
|
|
|
auto data = reinterpret_cast<float *>(shape_tensor->MutableData()); |
|
|
|
CalShape<float>(data, inputs_, &output_shape, shape_size); |
|
|
|
} break; |
|
|
|
case kNumberTypeUInt32: { |
|
|
|
auto data = reinterpret_cast<uint32_t *>(shape_tensor->MutableData()); |
|
|
|
CalShape<uint32_t>(data, inputs_, &output_shape, shape_size); |
|
|
|
} break; |
|
|
|
default: { |
|
|
|
MS_LOG(ERROR) << "Reshape weight tensor has unsupported dataType: " << shape_tensor->data_type(); |
|
|
|
return RET_INFER_ERR; |
|
|
|
} |
|
|
|
} |
|
|
|
} else if (inputs_.size() == kSingleNum) { |
|
|
|
auto new_height = GetNewHeight(); |
|
|
|
auto new_width = GetNewWidth(); |
|
|
|
output_shape.push_back(new_height); |
|
|
|
output_shape.push_back(new_width); |
|
|
|
} else { |
|
|
|
MS_LOG(ERROR) << "inputs tensor size invalid."; |
|
|
|
return RET_INFER_ERR; |
|
|
|
} |
|
|
|
output_shape.push_back(input->Channel()); |
|
|
|
output->set_shape(output_shape); |
|
|
|
|
|
|
|
|