Browse Source

!7674 sync code to support train

Merge pull request !7674 from yangjie159/sync_code
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
43f08e3817
31 changed files with 1091 additions and 13 deletions
  1. +6
    -0
      mindspore/lite/schema/model.fbs
  2. +20
    -0
      mindspore/lite/schema/ops.fbs
  3. +1
    -1
      mindspore/lite/src/ops/adam.cc
  4. +81
    -0
      mindspore/lite/src/ops/assign_add.cc
  5. +41
    -0
      mindspore/lite/src/ops/assign_add.h
  6. +6
    -1
      mindspore/lite/src/ops/bias_add.cc
  7. +106
    -0
      mindspore/lite/src/ops/binary_cross_entropy.cc
  8. +49
    -0
      mindspore/lite/src/ops/binary_cross_entropy.h
  9. +108
    -0
      mindspore/lite/src/ops/binary_cross_entropy_grad.cc
  10. +49
    -0
      mindspore/lite/src/ops/binary_cross_entropy_grad.h
  11. +59
    -0
      mindspore/lite/src/ops/control_depend.cc
  12. +40
    -0
      mindspore/lite/src/ops/control_depend.h
  13. +37
    -3
      mindspore/lite/src/ops/expand_dims.cc
  14. +1
    -0
      mindspore/lite/src/ops/expand_dims.h
  15. +29
    -1
      mindspore/lite/src/ops/gather.cc
  16. +1
    -0
      mindspore/lite/src/ops/gather.h
  17. +75
    -0
      mindspore/lite/src/ops/oneslike.cc
  18. +42
    -0
      mindspore/lite/src/ops/oneslike.h
  19. +45
    -0
      mindspore/lite/src/ops/power.cc
  20. +1
    -0
      mindspore/lite/src/ops/power.h
  21. +39
    -0
      mindspore/lite/src/ops/primitive_c.cc
  22. +3
    -0
      mindspore/lite/src/ops/reduce.cc
  23. +3
    -0
      mindspore/lite/src/ops/reshape.cc
  24. +7
    -5
      mindspore/lite/src/ops/squeeze.cc
  25. +28
    -0
      mindspore/lite/src/ops/sub.cc
  26. +1
    -0
      mindspore/lite/src/ops/sub.h
  27. +9
    -0
      mindspore/lite/src/ops/tile.cc
  28. +100
    -0
      mindspore/lite/src/ops/unsorted_segment_sum.cc
  29. +44
    -0
      mindspore/lite/src/ops/unsorted_segment_sum.h
  30. +34
    -0
      mindspore/lite/src/train/train_populate_parameter.cc
  31. +26
    -2
      mindspore/lite/tools/anf_exporter/anf_exporter.cc

+ 6
- 0
mindspore/lite/schema/model.fbs View File

@@ -227,6 +227,12 @@ union PrimitiveType {
Identity,
LayerNorm,
While,
ControlDepend,
UnsortedSegmentSum,
AssignAdd,
OnesLike,
BinaryCrossEntropyGrad,
BinaryCrossEntropy
}

enum QuantType: int {


+ 20
- 0
mindspore/lite/schema/ops.fbs View File

@@ -966,6 +966,8 @@ table Adam {
table Assign {
}

table AssignAdd {
}

table Where{
condition: [bool];
@@ -1010,6 +1012,9 @@ table ToFormat {
table Depend {
}

table ControlDepend {
}

table Return {
}

@@ -1108,3 +1113,18 @@ table While {
bodySubgraphIndex : int;
}

table UnsortedSegmentSum {
numSegments : int;
}

table OnesLike {

}

table BinaryCrossEntropy {
reduction : int = 1;
}

table BinaryCrossEntropyGrad {
reduction : int = 1;
}

+ 1
- 1
mindspore/lite/src/ops/adam.cc View File

@@ -73,7 +73,7 @@ Registry AdamRegistry(schema::PrimitiveType_Adam, AdamCreator);

int Adam::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) {
if (10 != inputs.size()) {
MS_LOG(ERROR) << "Adam should have at least 8 input tensors";
MS_LOG(ERROR) << "Adam should have at 10 input tensors";
return RET_ERROR;
}



+ 81
- 0
mindspore/lite/src/ops/assign_add.cc View File

@@ -0,0 +1,81 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "src/ops/assign_add.h"
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int AssignAdd::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitive error";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_AssignAdd;
}
if (this->primitive_->value.type != schema::PrimitiveType_AssignAdd) {
MS_LOG(ERROR) << "PrimitiveType_AssignAdd primitive value type : "
<< schema::EnumNamePrimitiveType(primitive_->value.type) << "is not equal"
<< schema::EnumNamePrimitiveType(schema::PrimitiveType_AssignAdd);
delete this->primitive_;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
this->primitive_->value.value = new (std::nothrow) schema::AssignAddT();
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
delete this->primitive_;
return RET_ERROR;
}
}
return RET_OK;
}
#else
int AssignAdd::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_AssignAdd();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_AssignAdd return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateAssignAdd(*fbb);
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_AssignAdd, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
#endif
int AssignAdd::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
Tensor *x = inputs_[0];
Tensor *y = inputs_[1];
Tensor *out = outputs_[0];
std::vector<int> x_shape = x->shape();
if (x->data_type() != y->data_type()) {
MS_LOG(ERROR) << "no matched shape of x and y";
return RET_ERROR;
}
std::vector<int> output_shape(x_shape.size());
for (int i = 0; i < static_cast<int>(x_shape.size()); i++) {
output_shape[i] = x_shape[i];
}
out->set_shape(output_shape);
out->SetFormat(x->GetFormat());
out->set_data_type(x->data_type());
return RET_OK;
}
} // namespace lite
} // namespace mindspore

+ 41
- 0
mindspore/lite/src/ops/assign_add.h View File

@@ -0,0 +1,41 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <vector>
#include <set>
#include <cmath>
#include "src/ops/primitive_c.h"
#ifndef LITE_SRC_OPS_ASSIGN_ADD_H_
#define LITE_SRC_OPS_ASSIGN_ADD_H_
namespace mindspore {
namespace lite {
class AssignAdd : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(AssignAdd, PrimitiveC);
AssignAdd() = default;
explicit AssignAdd(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else
AssignAdd() = default;

int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
};
} // namespace lite
} // namespace mindspore
#endif // LITE_SRC_OPS_ASSIGN_ADD_H_

+ 6
- 1
mindspore/lite/src/ops/bias_add.cc View File

@@ -47,7 +47,12 @@ int BiasAdd::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &in
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
attr->axis = {0};
if (prim.GetAttr("axis") == nullptr) {
MS_LOG(WARNING) << "get axis failed";
attr->axis = {1};
} else {
attr->axis = GetValue<std::vector<int>>(prim.GetAttr("axis"));
}
this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "primitive value is nullptr";


+ 106
- 0
mindspore/lite/src/ops/binary_cross_entropy.cc View File

@@ -0,0 +1,106 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <string>
#include "src/ops/binary_cross_entropy.h"

namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int BinaryCrossEntropy::GetReduction() const { return this->primitive_->value.AsBinaryCrossEntropy()->reduction; }

int BinaryCrossEntropy::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitive error";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_BinaryCrossEntropy;
}
if (this->primitive_->value.type != schema::PrimitiveType_BinaryCrossEntropy) {
MS_LOG(ERROR) << "PrimitiveType_BinaryCrossEntropy primitive value type : "
<< schema::EnumNamePrimitiveType(primitive_->value.type) << "is not equal"
<< schema::EnumNamePrimitiveType(schema::PrimitiveType_BinaryCrossEntropy);
delete this->primitive_;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
schema::BinaryCrossEntropyT *attr = new (std::nothrow) schema::BinaryCrossEntropyT();
if (attr == nullptr) {
MS_LOG(ERROR) << "new binary cross entropy attr failed!";
delete this->primitive_;
return RET_ERROR;
}
// default is mean
string reduction = "mean";
if (prim.GetAttr("reduction") == nullptr) {
MS_LOG(ERROR) << "get reduction failed!";
delete this->primitive_;
delete attr;
return RET_ERROR;
} else {
reduction = GetValue<string>(prim.GetAttr("reduction"));
}
if (reduction == "none") {
attr->reduction = 0;
} else if (reduction == "sum") {
attr->reduction = 2;
} else {
// default is mean
attr->reduction = 1;
}
this->primitive_->value.value = attr;
}

return RET_OK;
}
#else
int BinaryCrossEntropy::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_BinaryCrossEntropy();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_BinaryCrossEntropy return nullptr";
return RET_ERROR;
}
int reduction = attr->reduction();
auto val_offset = schema::CreateBinaryCrossEntropy(*fbb, reduction);
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_BinaryCrossEntropy, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}

int BinaryCrossEntropy::GetReduction() const { return this->primitive_->value_as_BinaryCrossEntropy()->reduction(); }
#endif
int BinaryCrossEntropy::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
Tensor *x = inputs_[0];
Tensor *out = outputs_[0];
out->SetFormat(x->GetFormat());
out->set_data_type(x->data_type());
int reduction = GetReduction();
if (reduction == 1 || reduction == 2) {
out->set_shape({1});
} else {
std::vector<int> x_shape = x->shape();
std::vector<int> output_shape(x_shape.size());
output_shape.assign(x_shape.begin(), x_shape.end());
out->set_shape(output_shape);
}
return RET_OK;
}
} // namespace lite
} // namespace mindspore

+ 49
- 0
mindspore/lite/src/ops/binary_cross_entropy.h View File

@@ -0,0 +1,49 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <vector>
#include <set>
#include <cmath>
#include "src/ops/primitive_c.h"

#ifndef LITE_SRC_OPS_BINARYCROSSENTROPY_H_
#define LITE_SRC_OPS_BINARYCROSSENTROPY_H_
namespace mindspore {
namespace lite {
class BinaryCrossEntropy : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(BinaryCrossEntropy, PrimitiveC);

BinaryCrossEntropy() = default;

explicit BinaryCrossEntropy(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}

int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;

int GetReduction() const;
#else
BinaryCrossEntropy() = default;

int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;

int GetReduction() const;
#endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
};
} // namespace lite
} // namespace mindspore
#endif // LITE_SRC_OPS_BINARYCROSSENTROPY_H_

+ 108
- 0
mindspore/lite/src/ops/binary_cross_entropy_grad.cc View File

@@ -0,0 +1,108 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <string>
#include "src/ops/binary_cross_entropy_grad.h"

namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE

int BinaryCrossEntropyGrad::GetReduction() const {
return this->primitive_->value.AsBinaryCrossEntropyGrad()->reduction;
}

int BinaryCrossEntropyGrad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitive error";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_BinaryCrossEntropyGrad;
}
if (this->primitive_->value.type != schema::PrimitiveType_BinaryCrossEntropyGrad) {
MS_LOG(ERROR) << "PrimitiveType_BinaryCrossEntropyGrad primitive value type : "
<< schema::EnumNamePrimitiveType(primitive_->value.type) << "is not equal"
<< schema::EnumNamePrimitiveType(schema::PrimitiveType_BinaryCrossEntropyGrad);
delete this->primitive_;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
schema::BinaryCrossEntropyGradT *attr = new (std::nothrow) schema::BinaryCrossEntropyGradT();
if (attr == nullptr) {
MS_LOG(ERROR) << "new binary cross entropy attr failed!";
delete this->primitive_;
return RET_ERROR;
}
// default is mean
string reduction = "mean";
if (prim.GetAttr("reduction") == nullptr) {
MS_LOG(ERROR) << "get reduction failed!";
delete this->primitive_;
delete attr;
return RET_ERROR;
} else {
reduction = GetValue<string>(prim.GetAttr("reduction"));
}

if (reduction == "none") {
attr->reduction = 0;
} else if (reduction == "sum") {
attr->reduction = 2;
} else {
// default is mean
attr->reduction = 1;
}
this->primitive_->value.value = attr;
}

return RET_OK;
}
#else
int BinaryCrossEntropyGrad::UnPackToFlatBuilder(const schema::Primitive *primitive,
flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_BinaryCrossEntropyGrad();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_BinaryCrossEntropyGrad return nullptr";
return RET_ERROR;
}
int reduction = attr->reduction();
auto val_offset = schema::CreateBinaryCrossEntropyGrad(*fbb, reduction);
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_BinaryCrossEntropyGrad, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}

int BinaryCrossEntropyGrad::GetReduction() const {
return this->primitive_->value_as_BinaryCrossEntropyGrad()->reduction();
}
#endif
int BinaryCrossEntropyGrad::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
Tensor *x = inputs_[0];
Tensor *out = outputs_[0];
out->SetFormat(x->GetFormat());
out->set_data_type(x->data_type());
std::vector<int> x_shape = x->shape();
std::vector<int> output_shape(x_shape.size());
output_shape.assign(x_shape.begin(), x_shape.end());
out->set_shape(output_shape);
return RET_OK;
}
} // namespace lite
} // namespace mindspore

+ 49
- 0
mindspore/lite/src/ops/binary_cross_entropy_grad.h View File

@@ -0,0 +1,49 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <vector>
#include <set>
#include <cmath>
#include "src/ops/primitive_c.h"

#ifndef LITE_SRC_OPS_BINARY_CROSS_ENTROPY_GRAD_H_
#define LITE_SRC_OPS_BINARY_CROSS_ENTROPY_GRAD_H_
namespace mindspore {
namespace lite {
class BinaryCrossEntropyGrad : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(BinaryCrossEntropyGrad, PrimitiveC);

BinaryCrossEntropyGrad() = default;

explicit BinaryCrossEntropyGrad(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}

int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;

int GetReduction() const;
#else
BinaryCrossEntropyGrad() = default;

int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;

int GetReduction() const;
#endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
};
} // namespace lite
} // namespace mindspore
#endif // LITE_SRC_OPS_BINARY_CROSS_ENTROPY_GRAD_H_

+ 59
- 0
mindspore/lite/src/ops/control_depend.cc View File

@@ -0,0 +1,59 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/control_depend.h"
#include <vector>
#include <memory>

namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int ControlDepend::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_ControlDepend;
}
if (this->primitive_->value.type != schema::PrimitiveType_ControlDepend) {
MS_LOG(ERROR) << "primitive_ type is error:" << this->primitive_->value.type;
delete this->primitive_;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow)(schema::ControlDependT);
if (attr == nullptr) {
MS_LOG(ERROR) << "attr is nullptr";
delete this->primitive_;
return RET_ERROR;
}
this->primitive_->value.value = attr;
}
return RET_OK;
}
#else
int ControlDepend::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto val_offset = schema::CreateControlDepend(*fbb);
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_ControlDepend, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
#endif
} // namespace lite
} // namespace mindspore

+ 40
- 0
mindspore/lite/src/ops/control_depend.h View File

@@ -0,0 +1,40 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef LITE_MINDSPORE_LITE_SRC_OPS_CONTROL_DEPEND_H_
#define LITE_MINDSPORE_LITE_SRC_OPS_CONTROL_DEPEND_H_

#include <vector>
#include "src/ops/primitive_c.h"

namespace mindspore {
namespace lite {
class ControlDepend : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(ControlDepend, PrimitiveC);
ControlDepend() = default;
explicit ControlDepend(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else
ControlDepend() = default;
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
};
} // namespace lite
} // namespace mindspore

#endif // LITE_MINDSPORE_LITE_SRC_OPS_CONTROL_DEPEND_H_

+ 37
- 3
mindspore/lite/src/ops/expand_dims.cc View File

@@ -27,6 +27,43 @@ int ExpandDims::GetDim() const { return this->primitive_->value.AsExpandDims()->

void ExpandDims::SetDim(int dim) { this->primitive_->value.AsExpandDims()->dim = dim; }

int ExpandDims::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_ExpandDims;
}
if (this->primitive_->value.type != schema::PrimitiveType_ExpandDims) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
delete this->primitive_;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::ExpandDimsT();
if (attr == nullptr) {
delete this->primitive_;
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
// use axis instead of dim
if (inputs[1]->isa<ValueNode>()) {
auto axis_tensor = inputs[1]->cast<ValueNodePtr>();
int axis = GetValue<int>(axis_tensor->value());
attr->dim = axis;
} else {
MS_LOG(ERROR) << "input axis is not value node.";
delete this->primitive_;
delete attr;
return RET_ERROR;
}
this->primitive_->value.value = attr;
}
return RET_OK;
}

#else
int ExpandDims::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
@@ -56,9 +93,6 @@ int ExpandDims::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *>
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
if (inputs_.size() != kSingleNum) {
MS_LOG(ERROR) << "input size is invalid";
}
if (outputs_.size() != kSingleNum) {
MS_LOG(ERROR) << "output size is invalid";
}


+ 1
- 0
mindspore/lite/src/ops/expand_dims.h View File

@@ -31,6 +31,7 @@ class ExpandDims : public PrimitiveC {
ExpandDims() = default;
explicit ExpandDims(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetDim(int dim);
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;

#else
ExpandDims() = default;


+ 29
- 1
mindspore/lite/src/ops/gather.cc View File

@@ -31,7 +31,35 @@ int Gather::GetBatchDims() const { return this->primitive_->value.AsGather()->ba

void Gather::SetAxis(int axis) { this->primitive_->value.AsGather()->axis = axis; }
void Gather::SetBatchDims(int batch_dims) { this->primitive_->value.AsGather()->batchDims = batch_dims; }

int Gather::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitive error";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_Gather;
}
if (this->primitive_->value.type != schema::PrimitiveType_Gather) {
MS_LOG(ERROR) << "Gather primitive value type : " << schema::EnumNamePrimitiveType(primitive_->value.type)
<< "is not equal" << schema::EnumNamePrimitiveType(schema::PrimitiveType_Gather);
delete this->primitive_;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto gather_attr = new (std::nothrow) schema::GatherT();
if (gather_attr == nullptr) {
MS_LOG(ERROR) << "new primitive value.value error";
delete this->primitive_;
delete gather_attr;
return RET_ERROR;
}
gather_attr->axis = GetValue<int>(prim.GetAttr("axis"));
gather_attr->batchDims = GetValue<int>(prim.GetAttr("batchDims"));
this->primitive_->value.value = gather_attr;
}
return RET_OK;
}
#else
int Gather::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);


+ 1
- 0
mindspore/lite/src/ops/gather.h View File

@@ -33,6 +33,7 @@ class Gather : public PrimitiveC {
explicit Gather(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetAxis(int axis);
void SetBatchDims(int batch_dims);
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else
Gather() = default;



+ 75
- 0
mindspore/lite/src/ops/oneslike.cc View File

@@ -0,0 +1,75 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "src/ops/oneslike.h"

namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int OnesLike::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitive error";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_OnesLike;
}
if (this->primitive_->value.type != schema::PrimitiveType_OnesLike) {
MS_LOG(ERROR) << "PrimitiveType_OnesLike primitive value type : "
<< schema::EnumNamePrimitiveType(primitive_->value.type) << "is not equal"
<< schema::EnumNamePrimitiveType(schema::PrimitiveType_OnesLike);
delete this->primitive_;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
this->primitive_->value.value = new (std::nothrow) schema::OnesLikeT();
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
delete this->primitive_;
return RET_ERROR;
}
}
return RET_OK;
}
#else
int OnesLike::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_OnesLike();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_OnesLike return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateOnesLike(*fbb);
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_OnesLike, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
#endif
int OnesLike::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
Tensor *x = inputs_[0];
Tensor *out = outputs_[0];
std::vector<int> x_shape = x->shape();
std::vector<int> output_shape(x_shape.size());
output_shape.assign(x_shape.begin(), x_shape.end());
out->set_shape(output_shape);
out->SetFormat(x->GetFormat());
out->set_data_type(x->data_type());
return RET_OK;
}
} // namespace lite
} // namespace mindspore

+ 42
- 0
mindspore/lite/src/ops/oneslike.h View File

@@ -0,0 +1,42 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <vector>
#include <set>
#include <cmath>
#include "src/ops/primitive_c.h"

#ifndef LITE_SRC_OPS_ONESLIKE_H_
#define LITE_SRC_OPS_ONESLIKE_H_
namespace mindspore {
namespace lite {
class OnesLike : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(OnesLike, PrimitiveC);
OnesLike() = default;
explicit OnesLike(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else
OnesLike() = default;

int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
};
} // namespace lite
} // namespace mindspore
#endif // LITE_SRC_OPS_ONESLIKE_H_

+ 45
- 0
mindspore/lite/src/ops/power.cc View File

@@ -31,6 +31,51 @@ void Power::SetPower(float power) { this->primitive_->value.AsPower()->power = p
void Power::SetScale(float scale) { this->primitive_->value.AsPower()->scale = scale; }
void Power::SetShift(float shift) { this->primitive_->value.AsPower()->shift = shift; }

int Power::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_Power;
}
if (this->primitive_->value.type != schema::PrimitiveType_Power) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
delete this->primitive_;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::PowerT();
if (attr == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
delete this->primitive_;
return RET_ERROR;
}

if (prim.GetAttr("scale") == nullptr) {
MS_LOG(WARNING) << "get scale failed";
attr->scale = 1.0f;
} else {
attr->scale = GetValue<float>(prim.GetAttr("scale"));
}
if (prim.GetAttr("power") == nullptr) {
MS_LOG(WARNING) << "get power failed";
attr->power = 1.0f;
} else {
attr->power = GetValue<float>(prim.GetAttr("power"));
}
if (prim.GetAttr("shift") == nullptr) {
MS_LOG(WARNING) << "get shift failed";
attr->shift = 0;
} else {
attr->shift = GetValue<float>(prim.GetAttr("shift"));
}
this->primitive_->value.value = attr;
}
return RET_OK;
}

#else

float Power::GetPower() const { return this->primitive_->value_as_Power()->power(); }


+ 1
- 0
mindspore/lite/src/ops/power.h View File

@@ -34,6 +34,7 @@ class Power : public PrimitiveC {
void SetPower(float power);
void SetScale(float scale);
void SetShift(float shift);
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else
Power() = default;



+ 39
- 0
mindspore/lite/src/ops/primitive_c.cc View File

@@ -145,6 +145,8 @@
#include "src/ops/identity.h"
#include "src/ops/instance_norm.h"
#include "src/ops/while.h"
#include "src/ops/oneslike.h"
#include "src/ops/unsorted_segment_sum.h"

#ifdef SUPPORT_TRAIN
#include "src/ops/neg_grad.h"
@@ -165,6 +167,10 @@
#include "src/ops/sgd.h"
#include "src/ops/adam.h"
#include "src/ops/assign.h"
#include "src/ops/control_depend.h"
#include "src/ops/assign_add.h"
#include "src/ops/binary_cross_entropy.h"
#include "src/ops/binary_cross_entropy_grad.h"
#endif

#endif
@@ -504,6 +510,18 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std:
return NewPrimitiveC<While>(prim, inputs, quantType);
} else if (op_type == "OneHot") {
return NewPrimitiveC<OneHot>(prim, inputs, quantType);
} else if (op_type == "GatherV2") {
return NewPrimitiveC<Gather>(prim, inputs, quantType);
} else if (op_type == "OnesLike") {
return NewPrimitiveC<OnesLike>(prim, inputs, quantType);
} else if (op_type == "Pow") {
return NewPrimitiveC<Power>(prim, inputs, quantType);
} else if (op_type == "Sub") {
return NewPrimitiveC<Sub>(prim, inputs, quantType);
} else if (op_type == "ExpandDims") {
return NewPrimitiveC<ExpandDims>(prim, inputs, quantType);
} else if (op_type == "UnsortedSegmentSum") {
return NewPrimitiveC<UnsortedSegmentSum>(prim, inputs, quantType);

#ifdef SUPPORT_TRAIN
} else if (op_type == "SoftmaxCrossEntropyWithLogits") {
@@ -514,6 +532,8 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std:
return NewPrimitiveC<ApplyMomentum>(prim, inputs, quantType);
} else if (op_type == "Depend") {
return NewPrimitiveC<Depend>(prim, inputs, quantType);
} else if (op_type == "ControlDepend") {
return NewPrimitiveC<ControlDepend>(prim, inputs, quantType);
} else if ((op_type == "ReluGrad" || op_type == "ReLU6Grad" || op_type == "SigmoidGrad" ||
op_type == "HSigmoidGrad" || op_type == "HSwishGrad")) {
return NewPrimitiveC<ActivationGrad>(prim, inputs, quantType);
@@ -539,6 +559,12 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std:
return NewPrimitiveC<Adam>(prim, inputs, quantType);
} else if (op_type == "Assign") {
return NewPrimitiveC<Assign>(prim, inputs, quantType);
} else if (op_type == "AssignAdd") {
return NewPrimitiveC<AssignAdd>(prim, inputs, quantType);
} else if (op_type == "BinaryCrossEntropy") {
return NewPrimitiveC<BinaryCrossEntropy>(prim, inputs, quantType);
} else if (op_type == "BinaryCrossEntropyGrad") {
return NewPrimitiveC<BinaryCrossEntropyGrad>(prim, inputs, quantType);
#else
} else if (op_type == "Conv2DBackpropInput") {
return NewPrimitiveC<DeConv2D>(prim, inputs, quantType);
@@ -830,6 +856,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) {
return new PowerGrad(primitive);
case schema::PrimitiveType_Depend:
return new Depend(primitive);
case schema::PrimitiveType_ControlDepend:
return new ControlDepend(primitive);
case schema::PrimitiveType_FlattenGrad:
return new FlattenGrad(primitive);
case schema::PrimitiveType_NegGrad:
@@ -842,6 +870,17 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) {
return new Adam(primitive);
case schema::PrimitiveType_Assign:
return new Assign(primitive);
case schema::PrimitiveType_AssignAdd:
return new AssignAdd(primitive);
case schema::PrimitiveType_OnesLike:
return new OnesLike(primitive);
case schema::PrimitiveType_UnsortedSegmentSum:
return new UnsortedSegmentSum(primitive);
case schema::PrimitiveType_BinaryCrossEntropyGrad:
return new BinaryCrossEntropyGrad(primitive);
case schema::PrimitiveType_BinaryCrossEntropy:
return new BinaryCrossEntropy(primitive);

#endif
default:
MS_LOG(ERROR) << "Unsupported primitive type in Create : " << schema::EnumNamePrimitiveType(op_type);


+ 3
- 0
mindspore/lite/src/ops/reduce.cc View File

@@ -86,6 +86,9 @@ int Reduce::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp
MS_ASSERT(elem != nullptr);
attr->axes.emplace_back(elem->value());
}
} else {
int axes_item = GetValue<int>(value);
attr->axes.push_back(axes_item);
}
}
}


+ 3
- 0
mindspore/lite/src/ops/reshape.cc View File

@@ -62,6 +62,9 @@ int Reshape::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &in
MS_ASSERT(elem != nullptr);
attr->shape.emplace_back(static_cast<int>(elem->value()));
}
} else {
int dim = GetValue<int>(val);
attr->shape = {dim};
}
}
if (attr == nullptr) {


+ 7
- 5
mindspore/lite/src/ops/squeeze.cc View File

@@ -46,12 +46,14 @@ int Squeeze::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &in
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
attr->axis = GetValue<std::vector<int>>(prim.GetAttr("axis"));
this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "primitive value is nullptr";
return RET_ERROR;
if (prim.GetAttr("axis") == nullptr) {
MS_LOG(WARNING) << "get axis failed";
attr->axis = {0};
} else {
int axis = GetValue<int>(prim.GetAttr("axis"));
attr->axis = {axis};
}
this->primitive_->value.value = attr;
}
return RET_OK;
}


+ 28
- 0
mindspore/lite/src/ops/sub.cc View File

@@ -29,6 +29,34 @@ void Sub::SetActivationType(int activation_type) {
this->primitive_->value.AsSub()->activationType = (schema::ActivationType)activation_type;
}

int Sub::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_Sub;
}
if (this->primitive_->value.type != schema::PrimitiveType_Sub) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
delete this->primitive_;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::SubT();
if (attr == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
delete this->primitive_;
return RET_ERROR;
}
// todo: confirm the activationType
attr->activationType = schema::ActivationType_NO_ACTIVATION;
this->primitive_->value.value = attr;
}
return RET_OK;
}

#else

int Sub::GetActivationType() const { return this->primitive_->value_as_Sub()->activationType(); }


+ 1
- 0
mindspore/lite/src/ops/sub.h View File

@@ -32,6 +32,7 @@ class Sub : public Arithmetic {
Sub() = default;
explicit Sub(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
void SetActivationType(int activation_type);
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;

#else
Sub() = default;


+ 9
- 0
mindspore/lite/src/ops/tile.cc View File

@@ -52,6 +52,12 @@ int Tile::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &input
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
if (prim.GetAttr("dims") == nullptr) {
MS_LOG(WARNING) << "get dims failed";
attr->dims = {1};
} else {
attr->dims = GetValue<std::vector<int>>(prim.GetAttr("dims"));
}
if (inputs.size() == kAnfPopulaterTwo) {
auto inputNode = inputs[kAnfPopulaterOne];
MS_ASSERT(inputNode != nullptr);
@@ -68,6 +74,9 @@ int Tile::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &input
MS_ASSERT(elem != nullptr);
attr->multiples.emplace_back(elem->value());
}
} else {
int multiple = GetValue<int>(value);
attr->multiples = {multiple};
}
}
}


+ 100
- 0
mindspore/lite/src/ops/unsorted_segment_sum.cc View File

@@ -0,0 +1,100 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <memory>
#include "src/ops/unsorted_segment_sum.h"

namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE

int UnsortedSegmentSum::GetNumSegments() const { return this->primitive_->value.AsUnsortedSegmentSum()->numSegments; }

int UnsortedSegmentSum::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitive error";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_UnsortedSegmentSum;
}
if (this->primitive_->value.type != schema::PrimitiveType_UnsortedSegmentSum) {
MS_LOG(ERROR) << "UnSortedSegmentSum primitive value type : "
<< schema::EnumNamePrimitiveType(primitive_->value.type) << "is not equal"
<< schema::EnumNamePrimitiveType(schema::PrimitiveType_UnsortedSegmentSum);
delete this->primitive_;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
std::unique_ptr<schema::UnsortedSegmentSumT> attr = std::make_unique<schema::UnsortedSegmentSumT>();
if (inputs[2]->isa<ValueNode>()) {
ValuePtr value = inputs[2]->cast<ValueNodePtr>()->value();
attr->numSegments = GetValue<int>(value);
this->primitive_->value.value = attr.release();
}
}
return RET_OK;
}
#else
int UnsortedSegmentSum::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_UnsortedSegmentSum();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_UnsortedSegmentSum return nullptr";
return RET_ERROR;
}
int num_segments = attr->numSegments();
auto val_offset = schema::CreateUnsortedSegmentSum(*fbb, num_segments);
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_UnsortedSegmentSum, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}

int UnsortedSegmentSum::GetNumSegments() const {
int ret = this->primitive_->value_as_UnsortedSegmentSum()->numSegments();
return ret;
}
#endif
int UnsortedSegmentSum::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
// check inputs and outputs
if (inputs_.size() != 3) {
MS_LOG(ERROR) << "invalid inputs numbers";
return RET_ERROR;
}
if (outputs_.size() != 1) {
MS_LOG(ERROR) << "invalid outputs numbers";
return RET_ERROR;
}
Tensor *out = outputs_.front();
Tensor *x = inputs_.front();
Tensor *segment_id = inputs_[1];
std::vector<int> x_shape = x->shape();
std::vector<int> segment_id_shape = segment_id->shape();
int num_segments = GetNumSegments();
std::vector<int> output_shape;
output_shape.push_back(num_segments);
for (int index = segment_id_shape.size(); index < static_cast<int>(x_shape.size()); index++) {
output_shape.push_back(x_shape[index]);
}
out->set_shape(output_shape);
out->SetFormat(x->GetFormat());
out->set_data_type(x->data_type());
return RET_OK;
}
} // namespace lite
} // namespace mindspore

+ 44
- 0
mindspore/lite/src/ops/unsorted_segment_sum.h View File

@@ -0,0 +1,44 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <vector>
#include <set>
#include <cmath>
#include "src/ops/primitive_c.h"
#ifndef LITE_SRC_OPS_UNSORTED_SEGMENT_SUM_H_
#define LITE_SRC_OPS_UNSORTED_SEGMENT_SUM_H_
namespace mindspore {
namespace lite {
class UnsortedSegmentSum : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(UnsortedSegmentSum, PrimitiveC);
UnsortedSegmentSum() = default;
explicit UnsortedSegmentSum(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
int GetNumSegments() const;
#else
UnsortedSegmentSum() = default;

int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;

int GetNumSegments() const;
#endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
};
} // namespace lite
} // namespace mindspore
#endif // LITE_SRC_OPS_UNSORTED_SEGMENT_SUM_H_

+ 34
- 0
mindspore/lite/src/train/train_populate_parameter.cc View File

@@ -36,6 +36,9 @@
#include "src/ops/bn_grad.h"
#include "nnacl/fp32_grad/batch_norm.h"
#include "src/ops/adam.h"
#include "src/ops/oneslike.h"
#include "src/ops/binary_cross_entropy.h"
#include "src/ops/binary_cross_entropy_grad.h"

namespace mindspore::kernel {

@@ -76,6 +79,30 @@ OpParameter *PopulateApplyMomentumParameter(const mindspore::lite::PrimitiveC *p
return reinterpret_cast<OpParameter *>(p);
}

OpParameter *PopulateBCEParameter(const mindspore::lite::PrimitiveC *primitive) {
int32_t *reduction = reinterpret_cast<int32_t *>(malloc(sizeof(int32_t)));
if (reduction == nullptr) {
MS_LOG(ERROR) << "malloc reduction failed.";
return nullptr;
}
auto param =
reinterpret_cast<mindspore::lite::BinaryCrossEntropy *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
*reduction = param->GetReduction();
return reinterpret_cast<OpParameter *>(reduction);
}

OpParameter *PopulateBCEGradParameter(const mindspore::lite::PrimitiveC *primitive) {
int32_t *reduction = reinterpret_cast<int32_t *>(malloc(sizeof(int32_t)));
if (reduction == nullptr) {
MS_LOG(ERROR) << "malloc reduction failed.";
return nullptr;
}
auto param =
reinterpret_cast<mindspore::lite::BinaryCrossEntropyGrad *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
*reduction = param->GetReduction();
return reinterpret_cast<OpParameter *>(reduction);
}

OpParameter *PopulateAdamParameter(const mindspore::lite::PrimitiveC *primitive) {
if (primitive == nullptr) {
MS_LOG(ERROR) << "Primitive is nullptr when populating parameter for op.";
@@ -396,6 +423,13 @@ void PopulateTrainParameters() {
lite::Registry BNGradParameterRegistry(schema::PrimitiveType_BNGrad, PopulateBNGradParameter);
lite::Registry AdamParameterRegistry(schema::PrimitiveType_Adam, PopulateAdamParameter);
lite::Registry AssignParameterRegistry(schema::PrimitiveType_Assign, DefaultPopulateParameter);
lite::Registry AssignAddParameterRegistry(schema::PrimitiveType_AssignAdd, DefaultPopulateParameter);
lite::Registry BinaryCrossEntropyParameterRegistry(schema::PrimitiveType_BinaryCrossEntropy, PopulateBCEParameter);
lite::Registry BinaryCrossEntropyGradParameterRegistry(schema::PrimitiveType_BinaryCrossEntropyGrad,
PopulateBCEGradParameter);
lite::Registry OnesLikeParameterRegistry(schema::PrimitiveType_OnesLike, DefaultPopulateParameter);
lite::Registry UnsortedSegmentSumParameterRegistry(schema::PrimitiveType_UnsortedSegmentSum,
DefaultPopulateParameter);
}

} // namespace mindspore::kernel

+ 26
- 2
mindspore/lite/tools/anf_exporter/anf_exporter.cc View File

@@ -69,7 +69,8 @@ void AnfExporter::RemoveIfDepend(const CNodePtr &cnode) {
continue;
}
auto dependNode = utils::cast<CNodePtr>(inputNode);
if (IsPrimitiveCNode(dependNode, schema::PrimitiveType_Depend)) {
if (IsPrimitiveCNode(dependNode, schema::PrimitiveType_Depend) ||
IsPrimitiveCNode(dependNode, schema::PrimitiveType_ControlDepend)) {
hasDepend = true;
for (size_t j = 1; j < dependNode->inputs().size(); ++j) {
AnfNodePtr dependInputNode = dependNode->input(j);
@@ -209,6 +210,7 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool kee
if ((primitive_c->Type() == schema::PrimitiveType_TupleGetItem) ||
#ifdef SUPPORT_TRAIN
(primitive_c->Type() == schema::PrimitiveType_Depend) ||
(primitive_c->Type() == schema::PrimitiveType_ControlDepend) ||
#endif
(primitive_c->Type() == schema::PrimitiveType_MakeTuple)) {
continue;
@@ -402,7 +404,9 @@ int AnfExporter::ConvertInputValueNode(std::shared_ptr<AnfNode> input_anode,
paramTensor->dims = {1};
paramTensor->nodeType = schema::NodeType::NodeType_ValueNode;
auto data = value->cast<mindspore::Int32ImmPtr>();
paramTensor->data.emplace_back(data->value());
int real_data = GetValue<int32_t>(data);
paramTensor->data.resize(sizeof(int32_t));
memcpy(paramTensor->data.data(), &real_data, sizeof(int32_t));
node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size();
output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size());
meta_graphT->allTensors.emplace_back(std::move(paramTensor));
@@ -418,6 +422,14 @@ int AnfExporter::ConvertInputValueNode(std::shared_ptr<AnfNode> input_anode,
node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size();
output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size());
meta_graphT->allTensors.emplace_back(std::move(paramTensor));
} else if (value->isa<mindspore::Int>()) {
paramTensor->dataType = kNumberTypeInt32;
paramTensor->dims = {1};
paramTensor->nodeType = schema::NodeType_ValueNode;
paramTensor->data.emplace_back(kNumberTypeInt32);
node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size();
output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size());
meta_graphT->allTensors.emplace_back(std::move(paramTensor));
} else if (value->isa<mindspore::ValueSequeue>()) {
#ifndef SUPPORT_TRAIN
MS_LOG(DEBUG) << "Value type is ValueSequence.";
@@ -456,6 +468,18 @@ int AnfExporter::ConvertInputValueNode(std::shared_ptr<AnfNode> input_anode,
MS_LOG(ERROR) << "Value type is ValueSequence not supported - " << valueAbstract->type_name() << ".";
}
#endif
} else if (value->isa<mindspore::BoolImm>()) {
auto valueAbstract = valueNode->abstract();
auto abstractScalar = utils::cast<abstract::AbstractScalarPtr>(valueAbstract);
auto typePtr = abstractScalar->GetTypeTrack();
paramTensor->dataType = typePtr->type_id();
paramTensor->dims = {1};
paramTensor->nodeType = schema::NodeType_ValueNode;
auto data = value->cast<mindspore::BoolImmPtr>();
paramTensor->data.emplace_back(data->value());
node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size();
output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size());
meta_graphT->allTensors.emplace_back(std::move(paramTensor));
} else if (value->isa<Number>()) {
MS_LOG(INFO) << "Value is a number.";
return RET_OK;


Loading…
Cancel
Save