From 9fb9de9f5ffcdfff5733d43649638527acee790f Mon Sep 17 00:00:00 2001 From: yefeng Date: Tue, 1 Dec 2020 10:01:15 +0800 Subject: [PATCH] 1103_assert_7 --- mindspore/lite/src/ops/assert_op.cc | 72 +++++++++++++++++ mindspore/lite/src/ops/assert_op.h | 43 +++++++++++ .../lite/src/ops/populate/assert_populate.cc | 37 +++++++++ .../src/runtime/kernel/arm/base/assert.cc | 77 +++++++++++++++++++ .../lite/src/runtime/kernel/arm/base/assert.h | 47 +++++++++++ .../converter/parser/tf/tf_assert_parser.cc | 68 ++++++++++++++++ .../converter/parser/tf/tf_assert_parser.h | 37 +++++++++ 7 files changed, 381 insertions(+) create mode 100644 mindspore/lite/src/ops/assert_op.cc create mode 100644 mindspore/lite/src/ops/assert_op.h create mode 100644 mindspore/lite/src/ops/populate/assert_populate.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/base/assert.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/base/assert.h create mode 100644 mindspore/lite/tools/converter/parser/tf/tf_assert_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/tf/tf_assert_parser.h diff --git a/mindspore/lite/src/ops/assert_op.cc b/mindspore/lite/src/ops/assert_op.cc new file mode 100644 index 0000000000..fce3cd8b43 --- /dev/null +++ b/mindspore/lite/src/ops/assert_op.cc @@ -0,0 +1,72 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/assert_op.h" +#ifndef PRIMITIVE_WRITEABLE +#include "src/ops/ops_register.h" +#endif + +namespace mindspore { +namespace lite { +#ifdef PRIMITIVE_WRITEABLE + +int AssertOP::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + if (this->primitive_ == nullptr) { + this->primitive_ = new (std::nothrow) schema::PrimitiveT; + if (this->primitive_ == nullptr) { + MS_LOG(ERROR) << "new primitiveT failed"; + return RET_ERROR; + } + this->primitive_->value.type = schema::PrimitiveType_Assert; + } + if (this->primitive_->value.type != schema::PrimitiveType_Assert) { + MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; + return RET_ERROR; + } + if (this->primitive_->value.value == nullptr) { + this->primitive_->value.value = new (std::nothrow) schema::AssertT(); + if (this->primitive_->value.value == nullptr) { + MS_LOG(ERROR) << "new primitiveT value failed"; + return RET_ERROR; + } + } + PopulaterQuantParam(prim, inputs); + return RET_OK; +} + +#else +int AssertOP::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_Assert(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_Assert return nullptr"; + return RET_ERROR; + } + auto val_offset = schema::CreateAssert(*fbb); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Assert, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} + +PrimitiveC *AssertCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry AssertRegistry(schema::PrimitiveType_Assert, AssertCreator); +#endif + +int AssertOP::InferShape(std::vector inputs_, std::vector outputs_) { return RET_OK; } + +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/assert_op.h b/mindspore/lite/src/ops/assert_op.h new file mode 100644 index 0000000000..ba0399e07d --- /dev/null +++ b/mindspore/lite/src/ops/assert_op.h @@ -0,0 +1,43 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LITE_MINDSPORE_LITE_SRC_OPS_ASSERT_OP_H_ +#define LITE_MINDSPORE_LITE_SRC_OPS_ASSERT_OP_H_ + +#include +#include +#include +#include "src/ops/primitive_c.h" + +namespace mindspore { +namespace lite { +class AssertOP : public PrimitiveC { + public: + AssertOP() = default; + ~AssertOP() = default; +#ifdef PRIMITIVE_WRITEABLE + MS_DECLARE_PARENT(AssertOP, PrimitiveC); + explicit AssertOP(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; +#else + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; +#endif + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // LITE_MINDSPORE_LITE_SRC_OPS_ASSERT_OP_H_ diff --git a/mindspore/lite/src/ops/populate/assert_populate.cc b/mindspore/lite/src/ops/populate/assert_populate.cc new file mode 100644 index 0000000000..02db20243d --- /dev/null +++ b/mindspore/lite/src/ops/populate/assert_populate.cc @@ -0,0 +1,37 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/assert_op.h" +#include "src/ops/primitive_c.h" +#include "src/ops/populate/populate_register.h" + +namespace mindspore { +namespace lite { + +OpParameter *PopulateAssertParameter(const mindspore::lite::PrimitiveC *primitive) { + OpParameter *assert_parameter = reinterpret_cast(malloc(sizeof(OpParameter))); + if (assert_parameter == nullptr) { + MS_LOG(ERROR) << "malloc AssertParameter failed."; + return nullptr; + } + memset(assert_parameter, 0, sizeof(OpParameter)); + assert_parameter->type_ = primitive->Type(); + + return reinterpret_cast(assert_parameter); +} +Registry AssertParameterRegistry(schema::PrimitiveType_Assert, PopulateAssertParameter); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/runtime/kernel/arm/base/assert.cc b/mindspore/lite/src/runtime/kernel/arm/base/assert.cc new file mode 100644 index 0000000000..49250b4d46 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/base/assert.cc @@ -0,0 +1,77 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/base/assert.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Assert; + +namespace mindspore::kernel { + +int AssertCPUKernel::Init() { return RET_OK; } + +int AssertCPUKernel::ReSize() { return RET_OK; } + +int AssertCPUKernel::Run() { + auto cond = reinterpret_cast(in_tensors_.front()->data_c()); + if (*cond) { + return RET_OK; + } else { + for (size_t i = 1; i < in_tensors_.size(); i++) { + MS_LOG(ERROR) << in_tensors_.at(i)->ToString(); + } + return RET_ERROR; + } +} + +kernel::LiteKernel *CpuAssertKernelCreator(const std::vector &inputs, + const std::vector &outputs, OpParameter *parameter, + const lite::InnerContext *ctx, const KernelKey &desc, + const mindspore::lite::PrimitiveC *primitive) { + if (parameter == nullptr) { + MS_LOG(ERROR) << "parameter is nullptr"; + return nullptr; + } + if (ctx == nullptr) { + MS_LOG(ERROR) << "ctx is nullptr"; + free(parameter); + return nullptr; + } + MS_ASSERT(desc.type == PrimitiveType_Assert); + auto *kernel = new (std::nothrow) AssertCPUKernel(parameter, inputs, outputs, ctx, primitive); + if (kernel == nullptr) { + MS_LOG(ERROR) << "Create kernel failed, name: " << parameter->name_; + free(parameter); + return nullptr; + } + + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_ + << ", type: " << schema::EnumNamePrimitiveType(static_cast(parameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Assert, CpuAssertKernelCreator) +REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_Assert, CpuAssertKernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/base/assert.h b/mindspore/lite/src/runtime/kernel/arm/base/assert.h new file mode 100644 index 0000000000..5e7e109d02 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/base/assert.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_ASSERT_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_ASSERT_H_ + +#include +#include "src/lite_kernel.h" + +namespace mindspore::kernel { + +typedef struct AssertParameter { + OpParameter op_parameter_; +} AssertParameter; + +class AssertCPUKernel : public LiteKernel { + public: + AssertCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::InnerContext *ctx, + const mindspore::lite::PrimitiveC *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive) { + assert_param_ = reinterpret_cast(op_parameter_); + } + ~AssertCPUKernel() override {} + + int Init() override; + int ReSize() override; + int Run() override; + + private: + AssertParameter *assert_param_ = nullptr; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_ASSERT_H_ diff --git a/mindspore/lite/tools/converter/parser/tf/tf_assert_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_assert_parser.cc new file mode 100644 index 0000000000..f9da640ff0 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_assert_parser.cc @@ -0,0 +1,68 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_assert_parser.h" +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" + +namespace mindspore { +namespace lite { +STATUS TFAssertParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, PrimitiveC **primitiveC, + std::vector *inputs, int *output_size) { + MS_LOG(INFO) << "TF AssertParser"; + if (primitiveC == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_NULL_PTR; + } + + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "New PrimitiveT failed"; + return RET_NULL_PTR; + } + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new attr failed"; + return RET_NULL_PTR; + } + tensorflow::AttrValue attr_value; + if (!TensorFlowUtils::FindAttrValue(tf_op, "summarize", &attr_value)) { + MS_LOG(ERROR) << "The keep_dims attr should be specified"; + return RET_ERROR; + } + attr->summarize = attr_value.i(); + + primitive->value.type = schema::PrimitiveType_Assert; + primitive->value.value = attr.release(); + *primitiveC = PrimitiveC::Create(primitive.release()); + if (*primitiveC == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_ERROR; + } + + *output_size = 1; + auto status = AddOpInput(tf_op, 0, inputs); + if (status != RET_OK) { + return status; + } + return status; +} +TFNodeRegistrar g_tfAssertParser("Assert", new TFAssertParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_assert_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_assert_parser.h new file mode 100644 index 0000000000..cf00d1f997 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_assert_parser.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ASSERT_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ASSERT_PARSER_H_ + +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser.h" + +namespace mindspore { +namespace lite { +class TFAssertParser : public TFNodeParser { + public: + TFAssertParser() = default; + ~TFAssertParser() override = default; + + STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ASSERT_PARSER_H_