| @@ -57,12 +57,12 @@ public class MSTensor { | |||
| return this.getLongData(this.tensorPtr); | |||
| } | |||
| public void setData(byte[] data) { | |||
| this.setData(this.tensorPtr, data, data.length); | |||
| public boolean setData(byte[] data) { | |||
| return this.setData(this.tensorPtr, data, data.length); | |||
| } | |||
| public void setData(ByteBuffer data) { | |||
| this.setByteBufferData(this.tensorPtr, data); | |||
| public boolean setData(ByteBuffer data) { | |||
| return this.setByteBufferData(this.tensorPtr, data); | |||
| } | |||
| public long size() { | |||
| @@ -66,6 +66,10 @@ extern "C" JNIEXPORT jbyteArray JNICALL Java_com_mindspore_lite_MSTensor_getByte | |||
| } | |||
| auto local_size = ms_tensor_ptr->Size(); | |||
| if (local_size <= 0) { | |||
| MS_LOGE("Size of tensor is negative: %zu", local_size); | |||
| return env->NewByteArray(0); | |||
| } | |||
| auto ret = env->NewByteArray(local_size); | |||
| env->SetByteArrayRegion(ret, 0, local_size, local_data); | |||
| return ret; | |||
| @@ -92,6 +96,10 @@ extern "C" JNIEXPORT jlongArray JNICALL Java_com_mindspore_lite_MSTensor_getLong | |||
| return env->NewLongArray(0); | |||
| } | |||
| auto local_element_num = ms_tensor_ptr->ElementsNum(); | |||
| if (local_element_num <= 0) { | |||
| MS_LOGE("ElementsNum of tensor is negative: %d", local_element_num); | |||
| return env->NewLongArray(0); | |||
| } | |||
| auto ret = env->NewLongArray(local_element_num); | |||
| env->SetLongArrayRegion(ret, 0, local_element_num, local_data); | |||
| return ret; | |||
| @@ -118,6 +126,10 @@ extern "C" JNIEXPORT jintArray JNICALL Java_com_mindspore_lite_MSTensor_getIntDa | |||
| return env->NewIntArray(0); | |||
| } | |||
| auto local_element_num = ms_tensor_ptr->ElementsNum(); | |||
| if (local_element_num <= 0) { | |||
| MS_LOGE("ElementsNum of tensor is negative: %d", local_element_num); | |||
| return env->NewIntArray(0); | |||
| } | |||
| auto ret = env->NewIntArray(local_element_num); | |||
| env->SetIntArrayRegion(ret, 0, local_element_num, local_data); | |||
| return ret; | |||
| @@ -144,6 +156,10 @@ extern "C" JNIEXPORT jfloatArray JNICALL Java_com_mindspore_lite_MSTensor_getFlo | |||
| return env->NewFloatArray(0); | |||
| } | |||
| auto local_element_num = ms_tensor_ptr->ElementsNum(); | |||
| if (local_element_num <= 0) { | |||
| MS_LOGE("ElementsNum of tensor is negative: %d", local_element_num); | |||
| return env->NewFloatArray(0); | |||
| } | |||
| auto ret = env->NewFloatArray(local_element_num); | |||
| env->SetFloatArrayRegion(ret, 0, local_element_num, local_data); | |||
| return ret; | |||
| @@ -259,7 +275,7 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_MSTensor_createTensor | |||
| memcpy(tensor_data, p_data, data_len); | |||
| int tensor_size = static_cast<jint>(data_len / sizeof(float)); | |||
| std::vector<int> shape = {tensor_size}; | |||
| auto tensor = mindspore::tensor::MSTensor::CreateTensor( | |||
| env->GetStringUTFChars(tensor_name, JNI_FALSE), mindspore::kNumberTypeFloat32, shape, tensor_data, data_len); | |||
| auto tensor = mindspore::tensor::MSTensor::CreateTensor(env->GetStringUTFChars(tensor_name, JNI_FALSE), | |||
| mindspore::kNumberTypeFloat32, shape, tensor_data, data_len); | |||
| return jlong(tensor); | |||
| } | |||
| @@ -38,7 +38,7 @@ OpParameter *PopulateBroadcastToParameter(const void *prim) { | |||
| param->op_parameter_.type_ = primitive->value_type(); | |||
| auto dst_shape = value->shape(); | |||
| if (dst_shape == nullptr) { | |||
| MS_LOG(WARNING) << "unable to get dst_shape from attribute."; | |||
| MS_LOG(INFO) << "broadcast_to has not shape const tensor."; | |||
| } else { | |||
| param->shape_size_ = dst_shape->size(); | |||
| for (size_t i = 0; i < param->shape_size_; ++i) { | |||
| @@ -44,6 +44,12 @@ int BroadcastToCPUKernel::ReSize() { | |||
| shape_info_->output_shape_[i] = output_shape[i]; | |||
| } | |||
| shape_info_->output_shape_size_ = static_cast<int>(output_shape.size()); | |||
| data_type_ = in_tensors_.at(0)->data_type(); | |||
| if (data_type_ != out_tensors_.at(0)->data_type()) { | |||
| MS_LOG(ERROR) << "BroadcastTo infer has error"; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -60,10 +66,36 @@ int BroadcastToCPUKernel::Init() { | |||
| } | |||
| int BroadcastToCPUKernel::Run() { | |||
| const auto input_data = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData()); | |||
| auto output_data = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData()); | |||
| return BroadcastTo(float, input_data, shape_info_, output_data); | |||
| if (in_tensors_.size() == 2) { | |||
| auto shape_tensor = in_tensors_.at(1); | |||
| MS_ASSERT(shape_tensor->data_type() == kNumberTypeInt32); | |||
| if (shape_tensor->ElementsNum() > MAX_SHAPE_SIZE) { | |||
| MS_LOG(ERROR) << "Size of broadcast_to shape exceed MAX_SHAPE_SIZE"; | |||
| return RET_ERROR; | |||
| } | |||
| auto shape_data = reinterpret_cast<int *>(shape_tensor->data()); | |||
| for (int i = 0; i < shape_tensor->ElementsNum(); i++) { | |||
| shape_info_->output_shape_[i] = (shape_data[i] == -1) ? (shape_info_->input_shape_[i]) : shape_data[i]; | |||
| } | |||
| } | |||
| switch (data_type_) { | |||
| case kNumberTypeFloat32: { | |||
| const auto input_data = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData()); | |||
| auto output_data = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData()); | |||
| return BroadcastTo(float, input_data, shape_info_, output_data); | |||
| } | |||
| case kNumberTypeInt32: | |||
| case kNumberTypeInt: { | |||
| const auto input_data = reinterpret_cast<int *>(in_tensors_.at(0)->MutableData()); | |||
| auto output_data = reinterpret_cast<int *>(out_tensors_.at(0)->MutableData()); | |||
| return BroadcastTo(int, input_data, shape_info_, output_data); | |||
| } | |||
| default: | |||
| MS_LOG(ERROR) << "UnSupported data type: " << data_type_; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_BroadcastTo, LiteKernelCreator<BroadcastToCPUKernel>) | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_BroadcastTo, LiteKernelCreator<BroadcastToCPUKernel>) | |||
| } // namespace mindspore::kernel | |||
| @@ -35,6 +35,7 @@ class BroadcastToCPUKernel : public InnerKernel { | |||
| private: | |||
| BroadcastShapeInfo *shape_info_ = nullptr; | |||
| TypeId data_type_ = kNumberTypeFloat32; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| @@ -0,0 +1,49 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/parser/tf/tf_broadcast_to_parser.h" | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||
| #include "ops/broadcast_to.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *TFBroadcastToParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| std::vector<std::string> *inputs, int *output_size) { | |||
| auto prim = std::make_unique<ops::BroadcastTo>(); | |||
| if (tf_op.input_size() == 1) { | |||
| MS_LOG(ERROR) << "tf broadcast_to parser not support one input now"; | |||
| return nullptr; | |||
| } else if (tf_op.input_size() == 2) { | |||
| *output_size = 1; | |||
| if (AddOpInput(tf_op, 0, inputs) != RET_OK || AddOpInput(tf_op, 1, inputs) != RET_OK) { | |||
| MS_LOG(ERROR) << "add op input failed"; | |||
| return nullptr; | |||
| } | |||
| return prim.release(); | |||
| } else { | |||
| MS_LOG(ERROR) << "broadcast_to has " << tf_op.input_size() << " inputs, invalid"; | |||
| return nullptr; | |||
| } | |||
| } | |||
| TFNodeRegistrar g_tfBroadcastToParser("BroadcastTo", new TFBroadcastToParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,38 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_BROADCAST_TO_PARSER_H_ | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_BROADCAST_TO_PARSER_H_ | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class TFBroadcastToParser : public TFNodeParser { | |||
| public: | |||
| TFBroadcastToParser() = default; | |||
| ~TFBroadcastToParser() override = default; | |||
| ops::PrimitiveC *Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| std::vector<std::string> *inputs, int *output_size) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_BROADCAST_TO_PARSER_H_ | |||