| @@ -227,6 +227,12 @@ union PrimitiveType { | |||
| Identity, | |||
| LayerNorm, | |||
| While, | |||
| ControlDepend, | |||
| UnsortedSegmentSum, | |||
| AssignAdd, | |||
| OnesLike, | |||
| BinaryCrossEntropyGrad, | |||
| BinaryCrossEntropy | |||
| } | |||
| enum QuantType: int { | |||
| @@ -966,6 +966,8 @@ table Adam { | |||
| table Assign { | |||
| } | |||
| table AssignAdd { | |||
| } | |||
| table Where{ | |||
| condition: [bool]; | |||
| @@ -1010,6 +1012,9 @@ table ToFormat { | |||
| table Depend { | |||
| } | |||
| table ControlDepend { | |||
| } | |||
| table Return { | |||
| } | |||
| @@ -1108,3 +1113,18 @@ table While { | |||
| bodySubgraphIndex : int; | |||
| } | |||
| table UnsortedSegmentSum { | |||
| numSegments : int; | |||
| } | |||
| table OnesLike { | |||
| } | |||
| table BinaryCrossEntropy { | |||
| reduction : int = 1; | |||
| } | |||
| table BinaryCrossEntropyGrad { | |||
| reduction : int = 1; | |||
| } | |||
| @@ -73,7 +73,7 @@ Registry AdamRegistry(schema::PrimitiveType_Adam, AdamCreator); | |||
| int Adam::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) { | |||
| if (10 != inputs.size()) { | |||
| MS_LOG(ERROR) << "Adam should have at least 8 input tensors"; | |||
| MS_LOG(ERROR) << "Adam should have at 10 input tensors"; | |||
| return RET_ERROR; | |||
| } | |||
| @@ -0,0 +1,81 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/ops/assign_add.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| int AssignAdd::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||
| if (this->primitive_ == nullptr) { | |||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||
| if (this->primitive_ == nullptr) { | |||
| MS_LOG(ERROR) << "new primitive error"; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.type = schema::PrimitiveType_AssignAdd; | |||
| } | |||
| if (this->primitive_->value.type != schema::PrimitiveType_AssignAdd) { | |||
| MS_LOG(ERROR) << "PrimitiveType_AssignAdd primitive value type : " | |||
| << schema::EnumNamePrimitiveType(primitive_->value.type) << "is not equal" | |||
| << schema::EnumNamePrimitiveType(schema::PrimitiveType_AssignAdd); | |||
| delete this->primitive_; | |||
| return RET_ERROR; | |||
| } | |||
| if (this->primitive_->value.value == nullptr) { | |||
| this->primitive_->value.value = new (std::nothrow) schema::AssignAddT(); | |||
| if (this->primitive_->value.value == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||
| delete this->primitive_; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| #else | |||
| int AssignAdd::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||
| MS_ASSERT(nullptr != primitive); | |||
| MS_ASSERT(nullptr != fbb); | |||
| auto attr = primitive->value_as_AssignAdd(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "value_as_AssignAdd return nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| auto val_offset = schema::CreateAssignAdd(*fbb); | |||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_AssignAdd, val_offset.o); | |||
| fbb->Finish(prim_offset); | |||
| return RET_OK; | |||
| } | |||
| #endif | |||
| int AssignAdd::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | |||
| Tensor *x = inputs_[0]; | |||
| Tensor *y = inputs_[1]; | |||
| Tensor *out = outputs_[0]; | |||
| std::vector<int> x_shape = x->shape(); | |||
| if (x->data_type() != y->data_type()) { | |||
| MS_LOG(ERROR) << "no matched shape of x and y"; | |||
| return RET_ERROR; | |||
| } | |||
| std::vector<int> output_shape(x_shape.size()); | |||
| for (int i = 0; i < static_cast<int>(x_shape.size()); i++) { | |||
| output_shape[i] = x_shape[i]; | |||
| } | |||
| out->set_shape(output_shape); | |||
| out->SetFormat(x->GetFormat()); | |||
| out->set_data_type(x->data_type()); | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,41 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <vector> | |||
| #include <set> | |||
| #include <cmath> | |||
| #include "src/ops/primitive_c.h" | |||
| #ifndef LITE_SRC_OPS_ASSIGN_ADD_H_ | |||
| #define LITE_SRC_OPS_ASSIGN_ADD_H_ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class AssignAdd : public PrimitiveC { | |||
| public: | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| MS_DECLARE_PARENT(AssignAdd, PrimitiveC); | |||
| AssignAdd() = default; | |||
| explicit AssignAdd(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||
| #else | |||
| AssignAdd() = default; | |||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||
| #endif | |||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // LITE_SRC_OPS_ASSIGN_ADD_H_ | |||
| @@ -47,7 +47,12 @@ int BiasAdd::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &in | |||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||
| return RET_ERROR; | |||
| } | |||
| attr->axis = {0}; | |||
| if (prim.GetAttr("axis") == nullptr) { | |||
| MS_LOG(WARNING) << "get axis failed"; | |||
| attr->axis = {1}; | |||
| } else { | |||
| attr->axis = GetValue<std::vector<int>>(prim.GetAttr("axis")); | |||
| } | |||
| this->primitive_->value.value = attr; | |||
| if (this->primitive_->value.value == nullptr) { | |||
| MS_LOG(ERROR) << "primitive value is nullptr"; | |||
| @@ -0,0 +1,106 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <string> | |||
| #include "src/ops/binary_cross_entropy.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| int BinaryCrossEntropy::GetReduction() const { return this->primitive_->value.AsBinaryCrossEntropy()->reduction; } | |||
| int BinaryCrossEntropy::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||
| if (this->primitive_ == nullptr) { | |||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||
| if (this->primitive_ == nullptr) { | |||
| MS_LOG(ERROR) << "new primitive error"; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.type = schema::PrimitiveType_BinaryCrossEntropy; | |||
| } | |||
| if (this->primitive_->value.type != schema::PrimitiveType_BinaryCrossEntropy) { | |||
| MS_LOG(ERROR) << "PrimitiveType_BinaryCrossEntropy primitive value type : " | |||
| << schema::EnumNamePrimitiveType(primitive_->value.type) << "is not equal" | |||
| << schema::EnumNamePrimitiveType(schema::PrimitiveType_BinaryCrossEntropy); | |||
| delete this->primitive_; | |||
| return RET_ERROR; | |||
| } | |||
| if (this->primitive_->value.value == nullptr) { | |||
| schema::BinaryCrossEntropyT *attr = new (std::nothrow) schema::BinaryCrossEntropyT(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new binary cross entropy attr failed!"; | |||
| delete this->primitive_; | |||
| return RET_ERROR; | |||
| } | |||
| // default is mean | |||
| string reduction = "mean"; | |||
| if (prim.GetAttr("reduction") == nullptr) { | |||
| MS_LOG(ERROR) << "get reduction failed!"; | |||
| delete this->primitive_; | |||
| delete attr; | |||
| return RET_ERROR; | |||
| } else { | |||
| reduction = GetValue<string>(prim.GetAttr("reduction")); | |||
| } | |||
| if (reduction == "none") { | |||
| attr->reduction = 0; | |||
| } else if (reduction == "sum") { | |||
| attr->reduction = 2; | |||
| } else { | |||
| // default is mean | |||
| attr->reduction = 1; | |||
| } | |||
| this->primitive_->value.value = attr; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| #else | |||
| int BinaryCrossEntropy::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||
| MS_ASSERT(nullptr != primitive); | |||
| MS_ASSERT(nullptr != fbb); | |||
| auto attr = primitive->value_as_BinaryCrossEntropy(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "value_as_BinaryCrossEntropy return nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| int reduction = attr->reduction(); | |||
| auto val_offset = schema::CreateBinaryCrossEntropy(*fbb, reduction); | |||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_BinaryCrossEntropy, val_offset.o); | |||
| fbb->Finish(prim_offset); | |||
| return RET_OK; | |||
| } | |||
| int BinaryCrossEntropy::GetReduction() const { return this->primitive_->value_as_BinaryCrossEntropy()->reduction(); } | |||
| #endif | |||
| int BinaryCrossEntropy::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | |||
| Tensor *x = inputs_[0]; | |||
| Tensor *out = outputs_[0]; | |||
| out->SetFormat(x->GetFormat()); | |||
| out->set_data_type(x->data_type()); | |||
| int reduction = GetReduction(); | |||
| if (reduction == 1 || reduction == 2) { | |||
| out->set_shape({1}); | |||
| } else { | |||
| std::vector<int> x_shape = x->shape(); | |||
| std::vector<int> output_shape(x_shape.size()); | |||
| output_shape.assign(x_shape.begin(), x_shape.end()); | |||
| out->set_shape(output_shape); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,49 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <vector> | |||
| #include <set> | |||
| #include <cmath> | |||
| #include "src/ops/primitive_c.h" | |||
| #ifndef LITE_SRC_OPS_BINARYCROSSENTROPY_H_ | |||
| #define LITE_SRC_OPS_BINARYCROSSENTROPY_H_ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class BinaryCrossEntropy : public PrimitiveC { | |||
| public: | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| MS_DECLARE_PARENT(BinaryCrossEntropy, PrimitiveC); | |||
| BinaryCrossEntropy() = default; | |||
| explicit BinaryCrossEntropy(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||
| int GetReduction() const; | |||
| #else | |||
| BinaryCrossEntropy() = default; | |||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||
| int GetReduction() const; | |||
| #endif | |||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // LITE_SRC_OPS_BINARYCROSSENTROPY_H_ | |||
| @@ -0,0 +1,108 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <string> | |||
| #include "src/ops/binary_cross_entropy_grad.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| int BinaryCrossEntropyGrad::GetReduction() const { | |||
| return this->primitive_->value.AsBinaryCrossEntropyGrad()->reduction; | |||
| } | |||
| int BinaryCrossEntropyGrad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||
| if (this->primitive_ == nullptr) { | |||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||
| if (this->primitive_ == nullptr) { | |||
| MS_LOG(ERROR) << "new primitive error"; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.type = schema::PrimitiveType_BinaryCrossEntropyGrad; | |||
| } | |||
| if (this->primitive_->value.type != schema::PrimitiveType_BinaryCrossEntropyGrad) { | |||
| MS_LOG(ERROR) << "PrimitiveType_BinaryCrossEntropyGrad primitive value type : " | |||
| << schema::EnumNamePrimitiveType(primitive_->value.type) << "is not equal" | |||
| << schema::EnumNamePrimitiveType(schema::PrimitiveType_BinaryCrossEntropyGrad); | |||
| delete this->primitive_; | |||
| return RET_ERROR; | |||
| } | |||
| if (this->primitive_->value.value == nullptr) { | |||
| schema::BinaryCrossEntropyGradT *attr = new (std::nothrow) schema::BinaryCrossEntropyGradT(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new binary cross entropy attr failed!"; | |||
| delete this->primitive_; | |||
| return RET_ERROR; | |||
| } | |||
| // default is mean | |||
| string reduction = "mean"; | |||
| if (prim.GetAttr("reduction") == nullptr) { | |||
| MS_LOG(ERROR) << "get reduction failed!"; | |||
| delete this->primitive_; | |||
| delete attr; | |||
| return RET_ERROR; | |||
| } else { | |||
| reduction = GetValue<string>(prim.GetAttr("reduction")); | |||
| } | |||
| if (reduction == "none") { | |||
| attr->reduction = 0; | |||
| } else if (reduction == "sum") { | |||
| attr->reduction = 2; | |||
| } else { | |||
| // default is mean | |||
| attr->reduction = 1; | |||
| } | |||
| this->primitive_->value.value = attr; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| #else | |||
| int BinaryCrossEntropyGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, | |||
| flatbuffers::FlatBufferBuilder *fbb) { | |||
| MS_ASSERT(nullptr != primitive); | |||
| MS_ASSERT(nullptr != fbb); | |||
| auto attr = primitive->value_as_BinaryCrossEntropyGrad(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "value_as_BinaryCrossEntropyGrad return nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| int reduction = attr->reduction(); | |||
| auto val_offset = schema::CreateBinaryCrossEntropyGrad(*fbb, reduction); | |||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_BinaryCrossEntropyGrad, val_offset.o); | |||
| fbb->Finish(prim_offset); | |||
| return RET_OK; | |||
| } | |||
| int BinaryCrossEntropyGrad::GetReduction() const { | |||
| return this->primitive_->value_as_BinaryCrossEntropyGrad()->reduction(); | |||
| } | |||
| #endif | |||
| int BinaryCrossEntropyGrad::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | |||
| Tensor *x = inputs_[0]; | |||
| Tensor *out = outputs_[0]; | |||
| out->SetFormat(x->GetFormat()); | |||
| out->set_data_type(x->data_type()); | |||
| std::vector<int> x_shape = x->shape(); | |||
| std::vector<int> output_shape(x_shape.size()); | |||
| output_shape.assign(x_shape.begin(), x_shape.end()); | |||
| out->set_shape(output_shape); | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,49 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <vector> | |||
| #include <set> | |||
| #include <cmath> | |||
| #include "src/ops/primitive_c.h" | |||
| #ifndef LITE_SRC_OPS_BINARY_CROSS_ENTROPY_GRAD_H_ | |||
| #define LITE_SRC_OPS_BINARY_CROSS_ENTROPY_GRAD_H_ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class BinaryCrossEntropyGrad : public PrimitiveC { | |||
| public: | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| MS_DECLARE_PARENT(BinaryCrossEntropyGrad, PrimitiveC); | |||
| BinaryCrossEntropyGrad() = default; | |||
| explicit BinaryCrossEntropyGrad(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||
| int GetReduction() const; | |||
| #else | |||
| BinaryCrossEntropyGrad() = default; | |||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||
| int GetReduction() const; | |||
| #endif | |||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // LITE_SRC_OPS_BINARY_CROSS_ENTROPY_GRAD_H_ | |||
| @@ -0,0 +1,59 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/ops/control_depend.h" | |||
| #include <vector> | |||
| #include <memory> | |||
| namespace mindspore { | |||
| namespace lite { | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| int ControlDepend::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||
| if (this->primitive_ == nullptr) { | |||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||
| if (this->primitive_ == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.type = schema::PrimitiveType_ControlDepend; | |||
| } | |||
| if (this->primitive_->value.type != schema::PrimitiveType_ControlDepend) { | |||
| MS_LOG(ERROR) << "primitive_ type is error:" << this->primitive_->value.type; | |||
| delete this->primitive_; | |||
| return RET_ERROR; | |||
| } | |||
| if (this->primitive_->value.value == nullptr) { | |||
| auto attr = new (std::nothrow)(schema::ControlDependT); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "attr is nullptr"; | |||
| delete this->primitive_; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.value = attr; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| #else | |||
| int ControlDepend::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||
| MS_ASSERT(nullptr != primitive); | |||
| MS_ASSERT(nullptr != fbb); | |||
| auto val_offset = schema::CreateControlDepend(*fbb); | |||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_ControlDepend, val_offset.o); | |||
| fbb->Finish(prim_offset); | |||
| return RET_OK; | |||
| } | |||
| #endif | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,40 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef LITE_MINDSPORE_LITE_SRC_OPS_CONTROL_DEPEND_H_ | |||
| #define LITE_MINDSPORE_LITE_SRC_OPS_CONTROL_DEPEND_H_ | |||
| #include <vector> | |||
| #include "src/ops/primitive_c.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class ControlDepend : public PrimitiveC { | |||
| public: | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| MS_DECLARE_PARENT(ControlDepend, PrimitiveC); | |||
| ControlDepend() = default; | |||
| explicit ControlDepend(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||
| #else | |||
| ControlDepend() = default; | |||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||
| #endif | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // LITE_MINDSPORE_LITE_SRC_OPS_CONTROL_DEPEND_H_ | |||
| @@ -27,6 +27,43 @@ int ExpandDims::GetDim() const { return this->primitive_->value.AsExpandDims()-> | |||
| void ExpandDims::SetDim(int dim) { this->primitive_->value.AsExpandDims()->dim = dim; } | |||
| int ExpandDims::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||
| if (this->primitive_ == nullptr) { | |||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||
| if (this->primitive_ == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.type = schema::PrimitiveType_ExpandDims; | |||
| } | |||
| if (this->primitive_->value.type != schema::PrimitiveType_ExpandDims) { | |||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||
| delete this->primitive_; | |||
| return RET_ERROR; | |||
| } | |||
| if (this->primitive_->value.value == nullptr) { | |||
| auto attr = new (std::nothrow) schema::ExpandDimsT(); | |||
| if (attr == nullptr) { | |||
| delete this->primitive_; | |||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||
| return RET_ERROR; | |||
| } | |||
| // use axis instead of dim | |||
| if (inputs[1]->isa<ValueNode>()) { | |||
| auto axis_tensor = inputs[1]->cast<ValueNodePtr>(); | |||
| int axis = GetValue<int>(axis_tensor->value()); | |||
| attr->dim = axis; | |||
| } else { | |||
| MS_LOG(ERROR) << "input axis is not value node."; | |||
| delete this->primitive_; | |||
| delete attr; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.value = attr; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| #else | |||
| int ExpandDims::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||
| MS_ASSERT(nullptr != primitive); | |||
| @@ -56,9 +93,6 @@ int ExpandDims::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> | |||
| MS_ASSERT(input != nullptr); | |||
| auto output = outputs_.front(); | |||
| MS_ASSERT(output != nullptr); | |||
| if (inputs_.size() != kSingleNum) { | |||
| MS_LOG(ERROR) << "input size is invalid"; | |||
| } | |||
| if (outputs_.size() != kSingleNum) { | |||
| MS_LOG(ERROR) << "output size is invalid"; | |||
| } | |||
| @@ -31,6 +31,7 @@ class ExpandDims : public PrimitiveC { | |||
| ExpandDims() = default; | |||
| explicit ExpandDims(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||
| void SetDim(int dim); | |||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||
| #else | |||
| ExpandDims() = default; | |||
| @@ -31,7 +31,35 @@ int Gather::GetBatchDims() const { return this->primitive_->value.AsGather()->ba | |||
| void Gather::SetAxis(int axis) { this->primitive_->value.AsGather()->axis = axis; } | |||
| void Gather::SetBatchDims(int batch_dims) { this->primitive_->value.AsGather()->batchDims = batch_dims; } | |||
| int Gather::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||
| if (this->primitive_ == nullptr) { | |||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||
| if (this->primitive_ == nullptr) { | |||
| MS_LOG(ERROR) << "new primitive error"; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.type = schema::PrimitiveType_Gather; | |||
| } | |||
| if (this->primitive_->value.type != schema::PrimitiveType_Gather) { | |||
| MS_LOG(ERROR) << "Gather primitive value type : " << schema::EnumNamePrimitiveType(primitive_->value.type) | |||
| << "is not equal" << schema::EnumNamePrimitiveType(schema::PrimitiveType_Gather); | |||
| delete this->primitive_; | |||
| return RET_ERROR; | |||
| } | |||
| if (this->primitive_->value.value == nullptr) { | |||
| auto gather_attr = new (std::nothrow) schema::GatherT(); | |||
| if (gather_attr == nullptr) { | |||
| MS_LOG(ERROR) << "new primitive value.value error"; | |||
| delete this->primitive_; | |||
| delete gather_attr; | |||
| return RET_ERROR; | |||
| } | |||
| gather_attr->axis = GetValue<int>(prim.GetAttr("axis")); | |||
| gather_attr->batchDims = GetValue<int>(prim.GetAttr("batchDims")); | |||
| this->primitive_->value.value = gather_attr; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| #else | |||
| int Gather::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||
| MS_ASSERT(nullptr != primitive); | |||
| @@ -33,6 +33,7 @@ class Gather : public PrimitiveC { | |||
| explicit Gather(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||
| void SetAxis(int axis); | |||
| void SetBatchDims(int batch_dims); | |||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||
| #else | |||
| Gather() = default; | |||
| @@ -0,0 +1,75 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/ops/oneslike.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| int OnesLike::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||
| if (this->primitive_ == nullptr) { | |||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||
| if (this->primitive_ == nullptr) { | |||
| MS_LOG(ERROR) << "new primitive error"; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.type = schema::PrimitiveType_OnesLike; | |||
| } | |||
| if (this->primitive_->value.type != schema::PrimitiveType_OnesLike) { | |||
| MS_LOG(ERROR) << "PrimitiveType_OnesLike primitive value type : " | |||
| << schema::EnumNamePrimitiveType(primitive_->value.type) << "is not equal" | |||
| << schema::EnumNamePrimitiveType(schema::PrimitiveType_OnesLike); | |||
| delete this->primitive_; | |||
| return RET_ERROR; | |||
| } | |||
| if (this->primitive_->value.value == nullptr) { | |||
| this->primitive_->value.value = new (std::nothrow) schema::OnesLikeT(); | |||
| if (this->primitive_->value.value == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||
| delete this->primitive_; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| #else | |||
| int OnesLike::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||
| MS_ASSERT(nullptr != primitive); | |||
| MS_ASSERT(nullptr != fbb); | |||
| auto attr = primitive->value_as_OnesLike(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "value_as_OnesLike return nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| auto val_offset = schema::CreateOnesLike(*fbb); | |||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_OnesLike, val_offset.o); | |||
| fbb->Finish(prim_offset); | |||
| return RET_OK; | |||
| } | |||
| #endif | |||
| int OnesLike::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | |||
| Tensor *x = inputs_[0]; | |||
| Tensor *out = outputs_[0]; | |||
| std::vector<int> x_shape = x->shape(); | |||
| std::vector<int> output_shape(x_shape.size()); | |||
| output_shape.assign(x_shape.begin(), x_shape.end()); | |||
| out->set_shape(output_shape); | |||
| out->SetFormat(x->GetFormat()); | |||
| out->set_data_type(x->data_type()); | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,42 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <vector> | |||
| #include <set> | |||
| #include <cmath> | |||
| #include "src/ops/primitive_c.h" | |||
| #ifndef LITE_SRC_OPS_ONESLIKE_H_ | |||
| #define LITE_SRC_OPS_ONESLIKE_H_ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class OnesLike : public PrimitiveC { | |||
| public: | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| MS_DECLARE_PARENT(OnesLike, PrimitiveC); | |||
| OnesLike() = default; | |||
| explicit OnesLike(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||
| #else | |||
| OnesLike() = default; | |||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||
| #endif | |||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // LITE_SRC_OPS_ONESLIKE_H_ | |||
| @@ -31,6 +31,51 @@ void Power::SetPower(float power) { this->primitive_->value.AsPower()->power = p | |||
| void Power::SetScale(float scale) { this->primitive_->value.AsPower()->scale = scale; } | |||
| void Power::SetShift(float shift) { this->primitive_->value.AsPower()->shift = shift; } | |||
| int Power::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||
| if (this->primitive_ == nullptr) { | |||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||
| if (this->primitive_ == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.type = schema::PrimitiveType_Power; | |||
| } | |||
| if (this->primitive_->value.type != schema::PrimitiveType_Power) { | |||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||
| delete this->primitive_; | |||
| return RET_ERROR; | |||
| } | |||
| if (this->primitive_->value.value == nullptr) { | |||
| auto attr = new (std::nothrow) schema::PowerT(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||
| delete this->primitive_; | |||
| return RET_ERROR; | |||
| } | |||
| if (prim.GetAttr("scale") == nullptr) { | |||
| MS_LOG(WARNING) << "get scale failed"; | |||
| attr->scale = 1.0f; | |||
| } else { | |||
| attr->scale = GetValue<float>(prim.GetAttr("scale")); | |||
| } | |||
| if (prim.GetAttr("power") == nullptr) { | |||
| MS_LOG(WARNING) << "get power failed"; | |||
| attr->power = 1.0f; | |||
| } else { | |||
| attr->power = GetValue<float>(prim.GetAttr("power")); | |||
| } | |||
| if (prim.GetAttr("shift") == nullptr) { | |||
| MS_LOG(WARNING) << "get shift failed"; | |||
| attr->shift = 0; | |||
| } else { | |||
| attr->shift = GetValue<float>(prim.GetAttr("shift")); | |||
| } | |||
| this->primitive_->value.value = attr; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| #else | |||
| float Power::GetPower() const { return this->primitive_->value_as_Power()->power(); } | |||
| @@ -34,6 +34,7 @@ class Power : public PrimitiveC { | |||
| void SetPower(float power); | |||
| void SetScale(float scale); | |||
| void SetShift(float shift); | |||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||
| #else | |||
| Power() = default; | |||
| @@ -145,6 +145,8 @@ | |||
| #include "src/ops/identity.h" | |||
| #include "src/ops/instance_norm.h" | |||
| #include "src/ops/while.h" | |||
| #include "src/ops/oneslike.h" | |||
| #include "src/ops/unsorted_segment_sum.h" | |||
| #ifdef SUPPORT_TRAIN | |||
| #include "src/ops/neg_grad.h" | |||
| @@ -165,6 +167,10 @@ | |||
| #include "src/ops/sgd.h" | |||
| #include "src/ops/adam.h" | |||
| #include "src/ops/assign.h" | |||
| #include "src/ops/control_depend.h" | |||
| #include "src/ops/assign_add.h" | |||
| #include "src/ops/binary_cross_entropy.h" | |||
| #include "src/ops/binary_cross_entropy_grad.h" | |||
| #endif | |||
| #endif | |||
| @@ -504,6 +510,18 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std: | |||
| return NewPrimitiveC<While>(prim, inputs, quantType); | |||
| } else if (op_type == "OneHot") { | |||
| return NewPrimitiveC<OneHot>(prim, inputs, quantType); | |||
| } else if (op_type == "GatherV2") { | |||
| return NewPrimitiveC<Gather>(prim, inputs, quantType); | |||
| } else if (op_type == "OnesLike") { | |||
| return NewPrimitiveC<OnesLike>(prim, inputs, quantType); | |||
| } else if (op_type == "Pow") { | |||
| return NewPrimitiveC<Power>(prim, inputs, quantType); | |||
| } else if (op_type == "Sub") { | |||
| return NewPrimitiveC<Sub>(prim, inputs, quantType); | |||
| } else if (op_type == "ExpandDims") { | |||
| return NewPrimitiveC<ExpandDims>(prim, inputs, quantType); | |||
| } else if (op_type == "UnsortedSegmentSum") { | |||
| return NewPrimitiveC<UnsortedSegmentSum>(prim, inputs, quantType); | |||
| #ifdef SUPPORT_TRAIN | |||
| } else if (op_type == "SoftmaxCrossEntropyWithLogits") { | |||
| @@ -514,6 +532,8 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std: | |||
| return NewPrimitiveC<ApplyMomentum>(prim, inputs, quantType); | |||
| } else if (op_type == "Depend") { | |||
| return NewPrimitiveC<Depend>(prim, inputs, quantType); | |||
| } else if (op_type == "ControlDepend") { | |||
| return NewPrimitiveC<ControlDepend>(prim, inputs, quantType); | |||
| } else if ((op_type == "ReluGrad" || op_type == "ReLU6Grad" || op_type == "SigmoidGrad" || | |||
| op_type == "HSigmoidGrad" || op_type == "HSwishGrad")) { | |||
| return NewPrimitiveC<ActivationGrad>(prim, inputs, quantType); | |||
| @@ -539,6 +559,12 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std: | |||
| return NewPrimitiveC<Adam>(prim, inputs, quantType); | |||
| } else if (op_type == "Assign") { | |||
| return NewPrimitiveC<Assign>(prim, inputs, quantType); | |||
| } else if (op_type == "AssignAdd") { | |||
| return NewPrimitiveC<AssignAdd>(prim, inputs, quantType); | |||
| } else if (op_type == "BinaryCrossEntropy") { | |||
| return NewPrimitiveC<BinaryCrossEntropy>(prim, inputs, quantType); | |||
| } else if (op_type == "BinaryCrossEntropyGrad") { | |||
| return NewPrimitiveC<BinaryCrossEntropyGrad>(prim, inputs, quantType); | |||
| #else | |||
| } else if (op_type == "Conv2DBackpropInput") { | |||
| return NewPrimitiveC<DeConv2D>(prim, inputs, quantType); | |||
| @@ -830,6 +856,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { | |||
| return new PowerGrad(primitive); | |||
| case schema::PrimitiveType_Depend: | |||
| return new Depend(primitive); | |||
| case schema::PrimitiveType_ControlDepend: | |||
| return new ControlDepend(primitive); | |||
| case schema::PrimitiveType_FlattenGrad: | |||
| return new FlattenGrad(primitive); | |||
| case schema::PrimitiveType_NegGrad: | |||
| @@ -842,6 +870,17 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { | |||
| return new Adam(primitive); | |||
| case schema::PrimitiveType_Assign: | |||
| return new Assign(primitive); | |||
| case schema::PrimitiveType_AssignAdd: | |||
| return new AssignAdd(primitive); | |||
| case schema::PrimitiveType_OnesLike: | |||
| return new OnesLike(primitive); | |||
| case schema::PrimitiveType_UnsortedSegmentSum: | |||
| return new UnsortedSegmentSum(primitive); | |||
| case schema::PrimitiveType_BinaryCrossEntropyGrad: | |||
| return new BinaryCrossEntropyGrad(primitive); | |||
| case schema::PrimitiveType_BinaryCrossEntropy: | |||
| return new BinaryCrossEntropy(primitive); | |||
| #endif | |||
| default: | |||
| MS_LOG(ERROR) << "Unsupported primitive type in Create : " << schema::EnumNamePrimitiveType(op_type); | |||
| @@ -86,6 +86,9 @@ int Reduce::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp | |||
| MS_ASSERT(elem != nullptr); | |||
| attr->axes.emplace_back(elem->value()); | |||
| } | |||
| } else { | |||
| int axes_item = GetValue<int>(value); | |||
| attr->axes.push_back(axes_item); | |||
| } | |||
| } | |||
| } | |||
| @@ -62,6 +62,9 @@ int Reshape::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &in | |||
| MS_ASSERT(elem != nullptr); | |||
| attr->shape.emplace_back(static_cast<int>(elem->value())); | |||
| } | |||
| } else { | |||
| int dim = GetValue<int>(val); | |||
| attr->shape = {dim}; | |||
| } | |||
| } | |||
| if (attr == nullptr) { | |||
| @@ -46,12 +46,14 @@ int Squeeze::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &in | |||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||
| return RET_ERROR; | |||
| } | |||
| attr->axis = GetValue<std::vector<int>>(prim.GetAttr("axis")); | |||
| this->primitive_->value.value = attr; | |||
| if (this->primitive_->value.value == nullptr) { | |||
| MS_LOG(ERROR) << "primitive value is nullptr"; | |||
| return RET_ERROR; | |||
| if (prim.GetAttr("axis") == nullptr) { | |||
| MS_LOG(WARNING) << "get axis failed"; | |||
| attr->axis = {0}; | |||
| } else { | |||
| int axis = GetValue<int>(prim.GetAttr("axis")); | |||
| attr->axis = {axis}; | |||
| } | |||
| this->primitive_->value.value = attr; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -29,6 +29,34 @@ void Sub::SetActivationType(int activation_type) { | |||
| this->primitive_->value.AsSub()->activationType = (schema::ActivationType)activation_type; | |||
| } | |||
| int Sub::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||
| if (this->primitive_ == nullptr) { | |||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||
| if (this->primitive_ == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.type = schema::PrimitiveType_Sub; | |||
| } | |||
| if (this->primitive_->value.type != schema::PrimitiveType_Sub) { | |||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||
| delete this->primitive_; | |||
| return RET_ERROR; | |||
| } | |||
| if (this->primitive_->value.value == nullptr) { | |||
| auto attr = new (std::nothrow) schema::SubT(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||
| delete this->primitive_; | |||
| return RET_ERROR; | |||
| } | |||
| // todo: confirm the activationType | |||
| attr->activationType = schema::ActivationType_NO_ACTIVATION; | |||
| this->primitive_->value.value = attr; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| #else | |||
| int Sub::GetActivationType() const { return this->primitive_->value_as_Sub()->activationType(); } | |||
| @@ -32,6 +32,7 @@ class Sub : public Arithmetic { | |||
| Sub() = default; | |||
| explicit Sub(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} | |||
| void SetActivationType(int activation_type); | |||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||
| #else | |||
| Sub() = default; | |||
| @@ -52,6 +52,12 @@ int Tile::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &input | |||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||
| return RET_ERROR; | |||
| } | |||
| if (prim.GetAttr("dims") == nullptr) { | |||
| MS_LOG(WARNING) << "get dims failed"; | |||
| attr->dims = {1}; | |||
| } else { | |||
| attr->dims = GetValue<std::vector<int>>(prim.GetAttr("dims")); | |||
| } | |||
| if (inputs.size() == kAnfPopulaterTwo) { | |||
| auto inputNode = inputs[kAnfPopulaterOne]; | |||
| MS_ASSERT(inputNode != nullptr); | |||
| @@ -68,6 +74,9 @@ int Tile::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &input | |||
| MS_ASSERT(elem != nullptr); | |||
| attr->multiples.emplace_back(elem->value()); | |||
| } | |||
| } else { | |||
| int multiple = GetValue<int>(value); | |||
| attr->multiples = {multiple}; | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,100 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <memory> | |||
| #include "src/ops/unsorted_segment_sum.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| int UnsortedSegmentSum::GetNumSegments() const { return this->primitive_->value.AsUnsortedSegmentSum()->numSegments; } | |||
| int UnsortedSegmentSum::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||
| if (this->primitive_ == nullptr) { | |||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||
| if (this->primitive_ == nullptr) { | |||
| MS_LOG(ERROR) << "new primitive error"; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.type = schema::PrimitiveType_UnsortedSegmentSum; | |||
| } | |||
| if (this->primitive_->value.type != schema::PrimitiveType_UnsortedSegmentSum) { | |||
| MS_LOG(ERROR) << "UnSortedSegmentSum primitive value type : " | |||
| << schema::EnumNamePrimitiveType(primitive_->value.type) << "is not equal" | |||
| << schema::EnumNamePrimitiveType(schema::PrimitiveType_UnsortedSegmentSum); | |||
| delete this->primitive_; | |||
| return RET_ERROR; | |||
| } | |||
| if (this->primitive_->value.value == nullptr) { | |||
| std::unique_ptr<schema::UnsortedSegmentSumT> attr = std::make_unique<schema::UnsortedSegmentSumT>(); | |||
| if (inputs[2]->isa<ValueNode>()) { | |||
| ValuePtr value = inputs[2]->cast<ValueNodePtr>()->value(); | |||
| attr->numSegments = GetValue<int>(value); | |||
| this->primitive_->value.value = attr.release(); | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| #else | |||
| int UnsortedSegmentSum::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||
| MS_ASSERT(nullptr != primitive); | |||
| MS_ASSERT(nullptr != fbb); | |||
| auto attr = primitive->value_as_UnsortedSegmentSum(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "value_as_UnsortedSegmentSum return nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| int num_segments = attr->numSegments(); | |||
| auto val_offset = schema::CreateUnsortedSegmentSum(*fbb, num_segments); | |||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_UnsortedSegmentSum, val_offset.o); | |||
| fbb->Finish(prim_offset); | |||
| return RET_OK; | |||
| } | |||
| int UnsortedSegmentSum::GetNumSegments() const { | |||
| int ret = this->primitive_->value_as_UnsortedSegmentSum()->numSegments(); | |||
| return ret; | |||
| } | |||
| #endif | |||
| int UnsortedSegmentSum::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | |||
| // check inputs and outputs | |||
| if (inputs_.size() != 3) { | |||
| MS_LOG(ERROR) << "invalid inputs numbers"; | |||
| return RET_ERROR; | |||
| } | |||
| if (outputs_.size() != 1) { | |||
| MS_LOG(ERROR) << "invalid outputs numbers"; | |||
| return RET_ERROR; | |||
| } | |||
| Tensor *out = outputs_.front(); | |||
| Tensor *x = inputs_.front(); | |||
| Tensor *segment_id = inputs_[1]; | |||
| std::vector<int> x_shape = x->shape(); | |||
| std::vector<int> segment_id_shape = segment_id->shape(); | |||
| int num_segments = GetNumSegments(); | |||
| std::vector<int> output_shape; | |||
| output_shape.push_back(num_segments); | |||
| for (int index = segment_id_shape.size(); index < static_cast<int>(x_shape.size()); index++) { | |||
| output_shape.push_back(x_shape[index]); | |||
| } | |||
| out->set_shape(output_shape); | |||
| out->SetFormat(x->GetFormat()); | |||
| out->set_data_type(x->data_type()); | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,44 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <vector> | |||
| #include <set> | |||
| #include <cmath> | |||
| #include "src/ops/primitive_c.h" | |||
| #ifndef LITE_SRC_OPS_UNSORTED_SEGMENT_SUM_H_ | |||
| #define LITE_SRC_OPS_UNSORTED_SEGMENT_SUM_H_ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class UnsortedSegmentSum : public PrimitiveC { | |||
| public: | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| MS_DECLARE_PARENT(UnsortedSegmentSum, PrimitiveC); | |||
| UnsortedSegmentSum() = default; | |||
| explicit UnsortedSegmentSum(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||
| int GetNumSegments() const; | |||
| #else | |||
| UnsortedSegmentSum() = default; | |||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||
| int GetNumSegments() const; | |||
| #endif | |||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // LITE_SRC_OPS_UNSORTED_SEGMENT_SUM_H_ | |||
| @@ -36,6 +36,9 @@ | |||
| #include "src/ops/bn_grad.h" | |||
| #include "nnacl/fp32_grad/batch_norm.h" | |||
| #include "src/ops/adam.h" | |||
| #include "src/ops/oneslike.h" | |||
| #include "src/ops/binary_cross_entropy.h" | |||
| #include "src/ops/binary_cross_entropy_grad.h" | |||
| namespace mindspore::kernel { | |||
| @@ -76,6 +79,30 @@ OpParameter *PopulateApplyMomentumParameter(const mindspore::lite::PrimitiveC *p | |||
| return reinterpret_cast<OpParameter *>(p); | |||
| } | |||
| OpParameter *PopulateBCEParameter(const mindspore::lite::PrimitiveC *primitive) { | |||
| int32_t *reduction = reinterpret_cast<int32_t *>(malloc(sizeof(int32_t))); | |||
| if (reduction == nullptr) { | |||
| MS_LOG(ERROR) << "malloc reduction failed."; | |||
| return nullptr; | |||
| } | |||
| auto param = | |||
| reinterpret_cast<mindspore::lite::BinaryCrossEntropy *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||
| *reduction = param->GetReduction(); | |||
| return reinterpret_cast<OpParameter *>(reduction); | |||
| } | |||
| OpParameter *PopulateBCEGradParameter(const mindspore::lite::PrimitiveC *primitive) { | |||
| int32_t *reduction = reinterpret_cast<int32_t *>(malloc(sizeof(int32_t))); | |||
| if (reduction == nullptr) { | |||
| MS_LOG(ERROR) << "malloc reduction failed."; | |||
| return nullptr; | |||
| } | |||
| auto param = | |||
| reinterpret_cast<mindspore::lite::BinaryCrossEntropyGrad *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||
| *reduction = param->GetReduction(); | |||
| return reinterpret_cast<OpParameter *>(reduction); | |||
| } | |||
| OpParameter *PopulateAdamParameter(const mindspore::lite::PrimitiveC *primitive) { | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "Primitive is nullptr when populating parameter for op."; | |||
| @@ -396,6 +423,13 @@ void PopulateTrainParameters() { | |||
| lite::Registry BNGradParameterRegistry(schema::PrimitiveType_BNGrad, PopulateBNGradParameter); | |||
| lite::Registry AdamParameterRegistry(schema::PrimitiveType_Adam, PopulateAdamParameter); | |||
| lite::Registry AssignParameterRegistry(schema::PrimitiveType_Assign, DefaultPopulateParameter); | |||
| lite::Registry AssignAddParameterRegistry(schema::PrimitiveType_AssignAdd, DefaultPopulateParameter); | |||
| lite::Registry BinaryCrossEntropyParameterRegistry(schema::PrimitiveType_BinaryCrossEntropy, PopulateBCEParameter); | |||
| lite::Registry BinaryCrossEntropyGradParameterRegistry(schema::PrimitiveType_BinaryCrossEntropyGrad, | |||
| PopulateBCEGradParameter); | |||
| lite::Registry OnesLikeParameterRegistry(schema::PrimitiveType_OnesLike, DefaultPopulateParameter); | |||
| lite::Registry UnsortedSegmentSumParameterRegistry(schema::PrimitiveType_UnsortedSegmentSum, | |||
| DefaultPopulateParameter); | |||
| } | |||
| } // namespace mindspore::kernel | |||
| @@ -69,7 +69,8 @@ void AnfExporter::RemoveIfDepend(const CNodePtr &cnode) { | |||
| continue; | |||
| } | |||
| auto dependNode = utils::cast<CNodePtr>(inputNode); | |||
| if (IsPrimitiveCNode(dependNode, schema::PrimitiveType_Depend)) { | |||
| if (IsPrimitiveCNode(dependNode, schema::PrimitiveType_Depend) || | |||
| IsPrimitiveCNode(dependNode, schema::PrimitiveType_ControlDepend)) { | |||
| hasDepend = true; | |||
| for (size_t j = 1; j < dependNode->inputs().size(); ++j) { | |||
| AnfNodePtr dependInputNode = dependNode->input(j); | |||
| @@ -209,6 +210,7 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool kee | |||
| if ((primitive_c->Type() == schema::PrimitiveType_TupleGetItem) || | |||
| #ifdef SUPPORT_TRAIN | |||
| (primitive_c->Type() == schema::PrimitiveType_Depend) || | |||
| (primitive_c->Type() == schema::PrimitiveType_ControlDepend) || | |||
| #endif | |||
| (primitive_c->Type() == schema::PrimitiveType_MakeTuple)) { | |||
| continue; | |||
| @@ -402,7 +404,9 @@ int AnfExporter::ConvertInputValueNode(std::shared_ptr<AnfNode> input_anode, | |||
| paramTensor->dims = {1}; | |||
| paramTensor->nodeType = schema::NodeType::NodeType_ValueNode; | |||
| auto data = value->cast<mindspore::Int32ImmPtr>(); | |||
| paramTensor->data.emplace_back(data->value()); | |||
| int real_data = GetValue<int32_t>(data); | |||
| paramTensor->data.resize(sizeof(int32_t)); | |||
| memcpy(paramTensor->data.data(), &real_data, sizeof(int32_t)); | |||
| node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size(); | |||
| output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); | |||
| meta_graphT->allTensors.emplace_back(std::move(paramTensor)); | |||
| @@ -418,6 +422,14 @@ int AnfExporter::ConvertInputValueNode(std::shared_ptr<AnfNode> input_anode, | |||
| node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size(); | |||
| output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); | |||
| meta_graphT->allTensors.emplace_back(std::move(paramTensor)); | |||
| } else if (value->isa<mindspore::Int>()) { | |||
| paramTensor->dataType = kNumberTypeInt32; | |||
| paramTensor->dims = {1}; | |||
| paramTensor->nodeType = schema::NodeType_ValueNode; | |||
| paramTensor->data.emplace_back(kNumberTypeInt32); | |||
| node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size(); | |||
| output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); | |||
| meta_graphT->allTensors.emplace_back(std::move(paramTensor)); | |||
| } else if (value->isa<mindspore::ValueSequeue>()) { | |||
| #ifndef SUPPORT_TRAIN | |||
| MS_LOG(DEBUG) << "Value type is ValueSequence."; | |||
| @@ -456,6 +468,18 @@ int AnfExporter::ConvertInputValueNode(std::shared_ptr<AnfNode> input_anode, | |||
| MS_LOG(ERROR) << "Value type is ValueSequence not supported - " << valueAbstract->type_name() << "."; | |||
| } | |||
| #endif | |||
| } else if (value->isa<mindspore::BoolImm>()) { | |||
| auto valueAbstract = valueNode->abstract(); | |||
| auto abstractScalar = utils::cast<abstract::AbstractScalarPtr>(valueAbstract); | |||
| auto typePtr = abstractScalar->GetTypeTrack(); | |||
| paramTensor->dataType = typePtr->type_id(); | |||
| paramTensor->dims = {1}; | |||
| paramTensor->nodeType = schema::NodeType_ValueNode; | |||
| auto data = value->cast<mindspore::BoolImmPtr>(); | |||
| paramTensor->data.emplace_back(data->value()); | |||
| node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size(); | |||
| output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); | |||
| meta_graphT->allTensors.emplace_back(std::move(paramTensor)); | |||
| } else if (value->isa<Number>()) { | |||
| MS_LOG(INFO) << "Value is a number."; | |||
| return RET_OK; | |||