diff --git a/mindspore/lite/nnacl/random_standard_normal_parameter.h b/mindspore/lite/nnacl/random_standard_normal_parameter.h new file mode 100644 index 0000000000..7e2926bdee --- /dev/null +++ b/mindspore/lite/nnacl/random_standard_normal_parameter.h @@ -0,0 +1,27 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_NNACL_RNADOM_STANDARD_NORMAL_PARAMETER_H_ +#define MINDSPORE_LITE_NNACL_RNADOM_STANDARD_NORMAL_PARAMETER_H_ + +#include "nnacl/op_base.h" + +typedef struct RandomStandardNormalParam { + OpParameter op_parameter_; + int seed_; + int seed2_; +} RandomStandardNormalParam; + +#endif // MINDSPORE_LITE_NNACL_RNADOM_STANDARD_NORMAL_PARAMETER_H_ diff --git a/mindspore/lite/schema/model.fbs b/mindspore/lite/schema/model.fbs index 7d45162d36..719a10ffed 100644 --- a/mindspore/lite/schema/model.fbs +++ b/mindspore/lite/schema/model.fbs @@ -269,6 +269,7 @@ union PrimitiveType { NonZero, InvertPermutation, Size, + RandomStandardNormal, } enum QuantType: int { diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs index 94612d6170..732d3bdae2 100644 --- a/mindspore/lite/schema/ops.fbs +++ b/mindspore/lite/schema/ops.fbs @@ -1249,4 +1249,9 @@ table InvertPermutation { } table Size { -} \ No newline at end of file +} + +table RandomStandardNormal { + seed : int; + seed2 : int; +} diff --git a/mindspore/lite/src/ops/populate/random_standard_normal_populate.cc b/mindspore/lite/src/ops/populate/random_standard_normal_populate.cc new file mode 100644 index 0000000000..89fddd46d5 --- /dev/null +++ b/mindspore/lite/src/ops/populate/random_standard_normal_populate.cc @@ -0,0 +1,42 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/random_standard_normal.h" +#include "src/ops/primitive_c.h" +#include "src/ops/populate/populate_register.h" +#include "nnacl/random_standard_normal_parameter.h" + +namespace mindspore { +namespace lite { +OpParameter *PopulateRandomStandardNormalParameter(const mindspore::lite::PrimitiveC *primitive) { + RandomStandardNormalParam *random_parameter = + reinterpret_cast(malloc(sizeof(RandomStandardNormalParam))); + if (random_parameter == nullptr) { + MS_LOG(ERROR) << "malloc RandomStandardNormal parameter failed."; + return nullptr; + } + memset(random_parameter, 0, sizeof(RandomStandardNormalParam)); + random_parameter->op_parameter_.type_ = primitive->Type(); + auto param = + reinterpret_cast(const_cast(primitive)); + random_parameter->seed_ = param->GetSeed(); + random_parameter->seed2_ = param->GetSeed2(); + return reinterpret_cast(random_parameter); +} +Registry RandomStandardNormalParameterRegistry(schema::PrimitiveType_RandomStandardNormal, + PopulateRandomStandardNormalParameter); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/primitive_c.cc b/mindspore/lite/src/ops/primitive_c.cc index f04095a073..769b2eb7cc 100644 --- a/mindspore/lite/src/ops/primitive_c.cc +++ b/mindspore/lite/src/ops/primitive_c.cc @@ -165,6 +165,7 @@ #include "src/ops/gelu.h" #include "src/ops/gru.h" #include "src/ops/size.h" +#include "src/ops/random_standard_normal.h" #include "src/ops/invert_permutation.h" #ifdef SUPPORT_TRAIN @@ -1012,6 +1013,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { return new (std::nothrow) Size(primitive); case schema::PrimitiveType_InvertPermutation: return new (std::nothrow) InvertPermutation(primitive); + case schema::PrimitiveType_RandomStandardNormal: + return new (std::nothrow) RandomStandardNormal(primitive); #ifdef SUPPORT_TRAIN case schema::PrimitiveType_ActivationGrad: return new (std::nothrow) ActivationGrad(primitive); diff --git a/mindspore/lite/src/ops/random_standard_normal.cc b/mindspore/lite/src/ops/random_standard_normal.cc new file mode 100644 index 0000000000..d17b1984ed --- /dev/null +++ b/mindspore/lite/src/ops/random_standard_normal.cc @@ -0,0 +1,101 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/random_standard_normal.h" + +#ifndef PRIMITIVE_WRITEABLE +#include "src/ops/ops_register.h" +#endif + +namespace mindspore { +namespace lite { +#ifdef PRIMITIVE_WRITEABLE +int RandomStandardNormal::GetSeed() const { return this->primitive_->value.AsRandomStandardNormal()->seed; } + +int RandomStandardNormal::GetSeed2() const { return this->primitive_->value.AsRandomStandardNormal()->seed2; } + +int RandomStandardNormal::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + if (this->primitive_ == nullptr) { + this->primitive_ = new (std::nothrow) schema::PrimitiveT; + if (this->primitive_ == nullptr) { + MS_LOG(ERROR) << "new primitiveT failed"; + return RET_ERROR; + } + this->primitive_->value.type = schema::PrimitiveType_RandomStandardNormal; + } + if (this->primitive_->value.type != schema::PrimitiveType_RandomStandardNormal) { + MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; + return RET_ERROR; + } + if (this->primitive_->value.value == nullptr) { + auto attr = new (std::nothrow) schema::RandomStandardNormalT(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new primitiveT value failed"; + return RET_ERROR; + } + + this->primitive_->value.value = attr; + if (this->primitive_->value.value == nullptr) { + MS_LOG(ERROR) << "primitive value is nullptr"; + return RET_ERROR; + } + } + return RET_OK; +} +#else +int RandomStandardNormal::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_RandomStandardNormal(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_RandomStandardNormal return nullptr"; + return RET_ERROR; + } + auto val_offset = schema::CreateRandomStandardNormal(*fbb, attr->seed(), attr->seed2()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_RandomStandardNormal, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} + +int RandomStandardNormal::GetSeed() const { return this->primitive_->value_as_RandomStandardNormal()->seed(); } + +int RandomStandardNormal::GetSeed2() const { return this->primitive_->value_as_RandomStandardNormal()->seed2(); } + +PrimitiveC *RandomStandardNormalCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry RandomStandardNormalRegistry(schema::PrimitiveType_RandomStandardNormal, RandomStandardNormalCreator); +#endif + +int RandomStandardNormal::InferShape(std::vector inputs_, std::vector outputs_) { + if (!infer_flag()) { + return RET_INFER_INVALID; + } + auto input_data = static_cast(inputs_[0]->data_c()); + if (input_data == nullptr) { + return RET_INFER_INVALID; + } + auto input_num = inputs_[0]->ElementsNum(); + std::vector output_shape = {}; + for (int i = 0; i < input_num; i++) { + output_shape.push_back(input_data[i]); + } + outputs_[0]->set_shape(output_shape); + outputs_[0]->set_data_type(kNumberTypeFloat32); + return RET_OK; +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/random_standard_normal.h b/mindspore/lite/src/ops/random_standard_normal.h new file mode 100644 index 0000000000..5cd60748aa --- /dev/null +++ b/mindspore/lite/src/ops/random_standard_normal.h @@ -0,0 +1,46 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LITE_MINDSPORE_LITE_C_OPS_RANDOM_STANDARD_NORMAL_H_ +#define LITE_MINDSPORE_LITE_C_OPS_RANDOM_STANDARD_NORMAL_H_ + +#include +#include +#include +#include +#include "src/ops/primitive_c.h" + +namespace mindspore { +namespace lite { +class RandomStandardNormal : public PrimitiveC { + public: + RandomStandardNormal() = default; + ~RandomStandardNormal() = default; +#ifdef PRIMITIVE_WRITEABLE + MS_DECLARE_PARENT(RandomStandardNormal, PrimitiveC); + explicit RandomStandardNormal(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; +#else + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; +#endif + int InferShape(std::vector inputs_, std::vector outputs_) override; + int GetSeed() const; + int GetSeed2() const; +}; +} // namespace lite +} // namespace mindspore + +#endif // LITE_MINDSPORE_LITE_C_OPS_RANDOM_STANDARD_NORMAL_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/base/random_standard_normal.cc b/mindspore/lite/src/runtime/kernel/arm/base/random_standard_normal.cc new file mode 100644 index 0000000000..3e76c5bb9e --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/base/random_standard_normal.cc @@ -0,0 +1,56 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/base/random_standard_normal.h" +#include +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/tensorlist.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_RandomStandardNormal; + +namespace mindspore::kernel { + +int RandomStandardNormalCPUKernel::Init() { return RET_OK; } + +int RandomStandardNormalCPUKernel::ReSize() { return RET_OK; } + +int RandomStandardNormalCPUKernel::Run() { + size_t random_seed = 0; + if (param_->seed2_ != 0) { + random_seed = static_cast(param_->seed2_); + } else if (param_->seed_ != 0) { + random_seed = static_cast(param_->seed_); + } else { + random_seed = static_cast(clock()); + } + std::default_random_engine engine{random_seed}; + std::normal_distribution nums(0, 1.0); + auto all_data_nums = out_tensors_[0]->ElementsNum(); + auto out_data = out_tensors_[0]->data_c(); + MS_ASSERT(out_data != nullptr); + auto output = reinterpret_cast(out_data); + for (int i = 0; i < all_data_nums; ++i) { + output[i] = nums(engine); + } + return RET_OK; +} + +REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_RandomStandardNormal, LiteKernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/base/random_standard_normal.h b/mindspore/lite/src/runtime/kernel/arm/base/random_standard_normal.h new file mode 100644 index 0000000000..f7a54caa6c --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/base/random_standard_normal.h @@ -0,0 +1,46 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_RANDOM_STANDARD_NORMAL_BASE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_RANDOM_STANDARD_NORMAL_BASE_H_ + +#include +#include "src/lite_kernel.h" +#include "nnacl/random_standard_normal_parameter.h" + +using mindspore::lite::InnerContext; + +namespace mindspore::kernel { +class RandomStandardNormalCPUKernel : public LiteKernel { + public: + RandomStandardNormalCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const InnerContext *ctx, + const mindspore::lite::PrimitiveC *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive) { + param_ = reinterpret_cast(parameter); + } + ~RandomStandardNormalCPUKernel() override = default; + + int Init() override; + int ReSize() override; + int Run() override; + + protected: + RandomStandardNormalParam *param_ = nullptr; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_RANDOM_STANDARD_NORMAL_BASE_H_ diff --git a/mindspore/lite/tools/converter/parser/tf/tf_random_standard_normal_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_random_standard_normal_parser.cc new file mode 100644 index 0000000000..0192dd95b9 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_random_standard_normal_parser.cc @@ -0,0 +1,69 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_random_standard_normal_parser.h" +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" + +namespace mindspore { +namespace lite { +STATUS TFRandomStandardNormalParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, + int *output_size) { + MS_LOG(WARNING) << "TF RandomStandardNormalParser"; + if (primitiveC == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_NULL_PTR; + } + + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "New PrimitiveT failed"; + return RET_NULL_PTR; + } + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new attr failed"; + return RET_NULL_PTR; + } + tensorflow::AttrValue attr_value; + if (!TensorFlowUtils::FindAttrValue(tf_op, "seed", &attr_value)) { + MS_LOG(ERROR) << "The seed attr should be specified"; + return RET_ERROR; + } + attr->seed = attr_value.i(); + if (!TensorFlowUtils::FindAttrValue(tf_op, "seed2", &attr_value)) { + MS_LOG(ERROR) << "The seed2 attr should be specified"; + return RET_ERROR; + } + attr->seed2 = attr_value.i(); + primitive->value.type = schema::PrimitiveType_RandomStandardNormal; + primitive->value.value = attr.release(); + *primitiveC = PrimitiveC::Create(primitive.release()); + if (*primitiveC == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_ERROR; + } + *output_size = 1; + auto status = AddOpInput(tf_op, 0, inputs); + return status; +} +TFNodeRegistrar g_tfRandomStandardNormalParser("RandomStandardNormal", new TFRandomStandardNormalParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_random_standard_normal_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_random_standard_normal_parser.h new file mode 100644 index 0000000000..0e0990e392 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_random_standard_normal_parser.h @@ -0,0 +1,36 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_RANDOM_STANDARD_NORMAL_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_RANDOM_STANDARD_NORMAL_PARSER_H_ +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser.h" + +namespace mindspore { +namespace lite { +class TFRandomStandardNormalParser : public TFNodeParser { + public: + TFRandomStandardNormalParser() = default; + ~TFRandomStandardNormalParser() override = default; + + STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_RANDOM_STANDARD_NORMAL_PARSER_H_ diff --git a/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc b/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc index cf5c843e04..e166696b92 100644 --- a/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc @@ -236,6 +236,9 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An } lite_primitive->InferShape(input_tensors, output_tensors); auto primitive = lite_primitive.get(); + if (primitive->Type() == schema::PrimitiveType_RandomStandardNormal) { + return nullptr; + } MS_ASSERT(primitive != nullptr); MS_ASSERT(primitive->Type() != nullptr); auto func_pointer =