From: @YeFeng_24 Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -0,0 +1,27 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_NNACL_RNADOM_STANDARD_NORMAL_PARAMETER_H_ | |||||
| #define MINDSPORE_LITE_NNACL_RNADOM_STANDARD_NORMAL_PARAMETER_H_ | |||||
| #include "nnacl/op_base.h" | |||||
| typedef struct RandomStandardNormalParam { | |||||
| OpParameter op_parameter_; | |||||
| int seed_; | |||||
| int seed2_; | |||||
| } RandomStandardNormalParam; | |||||
| #endif // MINDSPORE_LITE_NNACL_RNADOM_STANDARD_NORMAL_PARAMETER_H_ | |||||
| @@ -269,6 +269,7 @@ union PrimitiveType { | |||||
| NonZero, | NonZero, | ||||
| InvertPermutation, | InvertPermutation, | ||||
| Size, | Size, | ||||
| RandomStandardNormal, | |||||
| } | } | ||||
| enum QuantType: int { | enum QuantType: int { | ||||
| @@ -1249,4 +1249,9 @@ table InvertPermutation { | |||||
| } | } | ||||
| table Size { | table Size { | ||||
| } | |||||
| } | |||||
| table RandomStandardNormal { | |||||
| seed : int; | |||||
| seed2 : int; | |||||
| } | |||||
| @@ -0,0 +1,42 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/ops/random_standard_normal.h" | |||||
| #include "src/ops/primitive_c.h" | |||||
| #include "src/ops/populate/populate_register.h" | |||||
| #include "nnacl/random_standard_normal_parameter.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| OpParameter *PopulateRandomStandardNormalParameter(const mindspore::lite::PrimitiveC *primitive) { | |||||
| RandomStandardNormalParam *random_parameter = | |||||
| reinterpret_cast<RandomStandardNormalParam *>(malloc(sizeof(RandomStandardNormalParam))); | |||||
| if (random_parameter == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc RandomStandardNormal parameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(random_parameter, 0, sizeof(RandomStandardNormalParam)); | |||||
| random_parameter->op_parameter_.type_ = primitive->Type(); | |||||
| auto param = | |||||
| reinterpret_cast<mindspore::lite::RandomStandardNormal *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||||
| random_parameter->seed_ = param->GetSeed(); | |||||
| random_parameter->seed2_ = param->GetSeed2(); | |||||
| return reinterpret_cast<OpParameter *>(random_parameter); | |||||
| } | |||||
| Registry RandomStandardNormalParameterRegistry(schema::PrimitiveType_RandomStandardNormal, | |||||
| PopulateRandomStandardNormalParameter); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -165,6 +165,7 @@ | |||||
| #include "src/ops/gelu.h" | #include "src/ops/gelu.h" | ||||
| #include "src/ops/gru.h" | #include "src/ops/gru.h" | ||||
| #include "src/ops/size.h" | #include "src/ops/size.h" | ||||
| #include "src/ops/random_standard_normal.h" | |||||
| #include "src/ops/invert_permutation.h" | #include "src/ops/invert_permutation.h" | ||||
| #ifdef SUPPORT_TRAIN | #ifdef SUPPORT_TRAIN | ||||
| @@ -1012,6 +1013,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { | |||||
| return new (std::nothrow) Size(primitive); | return new (std::nothrow) Size(primitive); | ||||
| case schema::PrimitiveType_InvertPermutation: | case schema::PrimitiveType_InvertPermutation: | ||||
| return new (std::nothrow) InvertPermutation(primitive); | return new (std::nothrow) InvertPermutation(primitive); | ||||
| case schema::PrimitiveType_RandomStandardNormal: | |||||
| return new (std::nothrow) RandomStandardNormal(primitive); | |||||
| #ifdef SUPPORT_TRAIN | #ifdef SUPPORT_TRAIN | ||||
| case schema::PrimitiveType_ActivationGrad: | case schema::PrimitiveType_ActivationGrad: | ||||
| return new (std::nothrow) ActivationGrad(primitive); | return new (std::nothrow) ActivationGrad(primitive); | ||||
| @@ -0,0 +1,101 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/ops/random_standard_normal.h" | |||||
| #ifndef PRIMITIVE_WRITEABLE | |||||
| #include "src/ops/ops_register.h" | |||||
| #endif | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| int RandomStandardNormal::GetSeed() const { return this->primitive_->value.AsRandomStandardNormal()->seed; } | |||||
| int RandomStandardNormal::GetSeed2() const { return this->primitive_->value.AsRandomStandardNormal()->seed2; } | |||||
| int RandomStandardNormal::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||||
| if (this->primitive_ == nullptr) { | |||||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||||
| if (this->primitive_ == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| this->primitive_->value.type = schema::PrimitiveType_RandomStandardNormal; | |||||
| } | |||||
| if (this->primitive_->value.type != schema::PrimitiveType_RandomStandardNormal) { | |||||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| auto attr = new (std::nothrow) schema::RandomStandardNormalT(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| this->primitive_->value.value = attr; | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| MS_LOG(ERROR) << "primitive value is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| #else | |||||
| int RandomStandardNormal::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||||
| MS_ASSERT(nullptr != primitive); | |||||
| MS_ASSERT(nullptr != fbb); | |||||
| auto attr = primitive->value_as_RandomStandardNormal(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "value_as_RandomStandardNormal return nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto val_offset = schema::CreateRandomStandardNormal(*fbb, attr->seed(), attr->seed2()); | |||||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_RandomStandardNormal, val_offset.o); | |||||
| fbb->Finish(prim_offset); | |||||
| return RET_OK; | |||||
| } | |||||
| int RandomStandardNormal::GetSeed() const { return this->primitive_->value_as_RandomStandardNormal()->seed(); } | |||||
| int RandomStandardNormal::GetSeed2() const { return this->primitive_->value_as_RandomStandardNormal()->seed2(); } | |||||
| PrimitiveC *RandomStandardNormalCreator(const schema::Primitive *primitive) { | |||||
| return PrimitiveC::NewPrimitiveC<RandomStandardNormal>(primitive); | |||||
| } | |||||
| Registry RandomStandardNormalRegistry(schema::PrimitiveType_RandomStandardNormal, RandomStandardNormalCreator); | |||||
| #endif | |||||
| int RandomStandardNormal::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | |||||
| if (!infer_flag()) { | |||||
| return RET_INFER_INVALID; | |||||
| } | |||||
| auto input_data = static_cast<int32_t *>(inputs_[0]->data_c()); | |||||
| if (input_data == nullptr) { | |||||
| return RET_INFER_INVALID; | |||||
| } | |||||
| auto input_num = inputs_[0]->ElementsNum(); | |||||
| std::vector<int> output_shape = {}; | |||||
| for (int i = 0; i < input_num; i++) { | |||||
| output_shape.push_back(input_data[i]); | |||||
| } | |||||
| outputs_[0]->set_shape(output_shape); | |||||
| outputs_[0]->set_data_type(kNumberTypeFloat32); | |||||
| return RET_OK; | |||||
| } | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,46 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef LITE_MINDSPORE_LITE_C_OPS_RANDOM_STANDARD_NORMAL_H_ | |||||
| #define LITE_MINDSPORE_LITE_C_OPS_RANDOM_STANDARD_NORMAL_H_ | |||||
| #include <vector> | |||||
| #include <set> | |||||
| #include <cmath> | |||||
| #include <memory> | |||||
| #include "src/ops/primitive_c.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class RandomStandardNormal : public PrimitiveC { | |||||
| public: | |||||
| RandomStandardNormal() = default; | |||||
| ~RandomStandardNormal() = default; | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| MS_DECLARE_PARENT(RandomStandardNormal, PrimitiveC); | |||||
| explicit RandomStandardNormal(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||||
| #else | |||||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||||
| #endif | |||||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||||
| int GetSeed() const; | |||||
| int GetSeed2() const; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // LITE_MINDSPORE_LITE_C_OPS_RANDOM_STANDARD_NORMAL_H_ | |||||
| @@ -0,0 +1,56 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/runtime/kernel/arm/base/random_standard_normal.h" | |||||
| #include <random> | |||||
| #include "src/kernel_registry.h" | |||||
| #include "include/errorcode.h" | |||||
| #include "src/tensorlist.h" | |||||
| using mindspore::lite::KernelRegistrar; | |||||
| using mindspore::lite::RET_ERROR; | |||||
| using mindspore::lite::RET_OK; | |||||
| using mindspore::schema::PrimitiveType_RandomStandardNormal; | |||||
| namespace mindspore::kernel { | |||||
| int RandomStandardNormalCPUKernel::Init() { return RET_OK; } | |||||
| int RandomStandardNormalCPUKernel::ReSize() { return RET_OK; } | |||||
| int RandomStandardNormalCPUKernel::Run() { | |||||
| size_t random_seed = 0; | |||||
| if (param_->seed2_ != 0) { | |||||
| random_seed = static_cast<size_t>(param_->seed2_); | |||||
| } else if (param_->seed_ != 0) { | |||||
| random_seed = static_cast<size_t>(param_->seed_); | |||||
| } else { | |||||
| random_seed = static_cast<size_t>(clock()); | |||||
| } | |||||
| std::default_random_engine engine{random_seed}; | |||||
| std::normal_distribution<double> nums(0, 1.0); | |||||
| auto all_data_nums = out_tensors_[0]->ElementsNum(); | |||||
| auto out_data = out_tensors_[0]->data_c(); | |||||
| MS_ASSERT(out_data != nullptr); | |||||
| auto output = reinterpret_cast<float *>(out_data); | |||||
| for (int i = 0; i < all_data_nums; ++i) { | |||||
| output[i] = nums(engine); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_RandomStandardNormal, LiteKernelCreator<RandomStandardNormalCPUKernel>) | |||||
| } // namespace mindspore::kernel | |||||
| @@ -0,0 +1,46 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_RANDOM_STANDARD_NORMAL_BASE_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_RANDOM_STANDARD_NORMAL_BASE_H_ | |||||
| #include <vector> | |||||
| #include "src/lite_kernel.h" | |||||
| #include "nnacl/random_standard_normal_parameter.h" | |||||
| using mindspore::lite::InnerContext; | |||||
| namespace mindspore::kernel { | |||||
| class RandomStandardNormalCPUKernel : public LiteKernel { | |||||
| public: | |||||
| RandomStandardNormalCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, const InnerContext *ctx, | |||||
| const mindspore::lite::PrimitiveC *primitive) | |||||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive) { | |||||
| param_ = reinterpret_cast<RandomStandardNormalParam *>(parameter); | |||||
| } | |||||
| ~RandomStandardNormalCPUKernel() override = default; | |||||
| int Init() override; | |||||
| int ReSize() override; | |||||
| int Run() override; | |||||
| protected: | |||||
| RandomStandardNormalParam *param_ = nullptr; | |||||
| }; | |||||
| } // namespace mindspore::kernel | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_RANDOM_STANDARD_NORMAL_BASE_H_ | |||||
| @@ -0,0 +1,69 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/converter/parser/tf/tf_random_standard_normal_parser.h" | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include <map> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS TFRandomStandardNormalParser::Parse(const tensorflow::NodeDef &tf_op, | |||||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, | |||||
| int *output_size) { | |||||
| MS_LOG(WARNING) << "TF RandomStandardNormalParser"; | |||||
| if (primitiveC == nullptr || output_size == nullptr) { | |||||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "New PrimitiveT failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto attr = std::make_unique<schema::RandomStandardNormalT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new attr failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| tensorflow::AttrValue attr_value; | |||||
| if (!TensorFlowUtils::FindAttrValue(tf_op, "seed", &attr_value)) { | |||||
| MS_LOG(ERROR) << "The seed attr should be specified"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| attr->seed = attr_value.i(); | |||||
| if (!TensorFlowUtils::FindAttrValue(tf_op, "seed2", &attr_value)) { | |||||
| MS_LOG(ERROR) << "The seed2 attr should be specified"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| attr->seed2 = attr_value.i(); | |||||
| primitive->value.type = schema::PrimitiveType_RandomStandardNormal; | |||||
| primitive->value.value = attr.release(); | |||||
| *primitiveC = PrimitiveC::Create(primitive.release()); | |||||
| if (*primitiveC == nullptr) { | |||||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| *output_size = 1; | |||||
| auto status = AddOpInput(tf_op, 0, inputs); | |||||
| return status; | |||||
| } | |||||
| TFNodeRegistrar g_tfRandomStandardNormalParser("RandomStandardNormal", new TFRandomStandardNormalParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,36 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_RANDOM_STANDARD_NORMAL_PARSER_H_ | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_RANDOM_STANDARD_NORMAL_PARSER_H_ | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include <map> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tf/tf_node_parser.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class TFRandomStandardNormalParser : public TFNodeParser { | |||||
| public: | |||||
| TFRandomStandardNormalParser() = default; | |||||
| ~TFRandomStandardNormalParser() override = default; | |||||
| STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_RANDOM_STANDARD_NORMAL_PARSER_H_ | |||||
| @@ -236,6 +236,9 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An | |||||
| } | } | ||||
| lite_primitive->InferShape(input_tensors, output_tensors); | lite_primitive->InferShape(input_tensors, output_tensors); | ||||
| auto primitive = lite_primitive.get(); | auto primitive = lite_primitive.get(); | ||||
| if (primitive->Type() == schema::PrimitiveType_RandomStandardNormal) { | |||||
| return nullptr; | |||||
| } | |||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| MS_ASSERT(primitive->Type() != nullptr); | MS_ASSERT(primitive->Type() != nullptr); | ||||
| auto func_pointer = | auto func_pointer = | ||||