Browse Source

add tf broadcast-to parser

tags/v1.4.0
hangangqiang 4 years ago
parent
commit
70baede744
7 changed files with 146 additions and 10 deletions
  1. +4
    -4
      mindspore/lite/java/java/common/src/main/java/com/mindspore/lite/MSTensor.java
  2. +18
    -2
      mindspore/lite/java/native/runtime/ms_tensor.cpp
  3. +1
    -1
      mindspore/lite/src/ops/populate/broadcast_to_populate.cc
  4. +35
    -3
      mindspore/lite/src/runtime/kernel/arm/fp32/broadcast_to_fp32.cc
  5. +1
    -0
      mindspore/lite/src/runtime/kernel/arm/fp32/broadcast_to_fp32.h
  6. +49
    -0
      mindspore/lite/tools/converter/parser/tf/tf_broadcast_to_parser.cc
  7. +38
    -0
      mindspore/lite/tools/converter/parser/tf/tf_broadcast_to_parser.h

+ 4
- 4
mindspore/lite/java/java/common/src/main/java/com/mindspore/lite/MSTensor.java View File

@@ -57,12 +57,12 @@ public class MSTensor {
return this.getLongData(this.tensorPtr);
}

public void setData(byte[] data) {
this.setData(this.tensorPtr, data, data.length);
public boolean setData(byte[] data) {
return this.setData(this.tensorPtr, data, data.length);
}

public void setData(ByteBuffer data) {
this.setByteBufferData(this.tensorPtr, data);
public boolean setData(ByteBuffer data) {
return this.setByteBufferData(this.tensorPtr, data);
}

public long size() {


+ 18
- 2
mindspore/lite/java/native/runtime/ms_tensor.cpp View File

@@ -66,6 +66,10 @@ extern "C" JNIEXPORT jbyteArray JNICALL Java_com_mindspore_lite_MSTensor_getByte
}

auto local_size = ms_tensor_ptr->Size();
if (local_size <= 0) {
MS_LOGE("Size of tensor is negative: %zu", local_size);
return env->NewByteArray(0);
}
auto ret = env->NewByteArray(local_size);
env->SetByteArrayRegion(ret, 0, local_size, local_data);
return ret;
@@ -92,6 +96,10 @@ extern "C" JNIEXPORT jlongArray JNICALL Java_com_mindspore_lite_MSTensor_getLong
return env->NewLongArray(0);
}
auto local_element_num = ms_tensor_ptr->ElementsNum();
if (local_element_num <= 0) {
MS_LOGE("ElementsNum of tensor is negative: %d", local_element_num);
return env->NewLongArray(0);
}
auto ret = env->NewLongArray(local_element_num);
env->SetLongArrayRegion(ret, 0, local_element_num, local_data);
return ret;
@@ -118,6 +126,10 @@ extern "C" JNIEXPORT jintArray JNICALL Java_com_mindspore_lite_MSTensor_getIntDa
return env->NewIntArray(0);
}
auto local_element_num = ms_tensor_ptr->ElementsNum();
if (local_element_num <= 0) {
MS_LOGE("ElementsNum of tensor is negative: %d", local_element_num);
return env->NewIntArray(0);
}
auto ret = env->NewIntArray(local_element_num);
env->SetIntArrayRegion(ret, 0, local_element_num, local_data);
return ret;
@@ -144,6 +156,10 @@ extern "C" JNIEXPORT jfloatArray JNICALL Java_com_mindspore_lite_MSTensor_getFlo
return env->NewFloatArray(0);
}
auto local_element_num = ms_tensor_ptr->ElementsNum();
if (local_element_num <= 0) {
MS_LOGE("ElementsNum of tensor is negative: %d", local_element_num);
return env->NewFloatArray(0);
}
auto ret = env->NewFloatArray(local_element_num);
env->SetFloatArrayRegion(ret, 0, local_element_num, local_data);
return ret;
@@ -259,7 +275,7 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_MSTensor_createTensor
memcpy(tensor_data, p_data, data_len);
int tensor_size = static_cast<jint>(data_len / sizeof(float));
std::vector<int> shape = {tensor_size};
auto tensor = mindspore::tensor::MSTensor::CreateTensor(
env->GetStringUTFChars(tensor_name, JNI_FALSE), mindspore::kNumberTypeFloat32, shape, tensor_data, data_len);
auto tensor = mindspore::tensor::MSTensor::CreateTensor(env->GetStringUTFChars(tensor_name, JNI_FALSE),
mindspore::kNumberTypeFloat32, shape, tensor_data, data_len);
return jlong(tensor);
}

+ 1
- 1
mindspore/lite/src/ops/populate/broadcast_to_populate.cc View File

@@ -38,7 +38,7 @@ OpParameter *PopulateBroadcastToParameter(const void *prim) {
param->op_parameter_.type_ = primitive->value_type();
auto dst_shape = value->shape();
if (dst_shape == nullptr) {
MS_LOG(WARNING) << "unable to get dst_shape from attribute.";
MS_LOG(INFO) << "broadcast_to has not shape const tensor.";
} else {
param->shape_size_ = dst_shape->size();
for (size_t i = 0; i < param->shape_size_; ++i) {


+ 35
- 3
mindspore/lite/src/runtime/kernel/arm/fp32/broadcast_to_fp32.cc View File

@@ -44,6 +44,12 @@ int BroadcastToCPUKernel::ReSize() {
shape_info_->output_shape_[i] = output_shape[i];
}
shape_info_->output_shape_size_ = static_cast<int>(output_shape.size());

data_type_ = in_tensors_.at(0)->data_type();
if (data_type_ != out_tensors_.at(0)->data_type()) {
MS_LOG(ERROR) << "BroadcastTo infer has error";
return RET_ERROR;
}
return RET_OK;
}

@@ -60,10 +66,36 @@ int BroadcastToCPUKernel::Init() {
}

int BroadcastToCPUKernel::Run() {
const auto input_data = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData());
auto output_data = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData());
return BroadcastTo(float, input_data, shape_info_, output_data);
if (in_tensors_.size() == 2) {
auto shape_tensor = in_tensors_.at(1);
MS_ASSERT(shape_tensor->data_type() == kNumberTypeInt32);
if (shape_tensor->ElementsNum() > MAX_SHAPE_SIZE) {
MS_LOG(ERROR) << "Size of broadcast_to shape exceed MAX_SHAPE_SIZE";
return RET_ERROR;
}
auto shape_data = reinterpret_cast<int *>(shape_tensor->data());
for (int i = 0; i < shape_tensor->ElementsNum(); i++) {
shape_info_->output_shape_[i] = (shape_data[i] == -1) ? (shape_info_->input_shape_[i]) : shape_data[i];
}
}
switch (data_type_) {
case kNumberTypeFloat32: {
const auto input_data = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData());
auto output_data = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData());
return BroadcastTo(float, input_data, shape_info_, output_data);
}
case kNumberTypeInt32:
case kNumberTypeInt: {
const auto input_data = reinterpret_cast<int *>(in_tensors_.at(0)->MutableData());
auto output_data = reinterpret_cast<int *>(out_tensors_.at(0)->MutableData());
return BroadcastTo(int, input_data, shape_info_, output_data);
}
default:
MS_LOG(ERROR) << "UnSupported data type: " << data_type_;
return RET_ERROR;
}
}

REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_BroadcastTo, LiteKernelCreator<BroadcastToCPUKernel>)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_BroadcastTo, LiteKernelCreator<BroadcastToCPUKernel>)
} // namespace mindspore::kernel

+ 1
- 0
mindspore/lite/src/runtime/kernel/arm/fp32/broadcast_to_fp32.h View File

@@ -35,6 +35,7 @@ class BroadcastToCPUKernel : public InnerKernel {

private:
BroadcastShapeInfo *shape_info_ = nullptr;
TypeId data_type_ = kNumberTypeFloat32;
};
} // namespace mindspore::kernel



+ 49
- 0
mindspore/lite/tools/converter/parser/tf/tf_broadcast_to_parser.cc View File

@@ -0,0 +1,49 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_broadcast_to_parser.h"
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
#include "ops/broadcast_to.h"

namespace mindspore {
namespace lite {
ops::PrimitiveC *TFBroadcastToParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
std::vector<std::string> *inputs, int *output_size) {
auto prim = std::make_unique<ops::BroadcastTo>();

if (tf_op.input_size() == 1) {
MS_LOG(ERROR) << "tf broadcast_to parser not support one input now";
return nullptr;
} else if (tf_op.input_size() == 2) {
*output_size = 1;
if (AddOpInput(tf_op, 0, inputs) != RET_OK || AddOpInput(tf_op, 1, inputs) != RET_OK) {
MS_LOG(ERROR) << "add op input failed";
return nullptr;
}
return prim.release();
} else {
MS_LOG(ERROR) << "broadcast_to has " << tf_op.input_size() << " inputs, invalid";
return nullptr;
}
}

TFNodeRegistrar g_tfBroadcastToParser("BroadcastTo", new TFBroadcastToParser());
} // namespace lite
} // namespace mindspore

+ 38
- 0
mindspore/lite/tools/converter/parser/tf/tf_broadcast_to_parser.h View File

@@ -0,0 +1,38 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_BROADCAST_TO_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_BROADCAST_TO_PARSER_H_
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser.h"

namespace mindspore {
namespace lite {
class TFBroadcastToParser : public TFNodeParser {
public:
TFBroadcastToParser() = default;
~TFBroadcastToParser() override = default;

ops::PrimitiveC *Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore

#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_BROADCAST_TO_PARSER_H_

Loading…
Cancel
Save