Browse Source

!11268 [MS_LITE] random op

From: @YeFeng_24
Reviewed-by: 
Signed-off-by:
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
a18faa3fc7
12 changed files with 436 additions and 1 deletions
  1. +27
    -0
      mindspore/lite/nnacl/random_standard_normal_parameter.h
  2. +1
    -0
      mindspore/lite/schema/model.fbs
  3. +6
    -1
      mindspore/lite/schema/ops.fbs
  4. +42
    -0
      mindspore/lite/src/ops/populate/random_standard_normal_populate.cc
  5. +3
    -0
      mindspore/lite/src/ops/primitive_c.cc
  6. +101
    -0
      mindspore/lite/src/ops/random_standard_normal.cc
  7. +46
    -0
      mindspore/lite/src/ops/random_standard_normal.h
  8. +56
    -0
      mindspore/lite/src/runtime/kernel/arm/base/random_standard_normal.cc
  9. +46
    -0
      mindspore/lite/src/runtime/kernel/arm/base/random_standard_normal.h
  10. +69
    -0
      mindspore/lite/tools/converter/parser/tf/tf_random_standard_normal_parser.cc
  11. +36
    -0
      mindspore/lite/tools/converter/parser/tf/tf_random_standard_normal_parser.h
  12. +3
    -0
      mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc

+ 27
- 0
mindspore/lite/nnacl/random_standard_normal_parameter.h View File

@@ -0,0 +1,27 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_NNACL_RNADOM_STANDARD_NORMAL_PARAMETER_H_
#define MINDSPORE_LITE_NNACL_RNADOM_STANDARD_NORMAL_PARAMETER_H_

#include "nnacl/op_base.h"

typedef struct RandomStandardNormalParam {
OpParameter op_parameter_;
int seed_;
int seed2_;
} RandomStandardNormalParam;

#endif // MINDSPORE_LITE_NNACL_RNADOM_STANDARD_NORMAL_PARAMETER_H_

+ 1
- 0
mindspore/lite/schema/model.fbs View File

@@ -269,6 +269,7 @@ union PrimitiveType {
NonZero,
InvertPermutation,
Size,
RandomStandardNormal,
}

enum QuantType: int {


+ 6
- 1
mindspore/lite/schema/ops.fbs View File

@@ -1249,4 +1249,9 @@ table InvertPermutation {
}

table Size {
}
}

table RandomStandardNormal {
seed : int;
seed2 : int;
}

+ 42
- 0
mindspore/lite/src/ops/populate/random_standard_normal_populate.cc View File

@@ -0,0 +1,42 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "src/ops/random_standard_normal.h"
#include "src/ops/primitive_c.h"
#include "src/ops/populate/populate_register.h"
#include "nnacl/random_standard_normal_parameter.h"

namespace mindspore {
namespace lite {
OpParameter *PopulateRandomStandardNormalParameter(const mindspore::lite::PrimitiveC *primitive) {
RandomStandardNormalParam *random_parameter =
reinterpret_cast<RandomStandardNormalParam *>(malloc(sizeof(RandomStandardNormalParam)));
if (random_parameter == nullptr) {
MS_LOG(ERROR) << "malloc RandomStandardNormal parameter failed.";
return nullptr;
}
memset(random_parameter, 0, sizeof(RandomStandardNormalParam));
random_parameter->op_parameter_.type_ = primitive->Type();
auto param =
reinterpret_cast<mindspore::lite::RandomStandardNormal *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
random_parameter->seed_ = param->GetSeed();
random_parameter->seed2_ = param->GetSeed2();
return reinterpret_cast<OpParameter *>(random_parameter);
}
Registry RandomStandardNormalParameterRegistry(schema::PrimitiveType_RandomStandardNormal,
PopulateRandomStandardNormalParameter);
} // namespace lite
} // namespace mindspore

+ 3
- 0
mindspore/lite/src/ops/primitive_c.cc View File

@@ -165,6 +165,7 @@
#include "src/ops/gelu.h"
#include "src/ops/gru.h"
#include "src/ops/size.h"
#include "src/ops/random_standard_normal.h"
#include "src/ops/invert_permutation.h"

#ifdef SUPPORT_TRAIN
@@ -1012,6 +1013,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) {
return new (std::nothrow) Size(primitive);
case schema::PrimitiveType_InvertPermutation:
return new (std::nothrow) InvertPermutation(primitive);
case schema::PrimitiveType_RandomStandardNormal:
return new (std::nothrow) RandomStandardNormal(primitive);
#ifdef SUPPORT_TRAIN
case schema::PrimitiveType_ActivationGrad:
return new (std::nothrow) ActivationGrad(primitive);


+ 101
- 0
mindspore/lite/src/ops/random_standard_normal.cc View File

@@ -0,0 +1,101 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "src/ops/random_standard_normal.h"

#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif

namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int RandomStandardNormal::GetSeed() const { return this->primitive_->value.AsRandomStandardNormal()->seed; }

int RandomStandardNormal::GetSeed2() const { return this->primitive_->value.AsRandomStandardNormal()->seed2; }

int RandomStandardNormal::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_RandomStandardNormal;
}
if (this->primitive_->value.type != schema::PrimitiveType_RandomStandardNormal) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::RandomStandardNormalT();
if (attr == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}

this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "primitive value is nullptr";
return RET_ERROR;
}
}
return RET_OK;
}
#else
int RandomStandardNormal::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_RandomStandardNormal();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_RandomStandardNormal return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateRandomStandardNormal(*fbb, attr->seed(), attr->seed2());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_RandomStandardNormal, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}

int RandomStandardNormal::GetSeed() const { return this->primitive_->value_as_RandomStandardNormal()->seed(); }

int RandomStandardNormal::GetSeed2() const { return this->primitive_->value_as_RandomStandardNormal()->seed2(); }

PrimitiveC *RandomStandardNormalCreator(const schema::Primitive *primitive) {
return PrimitiveC::NewPrimitiveC<RandomStandardNormal>(primitive);
}
Registry RandomStandardNormalRegistry(schema::PrimitiveType_RandomStandardNormal, RandomStandardNormalCreator);
#endif

int RandomStandardNormal::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
if (!infer_flag()) {
return RET_INFER_INVALID;
}
auto input_data = static_cast<int32_t *>(inputs_[0]->data_c());
if (input_data == nullptr) {
return RET_INFER_INVALID;
}
auto input_num = inputs_[0]->ElementsNum();
std::vector<int> output_shape = {};
for (int i = 0; i < input_num; i++) {
output_shape.push_back(input_data[i]);
}
outputs_[0]->set_shape(output_shape);
outputs_[0]->set_data_type(kNumberTypeFloat32);
return RET_OK;
}
} // namespace lite
} // namespace mindspore

+ 46
- 0
mindspore/lite/src/ops/random_standard_normal.h View File

@@ -0,0 +1,46 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef LITE_MINDSPORE_LITE_C_OPS_RANDOM_STANDARD_NORMAL_H_
#define LITE_MINDSPORE_LITE_C_OPS_RANDOM_STANDARD_NORMAL_H_

#include <vector>
#include <set>
#include <cmath>
#include <memory>
#include "src/ops/primitive_c.h"

namespace mindspore {
namespace lite {
class RandomStandardNormal : public PrimitiveC {
public:
RandomStandardNormal() = default;
~RandomStandardNormal() = default;
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(RandomStandardNormal, PrimitiveC);
explicit RandomStandardNormal(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
int GetSeed() const;
int GetSeed2() const;
};
} // namespace lite
} // namespace mindspore

#endif // LITE_MINDSPORE_LITE_C_OPS_RANDOM_STANDARD_NORMAL_H_

+ 56
- 0
mindspore/lite/src/runtime/kernel/arm/base/random_standard_normal.cc View File

@@ -0,0 +1,56 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "src/runtime/kernel/arm/base/random_standard_normal.h"
#include <random>
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "src/tensorlist.h"

using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_RandomStandardNormal;

namespace mindspore::kernel {

int RandomStandardNormalCPUKernel::Init() { return RET_OK; }

int RandomStandardNormalCPUKernel::ReSize() { return RET_OK; }

int RandomStandardNormalCPUKernel::Run() {
size_t random_seed = 0;
if (param_->seed2_ != 0) {
random_seed = static_cast<size_t>(param_->seed2_);
} else if (param_->seed_ != 0) {
random_seed = static_cast<size_t>(param_->seed_);
} else {
random_seed = static_cast<size_t>(clock());
}
std::default_random_engine engine{random_seed};
std::normal_distribution<double> nums(0, 1.0);
auto all_data_nums = out_tensors_[0]->ElementsNum();
auto out_data = out_tensors_[0]->data_c();
MS_ASSERT(out_data != nullptr);
auto output = reinterpret_cast<float *>(out_data);
for (int i = 0; i < all_data_nums; ++i) {
output[i] = nums(engine);
}
return RET_OK;
}

REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_RandomStandardNormal, LiteKernelCreator<RandomStandardNormalCPUKernel>)
} // namespace mindspore::kernel

+ 46
- 0
mindspore/lite/src/runtime/kernel/arm/base/random_standard_normal.h View File

@@ -0,0 +1,46 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_RANDOM_STANDARD_NORMAL_BASE_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_RANDOM_STANDARD_NORMAL_BASE_H_

#include <vector>
#include "src/lite_kernel.h"
#include "nnacl/random_standard_normal_parameter.h"

using mindspore::lite::InnerContext;

namespace mindspore::kernel {
class RandomStandardNormalCPUKernel : public LiteKernel {
public:
RandomStandardNormalCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {
param_ = reinterpret_cast<RandomStandardNormalParam *>(parameter);
}
~RandomStandardNormalCPUKernel() override = default;

int Init() override;
int ReSize() override;
int Run() override;

protected:
RandomStandardNormalParam *param_ = nullptr;
};
} // namespace mindspore::kernel

#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_RANDOM_STANDARD_NORMAL_BASE_H_

+ 69
- 0
mindspore/lite/tools/converter/parser/tf/tf_random_standard_normal_parser.cc View File

@@ -0,0 +1,69 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_random_standard_normal_parser.h"
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"

namespace mindspore {
namespace lite {
STATUS TFRandomStandardNormalParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs,
int *output_size) {
MS_LOG(WARNING) << "TF RandomStandardNormalParser";
if (primitiveC == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_NULL_PTR;
}

auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "New PrimitiveT failed";
return RET_NULL_PTR;
}
auto attr = std::make_unique<schema::RandomStandardNormalT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new attr failed";
return RET_NULL_PTR;
}
tensorflow::AttrValue attr_value;
if (!TensorFlowUtils::FindAttrValue(tf_op, "seed", &attr_value)) {
MS_LOG(ERROR) << "The seed attr should be specified";
return RET_ERROR;
}
attr->seed = attr_value.i();
if (!TensorFlowUtils::FindAttrValue(tf_op, "seed2", &attr_value)) {
MS_LOG(ERROR) << "The seed2 attr should be specified";
return RET_ERROR;
}
attr->seed2 = attr_value.i();
primitive->value.type = schema::PrimitiveType_RandomStandardNormal;
primitive->value.value = attr.release();
*primitiveC = PrimitiveC::Create(primitive.release());
if (*primitiveC == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_ERROR;
}
*output_size = 1;
auto status = AddOpInput(tf_op, 0, inputs);
return status;
}
TFNodeRegistrar g_tfRandomStandardNormalParser("RandomStandardNormal", new TFRandomStandardNormalParser());
} // namespace lite
} // namespace mindspore

+ 36
- 0
mindspore/lite/tools/converter/parser/tf/tf_random_standard_normal_parser.h View File

@@ -0,0 +1,36 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_RANDOM_STANDARD_NORMAL_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_RANDOM_STANDARD_NORMAL_PARSER_H_
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser.h"

namespace mindspore {
namespace lite {
class TFRandomStandardNormalParser : public TFNodeParser {
public:
TFRandomStandardNormalParser() = default;
~TFRandomStandardNormalParser() override = default;

STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_RANDOM_STANDARD_NORMAL_PARSER_H_

+ 3
- 0
mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc View File

@@ -236,6 +236,9 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An
}
lite_primitive->InferShape(input_tensors, output_tensors);
auto primitive = lite_primitive.get();
if (primitive->Type() == schema::PrimitiveType_RandomStandardNormal) {
return nullptr;
}
MS_ASSERT(primitive != nullptr);
MS_ASSERT(primitive->Type() != nullptr);
auto func_pointer =


Loading…
Cancel
Save