| @@ -275,7 +275,6 @@ union PrimitiveType { | |||||
| Erf, | Erf, | ||||
| StridedSliceGrad, | StridedSliceGrad, | ||||
| IsFinite, | IsFinite, | ||||
| BatchMatMul, | |||||
| LinSpace, | LinSpace, | ||||
| UniformReal, | UniformReal, | ||||
| AbsGrad | AbsGrad | ||||
| @@ -1280,12 +1280,6 @@ table Erf { | |||||
| table IsFinite { | table IsFinite { | ||||
| } | } | ||||
| table BatchMatMul { | |||||
| transpose_a :bool; | |||||
| transpose_b :bool; | |||||
| } | |||||
| table LinSpace { | table LinSpace { | ||||
| } | } | ||||
| @@ -1,85 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/ops/batch_matmul.h" | |||||
| #ifndef PRIMITIVE_WRITEABLE | |||||
| #include "src/ops/ops_register.h" | |||||
| #endif | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| bool BatchMatMul::GetTransposeA() const { return this->primitive_->value.AsBatchMatMul()->transpose_a; } | |||||
| bool BatchMatMul::GetTransposeB() const { return this->primitive_->value.AsBatchMatMul()->transpose_b; } | |||||
| void BatchMatMul::SetTransposeA(bool transpose_a) { | |||||
| this->primitive_->value.AsBatchMatMul()->transpose_a = transpose_a; | |||||
| } | |||||
| void BatchMatMul::SetTransposeB(bool transpose_b) { | |||||
| this->primitive_->value.AsBatchMatMul()->transpose_b = transpose_b; | |||||
| } | |||||
| int BatchMatMul::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||||
| if (this->primitive_ == nullptr) { | |||||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||||
| if (this->primitive_ == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| this->primitive_->value.type = schema::PrimitiveType_BatchMatMul; | |||||
| } | |||||
| if (this->primitive_->value.type != schema::PrimitiveType_BatchMatMul) { | |||||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| auto attr = new (std::nothrow) schema::BatchMatMulT(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new FusedBatchMatMulT failed"; | |||||
| delete this->primitive_; | |||||
| this->primitive_ = nullptr; | |||||
| return RET_ERROR; | |||||
| } | |||||
| attr->transpose_a = GetValue<bool>(prim.GetAttr("transpose_a")); | |||||
| attr->transpose_b = GetValue<bool>(prim.GetAttr("transpose_b")); | |||||
| this->primitive_->value.value = attr; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| #else | |||||
| bool BatchMatMul::GetTransposeA() const { return this->primitive_->value_as_BatchMatMul()->transpose_a(); } | |||||
| bool BatchMatMul::GetTransposeB() const { return this->primitive_->value_as_BatchMatMul()->transpose_b(); } | |||||
| int BatchMatMul::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||||
| MS_ASSERT(nullptr != primitive); | |||||
| MS_ASSERT(nullptr != fbb); | |||||
| auto attr = primitive->value_as_BatchMatMul(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "value_as_Add return nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto val_offset = schema::CreateBatchMatMul(*fbb, attr->transpose_a(), attr->transpose_b()); | |||||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_BatchMatMul, val_offset.o); | |||||
| fbb->Finish(prim_offset); | |||||
| return RET_OK; | |||||
| } | |||||
| PrimitiveC *BatchMatMulCreator(const schema::Primitive *primitive) { | |||||
| return PrimitiveC::NewPrimitiveC<BatchMatMul>(primitive); | |||||
| } | |||||
| Registry BatchMatMulRegistry(schema::PrimitiveType_BatchMatMul, BatchMatMulCreator); | |||||
| #endif | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -1,45 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef LITE_MINDSPORE_LITE_C_OPS_BATCH_MATMUL_H_ | |||||
| #define LITE_MINDSPORE_LITE_C_OPS_BATCH_MATMUL_H_ | |||||
| #include <vector> | |||||
| #include <set> | |||||
| #include <cmath> | |||||
| #include "src/ops/primitive_c.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class BatchMatMul : public PrimitiveC { | |||||
| public: | |||||
| BatchMatMul() = default; | |||||
| ~BatchMatMul() = default; | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| MS_DECLARE_PARENT(BatchMatMul, PrimitiveC); | |||||
| explicit BatchMatMul(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||||
| void SetTransposeA(bool transpose_a); | |||||
| void SetTransposeB(bool transpose_b); | |||||
| #else | |||||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||||
| #endif | |||||
| bool GetTransposeA() const; | |||||
| bool GetTransposeB() const; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // LITE_MINDSPORE_LITE_C_OPS_BATCH_MATMUL_H_ | |||||
| @@ -170,7 +170,6 @@ | |||||
| #include "src/ops/crop_and_resize.h" | #include "src/ops/crop_and_resize.h" | ||||
| #include "src/ops/nonzero.h" | #include "src/ops/nonzero.h" | ||||
| #include "src/ops/erf.h" | #include "src/ops/erf.h" | ||||
| #include "src/ops/batch_matmul.h" | |||||
| #include "src/ops/lin_space.h" | #include "src/ops/lin_space.h" | ||||
| #include "src/ops/uniform_real.h" | #include "src/ops/uniform_real.h" | ||||
| #include "src/ops/rank.h" | #include "src/ops/rank.h" | ||||
| @@ -1057,8 +1056,6 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { | |||||
| return new (std::nothrow) Erf(primitive); | return new (std::nothrow) Erf(primitive); | ||||
| case schema::PrimitiveType_IsFinite: | case schema::PrimitiveType_IsFinite: | ||||
| return new (std::nothrow) IsFinite(primitive); | return new (std::nothrow) IsFinite(primitive); | ||||
| case schema::PrimitiveType_BatchMatMul: | |||||
| return new (std::nothrow) BatchMatMul(primitive); | |||||
| case schema::PrimitiveType_LinSpace: | case schema::PrimitiveType_LinSpace: | ||||
| return new (std::nothrow) LinSpace(primitive); | return new (std::nothrow) LinSpace(primitive); | ||||
| case schema::PrimitiveType_UniformReal: | case schema::PrimitiveType_UniformReal: | ||||
| @@ -125,6 +125,9 @@ int TensorListGetItem::InferShape(std::vector<lite::Tensor *> inputs_, std::vect | |||||
| MS_ASSERT(inputs_.at(1) != nullptr); | MS_ASSERT(inputs_.at(1) != nullptr); | ||||
| MS_ASSERT(inputs_.at(2) != nullptr); | MS_ASSERT(inputs_.at(2) != nullptr); | ||||
| auto input0 = reinterpret_cast<TensorList *>(inputs_.at(0)); | auto input0 = reinterpret_cast<TensorList *>(inputs_.at(0)); | ||||
| if (input0->root_tensor() != nullptr) { | |||||
| input0 = reinterpret_cast<TensorList *>(input0->root_tensor()); | |||||
| } | |||||
| auto get_index = inputs_.at(1); | auto get_index = inputs_.at(1); | ||||
| MS_ASSERT(get_index != nullptr); | MS_ASSERT(get_index != nullptr); | ||||
| if (get_index->ElementsNum() != 1) { | if (get_index->ElementsNum() != 1) { | ||||
| @@ -102,6 +102,7 @@ int TensorListFromTensorCPUKernel::Run() { | |||||
| memcpy(out_data, in_data, data_offset); | memcpy(out_data, in_data, data_offset); | ||||
| in_data += data_offset; | in_data += data_offset; | ||||
| } | } | ||||
| output0->set_tensors_data_type(dtype_); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -45,6 +45,9 @@ int TensorListGetItemCPUKernel::Run() { | |||||
| MS_ASSERT(in_tensors_.at(1) != nullptr); | MS_ASSERT(in_tensors_.at(1) != nullptr); | ||||
| MS_ASSERT(out_tensors_.at(0) != nullptr); | MS_ASSERT(out_tensors_.at(0) != nullptr); | ||||
| auto input0 = reinterpret_cast<lite::TensorList *>(in_tensors_.at(0)); | auto input0 = reinterpret_cast<lite::TensorList *>(in_tensors_.at(0)); | ||||
| if (input0->root_tensor() != nullptr) { | |||||
| input0 = reinterpret_cast<lite::TensorList *>(input0->root_tensor()); | |||||
| } | |||||
| if (dtype_ != input0->tensors_data_type()) { | if (dtype_ != input0->tensors_data_type()) { | ||||
| MS_LOG(ERROR) << "op dtype: " << dtype_ << " is not equal in_tensor[0] dtype: " << input0->tensors_data_type(); | MS_LOG(ERROR) << "op dtype: " << dtype_ << " is not equal in_tensor[0] dtype: " << input0->tensors_data_type(); | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -16,6 +16,7 @@ | |||||
| #include "src/sub_graph_kernel.h" | #include "src/sub_graph_kernel.h" | ||||
| #include "src/tensor.h" | #include "src/tensor.h" | ||||
| #include "src/tensorlist.h" | |||||
| #if defined(ENABLE_ARM64) && defined(ENABLE_FP16) | #if defined(ENABLE_ARM64) && defined(ENABLE_FP16) | ||||
| #include "src/runtime/kernel/arm/fp16/fp16_op_handler.h" | #include "src/runtime/kernel/arm/fp16/fp16_op_handler.h" | ||||
| #endif | #endif | ||||
| @@ -176,7 +177,8 @@ int CpuSubGraph::Prepare() { | |||||
| #ifdef ENABLE_FP16 | #ifdef ENABLE_FP16 | ||||
| void CpuFp16SubGraph::FreeOriginInputData() { | void CpuFp16SubGraph::FreeOriginInputData() { | ||||
| for (auto *data_store : this->origin_input_data_) { | |||||
| for (auto &iter : this->origin_input_data_) { | |||||
| auto *data_store = iter.second; | |||||
| if (data_store == nullptr) { | if (data_store == nullptr) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -199,37 +201,99 @@ void CpuFp16SubGraph::FreeOriginInputData() { | |||||
| this->origin_input_data_.clear(); | this->origin_input_data_.clear(); | ||||
| } | } | ||||
| int CpuFp16SubGraph::Float32TensorToFloat16Tensor(lite::Tensor *tensor) { | |||||
| auto float32_data = tensor->data_c(); | |||||
| if (float32_data == nullptr) { | |||||
| MS_LOG(ERROR) << "tensor data is null."; | |||||
| return lite::RET_NULL_PTR; | |||||
| } | |||||
| tensor->set_data(nullptr); | |||||
| tensor->set_data_type(TypeId::kNumberTypeFloat16); | |||||
| auto ret = tensor->MallocData(); | |||||
| if (RET_OK != ret) { | |||||
| MS_LOG(ERROR) << "malloc data failed"; | |||||
| this->FreeOriginInputData(); | |||||
| return RET_ERROR; | |||||
| } | |||||
| MS_ASSERT(tensor->data_c() != nullptr); | |||||
| Float32ToFloat16_fp16_handler(float32_data, tensor->data_c(), tensor->ElementsNum()); | |||||
| auto *data_store = DataStore::CreateDataStore(float32_data, tensor->allocator(), this->context_->allocator.get()); | |||||
| if (data_store == nullptr) { | |||||
| MS_LOG(ERROR) << "Create DataStore failed"; | |||||
| this->FreeOriginInputData(); | |||||
| return RET_ERROR; | |||||
| } | |||||
| origin_input_data_[tensor] = data_store; | |||||
| return RET_OK; | |||||
| } | |||||
| int CpuFp16SubGraph::Float16TensorToFloat32Tensor(lite::Tensor *tensor) { | |||||
| auto float16_data = tensor->data_c(); | |||||
| if (float16_data == nullptr) { | |||||
| MS_LOG(ERROR) << "tensor data is null."; | |||||
| return lite::RET_NULL_PTR; | |||||
| } | |||||
| tensor->set_data(nullptr); | |||||
| tensor->set_data_type(TypeId::kNumberTypeFloat32); | |||||
| auto ret = tensor->MallocData(); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "malloc data failed"; | |||||
| if (this->context_ != nullptr && this->context_->allocator != nullptr) { | |||||
| this->context_->allocator->Free(float16_data); | |||||
| } else { | |||||
| free(float16_data); | |||||
| } | |||||
| return RET_ERROR; | |||||
| } | |||||
| MS_ASSERT(tensor->data_c() != nullptr); | |||||
| Float16ToFloat32_fp16_handler(float16_data, tensor->data_c(), tensor->ElementsNum()); | |||||
| if (tensor->allocator() != nullptr) { | |||||
| tensor->allocator()->Free(float16_data); | |||||
| } else { | |||||
| free(float16_data); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int CpuFp16SubGraph::PreProcess() { | int CpuFp16SubGraph::PreProcess() { | ||||
| #ifdef ENABLE_ARM64 | #ifdef ENABLE_ARM64 | ||||
| if (!mindspore::lite::IsSupportFloat16()) { | if (!mindspore::lite::IsSupportFloat16()) { | ||||
| MS_LOG(ERROR) << "Unsupport fp16 in this devices"; | |||||
| MS_LOG(ERROR) << "Unsupported fp16 in this devices"; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| MS_ASSERT(origin_input_data_.empty()); | |||||
| int ret; | |||||
| for (auto tensor : this->in_tensors_) { | for (auto tensor : this->in_tensors_) { | ||||
| MS_ASSERT(tensor != nullptr); | MS_ASSERT(tensor != nullptr); | ||||
| if (tensor->data_type() == kNumberTypeFloat32) { | |||||
| auto float32_data = tensor->data_c(); | |||||
| MS_ASSERT(float32_data != nullptr); | |||||
| tensor->set_data(nullptr); | |||||
| tensor->set_data_type(TypeId::kNumberTypeFloat16); | |||||
| auto ret = tensor->MallocData(); | |||||
| auto real_tensor = tensor; | |||||
| if (tensor->root_tensor() != nullptr) { | |||||
| real_tensor = tensor->root_tensor(); | |||||
| if (tensor->data_type() == kNumberTypeFloat32) { | |||||
| tensor->set_data_type(kNumberTypeFloat16); | |||||
| } else if (tensor->data_type() == kObjectTypeTensorType) { | |||||
| auto tensorlist = reinterpret_cast<lite::TensorList *>(tensor); | |||||
| if (tensorlist->tensors_data_type() == kNumberTypeFloat32) { | |||||
| tensorlist->set_tensors_data_type(kNumberTypeFloat16); | |||||
| } | |||||
| } | |||||
| } | |||||
| if (real_tensor->data_type() == kNumberTypeFloat32) { | |||||
| ret = Float32TensorToFloat16Tensor(real_tensor); | |||||
| if (RET_OK != ret) { | if (RET_OK != ret) { | ||||
| MS_LOG(ERROR) << "malloc data failed"; | |||||
| this->FreeOriginInputData(); | |||||
| return RET_ERROR; | |||||
| MS_LOG(ERROR) << "Float32TensorToFloat16Tensor failed."; | |||||
| return ret; | |||||
| } | } | ||||
| MS_ASSERT(tensor->data_c() != nullptr); | |||||
| Float32ToFloat16_fp16_handler(float32_data, tensor->data_c(), tensor->ElementsNum()); | |||||
| auto *data_store = DataStore::CreateDataStore(float32_data, tensor->allocator(), this->context_->allocator.get()); | |||||
| if (data_store == nullptr) { | |||||
| MS_LOG(ERROR) << "Create DataStore failed"; | |||||
| this->FreeOriginInputData(); | |||||
| return RET_ERROR; | |||||
| } else if (real_tensor->data_type() == kObjectTypeTensorType) { | |||||
| auto tensorlist = reinterpret_cast<lite::TensorList *>(real_tensor); | |||||
| if (tensorlist->tensors_data_type() == kNumberTypeFloat32) { | |||||
| tensorlist->set_tensors_data_type(kNumberTypeFloat16); | |||||
| for (auto inner_tensor : tensorlist->tensors()) { | |||||
| ret = Float32TensorToFloat16Tensor(inner_tensor); | |||||
| if (RET_OK != ret) { | |||||
| MS_LOG(ERROR) << "Float32TensorToFloat16Tensor failed."; | |||||
| return ret; | |||||
| } | |||||
| } | |||||
| } | } | ||||
| origin_input_data_.emplace_back(data_store); | |||||
| } else { | |||||
| origin_input_data_.emplace_back(nullptr); | |||||
| } | } | ||||
| } | } | ||||
| for (auto kernel : this->nodes_) { | for (auto kernel : this->nodes_) { | ||||
| @@ -239,6 +303,11 @@ int CpuFp16SubGraph::PreProcess() { | |||||
| } | } | ||||
| if (tensor->data_type() == kNumberTypeFloat32) { | if (tensor->data_type() == kNumberTypeFloat32) { | ||||
| tensor->set_data_type(kNumberTypeFloat16); | tensor->set_data_type(kNumberTypeFloat16); | ||||
| } else if (tensor->data_type() == kObjectTypeTensorType) { | |||||
| auto tensorlist = reinterpret_cast<lite::TensorList *>(tensor); | |||||
| if (tensorlist->tensors_data_type() == kNumberTypeFloat32) { | |||||
| tensorlist->set_tensors_data_type(kNumberTypeFloat16); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -251,47 +320,72 @@ int CpuFp16SubGraph::PreProcess() { | |||||
| int CpuFp16SubGraph::PostProcess() { | int CpuFp16SubGraph::PostProcess() { | ||||
| #ifdef ENABLE_ARM64 | #ifdef ENABLE_ARM64 | ||||
| if (!mindspore::lite::IsSupportFloat16()) { | if (!mindspore::lite::IsSupportFloat16()) { | ||||
| MS_LOG(ERROR) << "Unsupport fp16 in this devices"; | |||||
| MS_LOG(ERROR) << "Unsupported fp16 in this devices"; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| int ret; | |||||
| for (auto tensor : this->out_tensors_) { | for (auto tensor : this->out_tensors_) { | ||||
| MS_ASSERT(tensor != nullptr); | MS_ASSERT(tensor != nullptr); | ||||
| if (tensor->data_type() == kNumberTypeFloat16) { | if (tensor->data_type() == kNumberTypeFloat16) { | ||||
| auto float16_data = tensor->data_c(); | |||||
| MS_ASSERT(float16_data != nullptr); | |||||
| tensor->set_data(nullptr); | |||||
| tensor->set_data_type(TypeId::kNumberTypeFloat32); | |||||
| auto ret = tensor->MallocData(); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "malloc data failed"; | |||||
| if (this->context_ != nullptr && this->context_->allocator != nullptr) { | |||||
| this->context_->allocator->Free(float16_data); | |||||
| } else { | |||||
| free(float16_data); | |||||
| } | |||||
| return RET_ERROR; | |||||
| ret = Float16TensorToFloat32Tensor(tensor); | |||||
| if (RET_OK != ret) { | |||||
| MS_LOG(ERROR) << "Float16TensorToFloat32Tensor failed."; | |||||
| return ret; | |||||
| } | } | ||||
| MS_ASSERT(tensor->data_c() != nullptr); | |||||
| Float16ToFloat32_fp16_handler(float16_data, tensor->data_c(), tensor->ElementsNum()); | |||||
| if (tensor->allocator() != nullptr) { | |||||
| tensor->allocator()->Free(float16_data); | |||||
| } else { | |||||
| free(float16_data); | |||||
| } else if (tensor->data_type() == kObjectTypeTensorType) { | |||||
| auto tensorlist = reinterpret_cast<lite::TensorList *>(tensor); | |||||
| if (tensorlist->tensors_data_type() == kNumberTypeFloat16) { | |||||
| tensorlist->set_tensors_data_type(kNumberTypeFloat32); | |||||
| for (auto inner_tensor : tensorlist->tensors()) { | |||||
| ret = Float16TensorToFloat32Tensor(inner_tensor); | |||||
| if (RET_OK != ret) { | |||||
| MS_LOG(ERROR) << "Float32TensorToFloat16Tensor failed."; | |||||
| return ret; | |||||
| } | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| MS_ASSERT(this->origin_input_data_.size() == this->in_tensors_.size()); | |||||
| int tensor_count = 0; | |||||
| for (size_t i = 0; i < this->in_tensors_.size(); i++) { | for (size_t i = 0; i < this->in_tensors_.size(); i++) { | ||||
| auto tensor = in_tensors_.at(i); | auto tensor = in_tensors_.at(i); | ||||
| MS_ASSERT(tensor != nullptr); | MS_ASSERT(tensor != nullptr); | ||||
| auto origin_tensor_data = origin_input_data_.at(i); | |||||
| if (tensor->data_type() == kNumberTypeFloat16 && origin_tensor_data != nullptr) { | |||||
| MS_ASSERT(tensor != nullptr); | |||||
| tensor->FreeData(); | |||||
| auto real_tensor = tensor; | |||||
| if (tensor->root_tensor() != nullptr) { | |||||
| real_tensor = tensor->root_tensor(); | |||||
| if (tensor->data_type() == kNumberTypeFloat16) { | |||||
| tensor->set_data_type(kNumberTypeFloat32); | |||||
| } else if (tensor->data_type() == kObjectTypeTensorType) { | |||||
| auto tensorlist = reinterpret_cast<lite::TensorList *>(tensor); | |||||
| if (tensorlist->tensors_data_type() == kNumberTypeFloat16) { | |||||
| tensorlist->set_tensors_data_type(kNumberTypeFloat32); | |||||
| } | |||||
| } | |||||
| } | |||||
| if (real_tensor->data_type() == kNumberTypeFloat16 && origin_input_data_.at(real_tensor) != nullptr) { | |||||
| auto origin_tensor_data = origin_input_data_.at(real_tensor); | |||||
| real_tensor->FreeData(); | |||||
| MS_ASSERT(origin_tensor_data->data_ != nullptr); | MS_ASSERT(origin_tensor_data->data_ != nullptr); | ||||
| tensor->set_data(origin_tensor_data->data_); | |||||
| tensor->set_data_type(kNumberTypeFloat32); | |||||
| real_tensor->set_data(origin_tensor_data->data_); | |||||
| real_tensor->set_data_type(kNumberTypeFloat32); | |||||
| origin_tensor_data->data_ = nullptr; | origin_tensor_data->data_ = nullptr; | ||||
| tensor_count++; | |||||
| } else if (real_tensor->data_type() == kObjectTypeTensorType) { | |||||
| auto tensorlist = reinterpret_cast<lite::TensorList *>(real_tensor); | |||||
| if (tensorlist->tensors_data_type() == kNumberTypeFloat16) { | |||||
| tensorlist->set_tensors_data_type(kNumberTypeFloat32); | |||||
| for (auto inner_tensor : tensorlist->tensors()) { | |||||
| MS_ASSERT(inner_tensor != nullptr); | |||||
| auto origin_tensor_data = origin_input_data_.at(inner_tensor); | |||||
| inner_tensor->FreeData(); | |||||
| MS_ASSERT(origin_tensor_data->data_ != nullptr); | |||||
| inner_tensor->set_data(origin_tensor_data->data_); | |||||
| inner_tensor->set_data_type(kNumberTypeFloat32); | |||||
| origin_tensor_data->data_ = nullptr; | |||||
| tensor_count++; | |||||
| } | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| this->FreeOriginInputData(); | this->FreeOriginInputData(); | ||||
| @@ -20,6 +20,7 @@ | |||||
| #include <utility> | #include <utility> | ||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include <map> | |||||
| #include "src/lite_kernel.h" | #include "src/lite_kernel.h" | ||||
| #include "src/executor.h" | #include "src/executor.h" | ||||
| #include "src/common/log_adapter.h" | #include "src/common/log_adapter.h" | ||||
| @@ -179,9 +180,11 @@ class CpuFp16SubGraph : public CpuSubGraph { | |||||
| private: | private: | ||||
| void FreeOriginInputData(); | void FreeOriginInputData(); | ||||
| int Float32TensorToFloat16Tensor(lite::Tensor *tensor); | |||||
| int Float16TensorToFloat32Tensor(lite::Tensor *tensor); | |||||
| private: | private: | ||||
| std::vector<DataStore *> origin_input_data_{}; | |||||
| std::map<lite::Tensor *, DataStore *> origin_input_data_; | |||||
| }; | }; | ||||
| #endif | #endif | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -35,7 +35,7 @@ STATUS TFBatchMatMulParser::Parse(const tensorflow::NodeDef &tf_op, | |||||
| MS_LOG(ERROR) << "New PrimitiveT failed"; | MS_LOG(ERROR) << "New PrimitiveT failed"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| auto attr = std::make_unique<schema::BatchMatMulT>(); | |||||
| auto attr = std::make_unique<schema::MatMulT>(); | |||||
| if (attr == nullptr) { | if (attr == nullptr) { | ||||
| MS_LOG(ERROR) << "new attr failed"; | MS_LOG(ERROR) << "new attr failed"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -45,13 +45,13 @@ STATUS TFBatchMatMulParser::Parse(const tensorflow::NodeDef &tf_op, | |||||
| MS_LOG(ERROR) << "The begin_mask attr should be specified"; | MS_LOG(ERROR) << "The begin_mask attr should be specified"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| attr->transpose_a = attr_value.b(); | |||||
| attr->transposeA = attr_value.b(); | |||||
| if (!TensorFlowUtils::FindAttrValue(tf_op, "adj_y", &attr_value)) { | if (!TensorFlowUtils::FindAttrValue(tf_op, "adj_y", &attr_value)) { | ||||
| MS_LOG(ERROR) << "The begin_mask attr should be specified"; | MS_LOG(ERROR) << "The begin_mask attr should be specified"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| attr->transpose_b = attr_value.b(); | |||||
| primitive->value.type = schema::PrimitiveType_BatchMatMul; | |||||
| attr->transposeB = attr_value.b(); | |||||
| primitive->value.type = schema::PrimitiveType_MatMul; | |||||
| primitive->value.value = attr.release(); | primitive->value.value = attr.release(); | ||||
| *primitiveC = PrimitiveC::Create(primitive.release()); | *primitiveC = PrimitiveC::Create(primitive.release()); | ||||
| if (*primitiveC == nullptr) { | if (*primitiveC == nullptr) { | ||||