Merge pull request !7674 from yangjie159/sync_codetags/v1.1.0
| @@ -227,6 +227,12 @@ union PrimitiveType { | |||||
| Identity, | Identity, | ||||
| LayerNorm, | LayerNorm, | ||||
| While, | While, | ||||
| ControlDepend, | |||||
| UnsortedSegmentSum, | |||||
| AssignAdd, | |||||
| OnesLike, | |||||
| BinaryCrossEntropyGrad, | |||||
| BinaryCrossEntropy | |||||
| } | } | ||||
| enum QuantType: int { | enum QuantType: int { | ||||
| @@ -966,6 +966,8 @@ table Adam { | |||||
| table Assign { | table Assign { | ||||
| } | } | ||||
| table AssignAdd { | |||||
| } | |||||
| table Where{ | table Where{ | ||||
| condition: [bool]; | condition: [bool]; | ||||
| @@ -1010,6 +1012,9 @@ table ToFormat { | |||||
| table Depend { | table Depend { | ||||
| } | } | ||||
| table ControlDepend { | |||||
| } | |||||
| table Return { | table Return { | ||||
| } | } | ||||
| @@ -1108,3 +1113,18 @@ table While { | |||||
| bodySubgraphIndex : int; | bodySubgraphIndex : int; | ||||
| } | } | ||||
| table UnsortedSegmentSum { | |||||
| numSegments : int; | |||||
| } | |||||
| table OnesLike { | |||||
| } | |||||
| table BinaryCrossEntropy { | |||||
| reduction : int = 1; | |||||
| } | |||||
| table BinaryCrossEntropyGrad { | |||||
| reduction : int = 1; | |||||
| } | |||||
| @@ -73,7 +73,7 @@ Registry AdamRegistry(schema::PrimitiveType_Adam, AdamCreator); | |||||
| int Adam::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) { | int Adam::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) { | ||||
| if (10 != inputs.size()) { | if (10 != inputs.size()) { | ||||
| MS_LOG(ERROR) << "Adam should have at least 8 input tensors"; | |||||
| MS_LOG(ERROR) << "Adam should have at 10 input tensors"; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| @@ -0,0 +1,81 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/ops/assign_add.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| int AssignAdd::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||||
| if (this->primitive_ == nullptr) { | |||||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||||
| if (this->primitive_ == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitive error"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| this->primitive_->value.type = schema::PrimitiveType_AssignAdd; | |||||
| } | |||||
| if (this->primitive_->value.type != schema::PrimitiveType_AssignAdd) { | |||||
| MS_LOG(ERROR) << "PrimitiveType_AssignAdd primitive value type : " | |||||
| << schema::EnumNamePrimitiveType(primitive_->value.type) << "is not equal" | |||||
| << schema::EnumNamePrimitiveType(schema::PrimitiveType_AssignAdd); | |||||
| delete this->primitive_; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| this->primitive_->value.value = new (std::nothrow) schema::AssignAddT(); | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||||
| delete this->primitive_; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| #else | |||||
| int AssignAdd::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||||
| MS_ASSERT(nullptr != primitive); | |||||
| MS_ASSERT(nullptr != fbb); | |||||
| auto attr = primitive->value_as_AssignAdd(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "value_as_AssignAdd return nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto val_offset = schema::CreateAssignAdd(*fbb); | |||||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_AssignAdd, val_offset.o); | |||||
| fbb->Finish(prim_offset); | |||||
| return RET_OK; | |||||
| } | |||||
| #endif | |||||
| int AssignAdd::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | |||||
| Tensor *x = inputs_[0]; | |||||
| Tensor *y = inputs_[1]; | |||||
| Tensor *out = outputs_[0]; | |||||
| std::vector<int> x_shape = x->shape(); | |||||
| if (x->data_type() != y->data_type()) { | |||||
| MS_LOG(ERROR) << "no matched shape of x and y"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| std::vector<int> output_shape(x_shape.size()); | |||||
| for (int i = 0; i < static_cast<int>(x_shape.size()); i++) { | |||||
| output_shape[i] = x_shape[i]; | |||||
| } | |||||
| out->set_shape(output_shape); | |||||
| out->SetFormat(x->GetFormat()); | |||||
| out->set_data_type(x->data_type()); | |||||
| return RET_OK; | |||||
| } | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,41 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <vector> | |||||
| #include <set> | |||||
| #include <cmath> | |||||
| #include "src/ops/primitive_c.h" | |||||
| #ifndef LITE_SRC_OPS_ASSIGN_ADD_H_ | |||||
| #define LITE_SRC_OPS_ASSIGN_ADD_H_ | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class AssignAdd : public PrimitiveC { | |||||
| public: | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| MS_DECLARE_PARENT(AssignAdd, PrimitiveC); | |||||
| AssignAdd() = default; | |||||
| explicit AssignAdd(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||||
| #else | |||||
| AssignAdd() = default; | |||||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||||
| #endif | |||||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // LITE_SRC_OPS_ASSIGN_ADD_H_ | |||||
| @@ -47,7 +47,12 @@ int BiasAdd::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &in | |||||
| MS_LOG(ERROR) << "new primitiveT value failed"; | MS_LOG(ERROR) << "new primitiveT value failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| attr->axis = {0}; | |||||
| if (prim.GetAttr("axis") == nullptr) { | |||||
| MS_LOG(WARNING) << "get axis failed"; | |||||
| attr->axis = {1}; | |||||
| } else { | |||||
| attr->axis = GetValue<std::vector<int>>(prim.GetAttr("axis")); | |||||
| } | |||||
| this->primitive_->value.value = attr; | this->primitive_->value.value = attr; | ||||
| if (this->primitive_->value.value == nullptr) { | if (this->primitive_->value.value == nullptr) { | ||||
| MS_LOG(ERROR) << "primitive value is nullptr"; | MS_LOG(ERROR) << "primitive value is nullptr"; | ||||
| @@ -0,0 +1,106 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <string> | |||||
| #include "src/ops/binary_cross_entropy.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| int BinaryCrossEntropy::GetReduction() const { return this->primitive_->value.AsBinaryCrossEntropy()->reduction; } | |||||
| int BinaryCrossEntropy::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||||
| if (this->primitive_ == nullptr) { | |||||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||||
| if (this->primitive_ == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitive error"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| this->primitive_->value.type = schema::PrimitiveType_BinaryCrossEntropy; | |||||
| } | |||||
| if (this->primitive_->value.type != schema::PrimitiveType_BinaryCrossEntropy) { | |||||
| MS_LOG(ERROR) << "PrimitiveType_BinaryCrossEntropy primitive value type : " | |||||
| << schema::EnumNamePrimitiveType(primitive_->value.type) << "is not equal" | |||||
| << schema::EnumNamePrimitiveType(schema::PrimitiveType_BinaryCrossEntropy); | |||||
| delete this->primitive_; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| schema::BinaryCrossEntropyT *attr = new (std::nothrow) schema::BinaryCrossEntropyT(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new binary cross entropy attr failed!"; | |||||
| delete this->primitive_; | |||||
| return RET_ERROR; | |||||
| } | |||||
| // default is mean | |||||
| string reduction = "mean"; | |||||
| if (prim.GetAttr("reduction") == nullptr) { | |||||
| MS_LOG(ERROR) << "get reduction failed!"; | |||||
| delete this->primitive_; | |||||
| delete attr; | |||||
| return RET_ERROR; | |||||
| } else { | |||||
| reduction = GetValue<string>(prim.GetAttr("reduction")); | |||||
| } | |||||
| if (reduction == "none") { | |||||
| attr->reduction = 0; | |||||
| } else if (reduction == "sum") { | |||||
| attr->reduction = 2; | |||||
| } else { | |||||
| // default is mean | |||||
| attr->reduction = 1; | |||||
| } | |||||
| this->primitive_->value.value = attr; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| #else | |||||
| int BinaryCrossEntropy::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||||
| MS_ASSERT(nullptr != primitive); | |||||
| MS_ASSERT(nullptr != fbb); | |||||
| auto attr = primitive->value_as_BinaryCrossEntropy(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "value_as_BinaryCrossEntropy return nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| int reduction = attr->reduction(); | |||||
| auto val_offset = schema::CreateBinaryCrossEntropy(*fbb, reduction); | |||||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_BinaryCrossEntropy, val_offset.o); | |||||
| fbb->Finish(prim_offset); | |||||
| return RET_OK; | |||||
| } | |||||
| int BinaryCrossEntropy::GetReduction() const { return this->primitive_->value_as_BinaryCrossEntropy()->reduction(); } | |||||
| #endif | |||||
| int BinaryCrossEntropy::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | |||||
| Tensor *x = inputs_[0]; | |||||
| Tensor *out = outputs_[0]; | |||||
| out->SetFormat(x->GetFormat()); | |||||
| out->set_data_type(x->data_type()); | |||||
| int reduction = GetReduction(); | |||||
| if (reduction == 1 || reduction == 2) { | |||||
| out->set_shape({1}); | |||||
| } else { | |||||
| std::vector<int> x_shape = x->shape(); | |||||
| std::vector<int> output_shape(x_shape.size()); | |||||
| output_shape.assign(x_shape.begin(), x_shape.end()); | |||||
| out->set_shape(output_shape); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,49 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <vector> | |||||
| #include <set> | |||||
| #include <cmath> | |||||
| #include "src/ops/primitive_c.h" | |||||
| #ifndef LITE_SRC_OPS_BINARYCROSSENTROPY_H_ | |||||
| #define LITE_SRC_OPS_BINARYCROSSENTROPY_H_ | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class BinaryCrossEntropy : public PrimitiveC { | |||||
| public: | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| MS_DECLARE_PARENT(BinaryCrossEntropy, PrimitiveC); | |||||
| BinaryCrossEntropy() = default; | |||||
| explicit BinaryCrossEntropy(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||||
| int GetReduction() const; | |||||
| #else | |||||
| BinaryCrossEntropy() = default; | |||||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||||
| int GetReduction() const; | |||||
| #endif | |||||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // LITE_SRC_OPS_BINARYCROSSENTROPY_H_ | |||||
| @@ -0,0 +1,108 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <string> | |||||
| #include "src/ops/binary_cross_entropy_grad.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| int BinaryCrossEntropyGrad::GetReduction() const { | |||||
| return this->primitive_->value.AsBinaryCrossEntropyGrad()->reduction; | |||||
| } | |||||
| int BinaryCrossEntropyGrad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||||
| if (this->primitive_ == nullptr) { | |||||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||||
| if (this->primitive_ == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitive error"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| this->primitive_->value.type = schema::PrimitiveType_BinaryCrossEntropyGrad; | |||||
| } | |||||
| if (this->primitive_->value.type != schema::PrimitiveType_BinaryCrossEntropyGrad) { | |||||
| MS_LOG(ERROR) << "PrimitiveType_BinaryCrossEntropyGrad primitive value type : " | |||||
| << schema::EnumNamePrimitiveType(primitive_->value.type) << "is not equal" | |||||
| << schema::EnumNamePrimitiveType(schema::PrimitiveType_BinaryCrossEntropyGrad); | |||||
| delete this->primitive_; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| schema::BinaryCrossEntropyGradT *attr = new (std::nothrow) schema::BinaryCrossEntropyGradT(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new binary cross entropy attr failed!"; | |||||
| delete this->primitive_; | |||||
| return RET_ERROR; | |||||
| } | |||||
| // default is mean | |||||
| string reduction = "mean"; | |||||
| if (prim.GetAttr("reduction") == nullptr) { | |||||
| MS_LOG(ERROR) << "get reduction failed!"; | |||||
| delete this->primitive_; | |||||
| delete attr; | |||||
| return RET_ERROR; | |||||
| } else { | |||||
| reduction = GetValue<string>(prim.GetAttr("reduction")); | |||||
| } | |||||
| if (reduction == "none") { | |||||
| attr->reduction = 0; | |||||
| } else if (reduction == "sum") { | |||||
| attr->reduction = 2; | |||||
| } else { | |||||
| // default is mean | |||||
| attr->reduction = 1; | |||||
| } | |||||
| this->primitive_->value.value = attr; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| #else | |||||
| int BinaryCrossEntropyGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, | |||||
| flatbuffers::FlatBufferBuilder *fbb) { | |||||
| MS_ASSERT(nullptr != primitive); | |||||
| MS_ASSERT(nullptr != fbb); | |||||
| auto attr = primitive->value_as_BinaryCrossEntropyGrad(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "value_as_BinaryCrossEntropyGrad return nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| int reduction = attr->reduction(); | |||||
| auto val_offset = schema::CreateBinaryCrossEntropyGrad(*fbb, reduction); | |||||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_BinaryCrossEntropyGrad, val_offset.o); | |||||
| fbb->Finish(prim_offset); | |||||
| return RET_OK; | |||||
| } | |||||
| int BinaryCrossEntropyGrad::GetReduction() const { | |||||
| return this->primitive_->value_as_BinaryCrossEntropyGrad()->reduction(); | |||||
| } | |||||
| #endif | |||||
| int BinaryCrossEntropyGrad::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | |||||
| Tensor *x = inputs_[0]; | |||||
| Tensor *out = outputs_[0]; | |||||
| out->SetFormat(x->GetFormat()); | |||||
| out->set_data_type(x->data_type()); | |||||
| std::vector<int> x_shape = x->shape(); | |||||
| std::vector<int> output_shape(x_shape.size()); | |||||
| output_shape.assign(x_shape.begin(), x_shape.end()); | |||||
| out->set_shape(output_shape); | |||||
| return RET_OK; | |||||
| } | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,49 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <vector> | |||||
| #include <set> | |||||
| #include <cmath> | |||||
| #include "src/ops/primitive_c.h" | |||||
| #ifndef LITE_SRC_OPS_BINARY_CROSS_ENTROPY_GRAD_H_ | |||||
| #define LITE_SRC_OPS_BINARY_CROSS_ENTROPY_GRAD_H_ | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class BinaryCrossEntropyGrad : public PrimitiveC { | |||||
| public: | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| MS_DECLARE_PARENT(BinaryCrossEntropyGrad, PrimitiveC); | |||||
| BinaryCrossEntropyGrad() = default; | |||||
| explicit BinaryCrossEntropyGrad(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||||
| int GetReduction() const; | |||||
| #else | |||||
| BinaryCrossEntropyGrad() = default; | |||||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||||
| int GetReduction() const; | |||||
| #endif | |||||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // LITE_SRC_OPS_BINARY_CROSS_ENTROPY_GRAD_H_ | |||||
| @@ -0,0 +1,59 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/ops/control_depend.h" | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| int ControlDepend::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||||
| if (this->primitive_ == nullptr) { | |||||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||||
| if (this->primitive_ == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| this->primitive_->value.type = schema::PrimitiveType_ControlDepend; | |||||
| } | |||||
| if (this->primitive_->value.type != schema::PrimitiveType_ControlDepend) { | |||||
| MS_LOG(ERROR) << "primitive_ type is error:" << this->primitive_->value.type; | |||||
| delete this->primitive_; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| auto attr = new (std::nothrow)(schema::ControlDependT); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "attr is nullptr"; | |||||
| delete this->primitive_; | |||||
| return RET_ERROR; | |||||
| } | |||||
| this->primitive_->value.value = attr; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| #else | |||||
| int ControlDepend::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||||
| MS_ASSERT(nullptr != primitive); | |||||
| MS_ASSERT(nullptr != fbb); | |||||
| auto val_offset = schema::CreateControlDepend(*fbb); | |||||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_ControlDepend, val_offset.o); | |||||
| fbb->Finish(prim_offset); | |||||
| return RET_OK; | |||||
| } | |||||
| #endif | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,40 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef LITE_MINDSPORE_LITE_SRC_OPS_CONTROL_DEPEND_H_ | |||||
| #define LITE_MINDSPORE_LITE_SRC_OPS_CONTROL_DEPEND_H_ | |||||
| #include <vector> | |||||
| #include "src/ops/primitive_c.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class ControlDepend : public PrimitiveC { | |||||
| public: | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| MS_DECLARE_PARENT(ControlDepend, PrimitiveC); | |||||
| ControlDepend() = default; | |||||
| explicit ControlDepend(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||||
| #else | |||||
| ControlDepend() = default; | |||||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||||
| #endif | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // LITE_MINDSPORE_LITE_SRC_OPS_CONTROL_DEPEND_H_ | |||||
| @@ -27,6 +27,43 @@ int ExpandDims::GetDim() const { return this->primitive_->value.AsExpandDims()-> | |||||
| void ExpandDims::SetDim(int dim) { this->primitive_->value.AsExpandDims()->dim = dim; } | void ExpandDims::SetDim(int dim) { this->primitive_->value.AsExpandDims()->dim = dim; } | ||||
| int ExpandDims::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||||
| if (this->primitive_ == nullptr) { | |||||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||||
| if (this->primitive_ == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| this->primitive_->value.type = schema::PrimitiveType_ExpandDims; | |||||
| } | |||||
| if (this->primitive_->value.type != schema::PrimitiveType_ExpandDims) { | |||||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||||
| delete this->primitive_; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| auto attr = new (std::nothrow) schema::ExpandDimsT(); | |||||
| if (attr == nullptr) { | |||||
| delete this->primitive_; | |||||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| // use axis instead of dim | |||||
| if (inputs[1]->isa<ValueNode>()) { | |||||
| auto axis_tensor = inputs[1]->cast<ValueNodePtr>(); | |||||
| int axis = GetValue<int>(axis_tensor->value()); | |||||
| attr->dim = axis; | |||||
| } else { | |||||
| MS_LOG(ERROR) << "input axis is not value node."; | |||||
| delete this->primitive_; | |||||
| delete attr; | |||||
| return RET_ERROR; | |||||
| } | |||||
| this->primitive_->value.value = attr; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| #else | #else | ||||
| int ExpandDims::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | int ExpandDims::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | ||||
| MS_ASSERT(nullptr != primitive); | MS_ASSERT(nullptr != primitive); | ||||
| @@ -56,9 +93,6 @@ int ExpandDims::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> | |||||
| MS_ASSERT(input != nullptr); | MS_ASSERT(input != nullptr); | ||||
| auto output = outputs_.front(); | auto output = outputs_.front(); | ||||
| MS_ASSERT(output != nullptr); | MS_ASSERT(output != nullptr); | ||||
| if (inputs_.size() != kSingleNum) { | |||||
| MS_LOG(ERROR) << "input size is invalid"; | |||||
| } | |||||
| if (outputs_.size() != kSingleNum) { | if (outputs_.size() != kSingleNum) { | ||||
| MS_LOG(ERROR) << "output size is invalid"; | MS_LOG(ERROR) << "output size is invalid"; | ||||
| } | } | ||||
| @@ -31,6 +31,7 @@ class ExpandDims : public PrimitiveC { | |||||
| ExpandDims() = default; | ExpandDims() = default; | ||||
| explicit ExpandDims(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | explicit ExpandDims(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | ||||
| void SetDim(int dim); | void SetDim(int dim); | ||||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||||
| #else | #else | ||||
| ExpandDims() = default; | ExpandDims() = default; | ||||
| @@ -31,7 +31,35 @@ int Gather::GetBatchDims() const { return this->primitive_->value.AsGather()->ba | |||||
| void Gather::SetAxis(int axis) { this->primitive_->value.AsGather()->axis = axis; } | void Gather::SetAxis(int axis) { this->primitive_->value.AsGather()->axis = axis; } | ||||
| void Gather::SetBatchDims(int batch_dims) { this->primitive_->value.AsGather()->batchDims = batch_dims; } | void Gather::SetBatchDims(int batch_dims) { this->primitive_->value.AsGather()->batchDims = batch_dims; } | ||||
| int Gather::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||||
| if (this->primitive_ == nullptr) { | |||||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||||
| if (this->primitive_ == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitive error"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| this->primitive_->value.type = schema::PrimitiveType_Gather; | |||||
| } | |||||
| if (this->primitive_->value.type != schema::PrimitiveType_Gather) { | |||||
| MS_LOG(ERROR) << "Gather primitive value type : " << schema::EnumNamePrimitiveType(primitive_->value.type) | |||||
| << "is not equal" << schema::EnumNamePrimitiveType(schema::PrimitiveType_Gather); | |||||
| delete this->primitive_; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| auto gather_attr = new (std::nothrow) schema::GatherT(); | |||||
| if (gather_attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitive value.value error"; | |||||
| delete this->primitive_; | |||||
| delete gather_attr; | |||||
| return RET_ERROR; | |||||
| } | |||||
| gather_attr->axis = GetValue<int>(prim.GetAttr("axis")); | |||||
| gather_attr->batchDims = GetValue<int>(prim.GetAttr("batchDims")); | |||||
| this->primitive_->value.value = gather_attr; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| #else | #else | ||||
| int Gather::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | int Gather::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | ||||
| MS_ASSERT(nullptr != primitive); | MS_ASSERT(nullptr != primitive); | ||||
| @@ -33,6 +33,7 @@ class Gather : public PrimitiveC { | |||||
| explicit Gather(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | explicit Gather(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | ||||
| void SetAxis(int axis); | void SetAxis(int axis); | ||||
| void SetBatchDims(int batch_dims); | void SetBatchDims(int batch_dims); | ||||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||||
| #else | #else | ||||
| Gather() = default; | Gather() = default; | ||||
| @@ -0,0 +1,75 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/ops/oneslike.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| int OnesLike::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||||
| if (this->primitive_ == nullptr) { | |||||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||||
| if (this->primitive_ == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitive error"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| this->primitive_->value.type = schema::PrimitiveType_OnesLike; | |||||
| } | |||||
| if (this->primitive_->value.type != schema::PrimitiveType_OnesLike) { | |||||
| MS_LOG(ERROR) << "PrimitiveType_OnesLike primitive value type : " | |||||
| << schema::EnumNamePrimitiveType(primitive_->value.type) << "is not equal" | |||||
| << schema::EnumNamePrimitiveType(schema::PrimitiveType_OnesLike); | |||||
| delete this->primitive_; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| this->primitive_->value.value = new (std::nothrow) schema::OnesLikeT(); | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||||
| delete this->primitive_; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| #else | |||||
| int OnesLike::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||||
| MS_ASSERT(nullptr != primitive); | |||||
| MS_ASSERT(nullptr != fbb); | |||||
| auto attr = primitive->value_as_OnesLike(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "value_as_OnesLike return nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto val_offset = schema::CreateOnesLike(*fbb); | |||||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_OnesLike, val_offset.o); | |||||
| fbb->Finish(prim_offset); | |||||
| return RET_OK; | |||||
| } | |||||
| #endif | |||||
| int OnesLike::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | |||||
| Tensor *x = inputs_[0]; | |||||
| Tensor *out = outputs_[0]; | |||||
| std::vector<int> x_shape = x->shape(); | |||||
| std::vector<int> output_shape(x_shape.size()); | |||||
| output_shape.assign(x_shape.begin(), x_shape.end()); | |||||
| out->set_shape(output_shape); | |||||
| out->SetFormat(x->GetFormat()); | |||||
| out->set_data_type(x->data_type()); | |||||
| return RET_OK; | |||||
| } | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,42 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <vector> | |||||
| #include <set> | |||||
| #include <cmath> | |||||
| #include "src/ops/primitive_c.h" | |||||
| #ifndef LITE_SRC_OPS_ONESLIKE_H_ | |||||
| #define LITE_SRC_OPS_ONESLIKE_H_ | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class OnesLike : public PrimitiveC { | |||||
| public: | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| MS_DECLARE_PARENT(OnesLike, PrimitiveC); | |||||
| OnesLike() = default; | |||||
| explicit OnesLike(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||||
| #else | |||||
| OnesLike() = default; | |||||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||||
| #endif | |||||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // LITE_SRC_OPS_ONESLIKE_H_ | |||||
| @@ -31,6 +31,51 @@ void Power::SetPower(float power) { this->primitive_->value.AsPower()->power = p | |||||
| void Power::SetScale(float scale) { this->primitive_->value.AsPower()->scale = scale; } | void Power::SetScale(float scale) { this->primitive_->value.AsPower()->scale = scale; } | ||||
| void Power::SetShift(float shift) { this->primitive_->value.AsPower()->shift = shift; } | void Power::SetShift(float shift) { this->primitive_->value.AsPower()->shift = shift; } | ||||
| int Power::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||||
| if (this->primitive_ == nullptr) { | |||||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||||
| if (this->primitive_ == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| this->primitive_->value.type = schema::PrimitiveType_Power; | |||||
| } | |||||
| if (this->primitive_->value.type != schema::PrimitiveType_Power) { | |||||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||||
| delete this->primitive_; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| auto attr = new (std::nothrow) schema::PowerT(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||||
| delete this->primitive_; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (prim.GetAttr("scale") == nullptr) { | |||||
| MS_LOG(WARNING) << "get scale failed"; | |||||
| attr->scale = 1.0f; | |||||
| } else { | |||||
| attr->scale = GetValue<float>(prim.GetAttr("scale")); | |||||
| } | |||||
| if (prim.GetAttr("power") == nullptr) { | |||||
| MS_LOG(WARNING) << "get power failed"; | |||||
| attr->power = 1.0f; | |||||
| } else { | |||||
| attr->power = GetValue<float>(prim.GetAttr("power")); | |||||
| } | |||||
| if (prim.GetAttr("shift") == nullptr) { | |||||
| MS_LOG(WARNING) << "get shift failed"; | |||||
| attr->shift = 0; | |||||
| } else { | |||||
| attr->shift = GetValue<float>(prim.GetAttr("shift")); | |||||
| } | |||||
| this->primitive_->value.value = attr; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| #else | #else | ||||
| float Power::GetPower() const { return this->primitive_->value_as_Power()->power(); } | float Power::GetPower() const { return this->primitive_->value_as_Power()->power(); } | ||||
| @@ -34,6 +34,7 @@ class Power : public PrimitiveC { | |||||
| void SetPower(float power); | void SetPower(float power); | ||||
| void SetScale(float scale); | void SetScale(float scale); | ||||
| void SetShift(float shift); | void SetShift(float shift); | ||||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||||
| #else | #else | ||||
| Power() = default; | Power() = default; | ||||
| @@ -145,6 +145,8 @@ | |||||
| #include "src/ops/identity.h" | #include "src/ops/identity.h" | ||||
| #include "src/ops/instance_norm.h" | #include "src/ops/instance_norm.h" | ||||
| #include "src/ops/while.h" | #include "src/ops/while.h" | ||||
| #include "src/ops/oneslike.h" | |||||
| #include "src/ops/unsorted_segment_sum.h" | |||||
| #ifdef SUPPORT_TRAIN | #ifdef SUPPORT_TRAIN | ||||
| #include "src/ops/neg_grad.h" | #include "src/ops/neg_grad.h" | ||||
| @@ -165,6 +167,10 @@ | |||||
| #include "src/ops/sgd.h" | #include "src/ops/sgd.h" | ||||
| #include "src/ops/adam.h" | #include "src/ops/adam.h" | ||||
| #include "src/ops/assign.h" | #include "src/ops/assign.h" | ||||
| #include "src/ops/control_depend.h" | |||||
| #include "src/ops/assign_add.h" | |||||
| #include "src/ops/binary_cross_entropy.h" | |||||
| #include "src/ops/binary_cross_entropy_grad.h" | |||||
| #endif | #endif | ||||
| #endif | #endif | ||||
| @@ -504,6 +510,18 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std: | |||||
| return NewPrimitiveC<While>(prim, inputs, quantType); | return NewPrimitiveC<While>(prim, inputs, quantType); | ||||
| } else if (op_type == "OneHot") { | } else if (op_type == "OneHot") { | ||||
| return NewPrimitiveC<OneHot>(prim, inputs, quantType); | return NewPrimitiveC<OneHot>(prim, inputs, quantType); | ||||
| } else if (op_type == "GatherV2") { | |||||
| return NewPrimitiveC<Gather>(prim, inputs, quantType); | |||||
| } else if (op_type == "OnesLike") { | |||||
| return NewPrimitiveC<OnesLike>(prim, inputs, quantType); | |||||
| } else if (op_type == "Pow") { | |||||
| return NewPrimitiveC<Power>(prim, inputs, quantType); | |||||
| } else if (op_type == "Sub") { | |||||
| return NewPrimitiveC<Sub>(prim, inputs, quantType); | |||||
| } else if (op_type == "ExpandDims") { | |||||
| return NewPrimitiveC<ExpandDims>(prim, inputs, quantType); | |||||
| } else if (op_type == "UnsortedSegmentSum") { | |||||
| return NewPrimitiveC<UnsortedSegmentSum>(prim, inputs, quantType); | |||||
| #ifdef SUPPORT_TRAIN | #ifdef SUPPORT_TRAIN | ||||
| } else if (op_type == "SoftmaxCrossEntropyWithLogits") { | } else if (op_type == "SoftmaxCrossEntropyWithLogits") { | ||||
| @@ -514,6 +532,8 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std: | |||||
| return NewPrimitiveC<ApplyMomentum>(prim, inputs, quantType); | return NewPrimitiveC<ApplyMomentum>(prim, inputs, quantType); | ||||
| } else if (op_type == "Depend") { | } else if (op_type == "Depend") { | ||||
| return NewPrimitiveC<Depend>(prim, inputs, quantType); | return NewPrimitiveC<Depend>(prim, inputs, quantType); | ||||
| } else if (op_type == "ControlDepend") { | |||||
| return NewPrimitiveC<ControlDepend>(prim, inputs, quantType); | |||||
| } else if ((op_type == "ReluGrad" || op_type == "ReLU6Grad" || op_type == "SigmoidGrad" || | } else if ((op_type == "ReluGrad" || op_type == "ReLU6Grad" || op_type == "SigmoidGrad" || | ||||
| op_type == "HSigmoidGrad" || op_type == "HSwishGrad")) { | op_type == "HSigmoidGrad" || op_type == "HSwishGrad")) { | ||||
| return NewPrimitiveC<ActivationGrad>(prim, inputs, quantType); | return NewPrimitiveC<ActivationGrad>(prim, inputs, quantType); | ||||
| @@ -539,6 +559,12 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std: | |||||
| return NewPrimitiveC<Adam>(prim, inputs, quantType); | return NewPrimitiveC<Adam>(prim, inputs, quantType); | ||||
| } else if (op_type == "Assign") { | } else if (op_type == "Assign") { | ||||
| return NewPrimitiveC<Assign>(prim, inputs, quantType); | return NewPrimitiveC<Assign>(prim, inputs, quantType); | ||||
| } else if (op_type == "AssignAdd") { | |||||
| return NewPrimitiveC<AssignAdd>(prim, inputs, quantType); | |||||
| } else if (op_type == "BinaryCrossEntropy") { | |||||
| return NewPrimitiveC<BinaryCrossEntropy>(prim, inputs, quantType); | |||||
| } else if (op_type == "BinaryCrossEntropyGrad") { | |||||
| return NewPrimitiveC<BinaryCrossEntropyGrad>(prim, inputs, quantType); | |||||
| #else | #else | ||||
| } else if (op_type == "Conv2DBackpropInput") { | } else if (op_type == "Conv2DBackpropInput") { | ||||
| return NewPrimitiveC<DeConv2D>(prim, inputs, quantType); | return NewPrimitiveC<DeConv2D>(prim, inputs, quantType); | ||||
| @@ -830,6 +856,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { | |||||
| return new PowerGrad(primitive); | return new PowerGrad(primitive); | ||||
| case schema::PrimitiveType_Depend: | case schema::PrimitiveType_Depend: | ||||
| return new Depend(primitive); | return new Depend(primitive); | ||||
| case schema::PrimitiveType_ControlDepend: | |||||
| return new ControlDepend(primitive); | |||||
| case schema::PrimitiveType_FlattenGrad: | case schema::PrimitiveType_FlattenGrad: | ||||
| return new FlattenGrad(primitive); | return new FlattenGrad(primitive); | ||||
| case schema::PrimitiveType_NegGrad: | case schema::PrimitiveType_NegGrad: | ||||
| @@ -842,6 +870,17 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { | |||||
| return new Adam(primitive); | return new Adam(primitive); | ||||
| case schema::PrimitiveType_Assign: | case schema::PrimitiveType_Assign: | ||||
| return new Assign(primitive); | return new Assign(primitive); | ||||
| case schema::PrimitiveType_AssignAdd: | |||||
| return new AssignAdd(primitive); | |||||
| case schema::PrimitiveType_OnesLike: | |||||
| return new OnesLike(primitive); | |||||
| case schema::PrimitiveType_UnsortedSegmentSum: | |||||
| return new UnsortedSegmentSum(primitive); | |||||
| case schema::PrimitiveType_BinaryCrossEntropyGrad: | |||||
| return new BinaryCrossEntropyGrad(primitive); | |||||
| case schema::PrimitiveType_BinaryCrossEntropy: | |||||
| return new BinaryCrossEntropy(primitive); | |||||
| #endif | #endif | ||||
| default: | default: | ||||
| MS_LOG(ERROR) << "Unsupported primitive type in Create : " << schema::EnumNamePrimitiveType(op_type); | MS_LOG(ERROR) << "Unsupported primitive type in Create : " << schema::EnumNamePrimitiveType(op_type); | ||||
| @@ -86,6 +86,9 @@ int Reduce::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp | |||||
| MS_ASSERT(elem != nullptr); | MS_ASSERT(elem != nullptr); | ||||
| attr->axes.emplace_back(elem->value()); | attr->axes.emplace_back(elem->value()); | ||||
| } | } | ||||
| } else { | |||||
| int axes_item = GetValue<int>(value); | |||||
| attr->axes.push_back(axes_item); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -62,6 +62,9 @@ int Reshape::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &in | |||||
| MS_ASSERT(elem != nullptr); | MS_ASSERT(elem != nullptr); | ||||
| attr->shape.emplace_back(static_cast<int>(elem->value())); | attr->shape.emplace_back(static_cast<int>(elem->value())); | ||||
| } | } | ||||
| } else { | |||||
| int dim = GetValue<int>(val); | |||||
| attr->shape = {dim}; | |||||
| } | } | ||||
| } | } | ||||
| if (attr == nullptr) { | if (attr == nullptr) { | ||||
| @@ -46,12 +46,14 @@ int Squeeze::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &in | |||||
| MS_LOG(ERROR) << "new primitiveT value failed"; | MS_LOG(ERROR) << "new primitiveT value failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| attr->axis = GetValue<std::vector<int>>(prim.GetAttr("axis")); | |||||
| this->primitive_->value.value = attr; | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| MS_LOG(ERROR) << "primitive value is nullptr"; | |||||
| return RET_ERROR; | |||||
| if (prim.GetAttr("axis") == nullptr) { | |||||
| MS_LOG(WARNING) << "get axis failed"; | |||||
| attr->axis = {0}; | |||||
| } else { | |||||
| int axis = GetValue<int>(prim.GetAttr("axis")); | |||||
| attr->axis = {axis}; | |||||
| } | } | ||||
| this->primitive_->value.value = attr; | |||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -29,6 +29,34 @@ void Sub::SetActivationType(int activation_type) { | |||||
| this->primitive_->value.AsSub()->activationType = (schema::ActivationType)activation_type; | this->primitive_->value.AsSub()->activationType = (schema::ActivationType)activation_type; | ||||
| } | } | ||||
| int Sub::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||||
| if (this->primitive_ == nullptr) { | |||||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||||
| if (this->primitive_ == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| this->primitive_->value.type = schema::PrimitiveType_Sub; | |||||
| } | |||||
| if (this->primitive_->value.type != schema::PrimitiveType_Sub) { | |||||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||||
| delete this->primitive_; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| auto attr = new (std::nothrow) schema::SubT(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||||
| delete this->primitive_; | |||||
| return RET_ERROR; | |||||
| } | |||||
| // todo: confirm the activationType | |||||
| attr->activationType = schema::ActivationType_NO_ACTIVATION; | |||||
| this->primitive_->value.value = attr; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| #else | #else | ||||
| int Sub::GetActivationType() const { return this->primitive_->value_as_Sub()->activationType(); } | int Sub::GetActivationType() const { return this->primitive_->value_as_Sub()->activationType(); } | ||||
| @@ -32,6 +32,7 @@ class Sub : public Arithmetic { | |||||
| Sub() = default; | Sub() = default; | ||||
| explicit Sub(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} | explicit Sub(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} | ||||
| void SetActivationType(int activation_type); | void SetActivationType(int activation_type); | ||||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||||
| #else | #else | ||||
| Sub() = default; | Sub() = default; | ||||
| @@ -52,6 +52,12 @@ int Tile::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &input | |||||
| MS_LOG(ERROR) << "new primitiveT value failed"; | MS_LOG(ERROR) << "new primitiveT value failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| if (prim.GetAttr("dims") == nullptr) { | |||||
| MS_LOG(WARNING) << "get dims failed"; | |||||
| attr->dims = {1}; | |||||
| } else { | |||||
| attr->dims = GetValue<std::vector<int>>(prim.GetAttr("dims")); | |||||
| } | |||||
| if (inputs.size() == kAnfPopulaterTwo) { | if (inputs.size() == kAnfPopulaterTwo) { | ||||
| auto inputNode = inputs[kAnfPopulaterOne]; | auto inputNode = inputs[kAnfPopulaterOne]; | ||||
| MS_ASSERT(inputNode != nullptr); | MS_ASSERT(inputNode != nullptr); | ||||
| @@ -68,6 +74,9 @@ int Tile::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &input | |||||
| MS_ASSERT(elem != nullptr); | MS_ASSERT(elem != nullptr); | ||||
| attr->multiples.emplace_back(elem->value()); | attr->multiples.emplace_back(elem->value()); | ||||
| } | } | ||||
| } else { | |||||
| int multiple = GetValue<int>(value); | |||||
| attr->multiples = {multiple}; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,100 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <memory> | |||||
| #include "src/ops/unsorted_segment_sum.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| int UnsortedSegmentSum::GetNumSegments() const { return this->primitive_->value.AsUnsortedSegmentSum()->numSegments; } | |||||
| int UnsortedSegmentSum::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||||
| if (this->primitive_ == nullptr) { | |||||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||||
| if (this->primitive_ == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitive error"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| this->primitive_->value.type = schema::PrimitiveType_UnsortedSegmentSum; | |||||
| } | |||||
| if (this->primitive_->value.type != schema::PrimitiveType_UnsortedSegmentSum) { | |||||
| MS_LOG(ERROR) << "UnSortedSegmentSum primitive value type : " | |||||
| << schema::EnumNamePrimitiveType(primitive_->value.type) << "is not equal" | |||||
| << schema::EnumNamePrimitiveType(schema::PrimitiveType_UnsortedSegmentSum); | |||||
| delete this->primitive_; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| std::unique_ptr<schema::UnsortedSegmentSumT> attr = std::make_unique<schema::UnsortedSegmentSumT>(); | |||||
| if (inputs[2]->isa<ValueNode>()) { | |||||
| ValuePtr value = inputs[2]->cast<ValueNodePtr>()->value(); | |||||
| attr->numSegments = GetValue<int>(value); | |||||
| this->primitive_->value.value = attr.release(); | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| #else | |||||
| int UnsortedSegmentSum::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||||
| MS_ASSERT(nullptr != primitive); | |||||
| MS_ASSERT(nullptr != fbb); | |||||
| auto attr = primitive->value_as_UnsortedSegmentSum(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "value_as_UnsortedSegmentSum return nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| int num_segments = attr->numSegments(); | |||||
| auto val_offset = schema::CreateUnsortedSegmentSum(*fbb, num_segments); | |||||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_UnsortedSegmentSum, val_offset.o); | |||||
| fbb->Finish(prim_offset); | |||||
| return RET_OK; | |||||
| } | |||||
| int UnsortedSegmentSum::GetNumSegments() const { | |||||
| int ret = this->primitive_->value_as_UnsortedSegmentSum()->numSegments(); | |||||
| return ret; | |||||
| } | |||||
| #endif | |||||
| int UnsortedSegmentSum::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | |||||
| // check inputs and outputs | |||||
| if (inputs_.size() != 3) { | |||||
| MS_LOG(ERROR) << "invalid inputs numbers"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (outputs_.size() != 1) { | |||||
| MS_LOG(ERROR) << "invalid outputs numbers"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| Tensor *out = outputs_.front(); | |||||
| Tensor *x = inputs_.front(); | |||||
| Tensor *segment_id = inputs_[1]; | |||||
| std::vector<int> x_shape = x->shape(); | |||||
| std::vector<int> segment_id_shape = segment_id->shape(); | |||||
| int num_segments = GetNumSegments(); | |||||
| std::vector<int> output_shape; | |||||
| output_shape.push_back(num_segments); | |||||
| for (int index = segment_id_shape.size(); index < static_cast<int>(x_shape.size()); index++) { | |||||
| output_shape.push_back(x_shape[index]); | |||||
| } | |||||
| out->set_shape(output_shape); | |||||
| out->SetFormat(x->GetFormat()); | |||||
| out->set_data_type(x->data_type()); | |||||
| return RET_OK; | |||||
| } | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,44 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <vector> | |||||
| #include <set> | |||||
| #include <cmath> | |||||
| #include "src/ops/primitive_c.h" | |||||
| #ifndef LITE_SRC_OPS_UNSORTED_SEGMENT_SUM_H_ | |||||
| #define LITE_SRC_OPS_UNSORTED_SEGMENT_SUM_H_ | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class UnsortedSegmentSum : public PrimitiveC { | |||||
| public: | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| MS_DECLARE_PARENT(UnsortedSegmentSum, PrimitiveC); | |||||
| UnsortedSegmentSum() = default; | |||||
| explicit UnsortedSegmentSum(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||||
| int GetNumSegments() const; | |||||
| #else | |||||
| UnsortedSegmentSum() = default; | |||||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||||
| int GetNumSegments() const; | |||||
| #endif | |||||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // LITE_SRC_OPS_UNSORTED_SEGMENT_SUM_H_ | |||||
| @@ -36,6 +36,9 @@ | |||||
| #include "src/ops/bn_grad.h" | #include "src/ops/bn_grad.h" | ||||
| #include "nnacl/fp32_grad/batch_norm.h" | #include "nnacl/fp32_grad/batch_norm.h" | ||||
| #include "src/ops/adam.h" | #include "src/ops/adam.h" | ||||
| #include "src/ops/oneslike.h" | |||||
| #include "src/ops/binary_cross_entropy.h" | |||||
| #include "src/ops/binary_cross_entropy_grad.h" | |||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| @@ -76,6 +79,30 @@ OpParameter *PopulateApplyMomentumParameter(const mindspore::lite::PrimitiveC *p | |||||
| return reinterpret_cast<OpParameter *>(p); | return reinterpret_cast<OpParameter *>(p); | ||||
| } | } | ||||
| OpParameter *PopulateBCEParameter(const mindspore::lite::PrimitiveC *primitive) { | |||||
| int32_t *reduction = reinterpret_cast<int32_t *>(malloc(sizeof(int32_t))); | |||||
| if (reduction == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc reduction failed."; | |||||
| return nullptr; | |||||
| } | |||||
| auto param = | |||||
| reinterpret_cast<mindspore::lite::BinaryCrossEntropy *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||||
| *reduction = param->GetReduction(); | |||||
| return reinterpret_cast<OpParameter *>(reduction); | |||||
| } | |||||
| OpParameter *PopulateBCEGradParameter(const mindspore::lite::PrimitiveC *primitive) { | |||||
| int32_t *reduction = reinterpret_cast<int32_t *>(malloc(sizeof(int32_t))); | |||||
| if (reduction == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc reduction failed."; | |||||
| return nullptr; | |||||
| } | |||||
| auto param = | |||||
| reinterpret_cast<mindspore::lite::BinaryCrossEntropyGrad *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||||
| *reduction = param->GetReduction(); | |||||
| return reinterpret_cast<OpParameter *>(reduction); | |||||
| } | |||||
| OpParameter *PopulateAdamParameter(const mindspore::lite::PrimitiveC *primitive) { | OpParameter *PopulateAdamParameter(const mindspore::lite::PrimitiveC *primitive) { | ||||
| if (primitive == nullptr) { | if (primitive == nullptr) { | ||||
| MS_LOG(ERROR) << "Primitive is nullptr when populating parameter for op."; | MS_LOG(ERROR) << "Primitive is nullptr when populating parameter for op."; | ||||
| @@ -396,6 +423,13 @@ void PopulateTrainParameters() { | |||||
| lite::Registry BNGradParameterRegistry(schema::PrimitiveType_BNGrad, PopulateBNGradParameter); | lite::Registry BNGradParameterRegistry(schema::PrimitiveType_BNGrad, PopulateBNGradParameter); | ||||
| lite::Registry AdamParameterRegistry(schema::PrimitiveType_Adam, PopulateAdamParameter); | lite::Registry AdamParameterRegistry(schema::PrimitiveType_Adam, PopulateAdamParameter); | ||||
| lite::Registry AssignParameterRegistry(schema::PrimitiveType_Assign, DefaultPopulateParameter); | lite::Registry AssignParameterRegistry(schema::PrimitiveType_Assign, DefaultPopulateParameter); | ||||
| lite::Registry AssignAddParameterRegistry(schema::PrimitiveType_AssignAdd, DefaultPopulateParameter); | |||||
| lite::Registry BinaryCrossEntropyParameterRegistry(schema::PrimitiveType_BinaryCrossEntropy, PopulateBCEParameter); | |||||
| lite::Registry BinaryCrossEntropyGradParameterRegistry(schema::PrimitiveType_BinaryCrossEntropyGrad, | |||||
| PopulateBCEGradParameter); | |||||
| lite::Registry OnesLikeParameterRegistry(schema::PrimitiveType_OnesLike, DefaultPopulateParameter); | |||||
| lite::Registry UnsortedSegmentSumParameterRegistry(schema::PrimitiveType_UnsortedSegmentSum, | |||||
| DefaultPopulateParameter); | |||||
| } | } | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -69,7 +69,8 @@ void AnfExporter::RemoveIfDepend(const CNodePtr &cnode) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| auto dependNode = utils::cast<CNodePtr>(inputNode); | auto dependNode = utils::cast<CNodePtr>(inputNode); | ||||
| if (IsPrimitiveCNode(dependNode, schema::PrimitiveType_Depend)) { | |||||
| if (IsPrimitiveCNode(dependNode, schema::PrimitiveType_Depend) || | |||||
| IsPrimitiveCNode(dependNode, schema::PrimitiveType_ControlDepend)) { | |||||
| hasDepend = true; | hasDepend = true; | ||||
| for (size_t j = 1; j < dependNode->inputs().size(); ++j) { | for (size_t j = 1; j < dependNode->inputs().size(); ++j) { | ||||
| AnfNodePtr dependInputNode = dependNode->input(j); | AnfNodePtr dependInputNode = dependNode->input(j); | ||||
| @@ -209,6 +210,7 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool kee | |||||
| if ((primitive_c->Type() == schema::PrimitiveType_TupleGetItem) || | if ((primitive_c->Type() == schema::PrimitiveType_TupleGetItem) || | ||||
| #ifdef SUPPORT_TRAIN | #ifdef SUPPORT_TRAIN | ||||
| (primitive_c->Type() == schema::PrimitiveType_Depend) || | (primitive_c->Type() == schema::PrimitiveType_Depend) || | ||||
| (primitive_c->Type() == schema::PrimitiveType_ControlDepend) || | |||||
| #endif | #endif | ||||
| (primitive_c->Type() == schema::PrimitiveType_MakeTuple)) { | (primitive_c->Type() == schema::PrimitiveType_MakeTuple)) { | ||||
| continue; | continue; | ||||
| @@ -402,7 +404,9 @@ int AnfExporter::ConvertInputValueNode(std::shared_ptr<AnfNode> input_anode, | |||||
| paramTensor->dims = {1}; | paramTensor->dims = {1}; | ||||
| paramTensor->nodeType = schema::NodeType::NodeType_ValueNode; | paramTensor->nodeType = schema::NodeType::NodeType_ValueNode; | ||||
| auto data = value->cast<mindspore::Int32ImmPtr>(); | auto data = value->cast<mindspore::Int32ImmPtr>(); | ||||
| paramTensor->data.emplace_back(data->value()); | |||||
| int real_data = GetValue<int32_t>(data); | |||||
| paramTensor->data.resize(sizeof(int32_t)); | |||||
| memcpy(paramTensor->data.data(), &real_data, sizeof(int32_t)); | |||||
| node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size(); | node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size(); | ||||
| output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); | output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); | ||||
| meta_graphT->allTensors.emplace_back(std::move(paramTensor)); | meta_graphT->allTensors.emplace_back(std::move(paramTensor)); | ||||
| @@ -418,6 +422,14 @@ int AnfExporter::ConvertInputValueNode(std::shared_ptr<AnfNode> input_anode, | |||||
| node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size(); | node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size(); | ||||
| output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); | output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); | ||||
| meta_graphT->allTensors.emplace_back(std::move(paramTensor)); | meta_graphT->allTensors.emplace_back(std::move(paramTensor)); | ||||
| } else if (value->isa<mindspore::Int>()) { | |||||
| paramTensor->dataType = kNumberTypeInt32; | |||||
| paramTensor->dims = {1}; | |||||
| paramTensor->nodeType = schema::NodeType_ValueNode; | |||||
| paramTensor->data.emplace_back(kNumberTypeInt32); | |||||
| node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size(); | |||||
| output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); | |||||
| meta_graphT->allTensors.emplace_back(std::move(paramTensor)); | |||||
| } else if (value->isa<mindspore::ValueSequeue>()) { | } else if (value->isa<mindspore::ValueSequeue>()) { | ||||
| #ifndef SUPPORT_TRAIN | #ifndef SUPPORT_TRAIN | ||||
| MS_LOG(DEBUG) << "Value type is ValueSequence."; | MS_LOG(DEBUG) << "Value type is ValueSequence."; | ||||
| @@ -456,6 +468,18 @@ int AnfExporter::ConvertInputValueNode(std::shared_ptr<AnfNode> input_anode, | |||||
| MS_LOG(ERROR) << "Value type is ValueSequence not supported - " << valueAbstract->type_name() << "."; | MS_LOG(ERROR) << "Value type is ValueSequence not supported - " << valueAbstract->type_name() << "."; | ||||
| } | } | ||||
| #endif | #endif | ||||
| } else if (value->isa<mindspore::BoolImm>()) { | |||||
| auto valueAbstract = valueNode->abstract(); | |||||
| auto abstractScalar = utils::cast<abstract::AbstractScalarPtr>(valueAbstract); | |||||
| auto typePtr = abstractScalar->GetTypeTrack(); | |||||
| paramTensor->dataType = typePtr->type_id(); | |||||
| paramTensor->dims = {1}; | |||||
| paramTensor->nodeType = schema::NodeType_ValueNode; | |||||
| auto data = value->cast<mindspore::BoolImmPtr>(); | |||||
| paramTensor->data.emplace_back(data->value()); | |||||
| node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size(); | |||||
| output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); | |||||
| meta_graphT->allTensors.emplace_back(std::move(paramTensor)); | |||||
| } else if (value->isa<Number>()) { | } else if (value->isa<Number>()) { | ||||
| MS_LOG(INFO) << "Value is a number."; | MS_LOG(INFO) << "Value is a number."; | ||||
| return RET_OK; | return RET_OK; | ||||