| @@ -0,0 +1,72 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/ops/assert_op.h" | |||||
| #ifndef PRIMITIVE_WRITEABLE | |||||
| #include "src/ops/ops_register.h" | |||||
| #endif | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| int AssertOP::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||||
| if (this->primitive_ == nullptr) { | |||||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||||
| if (this->primitive_ == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| this->primitive_->value.type = schema::PrimitiveType_Assert; | |||||
| } | |||||
| if (this->primitive_->value.type != schema::PrimitiveType_Assert) { | |||||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| this->primitive_->value.value = new (std::nothrow) schema::AssertT(); | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| PopulaterQuantParam(prim, inputs); | |||||
| return RET_OK; | |||||
| } | |||||
| #else | |||||
| int AssertOP::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||||
| MS_ASSERT(nullptr != primitive); | |||||
| MS_ASSERT(nullptr != fbb); | |||||
| auto attr = primitive->value_as_Assert(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "value_as_Assert return nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto val_offset = schema::CreateAssert(*fbb); | |||||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Assert, val_offset.o); | |||||
| fbb->Finish(prim_offset); | |||||
| return RET_OK; | |||||
| } | |||||
| PrimitiveC *AssertCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<AssertOP>(primitive); } | |||||
| Registry AssertRegistry(schema::PrimitiveType_Assert, AssertCreator); | |||||
| #endif | |||||
| int AssertOP::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { return RET_OK; } | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,43 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef LITE_MINDSPORE_LITE_SRC_OPS_ASSERT_OP_H_ | |||||
| #define LITE_MINDSPORE_LITE_SRC_OPS_ASSERT_OP_H_ | |||||
| #include <vector> | |||||
| #include <set> | |||||
| #include <cmath> | |||||
| #include "src/ops/primitive_c.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class AssertOP : public PrimitiveC { | |||||
| public: | |||||
| AssertOP() = default; | |||||
| ~AssertOP() = default; | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| MS_DECLARE_PARENT(AssertOP, PrimitiveC); | |||||
| explicit AssertOP(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||||
| #else | |||||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||||
| #endif | |||||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // LITE_MINDSPORE_LITE_SRC_OPS_ASSERT_OP_H_ | |||||
| @@ -0,0 +1,37 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/ops/assert_op.h" | |||||
| #include "src/ops/primitive_c.h" | |||||
| #include "src/ops/populate/populate_register.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| OpParameter *PopulateAssertParameter(const mindspore::lite::PrimitiveC *primitive) { | |||||
| OpParameter *assert_parameter = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter))); | |||||
| if (assert_parameter == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc AssertParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(assert_parameter, 0, sizeof(OpParameter)); | |||||
| assert_parameter->type_ = primitive->Type(); | |||||
| return reinterpret_cast<OpParameter *>(assert_parameter); | |||||
| } | |||||
| Registry AssertParameterRegistry(schema::PrimitiveType_Assert, PopulateAssertParameter); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,77 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/runtime/kernel/arm/base/assert.h" | |||||
| #include "src/kernel_registry.h" | |||||
| #include "include/errorcode.h" | |||||
| using mindspore::lite::KernelRegistrar; | |||||
| using mindspore::lite::RET_ERROR; | |||||
| using mindspore::lite::RET_OK; | |||||
| using mindspore::schema::PrimitiveType_Assert; | |||||
| namespace mindspore::kernel { | |||||
| int AssertCPUKernel::Init() { return RET_OK; } | |||||
| int AssertCPUKernel::ReSize() { return RET_OK; } | |||||
| int AssertCPUKernel::Run() { | |||||
| auto cond = reinterpret_cast<bool *>(in_tensors_.front()->data_c()); | |||||
| if (*cond) { | |||||
| return RET_OK; | |||||
| } else { | |||||
| for (size_t i = 1; i < in_tensors_.size(); i++) { | |||||
| MS_LOG(ERROR) << in_tensors_.at(i)->ToString(); | |||||
| } | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| kernel::LiteKernel *CpuAssertKernelCreator(const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, OpParameter *parameter, | |||||
| const lite::InnerContext *ctx, const KernelKey &desc, | |||||
| const mindspore::lite::PrimitiveC *primitive) { | |||||
| if (parameter == nullptr) { | |||||
| MS_LOG(ERROR) << "parameter is nullptr"; | |||||
| return nullptr; | |||||
| } | |||||
| if (ctx == nullptr) { | |||||
| MS_LOG(ERROR) << "ctx is nullptr"; | |||||
| free(parameter); | |||||
| return nullptr; | |||||
| } | |||||
| MS_ASSERT(desc.type == PrimitiveType_Assert); | |||||
| auto *kernel = new (std::nothrow) AssertCPUKernel(parameter, inputs, outputs, ctx, primitive); | |||||
| if (kernel == nullptr) { | |||||
| MS_LOG(ERROR) << "Create kernel failed, name: " << parameter->name_; | |||||
| free(parameter); | |||||
| return nullptr; | |||||
| } | |||||
| auto ret = kernel->Init(); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_ | |||||
| << ", type: " << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(parameter->type_)); | |||||
| delete kernel; | |||||
| return nullptr; | |||||
| } | |||||
| return kernel; | |||||
| } | |||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Assert, CpuAssertKernelCreator) | |||||
| REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_Assert, CpuAssertKernelCreator) | |||||
| } // namespace mindspore::kernel | |||||
| @@ -0,0 +1,47 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_ASSERT_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_ASSERT_H_ | |||||
| #include <vector> | |||||
| #include "src/lite_kernel.h" | |||||
| namespace mindspore::kernel { | |||||
| typedef struct AssertParameter { | |||||
| OpParameter op_parameter_; | |||||
| } AssertParameter; | |||||
| class AssertCPUKernel : public LiteKernel { | |||||
| public: | |||||
| AssertCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||||
| const mindspore::lite::PrimitiveC *primitive) | |||||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive) { | |||||
| assert_param_ = reinterpret_cast<AssertParameter *>(op_parameter_); | |||||
| } | |||||
| ~AssertCPUKernel() override {} | |||||
| int Init() override; | |||||
| int ReSize() override; | |||||
| int Run() override; | |||||
| private: | |||||
| AssertParameter *assert_param_ = nullptr; | |||||
| }; | |||||
| } // namespace mindspore::kernel | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_ASSERT_H_ | |||||
| @@ -0,0 +1,68 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/converter/parser/tf/tf_assert_parser.h" | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include <map> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS TFAssertParser::Parse(const tensorflow::NodeDef &tf_op, | |||||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC, | |||||
| std::vector<std::string> *inputs, int *output_size) { | |||||
| MS_LOG(INFO) << "TF AssertParser"; | |||||
| if (primitiveC == nullptr || output_size == nullptr) { | |||||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "New PrimitiveT failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto attr = std::make_unique<schema::AssertT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new attr failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| tensorflow::AttrValue attr_value; | |||||
| if (!TensorFlowUtils::FindAttrValue(tf_op, "summarize", &attr_value)) { | |||||
| MS_LOG(ERROR) << "The keep_dims attr should be specified"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| attr->summarize = attr_value.i(); | |||||
| primitive->value.type = schema::PrimitiveType_Assert; | |||||
| primitive->value.value = attr.release(); | |||||
| *primitiveC = PrimitiveC::Create(primitive.release()); | |||||
| if (*primitiveC == nullptr) { | |||||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| *output_size = 1; | |||||
| auto status = AddOpInput(tf_op, 0, inputs); | |||||
| if (status != RET_OK) { | |||||
| return status; | |||||
| } | |||||
| return status; | |||||
| } | |||||
| TFNodeRegistrar g_tfAssertParser("Assert", new TFAssertParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,37 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ASSERT_PARSER_H_ | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ASSERT_PARSER_H_ | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include <map> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tf/tf_node_parser.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class TFAssertParser : public TFNodeParser { | |||||
| public: | |||||
| TFAssertParser() = default; | |||||
| ~TFAssertParser() override = default; | |||||
| STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ASSERT_PARSER_H_ | |||||