From f732c7352fe7477668ade9439b67e224eecdb0d9 Mon Sep 17 00:00:00 2001 From: yangjie159 Date: Tue, 27 Oct 2020 19:05:35 +0800 Subject: [PATCH] sync code to support train --- mindspore/lite/schema/model.fbs | 6 + mindspore/lite/schema/ops.fbs | 20 ++++ mindspore/lite/src/ops/adam.cc | 2 +- mindspore/lite/src/ops/assign_add.cc | 81 +++++++++++++ mindspore/lite/src/ops/assign_add.h | 41 +++++++ mindspore/lite/src/ops/bias_add.cc | 7 +- .../lite/src/ops/binary_cross_entropy.cc | 106 +++++++++++++++++ mindspore/lite/src/ops/binary_cross_entropy.h | 49 ++++++++ .../lite/src/ops/binary_cross_entropy_grad.cc | 108 ++++++++++++++++++ .../lite/src/ops/binary_cross_entropy_grad.h | 49 ++++++++ mindspore/lite/src/ops/control_depend.cc | 59 ++++++++++ mindspore/lite/src/ops/control_depend.h | 40 +++++++ mindspore/lite/src/ops/expand_dims.cc | 40 ++++++- mindspore/lite/src/ops/expand_dims.h | 1 + mindspore/lite/src/ops/gather.cc | 30 ++++- mindspore/lite/src/ops/gather.h | 1 + mindspore/lite/src/ops/oneslike.cc | 75 ++++++++++++ mindspore/lite/src/ops/oneslike.h | 42 +++++++ mindspore/lite/src/ops/power.cc | 45 ++++++++ mindspore/lite/src/ops/power.h | 1 + mindspore/lite/src/ops/primitive_c.cc | 39 +++++++ mindspore/lite/src/ops/reduce.cc | 3 + mindspore/lite/src/ops/reshape.cc | 3 + mindspore/lite/src/ops/squeeze.cc | 12 +- mindspore/lite/src/ops/sub.cc | 28 +++++ mindspore/lite/src/ops/sub.h | 1 + mindspore/lite/src/ops/tile.cc | 9 ++ .../lite/src/ops/unsorted_segment_sum.cc | 100 ++++++++++++++++ mindspore/lite/src/ops/unsorted_segment_sum.h | 44 +++++++ .../src/train/train_populate_parameter.cc | 34 ++++++ .../lite/tools/anf_exporter/anf_exporter.cc | 28 ++++- 31 files changed, 1091 insertions(+), 13 deletions(-) create mode 100644 mindspore/lite/src/ops/assign_add.cc create mode 100644 mindspore/lite/src/ops/assign_add.h create mode 100644 mindspore/lite/src/ops/binary_cross_entropy.cc create mode 100644 mindspore/lite/src/ops/binary_cross_entropy.h create mode 100644 mindspore/lite/src/ops/binary_cross_entropy_grad.cc create mode 100644 mindspore/lite/src/ops/binary_cross_entropy_grad.h create mode 100644 mindspore/lite/src/ops/control_depend.cc create mode 100644 mindspore/lite/src/ops/control_depend.h create mode 100644 mindspore/lite/src/ops/oneslike.cc create mode 100644 mindspore/lite/src/ops/oneslike.h create mode 100644 mindspore/lite/src/ops/unsorted_segment_sum.cc create mode 100644 mindspore/lite/src/ops/unsorted_segment_sum.h diff --git a/mindspore/lite/schema/model.fbs b/mindspore/lite/schema/model.fbs index 3c032268fe..98bc3f5e81 100644 --- a/mindspore/lite/schema/model.fbs +++ b/mindspore/lite/schema/model.fbs @@ -227,6 +227,12 @@ union PrimitiveType { Identity, LayerNorm, While, + ControlDepend, + UnsortedSegmentSum, + AssignAdd, + OnesLike, + BinaryCrossEntropyGrad, + BinaryCrossEntropy } enum QuantType: int { diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs index d8bb3233b1..2e83136990 100644 --- a/mindspore/lite/schema/ops.fbs +++ b/mindspore/lite/schema/ops.fbs @@ -966,6 +966,8 @@ table Adam { table Assign { } +table AssignAdd { +} table Where{ condition: [bool]; @@ -1010,6 +1012,9 @@ table ToFormat { table Depend { } +table ControlDepend { +} + table Return { } @@ -1108,3 +1113,18 @@ table While { bodySubgraphIndex : int; } +table UnsortedSegmentSum { + numSegments : int; +} + +table OnesLike { + +} + +table BinaryCrossEntropy { + reduction : int = 1; +} + +table BinaryCrossEntropyGrad { + reduction : int = 1; +} diff --git a/mindspore/lite/src/ops/adam.cc b/mindspore/lite/src/ops/adam.cc index 63a140b6cf..33b09ae829 100644 --- a/mindspore/lite/src/ops/adam.cc +++ b/mindspore/lite/src/ops/adam.cc @@ -73,7 +73,7 @@ Registry AdamRegistry(schema::PrimitiveType_Adam, AdamCreator); int Adam::InferShape(std::vector inputs, std::vector outputs) { if (10 != inputs.size()) { - MS_LOG(ERROR) << "Adam should have at least 8 input tensors"; + MS_LOG(ERROR) << "Adam should have at 10 input tensors"; return RET_ERROR; } diff --git a/mindspore/lite/src/ops/assign_add.cc b/mindspore/lite/src/ops/assign_add.cc new file mode 100644 index 0000000000..882c86a176 --- /dev/null +++ b/mindspore/lite/src/ops/assign_add.cc @@ -0,0 +1,81 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/assign_add.h" +namespace mindspore { +namespace lite { +#ifdef PRIMITIVE_WRITEABLE +int AssignAdd::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + if (this->primitive_ == nullptr) { + this->primitive_ = new (std::nothrow) schema::PrimitiveT; + if (this->primitive_ == nullptr) { + MS_LOG(ERROR) << "new primitive error"; + return RET_ERROR; + } + this->primitive_->value.type = schema::PrimitiveType_AssignAdd; + } + if (this->primitive_->value.type != schema::PrimitiveType_AssignAdd) { + MS_LOG(ERROR) << "PrimitiveType_AssignAdd primitive value type : " + << schema::EnumNamePrimitiveType(primitive_->value.type) << "is not equal" + << schema::EnumNamePrimitiveType(schema::PrimitiveType_AssignAdd); + delete this->primitive_; + return RET_ERROR; + } + if (this->primitive_->value.value == nullptr) { + this->primitive_->value.value = new (std::nothrow) schema::AssignAddT(); + if (this->primitive_->value.value == nullptr) { + MS_LOG(ERROR) << "new primitiveT value failed"; + delete this->primitive_; + return RET_ERROR; + } + } + return RET_OK; +} +#else +int AssignAdd::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_AssignAdd(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_AssignAdd return nullptr"; + return RET_ERROR; + } + auto val_offset = schema::CreateAssignAdd(*fbb); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_AssignAdd, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} +#endif +int AssignAdd::InferShape(std::vector inputs_, std::vector outputs_) { + Tensor *x = inputs_[0]; + Tensor *y = inputs_[1]; + Tensor *out = outputs_[0]; + std::vector x_shape = x->shape(); + if (x->data_type() != y->data_type()) { + MS_LOG(ERROR) << "no matched shape of x and y"; + return RET_ERROR; + } + std::vector output_shape(x_shape.size()); + for (int i = 0; i < static_cast(x_shape.size()); i++) { + output_shape[i] = x_shape[i]; + } + out->set_shape(output_shape); + out->SetFormat(x->GetFormat()); + out->set_data_type(x->data_type()); + return RET_OK; +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/assign_add.h b/mindspore/lite/src/ops/assign_add.h new file mode 100644 index 0000000000..b956165ee3 --- /dev/null +++ b/mindspore/lite/src/ops/assign_add.h @@ -0,0 +1,41 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "src/ops/primitive_c.h" +#ifndef LITE_SRC_OPS_ASSIGN_ADD_H_ +#define LITE_SRC_OPS_ASSIGN_ADD_H_ +namespace mindspore { +namespace lite { +class AssignAdd : public PrimitiveC { + public: +#ifdef PRIMITIVE_WRITEABLE + MS_DECLARE_PARENT(AssignAdd, PrimitiveC); + AssignAdd() = default; + explicit AssignAdd(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; +#else + AssignAdd() = default; + + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; +#endif + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; +} // namespace lite +} // namespace mindspore +#endif // LITE_SRC_OPS_ASSIGN_ADD_H_ diff --git a/mindspore/lite/src/ops/bias_add.cc b/mindspore/lite/src/ops/bias_add.cc index 899597ea41..9368316635 100644 --- a/mindspore/lite/src/ops/bias_add.cc +++ b/mindspore/lite/src/ops/bias_add.cc @@ -47,7 +47,12 @@ int BiasAdd::UnPackAttr(const Primitive &prim, const std::vector &in MS_LOG(ERROR) << "new primitiveT value failed"; return RET_ERROR; } - attr->axis = {0}; + if (prim.GetAttr("axis") == nullptr) { + MS_LOG(WARNING) << "get axis failed"; + attr->axis = {1}; + } else { + attr->axis = GetValue>(prim.GetAttr("axis")); + } this->primitive_->value.value = attr; if (this->primitive_->value.value == nullptr) { MS_LOG(ERROR) << "primitive value is nullptr"; diff --git a/mindspore/lite/src/ops/binary_cross_entropy.cc b/mindspore/lite/src/ops/binary_cross_entropy.cc new file mode 100644 index 0000000000..73cd933553 --- /dev/null +++ b/mindspore/lite/src/ops/binary_cross_entropy.cc @@ -0,0 +1,106 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "src/ops/binary_cross_entropy.h" + +namespace mindspore { +namespace lite { +#ifdef PRIMITIVE_WRITEABLE +int BinaryCrossEntropy::GetReduction() const { return this->primitive_->value.AsBinaryCrossEntropy()->reduction; } + +int BinaryCrossEntropy::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + if (this->primitive_ == nullptr) { + this->primitive_ = new (std::nothrow) schema::PrimitiveT; + if (this->primitive_ == nullptr) { + MS_LOG(ERROR) << "new primitive error"; + return RET_ERROR; + } + this->primitive_->value.type = schema::PrimitiveType_BinaryCrossEntropy; + } + if (this->primitive_->value.type != schema::PrimitiveType_BinaryCrossEntropy) { + MS_LOG(ERROR) << "PrimitiveType_BinaryCrossEntropy primitive value type : " + << schema::EnumNamePrimitiveType(primitive_->value.type) << "is not equal" + << schema::EnumNamePrimitiveType(schema::PrimitiveType_BinaryCrossEntropy); + delete this->primitive_; + return RET_ERROR; + } + if (this->primitive_->value.value == nullptr) { + schema::BinaryCrossEntropyT *attr = new (std::nothrow) schema::BinaryCrossEntropyT(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new binary cross entropy attr failed!"; + delete this->primitive_; + return RET_ERROR; + } + // default is mean + string reduction = "mean"; + if (prim.GetAttr("reduction") == nullptr) { + MS_LOG(ERROR) << "get reduction failed!"; + delete this->primitive_; + delete attr; + return RET_ERROR; + } else { + reduction = GetValue(prim.GetAttr("reduction")); + } + if (reduction == "none") { + attr->reduction = 0; + } else if (reduction == "sum") { + attr->reduction = 2; + } else { + // default is mean + attr->reduction = 1; + } + this->primitive_->value.value = attr; + } + + return RET_OK; +} +#else +int BinaryCrossEntropy::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_BinaryCrossEntropy(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_BinaryCrossEntropy return nullptr"; + return RET_ERROR; + } + int reduction = attr->reduction(); + auto val_offset = schema::CreateBinaryCrossEntropy(*fbb, reduction); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_BinaryCrossEntropy, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} + +int BinaryCrossEntropy::GetReduction() const { return this->primitive_->value_as_BinaryCrossEntropy()->reduction(); } +#endif +int BinaryCrossEntropy::InferShape(std::vector inputs_, std::vector outputs_) { + Tensor *x = inputs_[0]; + Tensor *out = outputs_[0]; + out->SetFormat(x->GetFormat()); + out->set_data_type(x->data_type()); + int reduction = GetReduction(); + if (reduction == 1 || reduction == 2) { + out->set_shape({1}); + } else { + std::vector x_shape = x->shape(); + std::vector output_shape(x_shape.size()); + output_shape.assign(x_shape.begin(), x_shape.end()); + out->set_shape(output_shape); + } + return RET_OK; +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/binary_cross_entropy.h b/mindspore/lite/src/ops/binary_cross_entropy.h new file mode 100644 index 0000000000..75e0e1224f --- /dev/null +++ b/mindspore/lite/src/ops/binary_cross_entropy.h @@ -0,0 +1,49 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "src/ops/primitive_c.h" + +#ifndef LITE_SRC_OPS_BINARYCROSSENTROPY_H_ +#define LITE_SRC_OPS_BINARYCROSSENTROPY_H_ +namespace mindspore { +namespace lite { +class BinaryCrossEntropy : public PrimitiveC { + public: +#ifdef PRIMITIVE_WRITEABLE + MS_DECLARE_PARENT(BinaryCrossEntropy, PrimitiveC); + + BinaryCrossEntropy() = default; + + explicit BinaryCrossEntropy(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} + + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; + + int GetReduction() const; +#else + BinaryCrossEntropy() = default; + + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; + + int GetReduction() const; +#endif + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; +} // namespace lite +} // namespace mindspore +#endif // LITE_SRC_OPS_BINARYCROSSENTROPY_H_ diff --git a/mindspore/lite/src/ops/binary_cross_entropy_grad.cc b/mindspore/lite/src/ops/binary_cross_entropy_grad.cc new file mode 100644 index 0000000000..0fe7739992 --- /dev/null +++ b/mindspore/lite/src/ops/binary_cross_entropy_grad.cc @@ -0,0 +1,108 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "src/ops/binary_cross_entropy_grad.h" + +namespace mindspore { +namespace lite { +#ifdef PRIMITIVE_WRITEABLE + +int BinaryCrossEntropyGrad::GetReduction() const { + return this->primitive_->value.AsBinaryCrossEntropyGrad()->reduction; +} + +int BinaryCrossEntropyGrad::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + if (this->primitive_ == nullptr) { + this->primitive_ = new (std::nothrow) schema::PrimitiveT; + if (this->primitive_ == nullptr) { + MS_LOG(ERROR) << "new primitive error"; + return RET_ERROR; + } + this->primitive_->value.type = schema::PrimitiveType_BinaryCrossEntropyGrad; + } + if (this->primitive_->value.type != schema::PrimitiveType_BinaryCrossEntropyGrad) { + MS_LOG(ERROR) << "PrimitiveType_BinaryCrossEntropyGrad primitive value type : " + << schema::EnumNamePrimitiveType(primitive_->value.type) << "is not equal" + << schema::EnumNamePrimitiveType(schema::PrimitiveType_BinaryCrossEntropyGrad); + delete this->primitive_; + return RET_ERROR; + } + if (this->primitive_->value.value == nullptr) { + schema::BinaryCrossEntropyGradT *attr = new (std::nothrow) schema::BinaryCrossEntropyGradT(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new binary cross entropy attr failed!"; + delete this->primitive_; + return RET_ERROR; + } + // default is mean + string reduction = "mean"; + if (prim.GetAttr("reduction") == nullptr) { + MS_LOG(ERROR) << "get reduction failed!"; + delete this->primitive_; + delete attr; + return RET_ERROR; + } else { + reduction = GetValue(prim.GetAttr("reduction")); + } + + if (reduction == "none") { + attr->reduction = 0; + } else if (reduction == "sum") { + attr->reduction = 2; + } else { + // default is mean + attr->reduction = 1; + } + this->primitive_->value.value = attr; + } + + return RET_OK; +} +#else +int BinaryCrossEntropyGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, + flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_BinaryCrossEntropyGrad(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_BinaryCrossEntropyGrad return nullptr"; + return RET_ERROR; + } + int reduction = attr->reduction(); + auto val_offset = schema::CreateBinaryCrossEntropyGrad(*fbb, reduction); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_BinaryCrossEntropyGrad, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} + +int BinaryCrossEntropyGrad::GetReduction() const { + return this->primitive_->value_as_BinaryCrossEntropyGrad()->reduction(); +} +#endif +int BinaryCrossEntropyGrad::InferShape(std::vector inputs_, std::vector outputs_) { + Tensor *x = inputs_[0]; + Tensor *out = outputs_[0]; + out->SetFormat(x->GetFormat()); + out->set_data_type(x->data_type()); + std::vector x_shape = x->shape(); + std::vector output_shape(x_shape.size()); + output_shape.assign(x_shape.begin(), x_shape.end()); + out->set_shape(output_shape); + return RET_OK; +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/binary_cross_entropy_grad.h b/mindspore/lite/src/ops/binary_cross_entropy_grad.h new file mode 100644 index 0000000000..2c900090ef --- /dev/null +++ b/mindspore/lite/src/ops/binary_cross_entropy_grad.h @@ -0,0 +1,49 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "src/ops/primitive_c.h" + +#ifndef LITE_SRC_OPS_BINARY_CROSS_ENTROPY_GRAD_H_ +#define LITE_SRC_OPS_BINARY_CROSS_ENTROPY_GRAD_H_ +namespace mindspore { +namespace lite { +class BinaryCrossEntropyGrad : public PrimitiveC { + public: +#ifdef PRIMITIVE_WRITEABLE + MS_DECLARE_PARENT(BinaryCrossEntropyGrad, PrimitiveC); + + BinaryCrossEntropyGrad() = default; + + explicit BinaryCrossEntropyGrad(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} + + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; + + int GetReduction() const; +#else + BinaryCrossEntropyGrad() = default; + + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; + + int GetReduction() const; +#endif + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; +} // namespace lite +} // namespace mindspore +#endif // LITE_SRC_OPS_BINARY_CROSS_ENTROPY_GRAD_H_ diff --git a/mindspore/lite/src/ops/control_depend.cc b/mindspore/lite/src/ops/control_depend.cc new file mode 100644 index 0000000000..f4acf6428b --- /dev/null +++ b/mindspore/lite/src/ops/control_depend.cc @@ -0,0 +1,59 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/ops/control_depend.h" +#include +#include + +namespace mindspore { +namespace lite { +#ifdef PRIMITIVE_WRITEABLE +int ControlDepend::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + if (this->primitive_ == nullptr) { + this->primitive_ = new (std::nothrow) schema::PrimitiveT; + if (this->primitive_ == nullptr) { + MS_LOG(ERROR) << "new primitiveT failed"; + return RET_ERROR; + } + this->primitive_->value.type = schema::PrimitiveType_ControlDepend; + } + if (this->primitive_->value.type != schema::PrimitiveType_ControlDepend) { + MS_LOG(ERROR) << "primitive_ type is error:" << this->primitive_->value.type; + delete this->primitive_; + return RET_ERROR; + } + if (this->primitive_->value.value == nullptr) { + auto attr = new (std::nothrow)(schema::ControlDependT); + if (attr == nullptr) { + MS_LOG(ERROR) << "attr is nullptr"; + delete this->primitive_; + return RET_ERROR; + } + this->primitive_->value.value = attr; + } + return RET_OK; +} +#else +int ControlDepend::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto val_offset = schema::CreateControlDepend(*fbb); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_ControlDepend, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} +#endif +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/control_depend.h b/mindspore/lite/src/ops/control_depend.h new file mode 100644 index 0000000000..6a8f6b8079 --- /dev/null +++ b/mindspore/lite/src/ops/control_depend.h @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LITE_MINDSPORE_LITE_SRC_OPS_CONTROL_DEPEND_H_ +#define LITE_MINDSPORE_LITE_SRC_OPS_CONTROL_DEPEND_H_ + +#include +#include "src/ops/primitive_c.h" + +namespace mindspore { +namespace lite { +class ControlDepend : public PrimitiveC { + public: +#ifdef PRIMITIVE_WRITEABLE + MS_DECLARE_PARENT(ControlDepend, PrimitiveC); + ControlDepend() = default; + explicit ControlDepend(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; +#else + ControlDepend() = default; + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; +#endif +}; +} // namespace lite +} // namespace mindspore + +#endif // LITE_MINDSPORE_LITE_SRC_OPS_CONTROL_DEPEND_H_ diff --git a/mindspore/lite/src/ops/expand_dims.cc b/mindspore/lite/src/ops/expand_dims.cc index 01e8d5c66f..8d4608cb2a 100644 --- a/mindspore/lite/src/ops/expand_dims.cc +++ b/mindspore/lite/src/ops/expand_dims.cc @@ -27,6 +27,43 @@ int ExpandDims::GetDim() const { return this->primitive_->value.AsExpandDims()-> void ExpandDims::SetDim(int dim) { this->primitive_->value.AsExpandDims()->dim = dim; } +int ExpandDims::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + if (this->primitive_ == nullptr) { + this->primitive_ = new (std::nothrow) schema::PrimitiveT; + if (this->primitive_ == nullptr) { + MS_LOG(ERROR) << "new primitiveT failed"; + return RET_ERROR; + } + this->primitive_->value.type = schema::PrimitiveType_ExpandDims; + } + if (this->primitive_->value.type != schema::PrimitiveType_ExpandDims) { + MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; + delete this->primitive_; + return RET_ERROR; + } + if (this->primitive_->value.value == nullptr) { + auto attr = new (std::nothrow) schema::ExpandDimsT(); + if (attr == nullptr) { + delete this->primitive_; + MS_LOG(ERROR) << "new primitiveT value failed"; + return RET_ERROR; + } + // use axis instead of dim + if (inputs[1]->isa()) { + auto axis_tensor = inputs[1]->cast(); + int axis = GetValue(axis_tensor->value()); + attr->dim = axis; + } else { + MS_LOG(ERROR) << "input axis is not value node."; + delete this->primitive_; + delete attr; + return RET_ERROR; + } + this->primitive_->value.value = attr; + } + return RET_OK; +} + #else int ExpandDims::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { MS_ASSERT(nullptr != primitive); @@ -56,9 +93,6 @@ int ExpandDims::InferShape(std::vector inputs_, std::vector MS_ASSERT(input != nullptr); auto output = outputs_.front(); MS_ASSERT(output != nullptr); - if (inputs_.size() != kSingleNum) { - MS_LOG(ERROR) << "input size is invalid"; - } if (outputs_.size() != kSingleNum) { MS_LOG(ERROR) << "output size is invalid"; } diff --git a/mindspore/lite/src/ops/expand_dims.h b/mindspore/lite/src/ops/expand_dims.h index b6f5ece401..17488517f9 100644 --- a/mindspore/lite/src/ops/expand_dims.h +++ b/mindspore/lite/src/ops/expand_dims.h @@ -31,6 +31,7 @@ class ExpandDims : public PrimitiveC { ExpandDims() = default; explicit ExpandDims(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} void SetDim(int dim); + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; #else ExpandDims() = default; diff --git a/mindspore/lite/src/ops/gather.cc b/mindspore/lite/src/ops/gather.cc index bd18ecc21b..fbfa07c5a2 100644 --- a/mindspore/lite/src/ops/gather.cc +++ b/mindspore/lite/src/ops/gather.cc @@ -31,7 +31,35 @@ int Gather::GetBatchDims() const { return this->primitive_->value.AsGather()->ba void Gather::SetAxis(int axis) { this->primitive_->value.AsGather()->axis = axis; } void Gather::SetBatchDims(int batch_dims) { this->primitive_->value.AsGather()->batchDims = batch_dims; } - +int Gather::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + if (this->primitive_ == nullptr) { + this->primitive_ = new (std::nothrow) schema::PrimitiveT; + if (this->primitive_ == nullptr) { + MS_LOG(ERROR) << "new primitive error"; + return RET_ERROR; + } + this->primitive_->value.type = schema::PrimitiveType_Gather; + } + if (this->primitive_->value.type != schema::PrimitiveType_Gather) { + MS_LOG(ERROR) << "Gather primitive value type : " << schema::EnumNamePrimitiveType(primitive_->value.type) + << "is not equal" << schema::EnumNamePrimitiveType(schema::PrimitiveType_Gather); + delete this->primitive_; + return RET_ERROR; + } + if (this->primitive_->value.value == nullptr) { + auto gather_attr = new (std::nothrow) schema::GatherT(); + if (gather_attr == nullptr) { + MS_LOG(ERROR) << "new primitive value.value error"; + delete this->primitive_; + delete gather_attr; + return RET_ERROR; + } + gather_attr->axis = GetValue(prim.GetAttr("axis")); + gather_attr->batchDims = GetValue(prim.GetAttr("batchDims")); + this->primitive_->value.value = gather_attr; + } + return RET_OK; +} #else int Gather::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { MS_ASSERT(nullptr != primitive); diff --git a/mindspore/lite/src/ops/gather.h b/mindspore/lite/src/ops/gather.h index 554a9c46dc..d398391bb7 100644 --- a/mindspore/lite/src/ops/gather.h +++ b/mindspore/lite/src/ops/gather.h @@ -33,6 +33,7 @@ class Gather : public PrimitiveC { explicit Gather(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} void SetAxis(int axis); void SetBatchDims(int batch_dims); + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; #else Gather() = default; diff --git a/mindspore/lite/src/ops/oneslike.cc b/mindspore/lite/src/ops/oneslike.cc new file mode 100644 index 0000000000..166a5326d2 --- /dev/null +++ b/mindspore/lite/src/ops/oneslike.cc @@ -0,0 +1,75 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/oneslike.h" + +namespace mindspore { +namespace lite { +#ifdef PRIMITIVE_WRITEABLE +int OnesLike::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + if (this->primitive_ == nullptr) { + this->primitive_ = new (std::nothrow) schema::PrimitiveT; + if (this->primitive_ == nullptr) { + MS_LOG(ERROR) << "new primitive error"; + return RET_ERROR; + } + this->primitive_->value.type = schema::PrimitiveType_OnesLike; + } + if (this->primitive_->value.type != schema::PrimitiveType_OnesLike) { + MS_LOG(ERROR) << "PrimitiveType_OnesLike primitive value type : " + << schema::EnumNamePrimitiveType(primitive_->value.type) << "is not equal" + << schema::EnumNamePrimitiveType(schema::PrimitiveType_OnesLike); + delete this->primitive_; + return RET_ERROR; + } + if (this->primitive_->value.value == nullptr) { + this->primitive_->value.value = new (std::nothrow) schema::OnesLikeT(); + if (this->primitive_->value.value == nullptr) { + MS_LOG(ERROR) << "new primitiveT value failed"; + delete this->primitive_; + return RET_ERROR; + } + } + return RET_OK; +} +#else +int OnesLike::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_OnesLike(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_OnesLike return nullptr"; + return RET_ERROR; + } + auto val_offset = schema::CreateOnesLike(*fbb); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_OnesLike, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} +#endif +int OnesLike::InferShape(std::vector inputs_, std::vector outputs_) { + Tensor *x = inputs_[0]; + Tensor *out = outputs_[0]; + std::vector x_shape = x->shape(); + std::vector output_shape(x_shape.size()); + output_shape.assign(x_shape.begin(), x_shape.end()); + out->set_shape(output_shape); + out->SetFormat(x->GetFormat()); + out->set_data_type(x->data_type()); + return RET_OK; +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/oneslike.h b/mindspore/lite/src/ops/oneslike.h new file mode 100644 index 0000000000..dd6cbeadc7 --- /dev/null +++ b/mindspore/lite/src/ops/oneslike.h @@ -0,0 +1,42 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "src/ops/primitive_c.h" + +#ifndef LITE_SRC_OPS_ONESLIKE_H_ +#define LITE_SRC_OPS_ONESLIKE_H_ +namespace mindspore { +namespace lite { +class OnesLike : public PrimitiveC { + public: +#ifdef PRIMITIVE_WRITEABLE + MS_DECLARE_PARENT(OnesLike, PrimitiveC); + OnesLike() = default; + explicit OnesLike(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; +#else + OnesLike() = default; + + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; +#endif + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; +} // namespace lite +} // namespace mindspore +#endif // LITE_SRC_OPS_ONESLIKE_H_ diff --git a/mindspore/lite/src/ops/power.cc b/mindspore/lite/src/ops/power.cc index 0e2c9a6ad5..7a59d4544e 100644 --- a/mindspore/lite/src/ops/power.cc +++ b/mindspore/lite/src/ops/power.cc @@ -31,6 +31,51 @@ void Power::SetPower(float power) { this->primitive_->value.AsPower()->power = p void Power::SetScale(float scale) { this->primitive_->value.AsPower()->scale = scale; } void Power::SetShift(float shift) { this->primitive_->value.AsPower()->shift = shift; } +int Power::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + if (this->primitive_ == nullptr) { + this->primitive_ = new (std::nothrow) schema::PrimitiveT; + if (this->primitive_ == nullptr) { + MS_LOG(ERROR) << "new primitiveT failed"; + return RET_ERROR; + } + this->primitive_->value.type = schema::PrimitiveType_Power; + } + if (this->primitive_->value.type != schema::PrimitiveType_Power) { + MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; + delete this->primitive_; + return RET_ERROR; + } + if (this->primitive_->value.value == nullptr) { + auto attr = new (std::nothrow) schema::PowerT(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new primitiveT value failed"; + delete this->primitive_; + return RET_ERROR; + } + + if (prim.GetAttr("scale") == nullptr) { + MS_LOG(WARNING) << "get scale failed"; + attr->scale = 1.0f; + } else { + attr->scale = GetValue(prim.GetAttr("scale")); + } + if (prim.GetAttr("power") == nullptr) { + MS_LOG(WARNING) << "get power failed"; + attr->power = 1.0f; + } else { + attr->power = GetValue(prim.GetAttr("power")); + } + if (prim.GetAttr("shift") == nullptr) { + MS_LOG(WARNING) << "get shift failed"; + attr->shift = 0; + } else { + attr->shift = GetValue(prim.GetAttr("shift")); + } + this->primitive_->value.value = attr; + } + return RET_OK; +} + #else float Power::GetPower() const { return this->primitive_->value_as_Power()->power(); } diff --git a/mindspore/lite/src/ops/power.h b/mindspore/lite/src/ops/power.h index 1f30589991..d0c9f001ec 100644 --- a/mindspore/lite/src/ops/power.h +++ b/mindspore/lite/src/ops/power.h @@ -34,6 +34,7 @@ class Power : public PrimitiveC { void SetPower(float power); void SetScale(float scale); void SetShift(float shift); + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; #else Power() = default; diff --git a/mindspore/lite/src/ops/primitive_c.cc b/mindspore/lite/src/ops/primitive_c.cc index 773531d7fd..f25bbc911d 100644 --- a/mindspore/lite/src/ops/primitive_c.cc +++ b/mindspore/lite/src/ops/primitive_c.cc @@ -145,6 +145,8 @@ #include "src/ops/identity.h" #include "src/ops/instance_norm.h" #include "src/ops/while.h" +#include "src/ops/oneslike.h" +#include "src/ops/unsorted_segment_sum.h" #ifdef SUPPORT_TRAIN #include "src/ops/neg_grad.h" @@ -165,6 +167,10 @@ #include "src/ops/sgd.h" #include "src/ops/adam.h" #include "src/ops/assign.h" +#include "src/ops/control_depend.h" +#include "src/ops/assign_add.h" +#include "src/ops/binary_cross_entropy.h" +#include "src/ops/binary_cross_entropy_grad.h" #endif #endif @@ -504,6 +510,18 @@ std::shared_ptr PrimitiveC::Create(const Primitive &prim, const std: return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "OneHot") { return NewPrimitiveC(prim, inputs, quantType); + } else if (op_type == "GatherV2") { + return NewPrimitiveC(prim, inputs, quantType); + } else if (op_type == "OnesLike") { + return NewPrimitiveC(prim, inputs, quantType); + } else if (op_type == "Pow") { + return NewPrimitiveC(prim, inputs, quantType); + } else if (op_type == "Sub") { + return NewPrimitiveC(prim, inputs, quantType); + } else if (op_type == "ExpandDims") { + return NewPrimitiveC(prim, inputs, quantType); + } else if (op_type == "UnsortedSegmentSum") { + return NewPrimitiveC(prim, inputs, quantType); #ifdef SUPPORT_TRAIN } else if (op_type == "SoftmaxCrossEntropyWithLogits") { @@ -514,6 +532,8 @@ std::shared_ptr PrimitiveC::Create(const Primitive &prim, const std: return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "Depend") { return NewPrimitiveC(prim, inputs, quantType); + } else if (op_type == "ControlDepend") { + return NewPrimitiveC(prim, inputs, quantType); } else if ((op_type == "ReluGrad" || op_type == "ReLU6Grad" || op_type == "SigmoidGrad" || op_type == "HSigmoidGrad" || op_type == "HSwishGrad")) { return NewPrimitiveC(prim, inputs, quantType); @@ -539,6 +559,12 @@ std::shared_ptr PrimitiveC::Create(const Primitive &prim, const std: return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "Assign") { return NewPrimitiveC(prim, inputs, quantType); + } else if (op_type == "AssignAdd") { + return NewPrimitiveC(prim, inputs, quantType); + } else if (op_type == "BinaryCrossEntropy") { + return NewPrimitiveC(prim, inputs, quantType); + } else if (op_type == "BinaryCrossEntropyGrad") { + return NewPrimitiveC(prim, inputs, quantType); #else } else if (op_type == "Conv2DBackpropInput") { return NewPrimitiveC(prim, inputs, quantType); @@ -830,6 +856,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { return new PowerGrad(primitive); case schema::PrimitiveType_Depend: return new Depend(primitive); + case schema::PrimitiveType_ControlDepend: + return new ControlDepend(primitive); case schema::PrimitiveType_FlattenGrad: return new FlattenGrad(primitive); case schema::PrimitiveType_NegGrad: @@ -842,6 +870,17 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { return new Adam(primitive); case schema::PrimitiveType_Assign: return new Assign(primitive); + case schema::PrimitiveType_AssignAdd: + return new AssignAdd(primitive); + case schema::PrimitiveType_OnesLike: + return new OnesLike(primitive); + case schema::PrimitiveType_UnsortedSegmentSum: + return new UnsortedSegmentSum(primitive); + case schema::PrimitiveType_BinaryCrossEntropyGrad: + return new BinaryCrossEntropyGrad(primitive); + case schema::PrimitiveType_BinaryCrossEntropy: + return new BinaryCrossEntropy(primitive); + #endif default: MS_LOG(ERROR) << "Unsupported primitive type in Create : " << schema::EnumNamePrimitiveType(op_type); diff --git a/mindspore/lite/src/ops/reduce.cc b/mindspore/lite/src/ops/reduce.cc index 9aaee7543a..03c5eea50b 100644 --- a/mindspore/lite/src/ops/reduce.cc +++ b/mindspore/lite/src/ops/reduce.cc @@ -86,6 +86,9 @@ int Reduce::UnPackAttr(const Primitive &prim, const std::vector &inp MS_ASSERT(elem != nullptr); attr->axes.emplace_back(elem->value()); } + } else { + int axes_item = GetValue(value); + attr->axes.push_back(axes_item); } } } diff --git a/mindspore/lite/src/ops/reshape.cc b/mindspore/lite/src/ops/reshape.cc index 9e6ca197a0..c9171eb731 100644 --- a/mindspore/lite/src/ops/reshape.cc +++ b/mindspore/lite/src/ops/reshape.cc @@ -62,6 +62,9 @@ int Reshape::UnPackAttr(const Primitive &prim, const std::vector &in MS_ASSERT(elem != nullptr); attr->shape.emplace_back(static_cast(elem->value())); } + } else { + int dim = GetValue(val); + attr->shape = {dim}; } } if (attr == nullptr) { diff --git a/mindspore/lite/src/ops/squeeze.cc b/mindspore/lite/src/ops/squeeze.cc index 32b281a07b..700bee9494 100644 --- a/mindspore/lite/src/ops/squeeze.cc +++ b/mindspore/lite/src/ops/squeeze.cc @@ -46,12 +46,14 @@ int Squeeze::UnPackAttr(const Primitive &prim, const std::vector &in MS_LOG(ERROR) << "new primitiveT value failed"; return RET_ERROR; } - attr->axis = GetValue>(prim.GetAttr("axis")); - this->primitive_->value.value = attr; - if (this->primitive_->value.value == nullptr) { - MS_LOG(ERROR) << "primitive value is nullptr"; - return RET_ERROR; + if (prim.GetAttr("axis") == nullptr) { + MS_LOG(WARNING) << "get axis failed"; + attr->axis = {0}; + } else { + int axis = GetValue(prim.GetAttr("axis")); + attr->axis = {axis}; } + this->primitive_->value.value = attr; } return RET_OK; } diff --git a/mindspore/lite/src/ops/sub.cc b/mindspore/lite/src/ops/sub.cc index 70ba6465a2..4b4282f38b 100644 --- a/mindspore/lite/src/ops/sub.cc +++ b/mindspore/lite/src/ops/sub.cc @@ -29,6 +29,34 @@ void Sub::SetActivationType(int activation_type) { this->primitive_->value.AsSub()->activationType = (schema::ActivationType)activation_type; } +int Sub::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + if (this->primitive_ == nullptr) { + this->primitive_ = new (std::nothrow) schema::PrimitiveT; + if (this->primitive_ == nullptr) { + MS_LOG(ERROR) << "new primitiveT failed"; + return RET_ERROR; + } + this->primitive_->value.type = schema::PrimitiveType_Sub; + } + if (this->primitive_->value.type != schema::PrimitiveType_Sub) { + MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; + delete this->primitive_; + return RET_ERROR; + } + if (this->primitive_->value.value == nullptr) { + auto attr = new (std::nothrow) schema::SubT(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new primitiveT value failed"; + delete this->primitive_; + return RET_ERROR; + } + // todo: confirm the activationType + attr->activationType = schema::ActivationType_NO_ACTIVATION; + this->primitive_->value.value = attr; + } + return RET_OK; +} + #else int Sub::GetActivationType() const { return this->primitive_->value_as_Sub()->activationType(); } diff --git a/mindspore/lite/src/ops/sub.h b/mindspore/lite/src/ops/sub.h index 92780efbc7..375a816695 100644 --- a/mindspore/lite/src/ops/sub.h +++ b/mindspore/lite/src/ops/sub.h @@ -32,6 +32,7 @@ class Sub : public Arithmetic { Sub() = default; explicit Sub(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} void SetActivationType(int activation_type); + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; #else Sub() = default; diff --git a/mindspore/lite/src/ops/tile.cc b/mindspore/lite/src/ops/tile.cc index f9d64b5f21..d4f8c53da1 100644 --- a/mindspore/lite/src/ops/tile.cc +++ b/mindspore/lite/src/ops/tile.cc @@ -52,6 +52,12 @@ int Tile::UnPackAttr(const Primitive &prim, const std::vector &input MS_LOG(ERROR) << "new primitiveT value failed"; return RET_ERROR; } + if (prim.GetAttr("dims") == nullptr) { + MS_LOG(WARNING) << "get dims failed"; + attr->dims = {1}; + } else { + attr->dims = GetValue>(prim.GetAttr("dims")); + } if (inputs.size() == kAnfPopulaterTwo) { auto inputNode = inputs[kAnfPopulaterOne]; MS_ASSERT(inputNode != nullptr); @@ -68,6 +74,9 @@ int Tile::UnPackAttr(const Primitive &prim, const std::vector &input MS_ASSERT(elem != nullptr); attr->multiples.emplace_back(elem->value()); } + } else { + int multiple = GetValue(value); + attr->multiples = {multiple}; } } } diff --git a/mindspore/lite/src/ops/unsorted_segment_sum.cc b/mindspore/lite/src/ops/unsorted_segment_sum.cc new file mode 100644 index 0000000000..f9b4573b4c --- /dev/null +++ b/mindspore/lite/src/ops/unsorted_segment_sum.cc @@ -0,0 +1,100 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "src/ops/unsorted_segment_sum.h" + +namespace mindspore { +namespace lite { +#ifdef PRIMITIVE_WRITEABLE + +int UnsortedSegmentSum::GetNumSegments() const { return this->primitive_->value.AsUnsortedSegmentSum()->numSegments; } + +int UnsortedSegmentSum::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + if (this->primitive_ == nullptr) { + this->primitive_ = new (std::nothrow) schema::PrimitiveT; + if (this->primitive_ == nullptr) { + MS_LOG(ERROR) << "new primitive error"; + return RET_ERROR; + } + this->primitive_->value.type = schema::PrimitiveType_UnsortedSegmentSum; + } + if (this->primitive_->value.type != schema::PrimitiveType_UnsortedSegmentSum) { + MS_LOG(ERROR) << "UnSortedSegmentSum primitive value type : " + << schema::EnumNamePrimitiveType(primitive_->value.type) << "is not equal" + << schema::EnumNamePrimitiveType(schema::PrimitiveType_UnsortedSegmentSum); + delete this->primitive_; + return RET_ERROR; + } + if (this->primitive_->value.value == nullptr) { + std::unique_ptr attr = std::make_unique(); + if (inputs[2]->isa()) { + ValuePtr value = inputs[2]->cast()->value(); + attr->numSegments = GetValue(value); + this->primitive_->value.value = attr.release(); + } + } + return RET_OK; +} +#else +int UnsortedSegmentSum::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_UnsortedSegmentSum(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_UnsortedSegmentSum return nullptr"; + return RET_ERROR; + } + int num_segments = attr->numSegments(); + auto val_offset = schema::CreateUnsortedSegmentSum(*fbb, num_segments); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_UnsortedSegmentSum, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} + +int UnsortedSegmentSum::GetNumSegments() const { + int ret = this->primitive_->value_as_UnsortedSegmentSum()->numSegments(); + return ret; +} +#endif +int UnsortedSegmentSum::InferShape(std::vector inputs_, std::vector outputs_) { + // check inputs and outputs + if (inputs_.size() != 3) { + MS_LOG(ERROR) << "invalid inputs numbers"; + return RET_ERROR; + } + if (outputs_.size() != 1) { + MS_LOG(ERROR) << "invalid outputs numbers"; + return RET_ERROR; + } + Tensor *out = outputs_.front(); + Tensor *x = inputs_.front(); + Tensor *segment_id = inputs_[1]; + std::vector x_shape = x->shape(); + std::vector segment_id_shape = segment_id->shape(); + int num_segments = GetNumSegments(); + std::vector output_shape; + output_shape.push_back(num_segments); + for (int index = segment_id_shape.size(); index < static_cast(x_shape.size()); index++) { + output_shape.push_back(x_shape[index]); + } + out->set_shape(output_shape); + out->SetFormat(x->GetFormat()); + out->set_data_type(x->data_type()); + return RET_OK; +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/unsorted_segment_sum.h b/mindspore/lite/src/ops/unsorted_segment_sum.h new file mode 100644 index 0000000000..176c33bcad --- /dev/null +++ b/mindspore/lite/src/ops/unsorted_segment_sum.h @@ -0,0 +1,44 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "src/ops/primitive_c.h" +#ifndef LITE_SRC_OPS_UNSORTED_SEGMENT_SUM_H_ +#define LITE_SRC_OPS_UNSORTED_SEGMENT_SUM_H_ +namespace mindspore { +namespace lite { +class UnsortedSegmentSum : public PrimitiveC { + public: +#ifdef PRIMITIVE_WRITEABLE + MS_DECLARE_PARENT(UnsortedSegmentSum, PrimitiveC); + UnsortedSegmentSum() = default; + explicit UnsortedSegmentSum(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; + int GetNumSegments() const; +#else + UnsortedSegmentSum() = default; + + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; + + int GetNumSegments() const; +#endif + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; +} // namespace lite +} // namespace mindspore +#endif // LITE_SRC_OPS_UNSORTED_SEGMENT_SUM_H_ diff --git a/mindspore/lite/src/train/train_populate_parameter.cc b/mindspore/lite/src/train/train_populate_parameter.cc index e0ba28a391..07298534eb 100644 --- a/mindspore/lite/src/train/train_populate_parameter.cc +++ b/mindspore/lite/src/train/train_populate_parameter.cc @@ -36,6 +36,9 @@ #include "src/ops/bn_grad.h" #include "nnacl/fp32_grad/batch_norm.h" #include "src/ops/adam.h" +#include "src/ops/oneslike.h" +#include "src/ops/binary_cross_entropy.h" +#include "src/ops/binary_cross_entropy_grad.h" namespace mindspore::kernel { @@ -76,6 +79,30 @@ OpParameter *PopulateApplyMomentumParameter(const mindspore::lite::PrimitiveC *p return reinterpret_cast(p); } +OpParameter *PopulateBCEParameter(const mindspore::lite::PrimitiveC *primitive) { + int32_t *reduction = reinterpret_cast(malloc(sizeof(int32_t))); + if (reduction == nullptr) { + MS_LOG(ERROR) << "malloc reduction failed."; + return nullptr; + } + auto param = + reinterpret_cast(const_cast(primitive)); + *reduction = param->GetReduction(); + return reinterpret_cast(reduction); +} + +OpParameter *PopulateBCEGradParameter(const mindspore::lite::PrimitiveC *primitive) { + int32_t *reduction = reinterpret_cast(malloc(sizeof(int32_t))); + if (reduction == nullptr) { + MS_LOG(ERROR) << "malloc reduction failed."; + return nullptr; + } + auto param = + reinterpret_cast(const_cast(primitive)); + *reduction = param->GetReduction(); + return reinterpret_cast(reduction); +} + OpParameter *PopulateAdamParameter(const mindspore::lite::PrimitiveC *primitive) { if (primitive == nullptr) { MS_LOG(ERROR) << "Primitive is nullptr when populating parameter for op."; @@ -396,6 +423,13 @@ void PopulateTrainParameters() { lite::Registry BNGradParameterRegistry(schema::PrimitiveType_BNGrad, PopulateBNGradParameter); lite::Registry AdamParameterRegistry(schema::PrimitiveType_Adam, PopulateAdamParameter); lite::Registry AssignParameterRegistry(schema::PrimitiveType_Assign, DefaultPopulateParameter); + lite::Registry AssignAddParameterRegistry(schema::PrimitiveType_AssignAdd, DefaultPopulateParameter); + lite::Registry BinaryCrossEntropyParameterRegistry(schema::PrimitiveType_BinaryCrossEntropy, PopulateBCEParameter); + lite::Registry BinaryCrossEntropyGradParameterRegistry(schema::PrimitiveType_BinaryCrossEntropyGrad, + PopulateBCEGradParameter); + lite::Registry OnesLikeParameterRegistry(schema::PrimitiveType_OnesLike, DefaultPopulateParameter); + lite::Registry UnsortedSegmentSumParameterRegistry(schema::PrimitiveType_UnsortedSegmentSum, + DefaultPopulateParameter); } } // namespace mindspore::kernel diff --git a/mindspore/lite/tools/anf_exporter/anf_exporter.cc b/mindspore/lite/tools/anf_exporter/anf_exporter.cc index 8f4279102b..82ae88b150 100644 --- a/mindspore/lite/tools/anf_exporter/anf_exporter.cc +++ b/mindspore/lite/tools/anf_exporter/anf_exporter.cc @@ -69,7 +69,8 @@ void AnfExporter::RemoveIfDepend(const CNodePtr &cnode) { continue; } auto dependNode = utils::cast(inputNode); - if (IsPrimitiveCNode(dependNode, schema::PrimitiveType_Depend)) { + if (IsPrimitiveCNode(dependNode, schema::PrimitiveType_Depend) || + IsPrimitiveCNode(dependNode, schema::PrimitiveType_ControlDepend)) { hasDepend = true; for (size_t j = 1; j < dependNode->inputs().size(); ++j) { AnfNodePtr dependInputNode = dependNode->input(j); @@ -209,6 +210,7 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool kee if ((primitive_c->Type() == schema::PrimitiveType_TupleGetItem) || #ifdef SUPPORT_TRAIN (primitive_c->Type() == schema::PrimitiveType_Depend) || + (primitive_c->Type() == schema::PrimitiveType_ControlDepend) || #endif (primitive_c->Type() == schema::PrimitiveType_MakeTuple)) { continue; @@ -402,7 +404,9 @@ int AnfExporter::ConvertInputValueNode(std::shared_ptr input_anode, paramTensor->dims = {1}; paramTensor->nodeType = schema::NodeType::NodeType_ValueNode; auto data = value->cast(); - paramTensor->data.emplace_back(data->value()); + int real_data = GetValue(data); + paramTensor->data.resize(sizeof(int32_t)); + memcpy(paramTensor->data.data(), &real_data, sizeof(int32_t)); node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size(); output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); meta_graphT->allTensors.emplace_back(std::move(paramTensor)); @@ -418,6 +422,14 @@ int AnfExporter::ConvertInputValueNode(std::shared_ptr input_anode, node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size(); output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); meta_graphT->allTensors.emplace_back(std::move(paramTensor)); + } else if (value->isa()) { + paramTensor->dataType = kNumberTypeInt32; + paramTensor->dims = {1}; + paramTensor->nodeType = schema::NodeType_ValueNode; + paramTensor->data.emplace_back(kNumberTypeInt32); + node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size(); + output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); + meta_graphT->allTensors.emplace_back(std::move(paramTensor)); } else if (value->isa()) { #ifndef SUPPORT_TRAIN MS_LOG(DEBUG) << "Value type is ValueSequence."; @@ -456,6 +468,18 @@ int AnfExporter::ConvertInputValueNode(std::shared_ptr input_anode, MS_LOG(ERROR) << "Value type is ValueSequence not supported - " << valueAbstract->type_name() << "."; } #endif + } else if (value->isa()) { + auto valueAbstract = valueNode->abstract(); + auto abstractScalar = utils::cast(valueAbstract); + auto typePtr = abstractScalar->GetTypeTrack(); + paramTensor->dataType = typePtr->type_id(); + paramTensor->dims = {1}; + paramTensor->nodeType = schema::NodeType_ValueNode; + auto data = value->cast(); + paramTensor->data.emplace_back(data->value()); + node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size(); + output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); + meta_graphT->allTensors.emplace_back(std::move(paramTensor)); } else if (value->isa()) { MS_LOG(INFO) << "Value is a number."; return RET_OK;