| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_CORE_OPERATOR_OPS_H_ | |||||
| #define MINDSPORE_CORE_OPERATOR_OPS_H_ | |||||
| #ifndef MINDSPORE_CORE_BASE_CORE_OPS_H_ | |||||
| #define MINDSPORE_CORE_BASE_CORE_OPS_H_ | |||||
| #include <iostream> | #include <iostream> | ||||
| #include <string> | #include <string> | ||||
| @@ -182,6 +182,7 @@ inline const PrimitivePtr kPrimReverseV2 = std::make_shared<Primitive>("ReverseV | |||||
| inline const PrimitivePtr kPrimReverseSequence = std::make_shared<Primitive>("ReverseSequence"); | inline const PrimitivePtr kPrimReverseSequence = std::make_shared<Primitive>("ReverseSequence"); | ||||
| inline const PrimitivePtr kPrimRank = std::make_shared<Primitive>("Rank"); | inline const PrimitivePtr kPrimRank = std::make_shared<Primitive>("Rank"); | ||||
| inline const PrimitivePtr kPrimResizeBilinear = std::make_shared<Primitive>("ResizeBilinear"); | inline const PrimitivePtr kPrimResizeBilinear = std::make_shared<Primitive>("ResizeBilinear"); | ||||
| inline const PrimitivePtr kPrimResizeGrad = std::make_shared<Primitive>("ResizeGrad"); | |||||
| // NN | // NN | ||||
| inline const PrimitivePtr kPrimAdam = std::make_shared<Primitive>("Adam"); | inline const PrimitivePtr kPrimAdam = std::make_shared<Primitive>("Adam"); | ||||
| @@ -245,7 +246,6 @@ inline const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropInput = | |||||
| std::make_shared<Primitive>("DepthwiseConv2dNativeBackpropInput"); | std::make_shared<Primitive>("DepthwiseConv2dNativeBackpropInput"); | ||||
| inline const PrimitivePtr kPrimDetectionPostProcess = std::make_shared<Primitive>("DetectionPostProcess"); | inline const PrimitivePtr kPrimDetectionPostProcess = std::make_shared<Primitive>("DetectionPostProcess"); | ||||
| inline const PrimitivePtr kPrimBiasAdd = std::make_shared<Primitive>("BiasAdd"); | inline const PrimitivePtr kPrimBiasAdd = std::make_shared<Primitive>("BiasAdd"); | ||||
| inline const PrimitivePtr kPrimBiasGrad = std::make_shared<Primitive>("BiasGrad"); | |||||
| inline const PrimitivePtr kPrimBiasAddGrad = std::make_shared<Primitive>("BiasAddGrad"); | inline const PrimitivePtr kPrimBiasAddGrad = std::make_shared<Primitive>("BiasAddGrad"); | ||||
| inline const PrimitivePtr kPrimBiasSubGrad = std::make_shared<Primitive>("BiasSubGrad"); | inline const PrimitivePtr kPrimBiasSubGrad = std::make_shared<Primitive>("BiasSubGrad"); | ||||
| inline const PrimitivePtr kPrimBinaryCrossEntropy = std::make_shared<Primitive>("BinaryCrossEntropy"); | inline const PrimitivePtr kPrimBinaryCrossEntropy = std::make_shared<Primitive>("BinaryCrossEntropy"); | ||||
| @@ -390,6 +390,7 @@ inline const PrimitivePtr kPrimRound = std::make_shared<Primitive>("Round"); | |||||
| inline const PrimitivePtr kPrimExp = std::make_shared<Primitive>("Exp"); | inline const PrimitivePtr kPrimExp = std::make_shared<Primitive>("Exp"); | ||||
| inline const PrimitivePtr kPrimLog = std::make_shared<Primitive>("Log"); | inline const PrimitivePtr kPrimLog = std::make_shared<Primitive>("Log"); | ||||
| inline const PrimitivePtr kPrimRsqrt = std::make_shared<Primitive>("Rsqrt"); | inline const PrimitivePtr kPrimRsqrt = std::make_shared<Primitive>("Rsqrt"); | ||||
| inline const PrimitivePtr kPrimRsqrtGrad = std::make_shared<Primitive>("RsqrtGrad"); | |||||
| inline const PrimitivePtr kPrimSplitV = std::make_shared<Primitive>("SplitV"); | inline const PrimitivePtr kPrimSplitV = std::make_shared<Primitive>("SplitV"); | ||||
| inline const PrimitivePtr kPrimLinSpace = std::make_shared<Primitive>("LinSpace"); | inline const PrimitivePtr kPrimLinSpace = std::make_shared<Primitive>("LinSpace"); | ||||
| inline const PrimitivePtr kPrimNonMaxSuppression = std::make_shared<Primitive>("NonMaxSuppression"); | inline const PrimitivePtr kPrimNonMaxSuppression = std::make_shared<Primitive>("NonMaxSuppression"); | ||||
| @@ -551,4 +552,4 @@ using DoSignaturePrimitivePtr = std::shared_ptr<DoSignaturePrimitive>; | |||||
| } // namespace prim | } // namespace prim | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CORE_OPERATOR_OPS_H_ | |||||
| #endif // MINDSPORE_CORE_BASE_CORE_OPS_H_ | |||||
| @@ -0,0 +1,43 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ops/grad/layer_norm_grad.h" | |||||
| #include "ops/op_utils.h" | |||||
| #include "utils/check_convert_utils.h" | |||||
| namespace mindspore { | |||||
| namespace ops { | |||||
| void LayerNormGrad::Init(const int64_t begin_norm_axis, const int64_t begin_params_axis) { | |||||
| this->set_begin_norm_axis(begin_norm_axis); | |||||
| this->set_begin_params_axis(begin_params_axis); | |||||
| } | |||||
| void LayerNormGrad::set_begin_norm_axis(const int64_t begin_norm_axis) { | |||||
| this->AddAttr(kBeginNormAxis, MakeValue(begin_norm_axis)); | |||||
| } | |||||
| void LayerNormGrad::set_begin_params_axis(const int64_t begin_params_axis) { | |||||
| this->AddAttr(kBeginParamsAxis, MakeValue(begin_params_axis)); | |||||
| } | |||||
| int64_t LayerNormGrad::get_begin_norm_axis() const { | |||||
| auto value_ptr = this->GetAttr(kBeginNormAxis); | |||||
| return GetValue<int64_t>(value_ptr); | |||||
| } | |||||
| int64_t LayerNormGrad::get_begin_params_axis() const { | |||||
| auto value_ptr = this->GetAttr(kBeginParamsAxis); | |||||
| return GetValue<int64_t>(value_ptr); | |||||
| } | |||||
| REGISTER_PRIMITIVE_C(kNameLayerNormGrad, LayerNormGrad); | |||||
| } // namespace ops | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,43 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CORE_OPS_GRAD_LAYER_NORM_GRAD_H_ | |||||
| #define MINDSPORE_CORE_OPS_GRAD_LAYER_NORM_GRAD_H_ | |||||
| #include <string> | |||||
| #include "ops/primitive_c.h" | |||||
| #include "abstract/abstract_value.h" | |||||
| #include "utils/check_convert_utils.h" | |||||
| namespace mindspore { | |||||
| namespace ops { | |||||
| constexpr auto kNameLayerNormGrad = "LayerNormGrad"; | |||||
| class LayerNormGrad : public PrimitiveC { | |||||
| public: | |||||
| LayerNormGrad() : PrimitiveC(kNameLayerNormGrad) {} | |||||
| explicit LayerNormGrad(const std::string k_name) : PrimitiveC(k_name) {} | |||||
| ~LayerNormGrad() = default; | |||||
| MS_DECLARE_PARENT(LayerNormGrad, PrimitiveC); | |||||
| void Init(const int64_t begin_norm_axis = 1, const int64_t begin_params_axis = 1); | |||||
| void set_begin_norm_axis(const int64_t begin_norm_axis); | |||||
| void set_begin_params_axis(const int64_t begin_params_axis); | |||||
| int64_t get_begin_norm_axis() const; | |||||
| int64_t get_begin_params_axis() const; | |||||
| }; | |||||
| } // namespace ops | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CORE_OPS_GRAD_LAYER_NORM_GRAD_H_ | |||||
| @@ -0,0 +1,52 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ops/grad/resize_grad.h" | |||||
| #include <map> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "ops/op_utils.h" | |||||
| #include "utils/check_convert_utils.h" | |||||
| #include "abstract/primitive_infer_map.h" | |||||
| namespace mindspore { | |||||
| namespace ops { | |||||
| void ResizeGrad::Init(const ResizeMethod method, const bool align_corners) { | |||||
| this->set_method(method); | |||||
| this->set_align_corners(align_corners); | |||||
| } | |||||
| void ResizeGrad::set_method(const ResizeMethod method) { | |||||
| auto swi = (int64_t)method; | |||||
| this->AddAttr(kMethod, MakeValue(swi)); | |||||
| } | |||||
| void ResizeGrad::set_align_corners(const bool align_corners) { this->AddAttr(kAlignCorners, MakeValue(align_corners)); } | |||||
| ResizeMethod ResizeGrad::get_method() const { | |||||
| auto value_ptr = GetAttr(kMethod); | |||||
| return ResizeMethod(GetValue<int64_t>(value_ptr)); | |||||
| } | |||||
| bool ResizeGrad::get_align_corners() const { | |||||
| auto value_ptr = GetAttr(kAlignCorners); | |||||
| return GetValue<bool>(value_ptr); | |||||
| } | |||||
| REGISTER_PRIMITIVE_C(kNameResizeGrad, ResizeGrad); | |||||
| } // namespace ops | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,46 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CORE_OPS_GRAD_RESIZE_GRAD_H_ | |||||
| #define MINDSPORE_CORE_OPS_GRAD_RESIZE_GRAD_H_ | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "ops/primitive_c.h" | |||||
| #include "abstract/abstract_value.h" | |||||
| #include "utils/check_convert_utils.h" | |||||
| namespace mindspore { | |||||
| namespace ops { | |||||
| constexpr auto kNameResizeGrad = "ResizeGrad"; | |||||
| class ResizeGrad : public PrimitiveC { | |||||
| public: | |||||
| ResizeGrad() : PrimitiveC(kNameResizeGrad) {} | |||||
| ~ResizeGrad() = default; | |||||
| MS_DECLARE_PARENT(ResizeGrad, PrimitiveC); | |||||
| void Init(const ResizeMethod method, const bool align_corners); | |||||
| void set_method(const ResizeMethod method); | |||||
| void set_align_corners(const bool align_corners); | |||||
| ResizeMethod get_method() const; | |||||
| bool get_align_corners() const; | |||||
| }; | |||||
| AbstractBasePtr ResizeGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args); | |||||
| using PrimResizeGradPtr = std::shared_ptr<ResizeGrad>; | |||||
| } // namespace ops | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CORE_OPS_GRAD_RESIZE_GRAD_H_ | |||||
| @@ -0,0 +1,26 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ops/grad/rsqrt_grad.h" | |||||
| #include "ops/op_utils.h" | |||||
| #include "utils/check_convert_utils.h" | |||||
| #include "abstract/primitive_infer_map.h" | |||||
| namespace mindspore { | |||||
| namespace ops { | |||||
| REGISTER_PRIMITIVE_C(kNameRsqrtGrad, RsqrtGrad); | |||||
| } // namespace ops | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,36 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CORE_OPS_GRAD_RSQRT_GRAD_H_ | |||||
| #define MINDSPORE_CORE_OPS_GRAD_RSQRT_GRAD_H_ | |||||
| #include "ops/primitive_c.h" | |||||
| #include "abstract/abstract_value.h" | |||||
| #include "utils/check_convert_utils.h" | |||||
| namespace mindspore { | |||||
| namespace ops { | |||||
| constexpr auto kNameRsqrtGrad = "RsqrtGrad"; | |||||
| class RsqrtGrad : public PrimitiveC { | |||||
| public: | |||||
| RsqrtGrad() : PrimitiveC(kNameRsqrtGrad) { InitIOName({"out_backprop", "input"}, {"output"}); } | |||||
| ~RsqrtGrad() = default; | |||||
| MS_DECLARE_PARENT(RsqrtGrad, PrimitiveC); | |||||
| void Init() {} | |||||
| }; | |||||
| } // namespace ops | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CORE_OPS_GRAD_RSQRT_GRAD_H_ | |||||
| @@ -0,0 +1,26 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ops/grad/sqrt_grad.h" | |||||
| #include "ops/op_utils.h" | |||||
| #include "utils/check_convert_utils.h" | |||||
| #include "abstract/primitive_infer_map.h" | |||||
| namespace mindspore { | |||||
| namespace ops { | |||||
| REGISTER_PRIMITIVE_C(kNameSqrtGrad, SqrtGrad); | |||||
| } // namespace ops | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,36 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CORE_OPS_GRAD_SQRT_GRAD_H_ | |||||
| #define MINDSPORE_CORE_OPS_GRAD_SQRT_GRAD_H_ | |||||
| #include "ops/primitive_c.h" | |||||
| #include "abstract/abstract_value.h" | |||||
| #include "utils/check_convert_utils.h" | |||||
| namespace mindspore { | |||||
| namespace ops { | |||||
| constexpr auto kNameSqrtGrad = "SqrtGrad"; | |||||
| class SqrtGrad : public PrimitiveC { | |||||
| public: | |||||
| SqrtGrad() : PrimitiveC(kNameSqrtGrad) { InitIOName({"out_backprop", "input"}, {"output"}); } | |||||
| ~SqrtGrad() = default; | |||||
| MS_DECLARE_PARENT(SqrtGrad, PrimitiveC); | |||||
| void Init() {} | |||||
| }; | |||||
| } // namespace ops | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CORE_OPS_GRAD_SQRT_GRAD_H_ | |||||
| @@ -15,13 +15,20 @@ | |||||
| """densenet_train_export.""" | """densenet_train_export.""" | ||||
| import sys | import sys | ||||
| import os | |||||
| import numpy as np | import numpy as np | ||||
| from train_utils import SaveInOut, TrainWrap | from train_utils import SaveInOut, TrainWrap | ||||
| from official.cv.densenet121.src.network.densenet import DenseNet121 | |||||
| import mindspore.common.dtype as mstype | import mindspore.common.dtype as mstype | ||||
| from mindspore import context, Tensor, nn | from mindspore import context, Tensor, nn | ||||
| from mindspore.train.serialization import export | from mindspore.train.serialization import export | ||||
| sys.path.append(os.environ['CLOUD_MODEL_ZOO'] + 'official/cv/densenet121/') | |||||
| #pylint: disable=wrong-import-position | |||||
| from official.cv.densenet121.src.network.densenet import DenseNet121 | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU", save_graphs=False) | context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU", save_graphs=False) | ||||
| n = DenseNet121(num_classes=10) | n = DenseNet121(num_classes=10) | ||||
| @@ -12,7 +12,7 @@ | |||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| """mobilenetv2_train_export.""" | |||||
| """resnet_train_export""" | |||||
| import sys | import sys | ||||
| import numpy as np | import numpy as np | ||||
| @@ -0,0 +1,39 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """vgg_train_export.""" | |||||
| import sys | |||||
| import numpy as np | |||||
| from train_utils import SaveInOut, TrainWrap | |||||
| from official.cv.vgg16.src.vgg import vgg16 | |||||
| import mindspore.common.dtype as mstype | |||||
| from mindspore import context, Tensor, nn | |||||
| from mindspore.train.serialization import export | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU", save_graphs=False) | |||||
| batch = 2 | |||||
| n = vgg16(num_classes=10) | |||||
| loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False) | |||||
| optimizer = nn.Momentum(n.trainable_params(), 0.01, 0.9, use_nesterov=False) | |||||
| net = TrainWrap(n, loss_fn, optimizer) | |||||
| x = Tensor(np.random.randn(batch, 3, 224, 224), mstype.float32) | |||||
| label = Tensor(np.zeros([batch, 10]).astype(np.float32)) | |||||
| export(net, x, label, file_name="mindir/vgg_train", file_format='MINDIR') | |||||
| if len(sys.argv) > 1: | |||||
| SaveInOut(sys.argv[1] + "vgg", x, label, n, net) | |||||
| @@ -0,0 +1,42 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """inceptionv4_train_export""" | |||||
| import sys | |||||
| import numpy as np | |||||
| from train_utils import SaveInOut, TrainWrap | |||||
| from official.cv.xception.src.Xception import Xception | |||||
| import mindspore.common.dtype as mstype | |||||
| from mindspore import context, Tensor, nn | |||||
| from mindspore.train.serialization import export | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU", save_graphs=False) | |||||
| n = Xception(num_classes=1000) | |||||
| n.dropout = nn.Dropout(keep_prob=1.0) | |||||
| loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False) | |||||
| optimizer = nn.SGD(n.trainable_params(), learning_rate=0.01, momentum=0.9, dampening=0.0, weight_decay=0.0, | |||||
| nesterov=True, loss_scale=1.0) | |||||
| net = TrainWrap(n, loss_fn, optimizer) | |||||
| batch = 2 | |||||
| x = Tensor(np.random.randn(batch, 3, 299, 299), mstype.float32) | |||||
| label = Tensor(np.zeros([batch, 1000]).astype(np.float32)) | |||||
| export(net, x, label, file_name="mindir/xception_train", file_format='MINDIR') | |||||
| if len(sys.argv) > 1: | |||||
| SaveInOut(sys.argv[1] + "xception", x, label, n, net) | |||||
| @@ -8,5 +8,7 @@ effnet_tune | |||||
| resnet | resnet | ||||
| googlenet | googlenet | ||||
| nin | nin | ||||
| #shufflenetv2 | |||||
| #densenet | |||||
| densenet | |||||
| shufflenetv2 | |||||
| vgg noarm32 | |||||
| xception | |||||
| @@ -2,10 +2,9 @@ | |||||
| display_usage() | display_usage() | ||||
| { | { | ||||
| echo "Usage: prepare.sh [-d mindspore_docker] [-r release.tar.gz] [-i]" | |||||
| echo "Usage: prepare.sh [-d mindspore_docker] [-i]" | |||||
| echo "Options:" | echo "Options:" | ||||
| echo " -d docker where mindspore is installed. If no docker is provided script will use local python" | echo " -d docker where mindspore is installed. If no docker is provided script will use local python" | ||||
| echo " -r release tarball" | |||||
| echo " -i create input and output files" | echo " -i create input and output files" | ||||
| } | } | ||||
| @@ -20,9 +19,6 @@ checkopts() | |||||
| d) | d) | ||||
| DOCKER=$OPTARG | DOCKER=$OPTARG | ||||
| ;; | ;; | ||||
| r) | |||||
| TARBALL=$OPTARG | |||||
| ;; | |||||
| i) | i) | ||||
| TRAIN_IO="train_io/" | TRAIN_IO="train_io/" | ||||
| ;; | ;; | ||||
| @@ -55,16 +51,6 @@ echo ' ' > ${export_result_file} | |||||
| CLOUD_MODEL_ZOO=../../../../model_zoo/ | CLOUD_MODEL_ZOO=../../../../model_zoo/ | ||||
| checkopts "$@" | checkopts "$@" | ||||
| if [ "$TARBALL" == "" ]; then | |||||
| file=$(ls ../../../../output/mindspore-lite-*-train-linux-x64.tar.gz) | |||||
| if [ -f ${file} ]; then | |||||
| TARBALL=${file} | |||||
| else | |||||
| echo "release.tar.gz was not found" | |||||
| display_usage | |||||
| exit 1 | |||||
| fi | |||||
| fi | |||||
| if [ -z "${DOCKER}" ]; then | if [ -z "${DOCKER}" ]; then | ||||
| echo "MindSpore docker was not provided, attempting to run locally" | echo "MindSpore docker was not provided, attempting to run locally" | ||||
| @@ -76,13 +62,14 @@ if [ ! -z "${TRAIN_IO}" ]; then | |||||
| fi | fi | ||||
| while read line; do | while read line; do | ||||
| model_name=${line} | |||||
| LFS=" " read -r -a line_array <<< ${line} | |||||
| model_name=${line_array[0]} | |||||
| if [[ $model_name == \#* ]]; then | if [[ $model_name == \#* ]]; then | ||||
| continue | |||||
| continue | |||||
| fi | fi | ||||
| echo 'exporting' ${model_name} | echo 'exporting' ${model_name} | ||||
| if [ ! -z "${DOCKER}" ]; then | if [ ! -z "${DOCKER}" ]; then | ||||
| docker run -w $PWD --runtime=nvidia -v /home/$USER:/home/$USER --privileged=true ${DOCKER} /bin/bash -c "PYTHONPATH=${CLOUD_MODEL_ZOO} python models/${model_name}_train_export.py ${TRAIN_IO} && chmod 444 mindir/${model_name}_train.mindir" | |||||
| docker run -w $PWD --runtime=nvidia -v /home/$USER:/home/$USER --privileged=true ${DOCKER} /bin/bash -c "CLOUD_MODEL_ZOO=${CLOUD_MODEL_ZOO} PYTHONPATH=${CLOUD_MODEL_ZOO} python models/${model_name}_train_export.py ${TRAIN_IO} && chmod 444 mindir/${model_name}_train.mindir" | |||||
| else | else | ||||
| PYTHONPATH=${CLOUD_MODEL_ZOO} python models/${model_name}_train_export.py ${TRAIN_IO} | PYTHONPATH=${CLOUD_MODEL_ZOO} python models/${model_name}_train_export.py ${TRAIN_IO} | ||||
| fi | fi | ||||
| @@ -2,8 +2,12 @@ BASE_DIR=$(realpath ../../../../) | |||||
| APP:=bin/net_runner | APP:=bin/net_runner | ||||
| MSLIB:=mindspore-lite | MSLIB:=mindspore-lite | ||||
| LMDLIB:=-lminddata-lite -ljpeg | LMDLIB:=-lminddata-lite -ljpeg | ||||
| LHIAILIB:=-lhiai_ir_build -lhiai_ir -lhiai | |||||
| MSDIR:=$(realpath package-$(TARGET)/lib) | MSDIR:=$(realpath package-$(TARGET)/lib) | ||||
| ifneq ("$(wildcard $(MSDIR)/libhiai.so)","") | |||||
| LHIAILIB:=-lhiai_ir_build -lhiai_ir -lhiai | |||||
| else | |||||
| LHIAILIB:= | |||||
| endif | |||||
| SRC:=src/net_runner.cc | SRC:=src/net_runner.cc | ||||
| OBJ:=$(SRC:.cc=.o) | OBJ:=$(SRC:.cc=.o) | ||||
| @@ -96,7 +96,8 @@ void NetRunner::InitAndFigureInputs() { | |||||
| context.device_list_[0].device_type_ = mindspore::lite::DT_CPU; | context.device_list_[0].device_type_ = mindspore::lite::DT_CPU; | ||||
| context.thread_num_ = 2; | context.thread_num_ = 2; | ||||
| loop_ = mindspore::session::TrainLoop::CreateTrainLoop(ms_file_, &context); | |||||
| auto session = mindspore::session::TrainSession::CreateSession(ms_file_, &context); | |||||
| loop_ = mindspore::session::TrainLoop::CreateTrainLoop(session, &context); | |||||
| session_ = loop_->train_session(); | session_ = loop_->train_session(); | ||||
| MS_ASSERT(nullptr != session_); | MS_ASSERT(nullptr != session_); | ||||
| @@ -0,0 +1,5 @@ | |||||
| *.mindir | |||||
| *.ms | |||||
| msl | |||||
| package-* | |||||
| dataset | |||||
| @@ -40,11 +40,11 @@ class TrainLoop { | |||||
| public: | public: | ||||
| /// \brief Static method to create a TrainLoop object | /// \brief Static method to create a TrainLoop object | ||||
| /// | /// | ||||
| /// \param[in] filename Filename to read flatbuffer from | |||||
| /// \param[in] train_session Train session object as return from CreateSession\CreateTransferSession API | |||||
| /// \param[in] context Defines the context of the session to be created | /// \param[in] context Defines the context of the session to be created | ||||
| /// | /// | ||||
| /// \return Pointer of MindSpore Lite TrainLoop | /// \return Pointer of MindSpore Lite TrainLoop | ||||
| static TrainLoop *CreateTrainLoop(const std::string &model_filename, lite::Context *context, int batch_size = -1); | |||||
| static TrainLoop *CreateTrainLoop(session::TrainSession *train_session, lite::Context *context, int batch_size = -1); | |||||
| /// \brief Class destructor | /// \brief Class destructor | ||||
| virtual ~TrainLoop() = default; | virtual ~TrainLoop() = default; | ||||
| @@ -44,11 +44,40 @@ constexpr int RET_EXIT = 2; | |||||
| class TrainLoopCallBack { | class TrainLoopCallBack { | ||||
| public: | public: | ||||
| virtual ~TrainLoopCallBack() = default; | virtual ~TrainLoopCallBack() = default; | ||||
| /// \brief This method is called once before the network executing | |||||
| /// | |||||
| /// \param[in] cb_data info about current execution | |||||
| virtual void Begin(const TrainLoopCallBackData &cb_data) {} | virtual void Begin(const TrainLoopCallBackData &cb_data) {} | ||||
| /// \brief This method is called once following the network execution | |||||
| /// | |||||
| /// \param[in] cb_data info about current execution | |||||
| virtual void End(const TrainLoopCallBackData &cb_data) {} | virtual void End(const TrainLoopCallBackData &cb_data) {} | ||||
| /// \brief This method is called at the beginning of each epoch | |||||
| /// | |||||
| /// \param[in] cb_data info about current execution | |||||
| virtual void EpochBegin(const TrainLoopCallBackData &cb_data) {} | virtual void EpochBegin(const TrainLoopCallBackData &cb_data) {} | ||||
| /// \brief This method is called after the run of each epoch | |||||
| /// | |||||
| /// \param[in] cb_data info about current execution | |||||
| /// | |||||
| /// \return indication if to continue in the train loop: | |||||
| /// RET_CONTINUE -- continue training | |||||
| /// RET_STOP_TRAINING -- stop training (e.g., due to achieved accuracy) | |||||
| /// RET_EXIT -- Exit training (due to error of some sort) | |||||
| virtual int EpochEnd(const TrainLoopCallBackData &cb_data) { return RET_CONTINUE; } | virtual int EpochEnd(const TrainLoopCallBackData &cb_data) { return RET_CONTINUE; } | ||||
| /// \brief This method is called at the beginning of each step | |||||
| /// | |||||
| /// \param[in] cb_data info about current execution | |||||
| virtual void StepBegin(const TrainLoopCallBackData &cb_data) {} | virtual void StepBegin(const TrainLoopCallBackData &cb_data) {} | ||||
| /// \brief This method is called after each step is ran | |||||
| /// | |||||
| /// \param[in] cb_data info about current execution | |||||
| virtual void StepEnd(const TrainLoopCallBackData &cb_data) {} | virtual void StepEnd(const TrainLoopCallBackData &cb_data) {} | ||||
| }; | }; | ||||
| @@ -142,6 +142,14 @@ public class TrainSession { | |||||
| return this.setLearningRate(this.sessionPtr, learning_rate); | return this.setLearningRate(this.sessionPtr, learning_rate); | ||||
| } | } | ||||
| public boolean setupVirtualBatch(int virtualBatchMultiplier, float learningRate, float momentum) { | |||||
| return this.setupVirtualBatch(this.sessionPtr, virtualBatchMultiplier, learningRate, momentum); | |||||
| } | |||||
| public boolean setupVirtualBatch(int virtualBatchMultiplier) { | |||||
| return this.setupVirtualBatch(this.sessionPtr, virtualBatchMultiplier, -1.0f, -1.0f); | |||||
| } | |||||
| private native long createSession(String modelFilename, long msConfigPtr); | private native long createSession(String modelFilename, long msConfigPtr); | ||||
| private native void bindThread(long sessionPtr, boolean if_bind); | private native void bindThread(long sessionPtr, boolean if_bind); | ||||
| @@ -175,4 +183,6 @@ public class TrainSession { | |||||
| private native boolean isEval(long sessionPtr); | private native boolean isEval(long sessionPtr); | ||||
| private native boolean setLearningRate(long sessionPtr, float learning_rate); | private native boolean setLearningRate(long sessionPtr, float learning_rate); | ||||
| private native boolean setupVirtualBatch(long sessionPtr, int virtualBatchMultiplier, float learningRate, float momentum); | |||||
| } | } | ||||
| @@ -303,3 +303,18 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_TrainSession_setLe | |||||
| auto ret = train_session_ptr->SetLearningRate(learning_rate); | auto ret = train_session_ptr->SetLearningRate(learning_rate); | ||||
| return (jboolean)(ret == mindspore::lite::RET_OK); | return (jboolean)(ret == mindspore::lite::RET_OK); | ||||
| } | } | ||||
| extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_TrainSession_setupVirtualBatch(JNIEnv *env, jobject thiz, | |||||
| jlong session_ptr, | |||||
| jint virtualBatchMultiplier, | |||||
| jfloat learningRate, | |||||
| jfloat momentum) { | |||||
| auto *session_pointer = reinterpret_cast<void *>(session_ptr); | |||||
| if (session_pointer == nullptr) { | |||||
| MS_LOGE("Session pointer from java is nullptr"); | |||||
| return (jboolean) false; | |||||
| } | |||||
| auto *train_session_ptr = static_cast<mindspore::session::TrainSession *>(session_pointer); | |||||
| auto ret = train_session_ptr->SetupVirtualBatch(virtualBatchMultiplier, learningRate, momentum); | |||||
| return (jboolean)(ret == mindspore::lite::RET_OK); | |||||
| } | |||||
| @@ -244,6 +244,7 @@ set(LITE_KERNEL_SRC | |||||
| ${LITE_DIR}/nnacl/infer/hashtable_lookup_infer.c | ${LITE_DIR}/nnacl/infer/hashtable_lookup_infer.c | ||||
| ${LITE_DIR}/nnacl/infer/invert_permutation_infer.c | ${LITE_DIR}/nnacl/infer/invert_permutation_infer.c | ||||
| ${LITE_DIR}/nnacl/infer/layer_norm_infer.c | ${LITE_DIR}/nnacl/infer/layer_norm_infer.c | ||||
| ${LITE_DIR}/nnacl/infer/layer_norm_grad_infer.c | |||||
| ${LITE_DIR}/nnacl/infer/lin_space_infer.c | ${LITE_DIR}/nnacl/infer/lin_space_infer.c | ||||
| ${LITE_DIR}/nnacl/infer/lsh_projection_infer.c | ${LITE_DIR}/nnacl/infer/lsh_projection_infer.c | ||||
| ${LITE_DIR}/nnacl/infer/lstm_infer.c | ${LITE_DIR}/nnacl/infer/lstm_infer.c | ||||
| @@ -65,11 +65,10 @@ void LayerNormGammaAndBeta(float *dst, const float *src, const float *gamma_data | |||||
| } | } | ||||
| int LayerNorm(const float *src_data, const float *gamma_data, const float *beta_data, float *dst_data, | int LayerNorm(const float *src_data, const float *gamma_data, const float *beta_data, float *dst_data, | ||||
| LayerNormParameter *param, size_t task_id) { | |||||
| LayerNormParameter *param, float *out_mean, float *out_deno, size_t task_id) { | |||||
| if (src_data == NULL || dst_data == NULL || gamma_data == NULL || beta_data == NULL) { | if (src_data == NULL || dst_data == NULL || gamma_data == NULL || beta_data == NULL) { | ||||
| return NNACL_NULL_PTR; | return NNACL_NULL_PTR; | ||||
| } | } | ||||
| int step = UP_DIV(param->norm_outer_size_, param->op_parameter_.thread_num_); | int step = UP_DIV(param->norm_outer_size_, param->op_parameter_.thread_num_); | ||||
| int thread_end = MSMIN((task_id + 1) * step, param->norm_outer_size_); | int thread_end = MSMIN((task_id + 1) * step, param->norm_outer_size_); | ||||
| for (int i = task_id * step; i < thread_end; i++) { | for (int i = task_id * step; i < thread_end; i++) { | ||||
| @@ -79,7 +78,10 @@ int LayerNorm(const float *src_data, const float *gamma_data, const float *beta_ | |||||
| float square_mean = 0.0f; | float square_mean = 0.0f; | ||||
| LayerNormMeanAndSquare(src_norm, param->norm_inner_size_, &mean, &square_mean); | LayerNormMeanAndSquare(src_norm, param->norm_inner_size_, &mean, &square_mean); | ||||
| const float deno = 1 / sqrtf(square_mean - mean * mean + param->epsilon_); | const float deno = 1 / sqrtf(square_mean - mean * mean + param->epsilon_); | ||||
| if ((out_mean != NULL) && (out_deno != NULL)) { | |||||
| out_mean[i] = mean; | |||||
| out_deno[i] = deno; | |||||
| } | |||||
| if (param->norm_outer_size_ <= param->params_outer_size_) { | if (param->norm_outer_size_ <= param->params_outer_size_) { | ||||
| for (int x = 0; x < param->norm_inner_size_ / param->params_inner_size_; x++) { | for (int x = 0; x < param->norm_inner_size_ / param->params_inner_size_; x++) { | ||||
| const float *src_param = src_norm + x * param->params_inner_size_; | const float *src_param = src_norm + x * param->params_inner_size_; | ||||
| @@ -13,8 +13,8 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_LITE_NNACL_FP32_LAYER_NORM_H_ | |||||
| #define MINDSPORE_LITE_NNACL_FP32_LAYER_NORM_H_ | |||||
| #ifndef MINDSPORE_LITE_NNACL_FP32_LAYER_NORM_FP32_H_ | |||||
| #define MINDSPORE_LITE_NNACL_FP32_LAYER_NORM_FP32_H_ | |||||
| #include "nnacl/op_base.h" | #include "nnacl/op_base.h" | ||||
| #include "nnacl/layer_norm_parameter.h" | #include "nnacl/layer_norm_parameter.h" | ||||
| @@ -24,9 +24,9 @@ extern "C" { | |||||
| #endif | #endif | ||||
| int LayerNorm(const float *src_data, const float *gamma_data, const float *beta_data, float *dst_data, | int LayerNorm(const float *src_data, const float *gamma_data, const float *beta_data, float *dst_data, | ||||
| LayerNormParameter *param, size_t task_id); | |||||
| LayerNormParameter *param, float *out_mean, float *out_deno, size_t task_id); | |||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| } | } | ||||
| #endif | #endif | ||||
| #endif // MINDSPORE_LITE_NNACL_FP32_LAYER_NORM_H_ | |||||
| #endif // MINDSPORE_LITE_NNACL_FP32_LAYER_NORM_FP32_H_ | |||||
| @@ -95,3 +95,18 @@ int HSigmoidGrad(float *src0, float *src1, size_t length, float *dst) { | |||||
| } | } | ||||
| return NNACL_OK; | return NNACL_OK; | ||||
| } | } | ||||
| int EluGrad(float *src0, float *src1, size_t length, float *dst, float alpha) { | |||||
| for (size_t i = 0; i < length; ++i) { | |||||
| dst[i] = (src1[i] > 0.0f ? src0[i] : alpha * expm1(src1[i]) * src0[i]); | |||||
| } | |||||
| return NNACL_OK; | |||||
| } | |||||
| int GeluGrad(float *src0, float *src1, size_t length, float *dst) { | |||||
| for (size_t i = 0; i < length; ++i) { | |||||
| dst[i] = src0[i] * ((0.5 * (1.0 + erf(src1[i] / 1.4142135623730951))) + | |||||
| (src1[i] * exp(-0.5 * src1[i] * src1[i]) / 2.5066282746)); | |||||
| } | |||||
| return NNACL_OK; | |||||
| } | |||||
| @@ -37,6 +37,8 @@ int SigmoidGrad(float *src0, float *src1, size_t length, float *dst); | |||||
| int TanhGrad(float *src0, float *src1, size_t length, float *dst); | int TanhGrad(float *src0, float *src1, size_t length, float *dst); | ||||
| int HSwishGrad(float *src0, float *src1, size_t length, float *dst); | int HSwishGrad(float *src0, float *src1, size_t length, float *dst); | ||||
| int HSigmoidGrad(float *src0, float *src1, size_t length, float *dst); | int HSigmoidGrad(float *src0, float *src1, size_t length, float *dst); | ||||
| int EluGrad(float *src0, float *src1, size_t length, float *dst, float alpha); | |||||
| int GeluGrad(float *src0, float *src1, size_t length, float *dst); | |||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| } | } | ||||
| @@ -16,6 +16,7 @@ | |||||
| #include "nnacl/fp32_grad/arithmetic_grad.h" | #include "nnacl/fp32_grad/arithmetic_grad.h" | ||||
| #include <string.h> | #include <string.h> | ||||
| #include <math.h> | |||||
| #include "nnacl/fp32_grad/utils.h" | #include "nnacl/fp32_grad/utils.h" | ||||
| #include "nnacl/errorcode.h" | #include "nnacl/errorcode.h" | ||||
| @@ -137,3 +138,17 @@ void MinimumByAxes(const float *input0, const float *input1, const float *dy, co | |||||
| } while (NextIndex(num_dims, dy_dims, input_iter)); | } while (NextIndex(num_dims, dy_dims, input_iter)); | ||||
| } | } | ||||
| } | } | ||||
| int ElementSqrtGrad(const float *in1, const float *in2, float *out, const int element_size) { | |||||
| for (int i = 0; i < element_size; i++) { | |||||
| out[i] = 0.5f * in2[i] / in1[i]; | |||||
| } | |||||
| return NNACL_OK; | |||||
| } | |||||
| int ElementRsqrtGrad(const float *in1, const float *in2, float *out, const int element_size) { | |||||
| for (int i = 0; i < element_size; i++) { | |||||
| out[i] = -0.5f * in2[i] * in1[i] * in1[1] * in1[i]; | |||||
| } | |||||
| return NNACL_OK; | |||||
| } | |||||
| @@ -28,6 +28,9 @@ void MaximumByAxes(const float *input0, const float *input1, const float *dy, co | |||||
| const int *input1_dims, const int *dy_dims, float *output0, float *output1, int num_dims); | const int *input1_dims, const int *dy_dims, float *output0, float *output1, int num_dims); | ||||
| void MinimumByAxes(const float *input0, const float *input1, const float *dy, const int *input0_dims, | void MinimumByAxes(const float *input0, const float *input1, const float *dy, const int *input0_dims, | ||||
| const int *input1_dims, const int *dy_dims, float *output0, float *output1, int num_dims); | const int *input1_dims, const int *dy_dims, float *output0, float *output1, int num_dims); | ||||
| int ElementSqrtGrad(const float *in1, const float *in2, float *out, const int element_size); | |||||
| int ElementRsqrtGrad(const float *in1, const float *in2, float *out, const int element_size); | |||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| } | } | ||||
| #endif | #endif | ||||
| @@ -0,0 +1,57 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "nnacl/fp32_grad/layernorm_grad.h" | |||||
| #include <stddef.h> | |||||
| #include <math.h> | |||||
| void LayerNormGrad(const float *x, const float *dy, const float *var, const float *mean, const float *gamma, | |||||
| int param_num, int param_size, int block_num, int block_size, float *dx, float *dg, float *db) { | |||||
| // var is actually 1/sqrf(var)-> var^0.5 | |||||
| const float *var_sqrt_rev = var; | |||||
| for (size_t i = 0; i < param_num; ++i) { | |||||
| float dgamma = 0.0f; | |||||
| float dbeta = 0.0f; | |||||
| for (size_t j = i; j < param_size * param_num; j += param_num) { | |||||
| int norm_shift = (int)(j / block_size); | |||||
| dgamma += dy[j] * var_sqrt_rev[norm_shift] * (x[j] - mean[norm_shift]); | |||||
| dbeta += dy[j]; | |||||
| } | |||||
| dg[i] = dgamma; | |||||
| db[i] = dbeta; | |||||
| } | |||||
| for (size_t i = 0; i < block_num; ++i) { | |||||
| float sum1 = 0.0f; | |||||
| float sum2 = 0.0f; | |||||
| float sum3 = 0.0f; | |||||
| for (size_t j = i * block_size; j < (i + 1) * block_size; ++j) { | |||||
| int param_shift = j % param_num; | |||||
| int norm_shift = (int)(j / block_size); | |||||
| float dxm = x[j] - mean[norm_shift]; | |||||
| float dyg = dy[j] * gamma[param_shift]; | |||||
| sum1 += -0.5f * dyg * dxm * var_sqrt_rev[norm_shift] * var_sqrt_rev[norm_shift] * var_sqrt_rev[norm_shift]; | |||||
| sum3 += -2.0f * dxm; | |||||
| } | |||||
| for (size_t j = i * block_size; j < (i + 1) * block_size; ++j) { | |||||
| int param_shift = j % param_num; | |||||
| int norm_shift = (int)(j / block_size); | |||||
| float var_sqrt = var_sqrt_rev[norm_shift]; | |||||
| float dx1 = dy[j] * gamma[param_shift] * var_sqrt; | |||||
| float dx2 = sum1 * 2.0f / block_size * (x[j] - mean[norm_shift]); | |||||
| float dx3 = (-1.0f * var_sqrt * sum2 + (1.0f / block_size) * sum1 * sum3) * (1.0f / block_size); | |||||
| dx[j] = dx1 + dx2 + dx3; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,29 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_NNACL_FP32_GRAD_LAYERNORM_GRAD_H_ | |||||
| #define MINDSPORE_LITE_NNACL_FP32_GRAD_LAYERNORM_GRAD_H_ | |||||
| #ifdef __cplusplus | |||||
| extern "C" { | |||||
| #endif | |||||
| void LayerNormGrad(const float *x, const float *dy, const float *var, const float *mean, const float *gamma, | |||||
| int param_num, int param_size, int block_num, int block_size, float *dx, float *dg, float *db); | |||||
| #ifdef __cplusplus | |||||
| } | |||||
| #endif | |||||
| #endif // MINDSPORE_LITE_NNACL_FP32_GRAD_LAYERNORM_GRAD_H_ | |||||
| @@ -0,0 +1,27 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_NNACL_FP32_GRAD_LAYERNORMGRAD_PARAMETER_H_ | |||||
| #define MINDSPORE_LITE_NNACL_FP32_GRAD_LAYERNORMGRAD_PARAMETER_H_ | |||||
| #include "nnacl/op_base.h" | |||||
| typedef struct LayerNormGradParameter { | |||||
| OpParameter op_parameter_; | |||||
| int begin_norm_axis_; | |||||
| int begin_params_axis_; | |||||
| } LayerNormGradParameter; | |||||
| #endif // MINDSPORE_LITE_NNACL_FP32_GRAD_LAYERNORMGRAD_PARAMETER_H_ | |||||
| @@ -0,0 +1,84 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "nnacl/fp32_grad/resize_grad.h" | |||||
| #include <math.h> | |||||
| void ResizeNearestNeighborGrad(float *in_addr, float *out_addr, int batch_size, int channel, | |||||
| ResizeGradParameter *param) { | |||||
| bool align_corners = param->align_corners_; | |||||
| size_t in_hw_size = param->in_width_ * param->in_height_; | |||||
| size_t out_hw_size = param->out_width_ * param->out_height_; | |||||
| for (int32_t b = 0; b < batch_size; ++b) { | |||||
| for (size_t i = 0; i < in_hw_size; ++i) { | |||||
| size_t in_y = i / param->in_width_; | |||||
| size_t in_x = i % param->in_width_; | |||||
| for (int32_t c = 0; c < channel; ++c) { | |||||
| size_t out_y = MSMIN( | |||||
| (align_corners) ? (size_t)roundf(in_y * param->height_scale_) : (size_t)floorf(in_y * param->height_scale_), | |||||
| param->out_height_ - 1); | |||||
| size_t out_x = MSMIN( | |||||
| (align_corners) ? (size_t)roundf(in_x * param->width_scale_) : (size_t)floorf(in_x * param->width_scale_), | |||||
| param->out_width_ - 1); | |||||
| size_t out_offset = out_y * (param->out_width_ * channel) + (out_x * channel) + c; | |||||
| size_t in_offset = in_y * (param->in_width_ * channel) + (in_x * channel) + c; | |||||
| out_addr[out_offset] += in_addr[in_offset]; | |||||
| } | |||||
| } | |||||
| out_addr += out_hw_size * channel; | |||||
| in_addr += in_hw_size * channel; | |||||
| } | |||||
| } | |||||
| void ResizeBiLinearGrad(float *in_addr, float *out_addr, int batch_size, int channel, ResizeGradParameter *param) { | |||||
| size_t in_hw_size = param->in_width_ * param->in_height_; | |||||
| size_t out_hw_size = param->out_width_ * param->out_height_; | |||||
| for (int32_t b = 0; b < batch_size; ++b) { | |||||
| for (size_t i = 0; i < in_hw_size; ++i) { | |||||
| size_t h = i / param->in_width_; | |||||
| size_t w = i % param->in_width_; | |||||
| for (int32_t c = 0; c < channel; ++c) { | |||||
| float in_y = (float)h * param->height_scale_; | |||||
| size_t top_y_index = MSMAX((size_t)(floorf(in_y)), (size_t)(0)); | |||||
| size_t bottom_y_index = MSMIN((size_t)(ceilf(in_y)), param->out_height_ - 1); | |||||
| float y_lerp = in_y - floorf(in_y); | |||||
| float inverse_y_lerp = 1.0 - y_lerp; | |||||
| float in_x = (float)w * param->width_scale_; | |||||
| size_t left_x_index = MSMAX((size_t)(floorf(in_x)), (size_t)(0)); | |||||
| size_t right_x_index = MSMIN((size_t)(ceilf(in_x)), param->out_width_ - 1); | |||||
| float x_lerp = in_x - floorf(in_x); | |||||
| float inverse_x_lerp = 1.0 - x_lerp; | |||||
| size_t in_offset = h * (param->in_width_ * channel) + (w * channel) + c; | |||||
| size_t out_offset_top_y_left_x = top_y_index * (param->out_width_ * channel) + (left_x_index * channel) + c; | |||||
| size_t out_offset_top_y_right_x = top_y_index * (param->out_width_ * channel) + (right_x_index * channel) + c; | |||||
| size_t out_offset_bottom_y_left_x = | |||||
| bottom_y_index * (param->out_width_ * channel) + (left_x_index * channel) + c; | |||||
| size_t out_offset_bottom_y_right_x = | |||||
| bottom_y_index * (param->out_width_ * channel) + (right_x_index * channel) + c; | |||||
| out_addr[out_offset_top_y_left_x] += in_addr[in_offset] * (float)(inverse_y_lerp * inverse_x_lerp); | |||||
| out_addr[out_offset_top_y_right_x] += in_addr[in_offset] * (float)(inverse_y_lerp * x_lerp); | |||||
| out_addr[out_offset_bottom_y_left_x] += in_addr[in_offset] * (float)(y_lerp * inverse_x_lerp); | |||||
| out_addr[out_offset_bottom_y_right_x] += in_addr[in_offset] * (float)(y_lerp * x_lerp); | |||||
| } | |||||
| } | |||||
| out_addr += out_hw_size * channel; | |||||
| in_addr += in_hw_size * channel; | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,44 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_NNACL_FP32_GRAD_RESIZE_GRAD_H_ | |||||
| #define MINDSPORE_LITE_NNACL_FP32_GRAD_RESIZE_GRAD_H_ | |||||
| #include "nnacl/op_base.h" | |||||
| #ifdef __cplusplus | |||||
| extern "C" { | |||||
| #endif | |||||
| typedef struct ResizeGradParameter { | |||||
| OpParameter op_parameter_; | |||||
| bool align_corners_; | |||||
| int method; | |||||
| size_t in_height_; | |||||
| size_t in_width_; | |||||
| size_t out_height_; | |||||
| size_t out_width_; | |||||
| float height_scale_; | |||||
| float width_scale_; | |||||
| } ResizeGradParameter; | |||||
| void ResizeNearestNeighborGrad(float *in_addr, float *out_addr, int batch_size, int channel, | |||||
| ResizeGradParameter *param); | |||||
| void ResizeBiLinearGrad(float *in_addr, float *out_addr, int batch_size, int channel, ResizeGradParameter *param); | |||||
| #ifdef __cplusplus | |||||
| } | |||||
| #endif | |||||
| #endif // MINDSPORE_LITE_NNACL_FP32_GRAD_RESIZE_GRAD_H_ | |||||
| @@ -0,0 +1,33 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "nnacl/fp32_grad/unsorted_segment_sum.h" | |||||
| #include "nnacl/errorcode.h" | |||||
| int UnsortedSegmentSum(const float *input, int unit_num, int input_dim1, const int *indices, float *output, | |||||
| int output_dim0, int output_dim1) { | |||||
| for (int i = 0; i < unit_num; ++i) { | |||||
| int j = i / input_dim1; | |||||
| int k = i % input_dim1; | |||||
| int index = indices[j]; | |||||
| if (index < 0 || index >= output_dim0) { | |||||
| continue; | |||||
| } | |||||
| int output_index = index * output_dim1 + k; | |||||
| output[output_index] += input[i]; | |||||
| } | |||||
| return NNACL_OK; | |||||
| } | |||||
| @@ -0,0 +1,29 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_NNACL_FP32_GRAD_UNSORTED_SEGMENT_SUM_H_ | |||||
| #define MINDSPORE_LITE_NNACL_FP32_GRAD_UNSORTED_SEGMENT_SUM_H_ | |||||
| #ifdef __cplusplus | |||||
| extern "C" { | |||||
| #endif | |||||
| int UnsortedSegmentSum(const float *input, int unit_num, int input_dim1, const int *indices, float *output, | |||||
| int output_dim0, int output_dim1); | |||||
| #ifdef __cplusplus | |||||
| } | |||||
| #endif | |||||
| #endif // MINDSPORE_LITE_NNACL_FP32_GRAD_UNSORTED_SEGMENT_SUM_H_ | |||||
| @@ -15,7 +15,7 @@ | |||||
| */ | */ | ||||
| #include "nnacl/infer/add_sub_grad_infer.h" | #include "nnacl/infer/add_sub_grad_infer.h" | ||||
| #include "nnacl/infer/arithmetic_grad_infer.h" | |||||
| #include "nnacl/arithmetic.h" | |||||
| int AddSubGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, | int AddSubGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, | ||||
| OpParameter *parameter) { | OpParameter *parameter) { | ||||
| @@ -32,35 +32,29 @@ int AddSubGradInferShape(const TensorC *const *inputs, size_t inputs_size, Tenso | |||||
| TensorC *dx1 = outputs[0]; | TensorC *dx1 = outputs[0]; | ||||
| TensorC *dx2 = outputs[1]; | TensorC *dx2 = outputs[1]; | ||||
| ArithmeticGradParameter *param = (ArithmeticGradParameter *)parameter; | |||||
| if (!parameter->infer_flag_) { | |||||
| return NNACL_INFER_INVALID; | |||||
| } | |||||
| int in_shape0[MAX_SHAPE_SIZE]; | |||||
| size_t in_shape0_size = 0; | |||||
| ShapeSet(in_shape0, &in_shape0_size, x1->shape_, x1->shape_size_); | |||||
| int in_shape1[MAX_SHAPE_SIZE]; | |||||
| size_t in_shape1_size = 0; | |||||
| ShapeSet(in_shape1, &in_shape1_size, x2->shape_, x2->shape_size_); | |||||
| int outShape[MAX_SHAPE_SIZE]; | |||||
| size_t outShape_size = 0; | |||||
| ShapeSet(outShape, &outShape_size, dy->shape_, dy->shape_size_); | |||||
| ArithmeticParameter *param = (ArithmeticParameter *)parameter; | |||||
| param->ndim_ = outShape_size; | |||||
| param->x1_shape_size_ = param->ndim_; | |||||
| param->x2_shape_size_ = param->ndim_; | |||||
| param->dy_shape_size_ = param->ndim_; | |||||
| int fill_dim_num0 = outShape_size - in_shape0_size; | |||||
| int fill_dim_num1 = outShape_size - in_shape1_size; | |||||
| param->ndim_ = dy->shape_size_; | |||||
| param->in_elements_num0_ = param->ndim_; | |||||
| param->in_elements_num1_ = param->ndim_; | |||||
| param->out_elements_num_ = param->ndim_; | |||||
| int fillDimNum0 = dy->shape_size_ - x1->shape_size_; | |||||
| int fillDimNum1 = dy->shape_size_ - x2->shape_size_; | |||||
| int j0 = 0; | int j0 = 0; | ||||
| int j1 = 0; | int j1 = 0; | ||||
| for (unsigned int i = 0; i < outShape_size; i++) { | |||||
| param->x1_shape_[i] = (i < fill_dim_num0) ? 1 : in_shape0[j0++]; | |||||
| param->x2_shape_[i] = (i < fill_dim_num1) ? 1 : in_shape1[j1++]; | |||||
| param->dy_shape_[i] = outShape[i]; | |||||
| for (unsigned int i = 0; i < dy->shape_size_; i++) { | |||||
| param->in_shape0_[i] = (i < fillDimNum0) ? 1 : x1->shape_[j0++]; | |||||
| param->in_shape1_[i] = (i < fillDimNum1) ? 1 : x2->shape_[j1++]; | |||||
| param->out_shape_[i] = dy->shape_[i]; | |||||
| } | } | ||||
| SetShapeTensor(dx1, x1); | SetShapeTensor(dx1, x1); | ||||
| SetShapeTensor(dx2, x2); | SetShapeTensor(dx2, x2); | ||||
| dx1->data_type_ = dy->data_type_; | |||||
| dx2->data_type_ = dy->data_type_; | |||||
| SetDataTypeFormat(dx1, dy); | |||||
| SetDataTypeFormat(dx2, dy); | |||||
| return NNACL_OK; | return NNACL_OK; | ||||
| } | } | ||||
| @@ -15,6 +15,7 @@ | |||||
| */ | */ | ||||
| #include "nnacl/infer/arithmetic_grad_infer.h" | #include "nnacl/infer/arithmetic_grad_infer.h" | ||||
| #include "nnacl/arithmetic.h" | |||||
| /* | /* | ||||
| * the Arithmetic Grad op include AddGrad, SubGrad, MulGrad, DivGrad, MaximumGrad, MinimumGrad | * the Arithmetic Grad op include AddGrad, SubGrad, MulGrad, DivGrad, MaximumGrad, MinimumGrad | ||||
| @@ -38,8 +39,6 @@ int ArithmeticGradInferShape(const TensorC *const *inputs, size_t inputs_size, T | |||||
| TensorC *dx1 = outputs[0]; | TensorC *dx1 = outputs[0]; | ||||
| TensorC *dx2 = outputs[1]; | TensorC *dx2 = outputs[1]; | ||||
| ArithmeticGradParameter *param = (ArithmeticGradParameter *)parameter; | |||||
| int in_shape0[MAX_SHAPE_SIZE]; | int in_shape0[MAX_SHAPE_SIZE]; | ||||
| size_t in_shape0_size = 0; | size_t in_shape0_size = 0; | ||||
| ShapeSet(in_shape0, &in_shape0_size, x1->shape_, x1->shape_size_); | ShapeSet(in_shape0, &in_shape0_size, x1->shape_, x1->shape_size_); | ||||
| @@ -50,45 +49,47 @@ int ArithmeticGradInferShape(const TensorC *const *inputs, size_t inputs_size, T | |||||
| size_t out_shape_size = 0; | size_t out_shape_size = 0; | ||||
| ShapeSet(out_shape, &out_shape_size, dy->shape_, dy->shape_size_); | ShapeSet(out_shape, &out_shape_size, dy->shape_, dy->shape_size_); | ||||
| ArithmeticParameter *param = (ArithmeticParameter *)parameter; | |||||
| if (GetElementNum(dx1) < GetElementNum(dx2)) { | if (GetElementNum(dx1) < GetElementNum(dx2)) { | ||||
| param->ndim_ = in_shape1_size; | param->ndim_ = in_shape1_size; | ||||
| param->x1_shape_size_ = param->ndim_; | |||||
| param->x2_shape_size_ = param->ndim_; | |||||
| param->dy_shape_size_ = param->ndim_; | |||||
| param->in_elements_num0_ = param->ndim_; | |||||
| param->in_elements_num1_ = param->ndim_; | |||||
| param->out_elements_num_ = param->ndim_; | |||||
| int fill_dim_num = in_shape1_size - in_shape0_size; // This will not work for batch! | int fill_dim_num = in_shape1_size - in_shape0_size; // This will not work for batch! | ||||
| int j = 0; | int j = 0; | ||||
| for (unsigned int i = 0; i < in_shape1_size; i++) { | for (unsigned int i = 0; i < in_shape1_size; i++) { | ||||
| if (i < fill_dim_num) { | if (i < fill_dim_num) { | ||||
| param->x2_shape_[i] = 1; | |||||
| param->in_shape1_[i] = 1; | |||||
| } else { | } else { | ||||
| param->x2_shape_[i] = in_shape0[j++]; | |||||
| param->in_shape1_[i] = in_shape0[j++]; | |||||
| } | } | ||||
| param->x1_shape_[i] = in_shape1[i]; | |||||
| param->dy_shape_[i] = out_shape[i]; | |||||
| param->in_shape0_[i] = in_shape1[i]; | |||||
| param->out_shape_[i] = out_shape[i]; | |||||
| } | } | ||||
| } else if (GetElementNum(dx2) < GetElementNum(dx1)) { | } else if (GetElementNum(dx2) < GetElementNum(dx1)) { | ||||
| param->ndim_ = in_shape0_size; | param->ndim_ = in_shape0_size; | ||||
| param->x1_shape_size_ = param->ndim_; | |||||
| param->x2_shape_size_ = param->ndim_; | |||||
| param->dy_shape_size_ = param->ndim_; | |||||
| param->in_elements_num0_ = param->ndim_; | |||||
| param->in_elements_num1_ = param->ndim_; | |||||
| param->out_elements_num_ = param->ndim_; | |||||
| param->broadcasting_ = true; | param->broadcasting_ = true; | ||||
| int j = 0; | int j = 0; | ||||
| int fill_dim_num = in_shape0_size - in_shape1_size; | int fill_dim_num = in_shape0_size - in_shape1_size; | ||||
| for (unsigned int i = 0; i < in_shape0_size; i++) { | for (unsigned int i = 0; i < in_shape0_size; i++) { | ||||
| if (i < fill_dim_num) { | if (i < fill_dim_num) { | ||||
| param->x2_shape_[i] = 1; | |||||
| param->in_shape1_[i] = 1; | |||||
| } else { | } else { | ||||
| param->x2_shape_[i] = in_shape1[j++]; | |||||
| param->in_shape1_[i] = in_shape1[j++]; | |||||
| } | } | ||||
| param->x1_shape_[i] = in_shape0[i]; | |||||
| param->dy_shape_[i] = out_shape[i]; | |||||
| param->in_shape0_[i] = in_shape0[i]; | |||||
| param->out_shape_[i] = out_shape[i]; | |||||
| } | } | ||||
| } else { | } else { | ||||
| param->broadcasting_ = false; | param->broadcasting_ = false; | ||||
| for (unsigned int i = 0; i < in_shape0_size; i++) { | for (unsigned int i = 0; i < in_shape0_size; i++) { | ||||
| param->x2_shape_[i] = in_shape1[i]; | |||||
| param->x1_shape_[i] = in_shape0[i]; | |||||
| param->dy_shape_[i] = out_shape[i]; | |||||
| param->in_shape1_[i] = in_shape1[i]; | |||||
| param->in_shape0_[i] = in_shape0[i]; | |||||
| param->out_shape_[i] = out_shape[i]; | |||||
| } | } | ||||
| } | } | ||||
| @@ -13,8 +13,8 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_LITE_NNACL_ARITHMETIC_GRAD_INFER_H | |||||
| #define MINDSPORE_LITE_NNACL_ARITHMETIC_GRAD_INFER_H | |||||
| #ifndef MINDSPORE_LITE_NNACL_INFER_ARITHMETIC_GRAD_INFER_H_ | |||||
| #define MINDSPORE_LITE_NNACL_INFER_ARITHMETIC_GRAD_INFER_H_ | |||||
| #include "nnacl/infer/common_infer.h" | #include "nnacl/infer/common_infer.h" | ||||
| @@ -22,24 +22,10 @@ | |||||
| extern "C" { | extern "C" { | ||||
| #endif | #endif | ||||
| typedef struct ArithmeticGradParameter { | |||||
| OpParameter op_parameter_; | |||||
| int type_; | |||||
| bool broadcasting_; // default false | |||||
| int ndim_; | |||||
| // std::vector<int> dy_shape_; | |||||
| int dy_shape_[MAX_SHAPE_SIZE]; | |||||
| size_t dy_shape_size_; | |||||
| int x1_shape_[MAX_SHAPE_SIZE]; | |||||
| size_t x1_shape_size_; | |||||
| int x2_shape_[MAX_SHAPE_SIZE]; | |||||
| size_t x2_shape_size_; | |||||
| } ArithmeticGradParameter; | |||||
| int ArithmeticGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, | int ArithmeticGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, | ||||
| OpParameter *parameter); | OpParameter *parameter); | ||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| } | } | ||||
| #endif | #endif | ||||
| #endif // MINDSPORE_LITE_NNACL_ARITHMETIC_GRAD_INFER_H | |||||
| #endif // MINDSPORE_LITE_NNACL_INFER_ARITHMETIC_GRAD_INFER_H_ | |||||
| @@ -19,7 +19,7 @@ | |||||
| int FlattenGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, | int FlattenGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, | ||||
| OpParameter *parameter) { | OpParameter *parameter) { | ||||
| #ifdef Debug | #ifdef Debug | ||||
| int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); | |||||
| int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); | |||||
| if (check_ret != NNACL_OK) { | if (check_ret != NNACL_OK) { | ||||
| return check_ret; | return check_ret; | ||||
| } | } | ||||
| @@ -33,13 +33,7 @@ int FlattenGradInferShape(const TensorC *const *inputs, size_t inputs_size, Tens | |||||
| return NNACL_INFER_INVALID; | return NNACL_INFER_INVALID; | ||||
| } | } | ||||
| int output_shape[2]; | |||||
| size_t output_shape_size = 2; | |||||
| output_shape[0] = input->shape_[0]; | |||||
| output_shape[1] = 1; | |||||
| for (size_t i = 1; i < input->shape_size_; i++) { | |||||
| output_shape[1] *= input->shape_[i]; | |||||
| } | |||||
| SetShapeArray(output, output_shape, output_shape_size); | |||||
| int output_shape_size = inputs[1]->shape_[0]; | |||||
| SetShapeArray(output, (int *)(inputs[1]->data_), output_shape_size); | |||||
| return NNACL_OK; | return NNACL_OK; | ||||
| } | } | ||||
| @@ -0,0 +1,48 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "nnacl/infer/layer_norm_grad_infer.h" | |||||
| #include "nnacl/infer/common_infer.h" | |||||
| #include "nnacl/fp32_grad/layernormgrad_parameter.h" | |||||
| int LayerNormGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, | |||||
| OpParameter *parameter) { | |||||
| int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 5, 3); | |||||
| if (check_ret != NNACL_OK) { | |||||
| return check_ret; | |||||
| } | |||||
| LayerNormGradParameter *param = (LayerNormGradParameter *)parameter; | |||||
| const TensorC *input_x = inputs[0]; | |||||
| TensorC *output_dx = outputs[0]; | |||||
| TensorC *output_dg = outputs[1]; | |||||
| TensorC *output_db = outputs[2]; | |||||
| SetDataTypeFormat(output_dx, input_x); | |||||
| SetDataTypeFormat(output_dg, input_x); | |||||
| SetDataTypeFormat(output_db, input_x); | |||||
| SetShapeTensor(output_dx, input_x); | |||||
| int begin_params_axis = param->begin_params_axis_; | |||||
| if (param->begin_params_axis_ < 0) { | |||||
| begin_params_axis += input_x->shape_size_; | |||||
| } | |||||
| int size = 0; | |||||
| for (int i = begin_params_axis; i < input_x->shape_size_; i++) { | |||||
| output_dg->shape_[size] = input_x->shape_[i]; | |||||
| output_db->shape_[size] = input_x->shape_[i]; | |||||
| size++; | |||||
| } | |||||
| output_db->shape_size_ = size; | |||||
| output_dg->shape_size_ = size; | |||||
| return NNACL_OK; | |||||
| } | |||||
| @@ -0,0 +1,31 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_NNACL_INFER_LAYER_NORM_GRAD_INFER_H_ | |||||
| #define MINDSPORE_LITE_NNACL_INFER_LAYER_NORM_GRAD_INFER_H_ | |||||
| #include "nnacl/infer/common_infer.h" | |||||
| #ifdef __cplusplus | |||||
| extern "C" { | |||||
| #endif | |||||
| int LayerNormGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, | |||||
| OpParameter *parameter); | |||||
| #ifdef __cplusplus | |||||
| } | |||||
| #endif | |||||
| #endif // MINDSPORE_LITE_NNACL_INFER_LAYER_NORM_GRAD_INFER_H_ | |||||
| @@ -19,7 +19,10 @@ | |||||
| int LayerNormInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, | int LayerNormInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, | ||||
| OpParameter *parameter) { | OpParameter *parameter) { | ||||
| #ifdef Debug | #ifdef Debug | ||||
| int check_ret = CheckAugmentNullSizeInputTwo(inputs, inputs_size, outputs, outputs_size, parameter, 1, 3, 1); | |||||
| if ((inputs_size != 1 && inputs_size != 3) || (outputs_size != 1 && outputs_size != 3)) { | |||||
| return NNACL_INPUT_TENSOR_ERROR; | |||||
| } | |||||
| int check_ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); | |||||
| if (check_ret != NNACL_OK) { | if (check_ret != NNACL_OK) { | ||||
| return check_ret; | return check_ret; | ||||
| } | } | ||||
| @@ -28,11 +31,27 @@ int LayerNormInferShape(const TensorC *const *inputs, size_t inputs_size, Tensor | |||||
| const TensorC *input = inputs[0]; | const TensorC *input = inputs[0]; | ||||
| TensorC *output = outputs[0]; | TensorC *output = outputs[0]; | ||||
| SetDataTypeFormat(output, input); | SetDataTypeFormat(output, input); | ||||
| LayerNormParameter *param = (LayerNormParameter *)parameter; | LayerNormParameter *param = (LayerNormParameter *)parameter; | ||||
| if (!param->op_parameter_.infer_flag_) { | if (!param->op_parameter_.infer_flag_) { | ||||
| return NNACL_INFER_INVALID; | return NNACL_INFER_INVALID; | ||||
| } | } | ||||
| SetShapeTensor(output, input); | SetShapeTensor(output, input); | ||||
| // take care of other outputs | |||||
| if (outputs_size == 3) { | |||||
| TensorC *output_mean = outputs[1]; | |||||
| TensorC *output_var = outputs[2]; | |||||
| SetDataTypeFormat(output_mean, input); | |||||
| SetDataTypeFormat(output_var, input); | |||||
| int size = 0; | |||||
| for (int i = param->begin_norm_axis_; i < input->shape_size_; i++) { | |||||
| output_mean->shape_[size] = input->shape_[i]; | |||||
| output_var->shape_[size] = input->shape_[i]; | |||||
| size++; | |||||
| } | |||||
| output_mean->shape_size_ = size; | |||||
| output_var->shape_size_ = size; | |||||
| } | |||||
| return NNACL_OK; | return NNACL_OK; | ||||
| } | } | ||||
| @@ -15,6 +15,7 @@ | |||||
| */ | */ | ||||
| #include "nnacl/infer/maximum_grad_infer.h" | #include "nnacl/infer/maximum_grad_infer.h" | ||||
| #include "nnacl/arithmetic.h" | |||||
| int MaximumGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, | int MaximumGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, | ||||
| OpParameter *parameter) { | OpParameter *parameter) { | ||||
| @@ -35,19 +36,20 @@ int MaximumGradInferShape(const TensorC *const *inputs, size_t inputs_size, Tens | |||||
| return NNACL_INFER_INVALID; | return NNACL_INFER_INVALID; | ||||
| } | } | ||||
| MaximumGradParameter *param = (MaximumGradParameter *)parameter; | |||||
| ArithmeticParameter *param = (ArithmeticParameter *)parameter; | |||||
| param->ndim_ = dy->shape_size_; | param->ndim_ = dy->shape_size_; | ||||
| param->x1_shape_size_ = param->ndim_; | |||||
| param->x2_shape_size_ = param->ndim_; | |||||
| param->dy_shape_size_ = param->ndim_; | |||||
| param->in_elements_num0_ = param->ndim_; | |||||
| param->in_elements_num1_ = param->ndim_; | |||||
| param->out_elements_num_ = param->ndim_; | |||||
| int fillDimNum0 = dy->shape_size_ - x1->shape_size_; | int fillDimNum0 = dy->shape_size_ - x1->shape_size_; | ||||
| int fillDimNum1 = dy->shape_size_ - x2->shape_size_; | int fillDimNum1 = dy->shape_size_ - x2->shape_size_; | ||||
| int j0 = 0; | int j0 = 0; | ||||
| int j1 = 0; | int j1 = 0; | ||||
| for (unsigned int i = 0; i < dy->shape_size_; i++) { | for (unsigned int i = 0; i < dy->shape_size_; i++) { | ||||
| param->x1_shape_[i] = (i < fillDimNum0) ? 1 : x1->shape_[j0++]; | |||||
| param->x2_shape_[i] = (i < fillDimNum1) ? 1 : x2->shape_[j1++]; | |||||
| param->dy_shape_[i] = dy->shape_[i]; | |||||
| param->in_shape0_[i] = (i < fillDimNum0) ? 1 : x1->shape_[j0++]; | |||||
| param->in_shape1_[i] = (i < fillDimNum1) ? 1 : x2->shape_[j1++]; | |||||
| param->out_shape_[i] = dy->shape_[i]; | |||||
| } | } | ||||
| SetShapeTensor(dx1, x1); | SetShapeTensor(dx1, x1); | ||||
| @@ -13,8 +13,8 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_LITE_NNACL_MAXIMUM_GRAD_INFER_H | |||||
| #define MINDSPORE_LITE_NNACL_MAXIMUM_GRAD_INFER_H | |||||
| #ifndef MINDSPORE_LITE_NNACL_INFER_MAXIMUM_GRAD_INFER_H_ | |||||
| #define MINDSPORE_LITE_NNACL_INFER_MAXIMUM_GRAD_INFER_H_ | |||||
| #include "nnacl/infer/common_infer.h" | #include "nnacl/infer/common_infer.h" | ||||
| @@ -22,21 +22,10 @@ | |||||
| extern "C" { | extern "C" { | ||||
| #endif | #endif | ||||
| typedef struct MaximumGradParameter { | |||||
| OpParameter op_parameter_; | |||||
| int ndim_; | |||||
| int x1_shape_[MAX_SHAPE_SIZE]; | |||||
| size_t x1_shape_size_; | |||||
| int x2_shape_[MAX_SHAPE_SIZE]; | |||||
| size_t x2_shape_size_; | |||||
| int dy_shape_[MAX_SHAPE_SIZE]; | |||||
| size_t dy_shape_size_; | |||||
| } MaximumGradParameter; | |||||
| int MaximumGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, | int MaximumGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, | ||||
| OpParameter *parameter); | OpParameter *parameter); | ||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| } | } | ||||
| #endif | #endif | ||||
| #endif // MINDSPORE_LITE_NNACL_MAXIMUM_GRAD_INFER_H | |||||
| #endif // MINDSPORE_LITE_NNACL_INFER_MAXIMUM_GRAD_INFER_H_ | |||||
| @@ -201,6 +201,10 @@ union PrimitiveType { | |||||
| LinSpace, | LinSpace, | ||||
| UniformReal, | UniformReal, | ||||
| AbsGrad, | AbsGrad, | ||||
| RsqrtGrad, | |||||
| SqrtGrad, | |||||
| LayerNormGrad, | |||||
| ResizeGrad, | |||||
| } | } | ||||
| table Abs { | table Abs { | ||||
| @@ -1066,3 +1070,20 @@ table UniformReal { | |||||
| table AbsGrad { | table AbsGrad { | ||||
| } | } | ||||
| table RsqrtGrad { | |||||
| } | |||||
| table SqrtGrad { | |||||
| } | |||||
| table LayerNormGrad { | |||||
| begin_norm_axis: long; | |||||
| begin_params_axis: long; | |||||
| } | |||||
| table ResizeGrad { | |||||
| method: ResizeMethod; | |||||
| align_corners: bool; | |||||
| } | |||||
| @@ -200,6 +200,10 @@ OP_TYPE(IsFinite) | |||||
| OP_TYPE(LinSpace) | OP_TYPE(LinSpace) | ||||
| OP_TYPE(UniformReal) | OP_TYPE(UniformReal) | ||||
| OP_TYPE(AbsGrad) | OP_TYPE(AbsGrad) | ||||
| OP_TYPE(RsqrtGrad) | |||||
| OP_TYPE(SqrtGrad) | |||||
| OP_TYPE(LayerNormGrad) | |||||
| OP_TYPE(ResizeGrad) | |||||
| OP_TYPE_DEF_END(PrimitiveType) | OP_TYPE_DEF_END(PrimitiveType) | ||||
| OP_SCHEMA_DEF(Abs) | OP_SCHEMA_DEF(Abs) | ||||
| @@ -1065,3 +1069,19 @@ OP_SCHEMA_DEF_END(UniformReal) | |||||
| OP_SCHEMA_DEF(AbsGrad) | OP_SCHEMA_DEF(AbsGrad) | ||||
| OP_SCHEMA_DEF_END(AbsGrad) | OP_SCHEMA_DEF_END(AbsGrad) | ||||
| OP_SCHEMA_DEF(RsqrtGrad) | |||||
| OP_SCHEMA_DEF_END(RsqrtGrad) | |||||
| OP_SCHEMA_DEF(SqrtGrad) | |||||
| OP_SCHEMA_DEF_END(SqrtGrad) | |||||
| OP_SCHEMA_DEF(LayerNormGrad) | |||||
| OP_ATTR(begin_norm_axis, long) | |||||
| OP_ATTR(begin_params_axis, long) | |||||
| OP_SCHEMA_DEF_END(LayerNormGrad) | |||||
| OP_SCHEMA_DEF(ResizeGrad) | |||||
| OP_ATTR_ENUM(method, ResizeMethod) | |||||
| OP_ATTR(align_corners, bool) | |||||
| OP_SCHEMA_DEF_END(ResizeGrad) | |||||
| @@ -188,6 +188,7 @@ | |||||
| #include "ops/grad/dropout_grad.h" | #include "ops/grad/dropout_grad.h" | ||||
| #include "ops/grad/flatten_grad.h" | #include "ops/grad/flatten_grad.h" | ||||
| #include "ops/grad/group_conv2d_grad_input.h" | #include "ops/grad/group_conv2d_grad_input.h" | ||||
| #include "ops/grad/layer_norm_grad.h" | |||||
| #include "ops/grad/log_grad.h" | #include "ops/grad/log_grad.h" | ||||
| #include "ops/grad/max_pool_grad.h" | #include "ops/grad/max_pool_grad.h" | ||||
| #include "ops/grad/maximum_grad.h" | #include "ops/grad/maximum_grad.h" | ||||
| @@ -196,8 +197,11 @@ | |||||
| #include "ops/grad/neg_grad.h" | #include "ops/grad/neg_grad.h" | ||||
| #include "ops/grad/pooling_grad.h" | #include "ops/grad/pooling_grad.h" | ||||
| #include "ops/grad/power_grad.h" | #include "ops/grad/power_grad.h" | ||||
| #include "ops/grad/resize_grad.h" | |||||
| #include "ops/grad/rsqrt_grad.h" | |||||
| #include "ops/grad/sigmoid_cross_entropy_with_logits_grad.h" | #include "ops/grad/sigmoid_cross_entropy_with_logits_grad.h" | ||||
| #include "ops/grad/smooth_l1_loss_grad.h" | #include "ops/grad/smooth_l1_loss_grad.h" | ||||
| #include "ops/grad/sqrt_grad.h" | |||||
| #include "ops/grad/sub_grad.h" | #include "ops/grad/sub_grad.h" | ||||
| #include "ops/fusion/activation.h" | #include "ops/fusion/activation.h" | ||||
| #include "ops/fusion/add_fusion.h" | #include "ops/fusion/add_fusion.h" | ||||
| @@ -449,5 +453,9 @@ FUNC_MSOP2SCHEMAOP_DECLARE(IsFinite); | |||||
| FUNC_MSOP2SCHEMAOP_DECLARE(LinSpace); | FUNC_MSOP2SCHEMAOP_DECLARE(LinSpace); | ||||
| FUNC_MSOP2SCHEMAOP_DECLARE(UniformReal); | FUNC_MSOP2SCHEMAOP_DECLARE(UniformReal); | ||||
| FUNC_MSOP2SCHEMAOP_DECLARE(AbsGrad); | FUNC_MSOP2SCHEMAOP_DECLARE(AbsGrad); | ||||
| FUNC_MSOP2SCHEMAOP_DECLARE(RsqrtGrad); | |||||
| FUNC_MSOP2SCHEMAOP_DECLARE(SqrtGrad); | |||||
| FUNC_MSOP2SCHEMAOP_DECLARE(LayerNormGrad); | |||||
| FUNC_MSOP2SCHEMAOP_DECLARE(ResizeGrad); | |||||
| #endif | #endif | ||||
| #endif // MINDSPORE_LITE_SRC_OPS_OPS_FUNC_DECLARE_H_ | #endif // MINDSPORE_LITE_SRC_OPS_OPS_FUNC_DECLARE_H_ | ||||
| @@ -48,6 +48,10 @@ schema::PrimitiveT *AbsPrimitiveCreator(const AnfNodePtr &node) { | |||||
| auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::Abs>>(node); | auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::Abs>>(node); | ||||
| return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; | return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; | ||||
| } | } | ||||
| schema::PrimitiveT *AbsGradPrimitiveCreator(const AnfNodePtr &node) { | |||||
| auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::AbsGrad>>(node); | |||||
| return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; | |||||
| } | |||||
| schema::PrimitiveT *ActivationPrimitiveCreator(const AnfNodePtr &node) { | schema::PrimitiveT *ActivationPrimitiveCreator(const AnfNodePtr &node) { | ||||
| auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::Activation>>(node); | auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::Activation>>(node); | ||||
| return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; | return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; | ||||
| @@ -336,6 +340,10 @@ schema::PrimitiveT *LayerNormFusionPrimitiveCreator(const AnfNodePtr &node) { | |||||
| auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::LayerNormFusion>>(node); | auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::LayerNormFusion>>(node); | ||||
| return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; | return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; | ||||
| } | } | ||||
| schema::PrimitiveT *LayerNormGradPrimitiveCreator(const AnfNodePtr &node) { | |||||
| auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::LayerNormGrad>>(node); | |||||
| return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; | |||||
| } | |||||
| schema::PrimitiveT *LeakyReluPrimitiveCreator(const AnfNodePtr &node) { | schema::PrimitiveT *LeakyReluPrimitiveCreator(const AnfNodePtr &node) { | ||||
| auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::LeakyRelu>>(node); | auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::LeakyRelu>>(node); | ||||
| return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; | return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; | ||||
| @@ -516,6 +524,10 @@ schema::PrimitiveT *ResizePrimitiveCreator(const AnfNodePtr &node) { | |||||
| auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::Resize>>(node); | auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::Resize>>(node); | ||||
| return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; | return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; | ||||
| } | } | ||||
| schema::PrimitiveT *ResizeGradPrimitiveCreator(const AnfNodePtr &node) { | |||||
| auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::ResizeGrad>>(node); | |||||
| return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; | |||||
| } | |||||
| schema::PrimitiveT *ReverseV2PrimitiveCreator(const AnfNodePtr &node) { | schema::PrimitiveT *ReverseV2PrimitiveCreator(const AnfNodePtr &node) { | ||||
| auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::ReverseV2>>(node); | auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::ReverseV2>>(node); | ||||
| return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; | return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; | ||||
| @@ -540,6 +552,10 @@ schema::PrimitiveT *RsqrtPrimitiveCreator(const AnfNodePtr &node) { | |||||
| auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::Rsqrt>>(node); | auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::Rsqrt>>(node); | ||||
| return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; | return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; | ||||
| } | } | ||||
| schema::PrimitiveT *RsqrtGradPrimitiveCreator(const AnfNodePtr &node) { | |||||
| auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::RsqrtGrad>>(node); | |||||
| return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; | |||||
| } | |||||
| schema::PrimitiveT *ScaleFusionPrimitiveCreator(const AnfNodePtr &node) { | schema::PrimitiveT *ScaleFusionPrimitiveCreator(const AnfNodePtr &node) { | ||||
| auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::ScaleFusion>>(node); | auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::ScaleFusion>>(node); | ||||
| return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; | return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; | ||||
| @@ -628,6 +644,10 @@ schema::PrimitiveT *SqrtPrimitiveCreator(const AnfNodePtr &node) { | |||||
| auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::Sqrt>>(node); | auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::Sqrt>>(node); | ||||
| return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; | return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; | ||||
| } | } | ||||
| schema::PrimitiveT *SqrtGradPrimitiveCreator(const AnfNodePtr &node) { | |||||
| auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::SqrtGrad>>(node); | |||||
| return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; | |||||
| } | |||||
| schema::PrimitiveT *SquarePrimitiveCreator(const AnfNodePtr &node) { | schema::PrimitiveT *SquarePrimitiveCreator(const AnfNodePtr &node) { | ||||
| auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::Square>>(node); | auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::Square>>(node); | ||||
| return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; | return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; | ||||
| @@ -648,6 +668,12 @@ schema::PrimitiveT *StridedSlicePrimitiveCreator(const AnfNodePtr &node) { | |||||
| auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::StridedSlice>>(node); | auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::StridedSlice>>(node); | ||||
| return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; | return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; | ||||
| } | } | ||||
| schema::PrimitiveT *StridedSliceGradPrimitiveCreator(const AnfNodePtr &node) { | |||||
| auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::StridedSliceGrad>>(node); | |||||
| return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; | |||||
| } | |||||
| schema::PrimitiveT *SubFusionPrimitiveCreator(const AnfNodePtr &node) { | schema::PrimitiveT *SubFusionPrimitiveCreator(const AnfNodePtr &node) { | ||||
| auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::SubFusion>>(node); | auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::SubFusion>>(node); | ||||
| return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; | return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; | ||||
| @@ -718,6 +744,7 @@ schema::PrimitiveT *ZerosLikePrimitiveCreator(const AnfNodePtr &node) { | |||||
| } | } | ||||
| RegistryMSOps g_absPrimitiveCreatorRegistry("Abs", AbsPrimitiveCreator); | RegistryMSOps g_absPrimitiveCreatorRegistry("Abs", AbsPrimitiveCreator); | ||||
| RegistryMSOps g_absGradPrimitiveCreatorRegistry("AbsGrad", AbsGradPrimitiveCreator); | |||||
| RegistryMSOps g_activationPrimitiveCreatorRegistry("Activation", ActivationPrimitiveCreator); | RegistryMSOps g_activationPrimitiveCreatorRegistry("Activation", ActivationPrimitiveCreator); | ||||
| RegistryMSOps g_activationGradPrimitiveCreatorRegistry("ActivationGrad", ActivationGradPrimitiveCreator); | RegistryMSOps g_activationGradPrimitiveCreatorRegistry("ActivationGrad", ActivationGradPrimitiveCreator); | ||||
| RegistryMSOps g_reluGradPrimitiveCreatorRegistry("ReluGrad", ActivationGradPrimitiveCreator); // ? | RegistryMSOps g_reluGradPrimitiveCreatorRegistry("ReluGrad", ActivationGradPrimitiveCreator); // ? | ||||
| @@ -741,6 +768,8 @@ RegistryMSOps g_audioSpectrogramPrimitiveCreatorRegistry("AudioSpectrogram", Aud | |||||
| RegistryMSOps g_avgPoolPrimitiveCreatorRegistry("AvgPool", AvgPoolFusionPrimitiveCreator); | RegistryMSOps g_avgPoolPrimitiveCreatorRegistry("AvgPool", AvgPoolFusionPrimitiveCreator); | ||||
| RegistryMSOps g_avgPoolFusionPrimitiveCreatorRegistry("AvgPoolFusion", AvgPoolFusionPrimitiveCreator); | RegistryMSOps g_avgPoolFusionPrimitiveCreatorRegistry("AvgPoolFusion", AvgPoolFusionPrimitiveCreator); | ||||
| RegistryMSOps g_avgPoolGradPrimitiveCreatorRegistry("AvgPoolGrad", AvgPoolGradPrimitiveCreator); | RegistryMSOps g_avgPoolGradPrimitiveCreatorRegistry("AvgPoolGrad", AvgPoolGradPrimitiveCreator); | ||||
| RegistryMSOps g_avgPoolGradGpuPrimitiveCreatorRegistry("AvgPoolGradGpu", AvgPoolGradPrimitiveCreator); | |||||
| RegistryMSOps g_avgPoolGradCpuPrimitiveCreatorRegistry("AvgPoolGradCpu", AvgPoolGradPrimitiveCreator); | |||||
| RegistryMSOps g_batchNormPrimitiveCreatorRegistry("BatchNorm", BatchNormPrimitiveCreator); | RegistryMSOps g_batchNormPrimitiveCreatorRegistry("BatchNorm", BatchNormPrimitiveCreator); | ||||
| RegistryMSOps g_batchToSpacePrimitiveCreatorRegistry("BatchToSpace", BatchToSpacePrimitiveCreator); | RegistryMSOps g_batchToSpacePrimitiveCreatorRegistry("BatchToSpace", BatchToSpacePrimitiveCreator); | ||||
| RegistryMSOps g_batchToSpaceNDPrimitiveCreatorRegistry("BatchToSpaceND", BatchToSpaceNDPrimitiveCreator); | RegistryMSOps g_batchToSpaceNDPrimitiveCreatorRegistry("BatchToSpaceND", BatchToSpaceNDPrimitiveCreator); | ||||
| @@ -782,6 +811,7 @@ RegistryMSOps g_dropoutPrimitiveCreatorRegistry("Dropout", DropoutPrimitiveCreat | |||||
| RegistryMSOps g_dropoutGradPrimitiveCreatorRegistry("DropoutGrad", DropoutGradPrimitiveCreator); | RegistryMSOps g_dropoutGradPrimitiveCreatorRegistry("DropoutGrad", DropoutGradPrimitiveCreator); | ||||
| RegistryMSOps g_eltwisePrimitiveCreatorRegistry("Eltwise", EltwisePrimitiveCreator); | RegistryMSOps g_eltwisePrimitiveCreatorRegistry("Eltwise", EltwisePrimitiveCreator); | ||||
| RegistryMSOps g_eluPrimitiveCreatorRegistry("Elu", EluPrimitiveCreator); | RegistryMSOps g_eluPrimitiveCreatorRegistry("Elu", EluPrimitiveCreator); | ||||
| RegistryMSOps g_eluGradPrimitiveCreatorRegistry("EluGrad", ActivationGradPrimitiveCreator); | |||||
| RegistryMSOps g_equalPrimitiveCreatorRegistry("Equal", EqualPrimitiveCreator); | RegistryMSOps g_equalPrimitiveCreatorRegistry("Equal", EqualPrimitiveCreator); | ||||
| RegistryMSOps g_embeddingLookupFusionPrimitiveCreatorRegistry("EmbeddingLookupFusion", | RegistryMSOps g_embeddingLookupFusionPrimitiveCreatorRegistry("EmbeddingLookupFusion", | ||||
| EmbeddingLookupFusionPrimitiveCreator); | EmbeddingLookupFusionPrimitiveCreator); | ||||
| @@ -800,6 +830,7 @@ RegistryMSOps g_fullConnectionPrimitiveCreatorRegistry("FullConnection", FullCon | |||||
| RegistryMSOps g_fusedBatchNormPrimitiveCreatorRegistry("FusedBatchNorm", FusedBatchNormPrimitiveCreator); | RegistryMSOps g_fusedBatchNormPrimitiveCreatorRegistry("FusedBatchNorm", FusedBatchNormPrimitiveCreator); | ||||
| RegistryMSOps g_gatherPrimitiveCreatorRegistry("Gather", GatherPrimitiveCreator); | RegistryMSOps g_gatherPrimitiveCreatorRegistry("Gather", GatherPrimitiveCreator); | ||||
| RegistryMSOps g_gatherNdPrimitiveCreatorRegistry("GatherNd", GatherNdPrimitiveCreator); | RegistryMSOps g_gatherNdPrimitiveCreatorRegistry("GatherNd", GatherNdPrimitiveCreator); | ||||
| RegistryMSOps g_geluGradPrimitiveCreatorRegistry("GeluGrad", ActivationGradPrimitiveCreator); | |||||
| RegistryMSOps g_greaterPrimitiveCreatorRegistry("Greater", GreaterPrimitiveCreator); | RegistryMSOps g_greaterPrimitiveCreatorRegistry("Greater", GreaterPrimitiveCreator); | ||||
| RegistryMSOps g_greaterEqualPrimitiveCreatorRegistry("GreaterEqual", GreaterEqualPrimitiveCreator); | RegistryMSOps g_greaterEqualPrimitiveCreatorRegistry("GreaterEqual", GreaterEqualPrimitiveCreator); | ||||
| RegistryMSOps g_gRUPrimitiveCreatorRegistry("GRU", GRUPrimitiveCreator); | RegistryMSOps g_gRUPrimitiveCreatorRegistry("GRU", GRUPrimitiveCreator); | ||||
| @@ -808,6 +839,7 @@ RegistryMSOps g_instanceNormPrimitiveCreatorRegistry("InstanceNorm", InstanceNor | |||||
| RegistryMSOps g_invertPermutationPrimitiveCreatorRegistry("InvertPermutation", InvertPermutationPrimitiveCreator); | RegistryMSOps g_invertPermutationPrimitiveCreatorRegistry("InvertPermutation", InvertPermutationPrimitiveCreator); | ||||
| RegistryMSOps g_layerNormPrimitiveCreatorRegistry("LayerNorm", LayerNormFusionPrimitiveCreator); | RegistryMSOps g_layerNormPrimitiveCreatorRegistry("LayerNorm", LayerNormFusionPrimitiveCreator); | ||||
| RegistryMSOps g_layerNormFusionPrimitiveCreatorRegistry("LayerNormFusion", LayerNormFusionPrimitiveCreator); | RegistryMSOps g_layerNormFusionPrimitiveCreatorRegistry("LayerNormFusion", LayerNormFusionPrimitiveCreator); | ||||
| RegistryMSOps g_layerNormGradPrimitiveCreatorRegistry("LayerNormGrad", LayerNormGradPrimitiveCreator); | |||||
| RegistryMSOps g_leakyReluPrimitiveCreatorRegistry("LeakyRelu", LeakyReluPrimitiveCreator); | RegistryMSOps g_leakyReluPrimitiveCreatorRegistry("LeakyRelu", LeakyReluPrimitiveCreator); | ||||
| RegistryMSOps g_lessPrimitiveCreatorRegistry("Less", LessPrimitiveCreator); | RegistryMSOps g_lessPrimitiveCreatorRegistry("Less", LessPrimitiveCreator); | ||||
| RegistryMSOps g_lessEqualPrimitiveCreatorRegistry("LessEqual", LessEqualPrimitiveCreator); | RegistryMSOps g_lessEqualPrimitiveCreatorRegistry("LessEqual", LessEqualPrimitiveCreator); | ||||
| @@ -857,12 +889,14 @@ RegistryMSOps g_reducePrimitiveCreatorRegistry("Reduce", ReduceFusionPrimitiveCr | |||||
| RegistryMSOps g_reduceFusionPrimitiveCreatorRegistry("ReduceFusion", ReduceFusionPrimitiveCreator); | RegistryMSOps g_reduceFusionPrimitiveCreatorRegistry("ReduceFusion", ReduceFusionPrimitiveCreator); | ||||
| RegistryMSOps g_reshapePrimitiveCreatorRegistry("Reshape", ReshapePrimitiveCreator); | RegistryMSOps g_reshapePrimitiveCreatorRegistry("Reshape", ReshapePrimitiveCreator); | ||||
| RegistryMSOps g_resizePrimitiveCreatorRegistry("Resize", ResizePrimitiveCreator); | RegistryMSOps g_resizePrimitiveCreatorRegistry("Resize", ResizePrimitiveCreator); | ||||
| RegistryMSOps g_resizeGradPrimitiveCreatorRegistry("ResizeGrad", ResizeGradPrimitiveCreator); | |||||
| RegistryMSOps g_reverseV2PrimitiveCreatorRegistry("ReverseV2", ReverseV2PrimitiveCreator); | RegistryMSOps g_reverseV2PrimitiveCreatorRegistry("ReverseV2", ReverseV2PrimitiveCreator); | ||||
| RegistryMSOps g_reverseSequencePrimitiveCreatorRegistry("ReverseSequence", ReverseSequencePrimitiveCreator); | RegistryMSOps g_reverseSequencePrimitiveCreatorRegistry("ReverseSequence", ReverseSequencePrimitiveCreator); | ||||
| RegistryMSOps g_rfftPrimitiveCreatorRegistry("Rfft", RfftPrimitiveCreator); | RegistryMSOps g_rfftPrimitiveCreatorRegistry("Rfft", RfftPrimitiveCreator); | ||||
| RegistryMSOps g_rOIPoolingPrimitiveCreatorRegistry("ROIPooling", ROIPoolingPrimitiveCreator); | RegistryMSOps g_rOIPoolingPrimitiveCreatorRegistry("ROIPooling", ROIPoolingPrimitiveCreator); | ||||
| RegistryMSOps g_roundPrimitiveCreatorRegistry("Round", RoundPrimitiveCreator); | RegistryMSOps g_roundPrimitiveCreatorRegistry("Round", RoundPrimitiveCreator); | ||||
| RegistryMSOps g_rsqrtPrimitiveCreatorRegistry("Rsqrt", RsqrtPrimitiveCreator); | RegistryMSOps g_rsqrtPrimitiveCreatorRegistry("Rsqrt", RsqrtPrimitiveCreator); | ||||
| RegistryMSOps g_rsqrtGradPrimitiveCreatorRegistry("RsqrtGrad", RsqrtGradPrimitiveCreator); | |||||
| RegistryMSOps g_quantDTypeCastPrimitiveCreatorRegistry("QuantDTypeCast", QuantDTypeCastPrimitiveCreator); | RegistryMSOps g_quantDTypeCastPrimitiveCreatorRegistry("QuantDTypeCast", QuantDTypeCastPrimitiveCreator); | ||||
| RegistryMSOps g_scalePrimitiveCreatorRegistry("Scale", ScaleFusionPrimitiveCreator); | RegistryMSOps g_scalePrimitiveCreatorRegistry("Scale", ScaleFusionPrimitiveCreator); | ||||
| RegistryMSOps g_scaleFusionPrimitiveCreatorRegistry("ScaleFusion", ScaleFusionPrimitiveCreator); | RegistryMSOps g_scaleFusionPrimitiveCreatorRegistry("ScaleFusion", ScaleFusionPrimitiveCreator); | ||||
| @@ -891,11 +925,13 @@ RegistryMSOps g_sparseSoftmaxCrossEntropyWithLogitsPrimitiveCreatorRegistry( | |||||
| RegistryMSOps g_sparseToDensePrimitiveCreatorRegistry("SparseToDense", SparseToDensePrimitiveCreator); | RegistryMSOps g_sparseToDensePrimitiveCreatorRegistry("SparseToDense", SparseToDensePrimitiveCreator); | ||||
| RegistryMSOps g_splitPrimitiveCreatorRegistry("Split", SplitPrimitiveCreator); | RegistryMSOps g_splitPrimitiveCreatorRegistry("Split", SplitPrimitiveCreator); | ||||
| RegistryMSOps g_sqrtPrimitiveCreatorRegistry("Sqrt", SqrtPrimitiveCreator); | RegistryMSOps g_sqrtPrimitiveCreatorRegistry("Sqrt", SqrtPrimitiveCreator); | ||||
| RegistryMSOps g_sqrtGradPrimitiveCreatorRegistry("SqrtGrad", SqrtGradPrimitiveCreator); | |||||
| RegistryMSOps g_squeezePrimitiveCreatorRegistry("Squeeze", SqueezePrimitiveCreator); | RegistryMSOps g_squeezePrimitiveCreatorRegistry("Squeeze", SqueezePrimitiveCreator); | ||||
| RegistryMSOps g_squarePrimitiveCreatorRegistry("Square", SquarePrimitiveCreator); | RegistryMSOps g_squarePrimitiveCreatorRegistry("Square", SquarePrimitiveCreator); | ||||
| RegistryMSOps g_squaredDifferencePrimitiveCreatorRegistry("SquaredDifference", SquaredDifferencePrimitiveCreator); | RegistryMSOps g_squaredDifferencePrimitiveCreatorRegistry("SquaredDifference", SquaredDifferencePrimitiveCreator); | ||||
| RegistryMSOps g_stackPrimitiveCreatorRegistry("Stack", StackPrimitiveCreator); | RegistryMSOps g_stackPrimitiveCreatorRegistry("Stack", StackPrimitiveCreator); | ||||
| RegistryMSOps g_stridedSlicePrimitiveCreatorRegistry("StridedSlice", StridedSlicePrimitiveCreator); | RegistryMSOps g_stridedSlicePrimitiveCreatorRegistry("StridedSlice", StridedSlicePrimitiveCreator); | ||||
| RegistryMSOps g_stridedSliceGradPrimitiveCreatorRegistry("StridedSliceGrad", StridedSliceGradPrimitiveCreator); | |||||
| RegistryMSOps g_subPrimitiveCreatorRegistry("Sub", SubFusionPrimitiveCreator); | RegistryMSOps g_subPrimitiveCreatorRegistry("Sub", SubFusionPrimitiveCreator); | ||||
| RegistryMSOps g_subFusionPrimitiveCreatorRegistry("SubFusion", SubFusionPrimitiveCreator); | RegistryMSOps g_subFusionPrimitiveCreatorRegistry("SubFusion", SubFusionPrimitiveCreator); | ||||
| RegistryMSOps g_subGradPrimitiveCreatorRegistry("SubGrad", SubGradPrimitiveCreator); | RegistryMSOps g_subGradPrimitiveCreatorRegistry("SubGrad", SubGradPrimitiveCreator); | ||||
| @@ -0,0 +1,40 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "nnacl/fp32_grad/layernormgrad_parameter.h" | |||||
| #include "src/ops/populate/populate_register.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| OpParameter *PopulateLayerNormGradParameter(const void *prim) { | |||||
| auto layer_norm_grad_parameter = reinterpret_cast<LayerNormGradParameter *>(malloc(sizeof(LayerNormGradParameter))); | |||||
| if (layer_norm_grad_parameter == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc LayerNormParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(layer_norm_grad_parameter, 0, sizeof(LayerNormGradParameter)); | |||||
| auto *primitive = static_cast<const schema::Primitive *>(prim); | |||||
| layer_norm_grad_parameter->op_parameter_.type_ = primitive->value_type(); | |||||
| auto param = primitive->value_as_LayerNormGrad(); | |||||
| layer_norm_grad_parameter->begin_norm_axis_ = param->begin_norm_axis(); | |||||
| layer_norm_grad_parameter->begin_params_axis_ = param->begin_params_axis(); | |||||
| return reinterpret_cast<OpParameter *>(layer_norm_grad_parameter); | |||||
| } | |||||
| Registry g_layerNormGradParameterRegistry(schema::PrimitiveType_LayerNormGrad, PopulateLayerNormGradParameter, | |||||
| SCHEMA_CUR); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -63,6 +63,7 @@ | |||||
| #include "nnacl/infer/group_conv2d_grad_input_infer.h" | #include "nnacl/infer/group_conv2d_grad_input_infer.h" | ||||
| #include "nnacl/infer/hashtable_lookup_infer.h" | #include "nnacl/infer/hashtable_lookup_infer.h" | ||||
| #include "nnacl/infer/layer_norm_infer.h" | #include "nnacl/infer/layer_norm_infer.h" | ||||
| #include "nnacl/infer/layer_norm_grad_infer.h" | |||||
| #include "nnacl/infer/lsh_projection_infer.h" | #include "nnacl/infer/lsh_projection_infer.h" | ||||
| #include "nnacl/infer/lstm_infer.h" | #include "nnacl/infer/lstm_infer.h" | ||||
| #include "nnacl/infer/matmul_infer.h" | #include "nnacl/infer/matmul_infer.h" | ||||
| @@ -214,9 +215,9 @@ static RegistryInferShape g_Deconv2dInferShape(mindspore::schema::PrimitiveType_ | |||||
| static RegistryInferShape g_SquaredDifferenceInferShape(mindspore::schema::PrimitiveType_SquaredDifference, | static RegistryInferShape g_SquaredDifferenceInferShape(mindspore::schema::PrimitiveType_SquaredDifference, | ||||
| ArithmeticInferShape); | ArithmeticInferShape); | ||||
| static RegistryInferShape g_AddInferShape(mindspore::schema::PrimitiveType_AddFusion, ArithmeticInferShape); | static RegistryInferShape g_AddInferShape(mindspore::schema::PrimitiveType_AddFusion, ArithmeticInferShape); | ||||
| static RegistryInferShape g_AddSubInferShape(mindspore::schema::PrimitiveType_AddGrad, AddSubGradInferShape); | |||||
| static RegistryInferShape g_AddSubInferShape(mindspore::schema::PrimitiveType_AddGrad, MaximumGradInferShape); | |||||
| static RegistryInferShape g_SubInferShape(mindspore::schema::PrimitiveType_SubFusion, ArithmeticInferShape); | static RegistryInferShape g_SubInferShape(mindspore::schema::PrimitiveType_SubFusion, ArithmeticInferShape); | ||||
| static RegistryInferShape g_SubGradInferShape(mindspore::schema::PrimitiveType_SubGrad, AddSubGradInferShape); | |||||
| static RegistryInferShape g_SubGradInferShape(mindspore::schema::PrimitiveType_SubGrad, MaximumGradInferShape); | |||||
| static RegistryInferShape g_DivInferShape(mindspore::schema::PrimitiveType_DivFusion, ArithmeticInferShape); | static RegistryInferShape g_DivInferShape(mindspore::schema::PrimitiveType_DivFusion, ArithmeticInferShape); | ||||
| static RegistryInferShape g_DivGradInferShape(mindspore::schema::PrimitiveType_DivGrad, ArithmeticGradInferShape); | static RegistryInferShape g_DivGradInferShape(mindspore::schema::PrimitiveType_DivGrad, ArithmeticGradInferShape); | ||||
| static RegistryInferShape g_MulInferShape(mindspore::schema::PrimitiveType_MulFusion, ArithmeticInferShape); | static RegistryInferShape g_MulInferShape(mindspore::schema::PrimitiveType_MulFusion, ArithmeticInferShape); | ||||
| @@ -275,6 +276,8 @@ static RegistryInferShape g_QuantDtypeCastInferShape(mindspore::schema::Primitiv | |||||
| static RegistryInferShape g_MfccInferShape(mindspore::schema::PrimitiveType_Mfcc, MfccInferShape); | static RegistryInferShape g_MfccInferShape(mindspore::schema::PrimitiveType_Mfcc, MfccInferShape); | ||||
| static RegistryInferShape g_AssignAddInferShape(mindspore::schema::PrimitiveType_AssignAdd, AssignAddInferShape); | static RegistryInferShape g_AssignAddInferShape(mindspore::schema::PrimitiveType_AssignAdd, AssignAddInferShape); | ||||
| static RegistryInferShape g_LayerNormInferShape(mindspore::schema::PrimitiveType_LayerNormFusion, LayerNormInferShape); | static RegistryInferShape g_LayerNormInferShape(mindspore::schema::PrimitiveType_LayerNormFusion, LayerNormInferShape); | ||||
| static RegistryInferShape g_LayerNormGradInferShape(mindspore::schema::PrimitiveType_LayerNormGrad, | |||||
| LayerNormGradInferShape); | |||||
| static RegistryInferShape g_UnsortedSegmentSumInferShape(mindspore::schema::PrimitiveType_UnsortedSegmentSum, | static RegistryInferShape g_UnsortedSegmentSumInferShape(mindspore::schema::PrimitiveType_UnsortedSegmentSum, | ||||
| UnsortedSegmentSumInferShape); | UnsortedSegmentSumInferShape); | ||||
| static RegistryInferShape g_AddnInferShape(mindspore::schema::PrimitiveType_AddN, AddnInferShape); | static RegistryInferShape g_AddnInferShape(mindspore::schema::PrimitiveType_AddN, AddnInferShape); | ||||
| @@ -316,6 +319,7 @@ static RegistryInferShape g_ReverseSequenceInferShape(mindspore::schema::Primiti | |||||
| CommonInferShape); | CommonInferShape); | ||||
| static RegistryInferShape g_ZerosLikeInferShape(mindspore::schema::PrimitiveType_ZerosLike, CommonInferShape); | static RegistryInferShape g_ZerosLikeInferShape(mindspore::schema::PrimitiveType_ZerosLike, CommonInferShape); | ||||
| static RegistryInferShape g_AbsGradInferShape(mindspore::schema::PrimitiveType_AbsGrad, CommonInferShape); | |||||
| static RegistryInferShape g_AbsInferShape(mindspore::schema::PrimitiveType_Abs, CommonInferShape); | static RegistryInferShape g_AbsInferShape(mindspore::schema::PrimitiveType_Abs, CommonInferShape); | ||||
| static RegistryInferShape g_ActivationGradInferShape(mindspore::schema::PrimitiveType_ActivationGrad, CommonInferShape); | static RegistryInferShape g_ActivationGradInferShape(mindspore::schema::PrimitiveType_ActivationGrad, CommonInferShape); | ||||
| static RegistryInferShape g_ActivationInferShape(mindspore::schema::PrimitiveType_Activation, CommonInferShape); | static RegistryInferShape g_ActivationInferShape(mindspore::schema::PrimitiveType_Activation, CommonInferShape); | ||||
| @@ -345,8 +349,10 @@ static RegistryInferShape g_PowerGradInferShape(mindspore::schema::PrimitiveType | |||||
| static RegistryInferShape g_PReLUInferShape(mindspore::schema::PrimitiveType_PReLUFusion, CommonInferShape); | static RegistryInferShape g_PReLUInferShape(mindspore::schema::PrimitiveType_PReLUFusion, CommonInferShape); | ||||
| static RegistryInferShape g_ReverseInferShape(mindspore::schema::PrimitiveType_ReverseV2, CommonInferShape); | static RegistryInferShape g_ReverseInferShape(mindspore::schema::PrimitiveType_ReverseV2, CommonInferShape); | ||||
| static RegistryInferShape g_RoundInferShape(mindspore::schema::PrimitiveType_Round, CommonInferShape); | static RegistryInferShape g_RoundInferShape(mindspore::schema::PrimitiveType_Round, CommonInferShape); | ||||
| static RegistryInferShape g_RsqrtGradInferShape(mindspore::schema::PrimitiveType_RsqrtGrad, CommonInferShape); | |||||
| static RegistryInferShape g_RsqrtInferShape(mindspore::schema::PrimitiveType_Rsqrt, CommonInferShape); | static RegistryInferShape g_RsqrtInferShape(mindspore::schema::PrimitiveType_Rsqrt, CommonInferShape); | ||||
| static RegistryInferShape g_ScaleInferShape(mindspore::schema::PrimitiveType_ScaleFusion, CommonInferShape); | static RegistryInferShape g_ScaleInferShape(mindspore::schema::PrimitiveType_ScaleFusion, CommonInferShape); | ||||
| static RegistryInferShape g_SqrtGradInferShape(mindspore::schema::PrimitiveType_SqrtGrad, CommonInferShape); | |||||
| static RegistryInferShape g_SqrtInferShape(mindspore::schema::PrimitiveType_Sqrt, CommonInferShape); | static RegistryInferShape g_SqrtInferShape(mindspore::schema::PrimitiveType_Sqrt, CommonInferShape); | ||||
| static RegistryInferShape g_SquareInferShape(mindspore::schema::PrimitiveType_Square, CommonInferShape); | static RegistryInferShape g_SquareInferShape(mindspore::schema::PrimitiveType_Square, CommonInferShape); | ||||
| @@ -426,7 +432,6 @@ static RegistryInferShape g_StridedSliceGradInferShape(mindspore::schema::Primit | |||||
| static RegistryInferShape g_IsFiniteInferShape(mindspore::schema::PrimitiveType_IsFinite, CommonInferShape); | static RegistryInferShape g_IsFiniteInferShape(mindspore::schema::PrimitiveType_IsFinite, CommonInferShape); | ||||
| static RegistryInferShape g_LinSpaceInferShape(mindspore::schema::PrimitiveType_LinSpace, LinSpaceInferShape); | static RegistryInferShape g_LinSpaceInferShape(mindspore::schema::PrimitiveType_LinSpace, LinSpaceInferShape); | ||||
| static RegistryInferShape g_UniformRealInferShape(mindspore::schema::PrimitiveType_UniformReal, UniformRealInferShape); | static RegistryInferShape g_UniformRealInferShape(mindspore::schema::PrimitiveType_UniformReal, UniformRealInferShape); | ||||
| static RegistryInferShape g_AbsGradInferShape(mindspore::schema::PrimitiveType_AbsGrad, CommonInferShape); | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -13,8 +13,8 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_H_ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_FP32_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_FP32_H_ | |||||
| #include <vector> | #include <vector> | ||||
| #include "src/lite_kernel.h" | #include "src/lite_kernel.h" | ||||
| @@ -122,4 +122,4 @@ class ArithmeticCPUKernel : public LiteKernel { | |||||
| }; | }; | ||||
| int ArithmeticsRun(void *cdata, int task_id); | int ArithmeticsRun(void *cdata, int task_id); | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_H_ | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_FP32_H_ | |||||
| @@ -79,7 +79,6 @@ int ConvolutionWinogradCPUKernel::InitWeightBias() { | |||||
| // init bias | // init bias | ||||
| size_t new_bias_size = UP_ROUND(out_channel, C4NUM) * sizeof(float); | size_t new_bias_size = UP_ROUND(out_channel, C4NUM) * sizeof(float); | ||||
| bias_data_ = malloc(new_bias_size); | |||||
| if (bias_data_ == nullptr) { | if (bias_data_ == nullptr) { | ||||
| bias_data_ = reinterpret_cast<float *>(malloc(new_bias_size)); | bias_data_ = reinterpret_cast<float *>(malloc(new_bias_size)); | ||||
| if (bias_data_ == nullptr) { | if (bias_data_ == nullptr) { | ||||
| @@ -61,7 +61,7 @@ int LayerNormCPUKernel::ReSize() { | |||||
| } | } | ||||
| int LayerNormCPUKernel::DoLayerNorm(int thread_id) { | int LayerNormCPUKernel::DoLayerNorm(int thread_id) { | ||||
| int ret = LayerNorm(src_data_, gamma_data_, beta_data_, dst_data_, param_, thread_id); | |||||
| int ret = LayerNorm(src_data_, gamma_data_, beta_data_, dst_data_, param_, mean_data_, var_data_, thread_id); | |||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "DoLayerNorm error error_code[" << ret << "]"; | MS_LOG(ERROR) << "DoLayerNorm error error_code[" << ret << "]"; | ||||
| return ret; | return ret; | ||||
| @@ -80,17 +80,17 @@ int LayerNormRun(void *cdata, int task_id) { | |||||
| } | } | ||||
| int LayerNormCPUKernel::Run() { | int LayerNormCPUKernel::Run() { | ||||
| int ret = RET_OK; | |||||
| src_data_ = reinterpret_cast<float *>(in_tensors_.at(0)->data_c()); | src_data_ = reinterpret_cast<float *>(in_tensors_.at(0)->data_c()); | ||||
| gamma_data_ = reinterpret_cast<float *>(in_tensors_.at(1)->data_c()); | gamma_data_ = reinterpret_cast<float *>(in_tensors_.at(1)->data_c()); | ||||
| beta_data_ = reinterpret_cast<float *>(in_tensors_.at(2)->data_c()); | beta_data_ = reinterpret_cast<float *>(in_tensors_.at(2)->data_c()); | ||||
| dst_data_ = reinterpret_cast<float *>(out_tensors_.at(0)->data_c()); | dst_data_ = reinterpret_cast<float *>(out_tensors_.at(0)->data_c()); | ||||
| auto ret = ParallelLaunch(this->context_->thread_pool_, LayerNormRun, this, op_parameter_->thread_num_); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "LayerNormRun error error_code[" << ret << "]"; | |||||
| return ret; | |||||
| if (out_tensors_.size() >= 3) { | |||||
| mean_data_ = reinterpret_cast<float *>(out_tensors_.at(1)->data_c()); | |||||
| var_data_ = reinterpret_cast<float *>(out_tensors_.at(2)->data_c()); | |||||
| } | } | ||||
| return RET_OK; | |||||
| ret = ParallelLaunch(this->context_->thread_pool_, LayerNormRun, this, op_parameter_->thread_num_); | |||||
| return ret; | |||||
| } | } | ||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LayerNormFusion, LiteKernelCreator<LayerNormCPUKernel>) | REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LayerNormFusion, LiteKernelCreator<LayerNormCPUKernel>) | ||||
| @@ -13,8 +13,8 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_LAYER_NORM_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_LAYER_NORM_H_ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_LAYER_NORM_FP32_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_LAYER_NORM_FP32_H_ | |||||
| #include <vector> | #include <vector> | ||||
| #include "src/lite_kernel.h" | #include "src/lite_kernel.h" | ||||
| #include "include/context.h" | #include "include/context.h" | ||||
| @@ -43,7 +43,9 @@ class LayerNormCPUKernel : public LiteKernel { | |||||
| float *dst_data_ = nullptr; | float *dst_data_ = nullptr; | ||||
| float *gamma_data_ = nullptr; | float *gamma_data_ = nullptr; | ||||
| float *beta_data_ = nullptr; | float *beta_data_ = nullptr; | ||||
| float *mean_data_ = nullptr; | |||||
| float *var_data_ = nullptr; | |||||
| }; | }; | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_LAYER_NORM_H_ | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_LAYER_NORM_FP32_H_ | |||||
| @@ -25,6 +25,8 @@ using mindspore::kernel::KERNEL_ARCH::kCPU; | |||||
| using mindspore::lite::KernelRegistrar; | using mindspore::lite::KernelRegistrar; | ||||
| using mindspore::lite::RET_ERROR; | using mindspore::lite::RET_ERROR; | ||||
| using mindspore::lite::RET_OK; | using mindspore::lite::RET_OK; | ||||
| using mindspore::schema::ActivationType_ELU; | |||||
| using mindspore::schema::ActivationType_GELU; | |||||
| using mindspore::schema::ActivationType_HSWISH; | using mindspore::schema::ActivationType_HSWISH; | ||||
| using mindspore::schema::ActivationType_LEAKY_RELU; | using mindspore::schema::ActivationType_LEAKY_RELU; | ||||
| using mindspore::schema::ActivationType_RELU; | using mindspore::schema::ActivationType_RELU; | ||||
| @@ -69,6 +71,10 @@ int ActivationGradCPUKernel::DoActivation(int task_id) { | |||||
| error_code = HSwishGrad(yt_addr + start, input_addr + start, count, output_addr + start); | error_code = HSwishGrad(yt_addr + start, input_addr + start, count, output_addr + start); | ||||
| } else if (param_act_grad_->type_ == schema::ActivationType_HSIGMOID) { | } else if (param_act_grad_->type_ == schema::ActivationType_HSIGMOID) { | ||||
| error_code = HSigmoidGrad(yt_addr + start, input_addr + start, count, output_addr + start); | error_code = HSigmoidGrad(yt_addr + start, input_addr + start, count, output_addr + start); | ||||
| } else if (param_act_grad_->type_ == schema::ActivationType_ELU) { | |||||
| error_code = EluGrad(yt_addr + start, input_addr + start, count, output_addr + start, param_act_grad_->alpha_); | |||||
| } else if (param_act_grad_->type_ == schema::ActivationType_GELU) { | |||||
| error_code = GeluGrad(yt_addr + start, input_addr + start, count, output_addr + start); | |||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Activation type error"; | MS_LOG(ERROR) << "Activation type error"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -99,27 +105,5 @@ int ActivationGradCPUKernel::Run() { | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| kernel::LiteKernel *CpuActivationGradFp32KernelCreator(const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, | |||||
| OpParameter *opParameter, const lite::InnerContext *ctx, | |||||
| const kernel::KernelKey &desc) { | |||||
| MS_ASSERT(opParameter != nullptr); | |||||
| MS_ASSERT(desc.type == schema::PrimitiveType_ActivationGrad); | |||||
| auto *kernel = new (std::nothrow) ActivationGradCPUKernel(opParameter, inputs, outputs, ctx); | |||||
| if (kernel == nullptr) { | |||||
| MS_LOG(ERROR) << "new ActivationGradCPUKernel fail!"; | |||||
| free(opParameter); | |||||
| return nullptr; | |||||
| } | |||||
| auto ret = kernel->Init(); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | |||||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||||
| delete kernel; | |||||
| return nullptr; | |||||
| } | |||||
| return kernel; | |||||
| } | |||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ActivationGrad, CpuActivationGradFp32KernelCreator) | |||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ActivationGrad, LiteKernelCreator<ActivationGradCPUKernel>) | |||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -73,13 +73,13 @@ int AdamCPUKernel::Execute(int task_id) { | |||||
| auto beta2 = reinterpret_cast<float *>(in_tensors_.at(7)->MutableData())[0]; | auto beta2 = reinterpret_cast<float *>(in_tensors_.at(7)->MutableData())[0]; | ||||
| auto eps = reinterpret_cast<float *>(in_tensors_.at(8)->MutableData())[0]; | auto eps = reinterpret_cast<float *>(in_tensors_.at(8)->MutableData())[0]; | ||||
| auto gradient = reinterpret_cast<float *>(in_tensors_.at(9)->MutableData()); | auto gradient = reinterpret_cast<float *>(in_tensors_.at(9)->MutableData()); | ||||
| size_t length = in_tensors_.at(0)->ElementsNum(); | |||||
| int length = in_tensors_.at(0)->ElementsNum(); | |||||
| size_t stride = UP_DIV(length, thread_count_); | |||||
| size_t count = MSMIN(stride, length - stride * task_id); | |||||
| int stride = UP_DIV(length, thread_count_); | |||||
| int count = MSMIN(stride, length - stride * task_id); | |||||
| size_t start = stride * task_id; | |||||
| size_t end = start + count; | |||||
| int start = stride * task_id; | |||||
| int end = start + count; | |||||
| return DoAdam(m, v, gradient, weight, beta1, beta2, beta1_power, beta2_power, eps, learning_rate, | return DoAdam(m, v, gradient, weight, beta1, beta2, beta1_power, beta2_power, eps, learning_rate, | ||||
| adam_param_->use_nesterov_, start, end); | adam_param_->use_nesterov_, start, end); | ||||
| @@ -52,12 +52,12 @@ int ApplyMomentumCPUKernel::Execute(int task_id) { | |||||
| float learning_rate = lr_; | float learning_rate = lr_; | ||||
| auto gradient = reinterpret_cast<float *>(in_tensors_.at(3)->MutableData()); | auto gradient = reinterpret_cast<float *>(in_tensors_.at(3)->MutableData()); | ||||
| float moment = reinterpret_cast<float *>(in_tensors_.at(4)->MutableData())[0]; | float moment = reinterpret_cast<float *>(in_tensors_.at(4)->MutableData())[0]; | ||||
| size_t length = in_tensors_.at(0)->ElementsNum(); | |||||
| int length = in_tensors_.at(0)->ElementsNum(); | |||||
| size_t stride = UP_DIV(length, thread_count_); | |||||
| size_t count = MSMIN(stride, length - stride * task_id); | |||||
| size_t start = stride * task_id; | |||||
| size_t end = start + count; | |||||
| int stride = UP_DIV(length, thread_count_); | |||||
| int count = MSMIN(stride, length - stride * task_id); | |||||
| int start = stride * task_id; | |||||
| int end = start + count; | |||||
| DoApplyMomentum(weight, accumulate, learning_rate, gradient, moment, apply_momentum_param_->use_nesterov_, start, | DoApplyMomentum(weight, accumulate, learning_rate, gradient, moment, apply_momentum_param_->use_nesterov_, start, | ||||
| end); | end); | ||||
| @@ -28,6 +28,8 @@ using mindspore::lite::RET_ERROR; | |||||
| using mindspore::lite::RET_OK; | using mindspore::lite::RET_OK; | ||||
| using mindspore::schema::PrimitiveType_AbsGrad; | using mindspore::schema::PrimitiveType_AbsGrad; | ||||
| using mindspore::schema::PrimitiveType_LogGrad; | using mindspore::schema::PrimitiveType_LogGrad; | ||||
| using mindspore::schema::PrimitiveType_RsqrtGrad; | |||||
| using mindspore::schema::PrimitiveType_SqrtGrad; | |||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| namespace { | namespace { | ||||
| @@ -47,6 +49,12 @@ int ArithmeticSelfGradCPUKernel::Init() { | |||||
| case PrimitiveType_AbsGrad: | case PrimitiveType_AbsGrad: | ||||
| self_grad_operation_ = ElementAbsGrad; | self_grad_operation_ = ElementAbsGrad; | ||||
| break; | break; | ||||
| case PrimitiveType_SqrtGrad: | |||||
| self_grad_operation_ = ElementSqrtGrad; | |||||
| break; | |||||
| case PrimitiveType_RsqrtGrad: | |||||
| self_grad_operation_ = ElementRsqrtGrad; | |||||
| break; | |||||
| default: | default: | ||||
| MS_LOG(ERROR) << "Unsupported type: " << type; | MS_LOG(ERROR) << "Unsupported type: " << type; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -58,11 +66,11 @@ int ArithmeticSelfGradCPUKernel::DoArithmeticSelfGrad(int task_id) { | |||||
| auto dy = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData()); | auto dy = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData()); | ||||
| auto in_x = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData()); | auto in_x = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData()); | ||||
| auto dx = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData()); | auto dx = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData()); | ||||
| size_t length = in_tensors_.at(0)->ElementsNum(); | |||||
| int length = in_tensors_.at(0)->ElementsNum(); | |||||
| size_t stride = UP_DIV(length, thread_count_); | |||||
| size_t count = MSMIN(stride, length - stride * task_id); | |||||
| size_t start = stride * task_id; | |||||
| int stride = UP_DIV(length, thread_count_); | |||||
| int count = MSMIN(stride, length - stride * task_id); | |||||
| int start = stride * task_id; | |||||
| (*self_grad_operation_)(dy + start, in_x + start, dx + start, count); | (*self_grad_operation_)(dy + start, in_x + start, dx + start, count); | ||||
| return RET_OK; | return RET_OK; | ||||
| @@ -107,4 +115,6 @@ kernel::LiteKernel *CpuArithmeticSelfGradFp32KernelCreator(const std::vector<lit | |||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LogGrad, CpuArithmeticSelfGradFp32KernelCreator) | REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LogGrad, CpuArithmeticSelfGradFp32KernelCreator) | ||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_AbsGrad, CpuArithmeticSelfGradFp32KernelCreator) | REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_AbsGrad, CpuArithmeticSelfGradFp32KernelCreator) | ||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_SqrtGrad, CpuArithmeticSelfGradFp32KernelCreator) | |||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_RsqrtGrad, CpuArithmeticSelfGradFp32KernelCreator) | |||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -34,12 +34,12 @@ int AssignCPUKernel::ReSize() { return RET_OK; } | |||||
| int AssignCPUKernel::Execute(int task_id) { | int AssignCPUKernel::Execute(int task_id) { | ||||
| auto x = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData()); | auto x = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData()); | ||||
| auto y = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData()); | auto y = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData()); | ||||
| size_t length = in_tensors_.at(0)->ElementsNum(); | |||||
| int length = in_tensors_.at(0)->ElementsNum(); | |||||
| size_t stride = UP_DIV(length, thread_count_); | |||||
| size_t count = MSMIN(stride, length - stride * task_id); | |||||
| int stride = UP_DIV(length, thread_count_); | |||||
| int count = MSMIN(stride, length - stride * task_id); | |||||
| size_t start = stride * task_id; | |||||
| int start = stride * task_id; | |||||
| memcpy(&(x[start]), &(y[start]), count * sizeof(float)); | memcpy(&(x[start]), &(y[start]), count * sizeof(float)); | ||||
| return RET_OK; | return RET_OK; | ||||
| @@ -62,14 +62,13 @@ int DropoutGradCPUKernel::Execute(int task_id) { | |||||
| auto mask_ptr = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData()); | auto mask_ptr = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData()); | ||||
| auto output_ptr = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->MutableData()); | auto output_ptr = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->MutableData()); | ||||
| auto length = in_tensors_.at(kInputIndex)->ElementsNum(); | auto length = in_tensors_.at(kInputIndex)->ElementsNum(); | ||||
| int stride = UP_DIV(length, thread_count_); | int stride = UP_DIV(length, thread_count_); | ||||
| int count = MSMIN(stride, length - stride * task_id); | int count = MSMIN(stride, length - stride * task_id); | ||||
| size_t start = stride * task_id; | |||||
| DropoutGrad(&(yt_ptr[start]), &(mask_ptr[start]), &(output_ptr[start]), count, scale_); | |||||
| if (count > 0) { | |||||
| int start = stride * task_id; | |||||
| DropoutGrad(&(yt_ptr[start]), &(mask_ptr[start]), &(output_ptr[start]), count, scale_); | |||||
| } | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -0,0 +1,45 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_ELU_GRAD_FP32_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_ELU_GRAD_FP32_H_ | |||||
| #include <vector> | |||||
| #include "src/lite_kernel.h" | |||||
| #include "nnacl/fp32/elu_fp32.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| class EluGradCPUKernel : public LiteKernel { | |||||
| public: | |||||
| EluGradCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx) | |||||
| : LiteKernel(parameter, inputs, outputs, ctx) {} | |||||
| ~EluGradCPUKernel() = default; | |||||
| int Init() override; | |||||
| int ReSize() override; | |||||
| int Run() override; | |||||
| int DoExcute(int task_id); | |||||
| private: | |||||
| float alpha_ = 1.0; // currently MS supports only alpha = 1.0 | |||||
| }; | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_ELU_GRAD_FP32_H_ | |||||
| @@ -0,0 +1,109 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/runtime/kernel/arm/fp32_grad/layernorm_grad.h" | |||||
| #include <vector> | |||||
| #include "schema/model_generated.h" | |||||
| #include "src/kernel_registry.h" | |||||
| #include "nnacl/fp32_grad/layernorm_grad.h" | |||||
| #include "nnacl/fp32_grad/layernormgrad_parameter.h" | |||||
| #include "include/errorcode.h" | |||||
| #include "src/runtime/runtime_api.h" | |||||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||||
| using mindspore::lite::KernelRegistrar; | |||||
| using mindspore::lite::RET_ERROR; | |||||
| using mindspore::lite::RET_OK; | |||||
| using mindspore::schema::PrimitiveType_LayerNormGrad; | |||||
| namespace mindspore::kernel { | |||||
| int LayerNormGradCPUKernel::ReSize() { return RET_OK; } | |||||
| int LayerNormGradCPUKernel::Init() { | |||||
| auto lngrad_param = reinterpret_cast<LayerNormGradParameter *>(op_parameter_); | |||||
| auto *input_x = in_tensors_.at(0); | |||||
| std::vector<int> x_shape = input_x->shape(); | |||||
| int begin_norm_axis = lngrad_param->begin_norm_axis_; | |||||
| if (begin_norm_axis < 0) { | |||||
| begin_norm_axis += x_shape.size(); | |||||
| } | |||||
| auto begin_params_axis = lngrad_param->begin_params_axis_; | |||||
| if (begin_params_axis < 0) { | |||||
| begin_params_axis += x_shape.size(); | |||||
| } | |||||
| for (size_t i = 0; i < static_cast<size_t>(begin_norm_axis); i++) { | |||||
| block_num_ *= x_shape[i]; | |||||
| } | |||||
| for (size_t i = static_cast<size_t>(begin_norm_axis); i < x_shape.size(); i++) { | |||||
| block_size_ *= x_shape[i]; | |||||
| } | |||||
| for (size_t i = 0; i < static_cast<size_t>(begin_params_axis); i++) { | |||||
| param_size_ *= x_shape[i]; | |||||
| } | |||||
| for (size_t i = begin_params_axis; i < x_shape.size(); i++) { | |||||
| param_num_ *= x_shape[i]; | |||||
| } | |||||
| if (block_num_ <= 0 || block_size_ <= 0) { | |||||
| MS_LOG(ERROR) << "LayerNormGradCPUKernel input shape error, input shape: " << x_shape; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int LayerNormGradCPUKernel::Execute(int task_id) { | |||||
| auto input_x = in_tensors_.at(0); | |||||
| auto input_dy = in_tensors_.at(1); | |||||
| auto input_var = in_tensors_.at(2); | |||||
| auto input_mean = in_tensors_.at(3); | |||||
| auto input_gamma = in_tensors_.at(4); | |||||
| auto output_dx = out_tensors_.at(0); | |||||
| auto output_dg = out_tensors_.at(1); | |||||
| auto output_db = out_tensors_.at(2); | |||||
| float *x = reinterpret_cast<float *>(input_x->MutableData()); | |||||
| float *dy = reinterpret_cast<float *>(input_dy->MutableData()); | |||||
| float *var = reinterpret_cast<float *>(input_var->MutableData()); | |||||
| float *mean = reinterpret_cast<float *>(input_mean->MutableData()); | |||||
| float *gamma = reinterpret_cast<float *>(input_gamma->MutableData()); | |||||
| float *dx = reinterpret_cast<float *>(output_dx->MutableData()); | |||||
| float *dg = reinterpret_cast<float *>(output_dg->MutableData()); | |||||
| float *db = reinterpret_cast<float *>(output_db->MutableData()); | |||||
| LayerNormGrad(x, dy, var, mean, gamma, param_num_, param_size_, block_num_, block_size_, dx, dg, db); | |||||
| return RET_OK; | |||||
| } | |||||
| int LayerNormGradRun(void *cdata, int task_id) { | |||||
| MS_ASSERT(cdata != nullptr); | |||||
| auto ln_kernel = reinterpret_cast<LayerNormGradCPUKernel *>(cdata); | |||||
| auto error_code = ln_kernel->Execute(task_id); | |||||
| if (error_code != RET_OK) { | |||||
| MS_LOG(ERROR) << "LayerNormGradRun error task_id[" << task_id << "] error_code[" << error_code << "]"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int LayerNormGradCPUKernel::Run() { | |||||
| int error_code = ParallelLaunch(this->context_->thread_pool_, LayerNormGradRun, this, 1); | |||||
| if (error_code != RET_OK) { | |||||
| MS_LOG(ERROR) << "LayerNorm function error error_code[" << error_code << "]"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LayerNormGrad, LiteKernelCreator<LayerNormGradCPUKernel>) | |||||
| } // namespace mindspore::kernel | |||||
| @@ -0,0 +1,43 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_LAYERNORM_GRAD_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_LAYERNORM_GRAD_H_ | |||||
| #include <vector> | |||||
| #include "src/lite_kernel.h" | |||||
| namespace mindspore::kernel { | |||||
| class LayerNormGradCPUKernel : public LiteKernel { | |||||
| public: | |||||
| explicit LayerNormGradCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx) | |||||
| : LiteKernel(parameter, inputs, outputs, ctx) {} | |||||
| ~LayerNormGradCPUKernel() override {} | |||||
| int Init() override; | |||||
| int ReSize() override; | |||||
| int Run() override; | |||||
| int Execute(int task_id); | |||||
| private: | |||||
| int block_num_ = 1; | |||||
| int block_size_ = 1; | |||||
| int param_num_ = 1; | |||||
| int param_size_ = 1; | |||||
| }; | |||||
| } // namespace mindspore::kernel | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_LAYERNORM_GRAD_H_ | |||||
| @@ -42,12 +42,12 @@ int NegGradCPUKernel::Init() { return RET_OK; } | |||||
| int NegGradCPUKernel::DoNegGrad(int task_id) { | int NegGradCPUKernel::DoNegGrad(int task_id) { | ||||
| auto dy = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData()); | auto dy = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData()); | ||||
| auto dx = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData()); | auto dx = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData()); | ||||
| size_t length = in_tensors_.at(0)->ElementsNum(); | |||||
| int length = in_tensors_.at(0)->ElementsNum(); | |||||
| size_t stride = UP_DIV(length, thread_count_); | |||||
| size_t count = MSMIN(stride, length - stride * task_id); | |||||
| int stride = UP_DIV(length, thread_count_); | |||||
| int count = MSMIN(stride, length - stride * task_id); | |||||
| size_t start = stride * task_id; | |||||
| int start = stride * task_id; | |||||
| ElementNegative(dy + start, dx + start, count); | ElementNegative(dy + start, dx + start, count); | ||||
| return RET_OK; | return RET_OK; | ||||
| @@ -66,21 +66,23 @@ int PoolingGradCPUKernel::Execute(int task_id) { | |||||
| PoolingParameter *pool_param = reinterpret_cast<PoolingParameter *>(op_parameter_); | PoolingParameter *pool_param = reinterpret_cast<PoolingParameter *>(op_parameter_); | ||||
| auto input_ptr = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData()); | auto input_ptr = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData()); | ||||
| auto output_ptr = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData()); | auto output_ptr = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData()); | ||||
| int stride = UP_DIV(pool_param->output_batch_, thread_num_); | int stride = UP_DIV(pool_param->output_batch_, thread_num_); | ||||
| int count = MSMIN(stride, pool_param->output_batch_ - stride * task_id); | int count = MSMIN(stride, pool_param->output_batch_ - stride * task_id); | ||||
| int in_batch_size = pool_param->input_h_ * pool_param->input_w_ * pool_param->input_channel_; | |||||
| int out_batch_size = pool_param->output_h_ * pool_param->output_w_ * pool_param->input_channel_; | |||||
| std::fill(output_ptr + task_id * stride * in_batch_size, output_ptr + ((task_id * stride) + count) * in_batch_size, | |||||
| 0.f); | |||||
| if (pool_param->pool_mode_ == PoolMode_MaxPool) { | |||||
| auto dy_ptr = reinterpret_cast<float *>(in_tensors_.at(2)->MutableData()); | |||||
| MaxPoolingGrad(input_ptr + task_id * stride * in_batch_size, dy_ptr + task_id * stride * out_batch_size, | |||||
| output_ptr + task_id * stride * in_batch_size, count, pool_param); | |||||
| } else { | |||||
| input_ptr = reinterpret_cast<float *>(in_tensors_.at(2)->MutableData()); | |||||
| AvgPoolingGrad(input_ptr + task_id * stride * out_batch_size, output_ptr + task_id * stride * in_batch_size, count, | |||||
| pool_param); | |||||
| if (count > 0) { | |||||
| int in_batch_size = pool_param->input_h_ * pool_param->input_w_ * pool_param->input_channel_; | |||||
| int out_batch_size = pool_param->output_h_ * pool_param->output_w_ * pool_param->input_channel_; | |||||
| std::fill(output_ptr + task_id * stride * in_batch_size, output_ptr + ((task_id * stride) + count) * in_batch_size, | |||||
| 0.f); | |||||
| if (pool_param->pool_mode_ == PoolMode_MaxPool) { | |||||
| auto dy_ptr = reinterpret_cast<float *>(in_tensors_.at(2)->MutableData()); | |||||
| MaxPoolingGrad(input_ptr + task_id * stride * in_batch_size, dy_ptr + task_id * stride * out_batch_size, | |||||
| output_ptr + task_id * stride * in_batch_size, count, pool_param); | |||||
| } else { | |||||
| input_ptr = reinterpret_cast<float *>(in_tensors_.at(2)->MutableData()); | |||||
| AvgPoolingGrad(input_ptr + task_id * stride * out_batch_size, output_ptr + task_id * stride * in_batch_size, | |||||
| count, pool_param); | |||||
| } | |||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -46,19 +46,19 @@ int PowerGradCPUKernel::Execute(int task_id) { | |||||
| auto x_addr = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData()); | auto x_addr = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData()); | ||||
| auto dx_addr = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData()); | auto dx_addr = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData()); | ||||
| size_t length = in_tensors_.at(0)->ElementsNum(); | |||||
| int length = in_tensors_.at(0)->ElementsNum(); | |||||
| size_t stride = UP_DIV(length, thread_count_); | |||||
| size_t count = MSMIN(stride, length - stride * task_id); | |||||
| int stride = UP_DIV(length, thread_count_); | |||||
| int count = MSMIN(stride, length - stride * task_id); | |||||
| size_t start = stride * task_id; | |||||
| size_t end = start + count; | |||||
| int start = stride * task_id; | |||||
| int end = start + count; | |||||
| float exp = power_ - 1; | float exp = power_ - 1; | ||||
| Power(&(x_addr[start]), &exp, &(dx_addr[start]), count, scale_, shift_, true); | Power(&(x_addr[start]), &exp, &(dx_addr[start]), count, scale_, shift_, true); | ||||
| ElementMul(&(dx_addr[start]), &(dy_addr[start]), &(dx_addr[start]), count); | ElementMul(&(dx_addr[start]), &(dy_addr[start]), &(dx_addr[start]), count); | ||||
| float scale = scale_ * power_; | float scale = scale_ * power_; | ||||
| for (size_t i = start; i < end; i++) { | |||||
| for (int i = start; i < end; i++) { | |||||
| dx_addr[i] *= scale; | dx_addr[i] *= scale; | ||||
| } | } | ||||
| @@ -0,0 +1,104 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <vector> | |||||
| #include "src/runtime/kernel/arm/fp32_grad/resize_grad.h" | |||||
| #include "nnacl/fp32_grad/resize_grad.h" | |||||
| #include "schema/model_generated.h" | |||||
| #include "src/kernel_registry.h" | |||||
| #include "include/errorcode.h" | |||||
| #include "src/runtime/runtime_api.h" | |||||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||||
| using mindspore::lite::KernelRegistrar; | |||||
| using mindspore::lite::RET_ERROR; | |||||
| using mindspore::lite::RET_OK; | |||||
| using mindspore::schema::PrimitiveType_ResizeGrad; | |||||
| namespace mindspore::kernel { | |||||
| float Scaling(size_t in_size, size_t out_size, bool align_corners) { | |||||
| return (align_corners && out_size > 1) ? (in_size - 1) / static_cast<float>(out_size - 1) | |||||
| : in_size / static_cast<float>(out_size); | |||||
| } | |||||
| int ResizeGradCPUKernel::ReSize() { | |||||
| auto param = reinterpret_cast<ResizeGradParameter *>(op_parameter_); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "ResizeGradCPUKernel op_parameter_ is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| bool align_corners = param->align_corners_; | |||||
| param->in_height_ = static_cast<size_t>(in_tensors_.at(0)->Height()); | |||||
| param->in_width_ = static_cast<size_t>(in_tensors_.at(0)->Width()); | |||||
| param->out_height_ = static_cast<size_t>(out_tensors_.at(0)->Height()); | |||||
| param->out_width_ = static_cast<size_t>(out_tensors_.at(0)->Width()); | |||||
| param->height_scale_ = Scaling(param->out_height_, param->in_height_, align_corners); | |||||
| param->width_scale_ = Scaling(param->out_width_, param->in_width_, align_corners); | |||||
| return RET_OK; | |||||
| } | |||||
| int ResizeGradCPUKernel::Init() { | |||||
| if (!InferShapeDone()) { | |||||
| return RET_OK; | |||||
| } | |||||
| return ReSize(); | |||||
| } | |||||
| int ResizeGradCPUKernel::Execute(int task_id) { | |||||
| auto in_addr = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData()); | |||||
| auto out_addr = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData()); | |||||
| auto param = reinterpret_cast<ResizeGradParameter *>(op_parameter_); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "ResizeGradCPUKernel op_parameter_ is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto batch_size = in_tensors_.at(0)->Batch(); | |||||
| auto channel = in_tensors_.at(0)->Channel(); | |||||
| if (param->method == static_cast<int>(schema::ResizeMethod_NEAREST)) { | |||||
| ResizeNearestNeighborGrad(in_addr, out_addr, batch_size, channel, param); | |||||
| } else { | |||||
| ResizeBiLinearGrad(in_addr, out_addr, batch_size, channel, param); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int ResizeGradRun(void *cdata, int task_id) { | |||||
| auto resize_grad_kernel = reinterpret_cast<ResizeGradCPUKernel *>(cdata); | |||||
| auto error_code = resize_grad_kernel->Execute(task_id); | |||||
| if (error_code != RET_OK) { | |||||
| MS_LOG(ERROR) << "resize grad error task_id[" << task_id << "] error_code[" << error_code << "]"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int ResizeGradCPUKernel::Run() { | |||||
| auto out_addr = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData()); | |||||
| size_t elem_number = out_tensors_.at(0)->ElementsNum(); | |||||
| std::fill(out_addr, out_addr + elem_number, 0.f); | |||||
| int error_code = ParallelLaunch(this->context_->thread_pool_, ResizeGradRun, this, 1); | |||||
| if (error_code != RET_OK) { | |||||
| MS_LOG(ERROR) << "ResizeGradCPUKernel function error error_code[" << error_code << "]"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ResizeGrad, LiteKernelCreator<ResizeGradCPUKernel>) | |||||
| } // namespace mindspore::kernel | |||||
| @@ -0,0 +1,38 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_RESIZE_GRAD_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_RESIZE_GRAD_H_ | |||||
| #include <vector> | |||||
| #include "src/lite_kernel.h" | |||||
| namespace mindspore::kernel { | |||||
| class ResizeGradCPUKernel : public LiteKernel { | |||||
| public: | |||||
| explicit ResizeGradCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx) | |||||
| : LiteKernel(parameter, inputs, outputs, ctx) {} | |||||
| ~ResizeGradCPUKernel() override = default; | |||||
| int Init() override; | |||||
| int ReSize() override; | |||||
| int Run() override; | |||||
| int ExecuteInit(int task_id); | |||||
| int Execute(int task_id); | |||||
| }; | |||||
| } // namespace mindspore::kernel | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_RESIZE_GRAD_H_ | |||||
| @@ -76,13 +76,13 @@ int SgdCPUKernel::Execute(int task_id) { | |||||
| float learning_rate = lr_; | float learning_rate = lr_; | ||||
| auto gradient = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData()); | auto gradient = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData()); | ||||
| float moment = reinterpret_cast<float *>(in_tensors_.at(4)->MutableData())[0]; | float moment = reinterpret_cast<float *>(in_tensors_.at(4)->MutableData())[0]; | ||||
| size_t length = in_tensors_.at(0)->ElementsNum(); | |||||
| int length = in_tensors_.at(0)->ElementsNum(); | |||||
| size_t stride = UP_DIV(length, thread_count_); | |||||
| size_t count = MSMIN(stride, length - stride * task_id); | |||||
| int stride = UP_DIV(length, thread_count_); | |||||
| int count = MSMIN(stride, length - stride * task_id); | |||||
| size_t start = stride * task_id; | |||||
| size_t end = start + count; | |||||
| int start = stride * task_id; | |||||
| int end = start + count; | |||||
| DoSgd(weight, accumulate, gradient, learning_rate, sgd_param_->dampening_, moment, sgd_param_->use_nesterov_, start, | DoSgd(weight, accumulate, gradient, learning_rate, sgd_param_->dampening_, moment, sgd_param_->use_nesterov_, start, | ||||
| end); | end); | ||||
| @@ -36,17 +36,17 @@ int SmoothL1LossGradCPUKernel::Execute(int task_id) { | |||||
| auto d_loss = reinterpret_cast<float *>(in_tensors_.at(2)->MutableData()); | auto d_loss = reinterpret_cast<float *>(in_tensors_.at(2)->MutableData()); | ||||
| auto *out = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData()); | auto *out = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData()); | ||||
| const size_t length = in_tensors_.at(0)->ElementsNum(); | |||||
| int length = in_tensors_.at(0)->ElementsNum(); | |||||
| size_t stride = UP_DIV(length, thread_count_); | |||||
| size_t count = MSMIN(stride, length - stride * task_id); | |||||
| int stride = UP_DIV(length, thread_count_); | |||||
| int count = MSMIN(stride, length - stride * task_id); | |||||
| size_t start = stride * task_id; | |||||
| size_t end = start + count; | |||||
| int start = stride * task_id; | |||||
| int end = start + count; | |||||
| const float beta = smooth_l1_loss_param->beta_; | const float beta = smooth_l1_loss_param->beta_; | ||||
| for (uint64_t i = start; i < end; ++i) { | |||||
| for (int i = start; i < end; ++i) { | |||||
| float diff = predict[i] - target[i]; | float diff = predict[i] - target[i]; | ||||
| if (diff > beta) { | if (diff > beta) { | ||||
| out[i] = d_loss[i]; | out[i] = d_loss[i]; | ||||
| @@ -0,0 +1,94 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/runtime/kernel/arm/fp32_grad/unsorted_segment_sum.h" | |||||
| #include <vector> | |||||
| #include <algorithm> | |||||
| #include "schema/model_generated.h" | |||||
| #include "src/kernel_registry.h" | |||||
| #include "nnacl/fp32_grad/unsorted_segment_sum.h" | |||||
| #include "include/errorcode.h" | |||||
| #include "src/runtime/runtime_api.h" | |||||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||||
| using mindspore::lite::KernelRegistrar; | |||||
| using mindspore::lite::RET_ERROR; | |||||
| using mindspore::lite::RET_OK; | |||||
| using mindspore::schema::PrimitiveType_UnsortedSegmentSum; | |||||
| namespace mindspore::kernel { | |||||
| int UnsortedSegmentSumCPUKernel::Init() { | |||||
| if (!InferShapeDone()) { | |||||
| return RET_OK; | |||||
| } | |||||
| auto input_shape = in_tensors_.at(0)->shape(); | |||||
| auto segment_ids_shape = in_tensors_.at(1)->shape(); | |||||
| auto output_shape = out_tensors_.at(0)->shape(); | |||||
| for (size_t i = 0; i < input_shape.size(); ++i) { | |||||
| unit_num_ *= input_shape[i]; | |||||
| if (i >= segment_ids_shape.size()) { | |||||
| input_dim1_ *= input_shape[i]; | |||||
| } | |||||
| } | |||||
| output_dim0_ = output_shape[0]; | |||||
| for (size_t j = 1; j < output_shape.size(); j++) { | |||||
| output_dim1_ *= output_shape[j]; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int UnsortedSegmentSumCPUKernel::ReSize() { return RET_OK; } | |||||
| int UnsortedSegmentSumRun(void *cdata, int task_id) { | |||||
| MS_ASSERT(cdata != nullptr); | |||||
| auto kernel = reinterpret_cast<UnsortedSegmentSumCPUKernel *>(cdata); | |||||
| auto error_code = kernel->Execute(task_id); | |||||
| if (error_code != RET_OK) { | |||||
| MS_LOG(ERROR) << "UnsortedSegmentSum Run error task_id[" << task_id << "] error_code[" << error_code << "]"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int UnsortedSegmentSumCPUKernel::Run() { | |||||
| int error_code = ParallelLaunch(this->context_->thread_pool_, UnsortedSegmentSumRun, this, 1); | |||||
| if (error_code != RET_OK) { | |||||
| MS_LOG(ERROR) << "Strided slice error error_code[" << error_code << "]"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int UnsortedSegmentSumCPUKernel::Execute(int task_id) { | |||||
| int ret; | |||||
| auto input_tensor = in_tensors_.at(0); | |||||
| auto indices_tensor = in_tensors_.at(1); | |||||
| auto output_tensor = out_tensors_.at(0); | |||||
| float *input = reinterpret_cast<float *>(input_tensor->data_c()); | |||||
| int *indices = reinterpret_cast<int *>(indices_tensor->data_c()); | |||||
| float *output = reinterpret_cast<float *>(output_tensor->MutableData()); | |||||
| std::fill(output, output + output_tensor->ElementsNum(), 0.f); | |||||
| ret = UnsortedSegmentSum(input, unit_num_, input_dim1_, indices, output, output_dim0_, output_dim1_); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "StridedSliceGrad error error_code[" << ret << "]"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_UnsortedSegmentSum, LiteKernelCreator<UnsortedSegmentSumCPUKernel>) | |||||
| } // namespace mindspore::kernel | |||||
| @@ -0,0 +1,44 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_UNSORTED_SEGMENT_SUM_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_UNSORTED_SEGMENT_SUM_H_ | |||||
| #include <vector> | |||||
| #include "src/lite_kernel.h" | |||||
| namespace mindspore::kernel { | |||||
| class UnsortedSegmentSumCPUKernel : public LiteKernel { | |||||
| public: | |||||
| UnsortedSegmentSumCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx) | |||||
| : LiteKernel(parameter, inputs, outputs, ctx) {} | |||||
| ~UnsortedSegmentSumCPUKernel() override = default; | |||||
| int Init() override; | |||||
| int ReSize() override; | |||||
| int Run() override; | |||||
| int Execute(int task_id); | |||||
| size_t unit_num_; | |||||
| size_t input_dim1_; | |||||
| size_t output_dim0_; | |||||
| size_t output_dim1_; | |||||
| private: | |||||
| }; | |||||
| } // namespace mindspore::kernel | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_UNSORTED_SEGMENT_SUM_H_ | |||||
| @@ -18,10 +18,13 @@ | |||||
| #include <sys/stat.h> | #include <sys/stat.h> | ||||
| #include <vector> | #include <vector> | ||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| #include "src/common/log_adapter.h" | |||||
| #include "include/train_session.h" | #include "include/train_session.h" | ||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #include "src/train/train_utils.h" | #include "src/train/train_utils.h" | ||||
| using mindspore::WARNING; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| @@ -69,23 +69,24 @@ class OptimizerKernel : public LiteKernel { | |||||
| std::fill(grad_sum_, grad_sum_ + elem_num, 0); | std::fill(grad_sum_, grad_sum_ + elem_num, 0); | ||||
| } else { | } else { | ||||
| if (grad_sum_ != nullptr) { | if (grad_sum_ != nullptr) { | ||||
| OptimizerStep(); | |||||
| context_->allocator->Free(grad_sum_); | context_->allocator->Free(grad_sum_); | ||||
| grad_sum_ = nullptr; | grad_sum_ = nullptr; | ||||
| } | } | ||||
| } | } | ||||
| weightUpdateMod_ = WeightUpdateMode::VIRTUAL_BATCH; | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int ExecuteVirtualBatch(int task_id) { | int ExecuteVirtualBatch(int task_id) { | ||||
| auto gradient = reinterpret_cast<float *>(in_tensors_.at(grad_idx_)->MutableData()); | auto gradient = reinterpret_cast<float *>(in_tensors_.at(grad_idx_)->MutableData()); | ||||
| size_t length = in_tensors_.at(grad_idx_)->ElementsNum(); | |||||
| int length = in_tensors_.at(grad_idx_)->ElementsNum(); | |||||
| size_t stride = UP_DIV(length, context_->thread_num_); | |||||
| size_t count = MSMIN(stride, length - stride * task_id); | |||||
| size_t start = stride * task_id; | |||||
| size_t end = start + count; | |||||
| for (size_t i = start; i < end; ++i) { | |||||
| int stride = UP_DIV(length, context_->thread_num_); | |||||
| int count = MSMIN(stride, length - stride * task_id); | |||||
| int start = stride * task_id; | |||||
| int end = start + count; | |||||
| for (int i = start; i < end; ++i) { | |||||
| grad_sum_[i] += gradient[i]; | grad_sum_[i] += gradient[i]; | ||||
| } | } | ||||
| valid_grad_sum_ = true; | valid_grad_sum_ = true; | ||||
| @@ -97,7 +98,10 @@ class OptimizerKernel : public LiteKernel { | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int Eval() override { return OptimizerStep(); } | |||||
| int Eval() override { | |||||
| OptimizerStep(); | |||||
| return LiteKernel::Eval(); | |||||
| } | |||||
| protected: | protected: | ||||
| float default_lr_ = 0.0f; | float default_lr_ = 0.0f; | ||||
| @@ -22,6 +22,7 @@ | |||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| #include "include/train_session.h" | #include "include/train_session.h" | ||||
| #include "include/iterator.h" | #include "include/iterator.h" | ||||
| #include "src/common/log_adapter.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| @@ -167,11 +168,9 @@ int TrainLoop::LoadPartialData(std::vector<tensor::MSTensor *> inputs, dataset:: | |||||
| } // namespace lite | } // namespace lite | ||||
| session::TrainLoop *session::TrainLoop::CreateTrainLoop(const std::string &model_filename, lite::Context *context, | |||||
| session::TrainLoop *session::TrainLoop::CreateTrainLoop(session::TrainSession *train_session, lite::Context *context, | |||||
| int batch_size) { | int batch_size) { | ||||
| auto train_session = session::TrainSession::CreateSession(model_filename, context); | |||||
| auto loop = new (std::nothrow) lite::TrainLoop(train_session); | auto loop = new (std::nothrow) lite::TrainLoop(train_session); | ||||
| return loop; | return loop; | ||||
| } | } | ||||
| @@ -20,10 +20,10 @@ | |||||
| #include <tuple> | #include <tuple> | ||||
| #include <memory> | #include <memory> | ||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include "include/errorcode.h" | |||||
| #include "include/train/train_loop.h" | #include "include/train/train_loop.h" | ||||
| #include "include/train/metrics.h" | #include "include/train/metrics.h" | ||||
| #include "include/train_session.h" | #include "include/train_session.h" | ||||
| #include "include/errorcode.h" | |||||
| #include "include/datasets.h" | #include "include/datasets.h" | ||||
| #include "include/iterator.h" | #include "include/iterator.h" | ||||
| #include "src/common/log_adapter.h" | #include "src/common/log_adapter.h" | ||||
| @@ -28,8 +28,10 @@ | |||||
| #include "nnacl/fp32_grad/batch_norm.h" | #include "nnacl/fp32_grad/batch_norm.h" | ||||
| #include "nnacl/fp32_grad/dropout_parameter.h" | #include "nnacl/fp32_grad/dropout_parameter.h" | ||||
| #include "nnacl/fp32_grad/smooth_l1_loss.h" | #include "nnacl/fp32_grad/smooth_l1_loss.h" | ||||
| #include "nnacl/fp32_grad/resize_grad.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| namespace mindspore::kernel { | |||||
| OpParameter *PopulateSmoothL1LossParameter(const void *prim) { | OpParameter *PopulateSmoothL1LossParameter(const void *prim) { | ||||
| SmoothL1LossParameter *p = reinterpret_cast<SmoothL1LossParameter *>(malloc(sizeof(SmoothL1LossParameter))); | SmoothL1LossParameter *p = reinterpret_cast<SmoothL1LossParameter *>(malloc(sizeof(SmoothL1LossParameter))); | ||||
| if (p == nullptr) { | if (p == nullptr) { | ||||
| @@ -170,9 +172,20 @@ OpParameter *PopulateMaxPoolGradParameter(const void *prim) { | |||||
| pooling_param->pad_r_ = 0; | pooling_param->pad_r_ = 0; | ||||
| pooling_param->stride_w_ = static_cast<int>(value->strides()->Get(1)); | pooling_param->stride_w_ = static_cast<int>(value->strides()->Get(1)); | ||||
| pooling_param->stride_h_ = static_cast<int>(value->strides()->Get(0)); | pooling_param->stride_h_ = static_cast<int>(value->strides()->Get(0)); | ||||
| pooling_param->round_mode_ = RoundMode_No; | pooling_param->round_mode_ = RoundMode_No; | ||||
| pooling_param->pool_mode_ = PoolMode_MaxPool; | pooling_param->pool_mode_ = PoolMode_MaxPool; | ||||
| switch (value->pad_mode()) { | |||||
| case schema::PadMode_SAME: | |||||
| pooling_param->pad_mode_ = Pad_same; | |||||
| break; | |||||
| case schema::PadMode_VALID: | |||||
| pooling_param->pad_mode_ = Pad_valid; | |||||
| break; | |||||
| default: | |||||
| pooling_param->pad_mode_ = Pad_pad; | |||||
| break; | |||||
| } | |||||
| return reinterpret_cast<OpParameter *>(pooling_param); | return reinterpret_cast<OpParameter *>(pooling_param); | ||||
| } | } | ||||
| @@ -197,8 +210,30 @@ OpParameter *PopulateAvgPoolGradParameter(const void *prim) { | |||||
| pooling_param->stride_w_ = static_cast<int>(value->strides()->Get(1)); | pooling_param->stride_w_ = static_cast<int>(value->strides()->Get(1)); | ||||
| pooling_param->stride_h_ = static_cast<int>(value->strides()->Get(0)); | pooling_param->stride_h_ = static_cast<int>(value->strides()->Get(0)); | ||||
| switch (value->pad_mode()) { | |||||
| case schema::PadMode_SAME: | |||||
| pooling_param->pad_mode_ = Pad_same; | |||||
| break; | |||||
| case schema::PadMode_VALID: | |||||
| pooling_param->pad_mode_ = Pad_valid; | |||||
| break; | |||||
| default: | |||||
| pooling_param->pad_mode_ = Pad_pad; | |||||
| break; | |||||
| } | |||||
| pooling_param->round_mode_ = RoundMode_No; | pooling_param->round_mode_ = RoundMode_No; | ||||
| pooling_param->pool_mode_ = PoolMode_AvgPool; | pooling_param->pool_mode_ = PoolMode_AvgPool; | ||||
| switch (value->pad_mode()) { | |||||
| case schema::PadMode_SAME: | |||||
| pooling_param->pad_mode_ = Pad_same; | |||||
| break; | |||||
| case schema::PadMode_VALID: | |||||
| pooling_param->pad_mode_ = Pad_valid; | |||||
| break; | |||||
| default: | |||||
| pooling_param->pad_mode_ = Pad_pad; | |||||
| break; | |||||
| } | |||||
| return reinterpret_cast<OpParameter *>(pooling_param); | return reinterpret_cast<OpParameter *>(pooling_param); | ||||
| } | } | ||||
| @@ -378,6 +413,23 @@ OpParameter *PopulateArithmeticGradParameter(const void *prim) { | |||||
| return reinterpret_cast<OpParameter *>(arithmetic_param); | return reinterpret_cast<OpParameter *>(arithmetic_param); | ||||
| } | } | ||||
| OpParameter *PopulateResizeGradParameter(const void *prim) { | |||||
| ResizeGradParameter *resize_grad_param = reinterpret_cast<ResizeGradParameter *>(malloc(sizeof(ResizeGradParameter))); | |||||
| if (resize_grad_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc resize grad parameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(resize_grad_param, 0, sizeof(ResizeGradParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | |||||
| resize_grad_param->op_parameter_.type_ = primitive->value_type(); | |||||
| auto param = primitive->value_as_ResizeGrad(); | |||||
| resize_grad_param->method = static_cast<int>(param->method()); | |||||
| resize_grad_param->align_corners_ = param->align_corners(); | |||||
| return reinterpret_cast<OpParameter *>(resize_grad_param); | |||||
| } | |||||
| void PopulateTrainParameters() { | void PopulateTrainParameters() { | ||||
| lite::Registry ApplyMomentumParameterRegistry(schema::PrimitiveType_ApplyMomentum, PopulateApplyMomentumParameter, | lite::Registry ApplyMomentumParameterRegistry(schema::PrimitiveType_ApplyMomentum, PopulateApplyMomentumParameter, | ||||
| lite::SCHEMA_CUR); | lite::SCHEMA_CUR); | ||||
| @@ -437,8 +489,14 @@ void PopulateTrainParameters() { | |||||
| lite::SCHEMA_CUR); | lite::SCHEMA_CUR); | ||||
| lite::Registry StridedSliceGradParameterRegistry(schema::PrimitiveType_StridedSliceGrad, | lite::Registry StridedSliceGradParameterRegistry(schema::PrimitiveType_StridedSliceGrad, | ||||
| lite::PopulateStridedSliceParameter, lite::SCHEMA_CUR); | lite::PopulateStridedSliceParameter, lite::SCHEMA_CUR); | ||||
| lite::Registry SqrtGradParameterRegistry(schema::PrimitiveType_SqrtGrad, lite::DefaultPopulateParameter, | |||||
| lite::SCHEMA_CUR); | |||||
| lite::Registry RsqrtGradParameterRegistry(schema::PrimitiveType_RsqrtGrad, lite::DefaultPopulateParameter, | |||||
| lite::SCHEMA_CUR); | |||||
| lite::Registry ResizeGradParameterRegistry(schema::PrimitiveType_ResizeGrad, PopulateResizeGradParameter, | |||||
| lite::SCHEMA_CUR); | |||||
| lite::Registry AbsGradParameterRegistry(schema::PrimitiveType_AbsGrad, lite::DefaultPopulateParameter, | lite::Registry AbsGradParameterRegistry(schema::PrimitiveType_AbsGrad, lite::DefaultPopulateParameter, | ||||
| lite::SCHEMA_CUR); | lite::SCHEMA_CUR); | ||||
| } | } | ||||
| } // namespace mindspore::kernel | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| @@ -272,7 +272,7 @@ void TrainSession::CompileEvalOutputs() { | |||||
| eval_output_node_map_.clear(); | eval_output_node_map_.clear(); | ||||
| eval_output_tensor_map_.clear(); | eval_output_tensor_map_.clear(); | ||||
| for (auto kernel : this->train_kernels_) { | for (auto kernel : this->train_kernels_) { | ||||
| if (IsLossKernel(kernel)) { | |||||
| if (IsLossKernel(kernel) && !(IsGradKernel(kernel))) { | |||||
| for (auto in_kernel : kernel->in_kernels()) { | for (auto in_kernel : kernel->in_kernels()) { | ||||
| if (IsLossKernel(in_kernel) || IsGradKernel(in_kernel)) continue; | if (IsLossKernel(in_kernel) || IsGradKernel(in_kernel)) continue; | ||||
| // insert if not already in | // insert if not already in | ||||
| @@ -33,6 +33,7 @@ | |||||
| #include "src/executor.h" | #include "src/executor.h" | ||||
| #include "src/kernel_registry.h" | #include "src/kernel_registry.h" | ||||
| #include "src/runtime/kernel/arm/fp32_grad/convolution.h" | #include "src/runtime/kernel/arm/fp32_grad/convolution.h" | ||||
| #include "nnacl/fp32/pack_fp32.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| @@ -54,10 +55,20 @@ TransferSession::TransferSession(const char *model_buf_backbone, size_t size_bac | |||||
| std::vector<tensor::MSTensor *> TransferSession::GetInputs() const { return combined_inputs_; } | std::vector<tensor::MSTensor *> TransferSession::GetInputs() const { return combined_inputs_; } | ||||
| bool TransferSession::CompileFormatTransform(tensor::MSTensor *out, tensor::MSTensor *in, int *mask) { | |||||
| for (std::size_t dim = 0; dim != out->shape().size(); ++dim) { | |||||
| if (in->shape().at(mask[dim]) != out->shape().at(dim)) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| int TransferSession::CompileTransferGraph() { | int TransferSession::CompileTransferGraph() { | ||||
| combined_inputs_ = backbone_session_->GetInputs(); | combined_inputs_ = backbone_session_->GetInputs(); | ||||
| auto outputs_backbone = backbone_session_->GetOutputs(); | auto outputs_backbone = backbone_session_->GetOutputs(); | ||||
| auto inputs_head = lite::TrainSession::GetInputs(); | auto inputs_head = lite::TrainSession::GetInputs(); | ||||
| int ret = RET_OK; | int ret = RET_OK; | ||||
| for (auto input : inputs_head) { | for (auto input : inputs_head) { | ||||
| bool match = false; | bool match = false; | ||||
| @@ -72,6 +83,11 @@ int TransferSession::CompileTransferGraph() { | |||||
| break; | break; | ||||
| } | } | ||||
| } | } | ||||
| if (match == false && input->shape().size() == 4) { | |||||
| int nchw2nhwc_mask[4] = {0, 3, 1, 2}; | |||||
| nchw2nhwc_ = CompileFormatTransform(output, input, nchw2nhwc_mask); | |||||
| match = nchw2nhwc_; | |||||
| } | |||||
| if (true == match) { | if (true == match) { | ||||
| break; | break; | ||||
| } | } | ||||
| @@ -124,7 +140,14 @@ int TransferSession::RunGraph(const KernelCallBack &before, const KernelCallBack | |||||
| auto output = backbone_head_pair.second; | auto output = backbone_head_pair.second; | ||||
| char *input_data = reinterpret_cast<char *>(input->MutableData()); | char *input_data = reinterpret_cast<char *>(input->MutableData()); | ||||
| char *output_data = reinterpret_cast<char *>(output->MutableData()); | char *output_data = reinterpret_cast<char *>(output->MutableData()); | ||||
| std::copy(output_data, output_data + output->Size(), input_data); | |||||
| if (nchw2nhwc_) { | |||||
| int plane = input->shape().at(1) * input->shape().at(2); | |||||
| int batch = input->shape().at(0); | |||||
| int channel = input->shape().at(3); | |||||
| PackNCHWToNHWCFp32(output_data, input_data, batch, plane, channel, 0, 1); | |||||
| } else { | |||||
| std::copy(output_data, output_data + output->Size(), input_data); | |||||
| } | |||||
| } | } | ||||
| ret = lite::TrainSession::RunGraph(before, after); | ret = lite::TrainSession::RunGraph(before, after); | ||||
| return ret; | return ret; | ||||
| @@ -72,6 +72,8 @@ class TransferSession : public lite::TrainSession { | |||||
| bool is_valid_; | bool is_valid_; | ||||
| private: | private: | ||||
| bool CompileFormatTransform(tensor::MSTensor *out, tensor::MSTensor *in, int *mask); | |||||
| bool nchw2nhwc_ = false; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -1,3 +1,4 @@ | |||||
| # | |||||
| mini_alexnet_r1.1 | mini_alexnet_r1.1 | ||||
| mobilenetv1_r1.1 | mobilenetv1_r1.1 | ||||
| mobilenetv2_r1.1 | mobilenetv2_r1.1 | ||||
| @@ -5,4 +6,17 @@ lenet_r1.1 | |||||
| effnet_r1.1 | effnet_r1.1 | ||||
| effnet_tune_r1.1 | effnet_tune_r1.1 | ||||
| googlenet_r1.1 | googlenet_r1.1 | ||||
| #LAST | |||||
| # mini_alexnet | |||||
| # nin | |||||
| # lenet | |||||
| # mobilenetv1 | |||||
| # mobilenetv2 | |||||
| # mobilenetv3 | |||||
| # effnet | |||||
| # resnet | |||||
| # effnet_tune | |||||
| # googlenet | |||||
| # densenet | |||||
| # shufflenetv2 | |||||
| # xception | |||||
| # LAST | |||||
| @@ -47,7 +47,8 @@ logs_path=${basepath}/logs_train | |||||
| rm -rf ${logs_path} | rm -rf ${logs_path} | ||||
| mkdir -p ${logs_path} | mkdir -p ${logs_path} | ||||
| docker_image=mindspore/mindspore-gpu:1.1.0 | |||||
| docker_image=mindspore_build:210301 | |||||
| #docker_image=mindspore/mindspore-gpu:1.1.1 | |||||
| # Export models | # Export models | ||||
| echo "Start Exporting models ..." | echo "Start Exporting models ..." | ||||
| # Set log files | # Set log files | ||||
| @@ -65,12 +66,15 @@ if [[ -z "${CLOUD_MODEL_ZOO}" ]]; then | |||||
| fi | fi | ||||
| # Export mindspore train models: | # Export mindspore train models: | ||||
| fail=0 | |||||
| while read line; do | while read line; do | ||||
| model_name=${line} | |||||
| LFS=" " read -r -a line_array <<< ${line} | |||||
| model_name=${line_array[0]} | |||||
| if [[ $model_name == \#* ]]; then | if [[ $model_name == \#* ]]; then | ||||
| continue | continue | ||||
| fi | fi | ||||
| echo ${model_name}'_train_export.py' >> "${export_log_file}" | echo ${model_name}'_train_export.py' >> "${export_log_file}" | ||||
| rm -f ${models_path}/${model_name}_train.mindir | |||||
| echo 'exporting' ${model_name} | echo 'exporting' ${model_name} | ||||
| echo 'docker run --user '"$(id -u):$(id -g)"' --env CLOUD_MODEL_ZOO=${CLOUD_MODEL_ZOO} -w $PWD --runtime=nvidia -v /home/$USER:/home/$USER -v /opt/share:/opt/share --privileged=true '${docker_image}' python '${models_path}'/'${model_name}'_train_export.py' >> "${export_log_file}" | echo 'docker run --user '"$(id -u):$(id -g)"' --env CLOUD_MODEL_ZOO=${CLOUD_MODEL_ZOO} -w $PWD --runtime=nvidia -v /home/$USER:/home/$USER -v /opt/share:/opt/share --privileged=true '${docker_image}' python '${models_path}'/'${model_name}'_train_export.py' >> "${export_log_file}" | ||||
| docker run --user "$(id -u):$(id -g)" --env CLOUD_MODEL_ZOO=${CLOUD_MODEL_ZOO} -w $PWD --runtime=nvidia -v /home/$USER:/home/$USER -v /opt/share:/opt/share --privileged=true "${docker_image}" python ${models_path}'/'${model_name}_train_export.py "${epoch_num}" | docker run --user "$(id -u):$(id -g)" --env CLOUD_MODEL_ZOO=${CLOUD_MODEL_ZOO} -w $PWD --runtime=nvidia -v /home/$USER:/home/$USER -v /opt/share:/opt/share --privileged=true "${docker_image}" python ${models_path}'/'${model_name}_train_export.py "${epoch_num}" | ||||
| @@ -78,8 +82,10 @@ while read line; do | |||||
| export_result='export mindspore '${model_name}'_train_export pass';echo ${export_result} >> ${export_result_file} | export_result='export mindspore '${model_name}'_train_export pass';echo ${export_result} >> ${export_result_file} | ||||
| else | else | ||||
| export_result='export mindspore '${model_name}'_train_export failed';echo ${export_result} >> ${export_result_file} | export_result='export mindspore '${model_name}'_train_export failed';echo ${export_result} >> ${export_result_file} | ||||
| fail=1 | |||||
| fi | fi | ||||
| done < ${models_mindspore_train_config} | done < ${models_mindspore_train_config} | ||||
| Print_Result ${export_result_file} | Print_Result ${export_result_file} | ||||
| exit $fail | |||||
| @@ -1,7 +1,7 @@ | |||||
| #!/bin/bash | #!/bin/bash | ||||
| # Run Export on x86 platform and create output test files: | # Run Export on x86 platform and create output test files: | ||||
| docker_image= | |||||
| docker_image=mindspore_build:210301 | |||||
| function Run_Export(){ | function Run_Export(){ | ||||
| cd $models_path || exit 1 | cd $models_path || exit 1 | ||||
| if [[ -z "${CLOUD_MODEL_ZOO}" ]]; then | if [[ -z "${CLOUD_MODEL_ZOO}" ]]; then | ||||
| @@ -10,7 +10,8 @@ function Run_Export(){ | |||||
| fi | fi | ||||
| # Export mindspore train models: | # Export mindspore train models: | ||||
| while read line; do | while read line; do | ||||
| model_name=${line} | |||||
| LFS=" " read -r -a line_array <<< ${line} | |||||
| model_name=${line_array[0]} | |||||
| if [[ $model_name == \#* ]]; then | if [[ $model_name == \#* ]]; then | ||||
| continue | continue | ||||
| fi | fi | ||||
| @@ -47,10 +48,11 @@ function Run_Converter() { | |||||
| rm -rf ${ms_models_path} | rm -rf ${ms_models_path} | ||||
| mkdir -p ${ms_models_path} | mkdir -p ${ms_models_path} | ||||
| fail=0 | |||||
| # Convert mindspore train models: | # Convert mindspore train models: | ||||
| while read line; do | while read line; do | ||||
| model_name=${line} | |||||
| LFS=" " read -r -a line_array <<< ${line} | |||||
| model_name=${line_array[0]} | |||||
| if [[ $model_name == \#* ]]; then | if [[ $model_name == \#* ]]; then | ||||
| continue | continue | ||||
| fi | fi | ||||
| @@ -64,8 +66,10 @@ function Run_Converter() { | |||||
| converter_result='converter mindspore '${model_name}'_train pass';echo ${converter_result} >> ${run_converter_result_file} | converter_result='converter mindspore '${model_name}'_train pass';echo ${converter_result} >> ${run_converter_result_file} | ||||
| else | else | ||||
| converter_result='converter mindspore '${model_name}'_train failed';echo ${converter_result} >> ${run_converter_result_file} | converter_result='converter mindspore '${model_name}'_train failed';echo ${converter_result} >> ${run_converter_result_file} | ||||
| fail=1 | |||||
| fi | fi | ||||
| done < ${models_mindspore_train_config} | done < ${models_mindspore_train_config} | ||||
| return ${fail} | |||||
| } | } | ||||
| # Run on x86 platform: | # Run on x86 platform: | ||||
| @@ -73,7 +77,8 @@ function Run_x86() { | |||||
| # Run mindspore converted train models: | # Run mindspore converted train models: | ||||
| fail=0 | fail=0 | ||||
| while read line; do | while read line; do | ||||
| model_name=${line} | |||||
| LFS=" " read -r -a line_array <<< ${line} | |||||
| model_name=${line_array[0]} | |||||
| if [[ $model_name == \#* ]]; then | if [[ $model_name == \#* ]]; then | ||||
| continue | continue | ||||
| fi | fi | ||||
| @@ -81,7 +86,7 @@ function Run_x86() { | |||||
| echo ${model_name}'_train' >> "${run_x86_log_file}" | echo ${model_name}'_train' >> "${run_x86_log_file}" | ||||
| echo 'cd '${x86_path}'/mindspore-lite-'${version}'-train-linux-x64' >> "${run_x86_log_file}" | echo 'cd '${x86_path}'/mindspore-lite-'${version}'-train-linux-x64' >> "${run_x86_log_file}" | ||||
| cd ${x86_path}/mindspore-lite-${version}-train-linux-x64 || return 1 | cd ${x86_path}/mindspore-lite-${version}-train-linux-x64 || return 1 | ||||
| echo 'LD_LIBRARY_PATH='${LD_LIBRARY_PATH}':./lib:./third_party/libjpeg-turbo/lib:./third_party/opencv/lib ./benchmark_train/benchmark_train --epochs='${epoch_num}' --modelFile='${ms_models_path}'/'${model_name}'_train.ms --inDataFile='${train_io_path}/${model_name}_input1.bin,${train_io_path}/${model_name}_input2.bin' --expectedDataFile='${train_io_path}'/'${model_name}'_output' >> "${run_x86_log_file}" | |||||
| echo 'LD_LIBRARY_PATH='${LD_LIBRARY_PATH}':./lib:./third_party/libjpeg-turbo/lib:./third_party/opencv/lib ./benchmark_train/benchmark_train --epochs='${epoch_num}' --modelFile='${ms_models_path}'/'${model_name}'_train.ms --inDataFile='${train_io_path}/${model_name}_input1.bin,${train_io_path}/${model_name}_input2.bin' --expectedDataFile='${train_io_path}'/'${model_name}'_output --exportFile='${ms_models_path}'/'${model_name}'_train_exported.ms' >> "${run_x86_log_file}" | |||||
| echo '-------------------------------------------------------------------------------' >> "${run_x86_log_file}" | echo '-------------------------------------------------------------------------------' >> "${run_x86_log_file}" | ||||
| LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:./lib:./third_party/libjpeg-turbo/lib:./third_party/opencv/lib:./minddata/lib:./minddata/third_party/libjpeg-turbo/lib \ | LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:./lib:./third_party/libjpeg-turbo/lib:./third_party/opencv/lib:./minddata/lib:./minddata/third_party/libjpeg-turbo/lib \ | ||||
| ${run_valgrind}./benchmark_train/benchmark_train \ | ${run_valgrind}./benchmark_train/benchmark_train \ | ||||
| @@ -159,10 +164,16 @@ function Run_arm() { | |||||
| fail=0 | fail=0 | ||||
| # Run mindir converted train models: | # Run mindir converted train models: | ||||
| while read line; do | while read line; do | ||||
| model_name=${line} | |||||
| LFS=" " read -r -a line_array <<< ${line} | |||||
| model_name=${line_array[0]} | |||||
| if [[ $model_name == \#* ]]; then | if [[ $model_name == \#* ]]; then | ||||
| continue | continue | ||||
| fi | fi | ||||
| if [[ "${line_array[1]}" == "noarm32" ]] && [[ "$1" == arm32 ]]; then | |||||
| run_result=$1': '${model_name}'_train irrelevant'; echo ${run_result} >> ${run_benchmark_train_result_file} | |||||
| continue | |||||
| fi | |||||
| # run benchmark_train test without clib data | # run benchmark_train test without clib data | ||||
| echo ${model_name}'_train' >> "${run_arm_log_file}" | echo ${model_name}'_train' >> "${run_arm_log_file}" | ||||
| @@ -339,7 +350,7 @@ START=$(date +%s.%N) | |||||
| # Run converter | # Run converter | ||||
| echo "start run converter ..." | echo "start run converter ..." | ||||
| Run_Converter | |||||
| Run_Converter & | |||||
| Run_converter_PID=$! | Run_converter_PID=$! | ||||
| sleep 1 | sleep 1 | ||||
| @@ -15,6 +15,7 @@ | |||||
| */ | */ | ||||
| #include "common/common_test.h" | #include "common/common_test.h" | ||||
| #include "mindspore/lite/nnacl/infer/maximum_grad_infer.h" | #include "mindspore/lite/nnacl/infer/maximum_grad_infer.h" | ||||
| #include "mindspore/lite/nnacl/arithmetic.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -44,7 +45,7 @@ TEST_F(MaximumGradInferTest, MaximumGradInferTest0) { | |||||
| std::vector<TensorC *> outputs(2, NULL); | std::vector<TensorC *> outputs(2, NULL); | ||||
| outputs[0] = new TensorC; | outputs[0] = new TensorC; | ||||
| outputs[1] = new TensorC; | outputs[1] = new TensorC; | ||||
| MaximumGradParameter *parameter = new MaximumGradParameter; | |||||
| ArithmeticParameter *parameter = new ArithmeticParameter; | |||||
| parameter->op_parameter_.infer_flag_ = true; | parameter->op_parameter_.infer_flag_ = true; | ||||
| int ret = MaximumGradInferShape((const TensorC **)inputs.data(), inputs.size(), outputs.data(), outputs.size(), | int ret = MaximumGradInferShape((const TensorC **)inputs.data(), inputs.size(), outputs.data(), outputs.size(), | ||||
| reinterpret_cast<OpParameter *>(parameter)); | reinterpret_cast<OpParameter *>(parameter)); | ||||
| @@ -60,18 +61,18 @@ TEST_F(MaximumGradInferTest, MaximumGradInferTest0) { | |||||
| ASSERT_EQ(outputs[1]->data_type_, kNumberTypeInt32); | ASSERT_EQ(outputs[1]->data_type_, kNumberTypeInt32); | ||||
| ASSERT_EQ(outputs[1]->format_, Format_NHWC); | ASSERT_EQ(outputs[1]->format_, Format_NHWC); | ||||
| ASSERT_EQ(parameter->ndim_, 3); | ASSERT_EQ(parameter->ndim_, 3); | ||||
| ASSERT_EQ(parameter->dy_shape_size_, 3); | |||||
| ASSERT_EQ(parameter->dy_shape_[0], 7); | |||||
| ASSERT_EQ(parameter->dy_shape_[1], 8); | |||||
| ASSERT_EQ(parameter->dy_shape_[2], 9); | |||||
| ASSERT_EQ(parameter->x1_shape_size_, 3); | |||||
| ASSERT_EQ(parameter->x1_shape_[0], 1); | |||||
| ASSERT_EQ(parameter->x1_shape_[1], 4); | |||||
| ASSERT_EQ(parameter->x1_shape_[2], 3); | |||||
| ASSERT_EQ(parameter->x2_shape_size_, 3); | |||||
| ASSERT_EQ(parameter->x2_shape_[0], 1); | |||||
| ASSERT_EQ(parameter->x2_shape_[1], 5); | |||||
| ASSERT_EQ(parameter->x2_shape_[2], 6); | |||||
| ASSERT_EQ(parameter->out_elements_num_, 3); | |||||
| ASSERT_EQ(parameter->out_shape_[0], 7); | |||||
| ASSERT_EQ(parameter->out_shape_[1], 8); | |||||
| ASSERT_EQ(parameter->out_shape_[2], 9); | |||||
| ASSERT_EQ(parameter->in_elements_num0_, 3); | |||||
| ASSERT_EQ(parameter->in_shape0_[0], 1); | |||||
| ASSERT_EQ(parameter->in_shape0_[1], 4); | |||||
| ASSERT_EQ(parameter->in_shape0_[2], 3); | |||||
| ASSERT_EQ(parameter->in_elements_num1_, 3); | |||||
| ASSERT_EQ(parameter->in_shape1_[0], 1); | |||||
| ASSERT_EQ(parameter->in_shape1_[1], 5); | |||||
| ASSERT_EQ(parameter->in_shape1_[2], 6); | |||||
| delete parameter; | delete parameter; | ||||
| for (size_t i = 0; i < inputs_size; i++) { | for (size_t i = 0; i < inputs_size; i++) { | ||||
| delete inputs[i]; | delete inputs[i]; | ||||
| @@ -98,7 +98,8 @@ void AnfExporter::RemoveIfMakeTuple(const CNodePtr &cnode) { | |||||
| MS_LOG(ERROR) << "value node is invalid."; | MS_LOG(ERROR) << "value node is invalid."; | ||||
| return; | return; | ||||
| } | } | ||||
| if (value_node->value() != nullptr && opt::CheckPrimitiveType(make_tuple_node, opt::kPrimMakeTuple)) { | |||||
| if (value_node->value() != nullptr && (opt::CheckPrimitiveType(make_tuple_node, opt::kPrimMakeTuple) || | |||||
| opt::CheckPrimitiveType(make_tuple_node, opt::kPrimMakeTupleV2))) { | |||||
| has_make_tuple = true; | has_make_tuple = true; | ||||
| for (size_t j = 1; j < make_tuple_node->inputs().size(); ++j) { | for (size_t j = 1; j < make_tuple_node->inputs().size(); ++j) { | ||||
| inputs.emplace_back(make_tuple_node->input(j)); | inputs.emplace_back(make_tuple_node->input(j)); | ||||
| @@ -360,6 +361,9 @@ int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr<sc | |||||
| if (prim->name() == mindspore::ops::kNameDepend || prim->name() == mindspore::ops::kNameControlDepend) { | if (prim->name() == mindspore::ops::kNameDepend || prim->name() == mindspore::ops::kNameControlDepend) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| if (prim->name() == "make_tuple") { | |||||
| continue; | |||||
| } | |||||
| if (prim->name() == mindspore::ops::kNameTupleGetItem || prim->name() == mindspore::ops::kNameMakeTuple) { | if (prim->name() == mindspore::ops::kNameTupleGetItem || prim->name() == mindspore::ops::kNameMakeTuple) { | ||||
| continue; | continue; | ||||
| @@ -769,7 +773,7 @@ int AnfExporter::ConvertInputValueNode(const std::shared_ptr<AnfNode> &input_ano | |||||
| MS_LOG(INFO) << "op name:" << input_anode->fullname_with_scope() << " input is func_graph"; | MS_LOG(INFO) << "op name:" << input_anode->fullname_with_scope() << " input is func_graph"; | ||||
| return RET_OK; | return RET_OK; | ||||
| } else if (value->isa<Monad>()) { | } else if (value->isa<Monad>()) { | ||||
| MS_LOG(INFO) << "value is a monad."; | |||||
| MS_LOG(INFO) << "op name:" << input_anode->fullname_with_scope() << " input is Monad"; | |||||
| return RET_OK; | return RET_OK; | ||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Not support value type , need add support."; | MS_LOG(ERROR) << "Not support value type , need add support."; | ||||
| @@ -125,7 +125,8 @@ int NetTrain::ReadInputFile() { | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } else { | } else { | ||||
| if (ms_inputs_.size() > flags_->input_data_list_.size()) { | if (ms_inputs_.size() > flags_->input_data_list_.size()) { | ||||
| MS_LOG(ERROR) << "missing input files"; | |||||
| MS_LOG(ERROR) << "missing input files expecting " << ms_inputs_.size() << ",got " | |||||
| << flags_->input_data_list_.size(); | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| for (size_t i = 0; i < ms_inputs_.size(); i++) { | for (size_t i = 0; i < ms_inputs_.size(); i++) { | ||||
| @@ -327,8 +328,8 @@ int NetTrain::RunExportedNet() { | |||||
| context->thread_num_ = flags_->num_threads_; | context->thread_num_ = flags_->num_threads_; | ||||
| session_ = session::TrainSession::CreateSession(flags_->export_file_.c_str(), context.get()); | session_ = session::TrainSession::CreateSession(flags_->export_file_.c_str(), context.get()); | ||||
| if (session_ == nullptr) { | if (session_ == nullptr) { | ||||
| MS_LOG(ERROR) << "CreateSession failed while running ", model_name.c_str(); | |||||
| std::cout << "CreateSession failed while running ", model_name.c_str(); | |||||
| MS_LOG(ERROR) << "ExportedFile CreateSession failed while running " << model_name.c_str(); | |||||
| std::cout << "CreateSession failed while running " << model_name.c_str() << std::endl; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| ms_inputs_ = session_->GetInputs(); | ms_inputs_ = session_->GetInputs(); | ||||
| @@ -344,13 +345,6 @@ int NetTrain::RunExportedNet() { | |||||
| return status; | return status; | ||||
| } | } | ||||
| status = session_->RunGraph(); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "Inference error " << status; | |||||
| std::cerr << "Inference error " << status << std::endl; | |||||
| return status; | |||||
| } | |||||
| if (!flags_->data_file_.empty()) { | if (!flags_->data_file_.empty()) { | ||||
| MS_LOG(INFO) << "Check accuracy for exported model"; | MS_LOG(INFO) << "Check accuracy for exported model"; | ||||
| std::cout << "Check accuracy for exported model " << std::endl; | std::cout << "Check accuracy for exported model " << std::endl; | ||||
| @@ -391,11 +385,13 @@ int NetTrain::RunNetTrain() { | |||||
| } else { | } else { | ||||
| context->device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = NO_BIND; | context->device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = NO_BIND; | ||||
| } | } | ||||
| layer_checksum_ = flags_->layer_checksum_; | |||||
| context->thread_num_ = flags_->num_threads_; | context->thread_num_ = flags_->num_threads_; | ||||
| session_ = session::TrainSession::CreateSession(flags_->model_file_.c_str(), context.get()); | session_ = session::TrainSession::CreateSession(flags_->model_file_.c_str(), context.get()); | ||||
| if (session_ == nullptr) { | if (session_ == nullptr) { | ||||
| MS_LOG(ERROR) << "CreateSession failed while running ", model_name.c_str(); | |||||
| std::cout << "CreateSession failed while running ", model_name.c_str(); | |||||
| MS_LOG(ERROR) << "RunNetTrain CreateSession failed while running " << model_name.c_str(); | |||||
| std::cout << "RunNetTrain CreateSession failed while running " << model_name.c_str() << std::endl; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| @@ -501,7 +497,6 @@ int NetTrain::InitCallbackParameter() { | |||||
| if (op_times_by_name_.find(callParam.node_name) == op_times_by_name_.end()) { | if (op_times_by_name_.find(callParam.node_name) == op_times_by_name_.end()) { | ||||
| op_times_by_name_.insert(std::make_pair(callParam.node_name, std::make_pair(0, 0.0f))); | op_times_by_name_.insert(std::make_pair(callParam.node_name, std::make_pair(0, 0.0f))); | ||||
| } | } | ||||
| op_call_times_total_++; | op_call_times_total_++; | ||||
| op_begin_ = GetTimeUs(); | op_begin_ = GetTimeUs(); | ||||
| return true; | return true; | ||||
| @@ -526,9 +521,14 @@ int NetTrain::InitCallbackParameter() { | |||||
| op_times_by_type_[call_param.node_type].second += cost; | op_times_by_type_[call_param.node_type].second += cost; | ||||
| op_times_by_name_[call_param.node_name].first++; | op_times_by_name_[call_param.node_name].first++; | ||||
| op_times_by_name_[call_param.node_name].second += cost; | op_times_by_name_[call_param.node_name].second += cost; | ||||
| if (layer_checksum_) { | |||||
| float *output = reinterpret_cast<float *>(after_outputs.at(0)->MutableData()); | |||||
| float sum = 0; | |||||
| for (int i = 0; i < after_outputs.at(0)->ElementsNum(); i++) sum += output[i]; | |||||
| std::cout << call_param.node_type << " shape= " << after_outputs.at(0)->shape() << " sum=" << sum << "\n"; | |||||
| } | |||||
| return true; | return true; | ||||
| }; | }; | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -29,6 +29,7 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include <cfloat> | #include <cfloat> | ||||
| #include <utility> | #include <utility> | ||||
| #include <algorithm> | |||||
| #include "tools/common/flag_parser.h" | #include "tools/common/flag_parser.h" | ||||
| #include "src/common/file_utils.h" | #include "src/common/file_utils.h" | ||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| @@ -64,6 +65,7 @@ class MS_API NetTrainFlags : public virtual FlagParser { | |||||
| AddFlag(&NetTrainFlags::data_file_, "expectedDataFile", "Expected results data file path", ""); | AddFlag(&NetTrainFlags::data_file_, "expectedDataFile", "Expected results data file path", ""); | ||||
| AddFlag(&NetTrainFlags::export_file_, "exportFile", "MS File to export trained model into", ""); | AddFlag(&NetTrainFlags::export_file_, "exportFile", "MS File to export trained model into", ""); | ||||
| AddFlag(&NetTrainFlags::accuracy_threshold_, "accuracyThreshold", "Threshold of accuracy", 0.5); | AddFlag(&NetTrainFlags::accuracy_threshold_, "accuracyThreshold", "Threshold of accuracy", 0.5); | ||||
| AddFlag(&NetTrainFlags::layer_checksum_, "layerCheckSum", "layer output checksum print (debug)", false); | |||||
| } | } | ||||
| ~NetTrainFlags() override = default; | ~NetTrainFlags() override = default; | ||||
| @@ -92,6 +94,7 @@ class MS_API NetTrainFlags : public virtual FlagParser { | |||||
| // Resize | // Resize | ||||
| std::string export_file_ = ""; | std::string export_file_ = ""; | ||||
| std::string resize_dims_in_ = ""; | std::string resize_dims_in_ = ""; | ||||
| bool layer_checksum_ = false; | |||||
| std::vector<std::vector<int64_t>> resize_dims_; | std::vector<std::vector<int64_t>> resize_dims_; | ||||
| }; | }; | ||||
| @@ -142,11 +145,16 @@ class MS_API NetTrain { | |||||
| size_t errorCount = 0; | size_t errorCount = 0; | ||||
| float meanError = 0; | float meanError = 0; | ||||
| std::cout << "Data of model output: "; | std::cout << "Data of model output: "; | ||||
| for (int j = 0; j < std::min(50, size); j++) { | |||||
| std::cout << static_cast<float>(msTensorData[j]) << " "; | |||||
| } | |||||
| std::cout << std::endl; | |||||
| std::cout << "Data of Ref output : "; | |||||
| for (int j = 0; j < std::min(50, size); j++) { | |||||
| std::cout << refOutput[j] << " "; | |||||
| } | |||||
| for (int j = 0; j < size; j++) { | for (int j = 0; j < size; j++) { | ||||
| if (j < 50) { | |||||
| std::cout << static_cast<float>(msTensorData[j]) << " "; | |||||
| } | |||||
| std::cout << std::endl; | |||||
| if (std::isnan(msTensorData[j]) || std::isinf(msTensorData[j])) { | if (std::isnan(msTensorData[j]) || std::isinf(msTensorData[j])) { | ||||
| std::cerr << "Output tensor has nan or inf data, compare fail" << std::endl; | std::cerr << "Output tensor has nan or inf data, compare fail" << std::endl; | ||||
| MS_LOG(ERROR) << "Output tensor has nan or inf data, compare fail"; | MS_LOG(ERROR) << "Output tensor has nan or inf data, compare fail"; | ||||
| @@ -205,6 +213,7 @@ class MS_API NetTrain { | |||||
| mindspore::KernelCallBack before_call_back_; | mindspore::KernelCallBack before_call_back_; | ||||
| mindspore::KernelCallBack after_call_back_; | mindspore::KernelCallBack after_call_back_; | ||||
| bool layer_checksum_ = false; | |||||
| }; | }; | ||||
| int MS_API RunNetTrain(int argc, const char **argv); | int MS_API RunNetTrain(int argc, const char **argv); | ||||
| @@ -33,6 +33,7 @@ static const std::vector<schema::PrimitiveType> nhwcOpList = {schema::PrimitiveT | |||||
| schema::PrimitiveType_ApplyMomentum, | schema::PrimitiveType_ApplyMomentum, | ||||
| schema::PrimitiveType_SGD, | schema::PrimitiveType_SGD, | ||||
| schema::PrimitiveType_Adam, | schema::PrimitiveType_Adam, | ||||
| schema::PrimitiveType_ResizeGrad, | |||||
| schema::PrimitiveType_AvgPoolFusion, | schema::PrimitiveType_AvgPoolFusion, | ||||
| schema::PrimitiveType_MaxPoolFusion, | schema::PrimitiveType_MaxPoolFusion, | ||||
| schema::PrimitiveType_Conv2DFusion, | schema::PrimitiveType_Conv2DFusion, | ||||
| @@ -51,8 +52,9 @@ static const std::vector<schema::PrimitiveType> nhwcOpList = {schema::PrimitiveT | |||||
| schema::PrimitiveType_SpaceToBatchND}; | schema::PrimitiveType_SpaceToBatchND}; | ||||
| static const std::vector<schema::PrimitiveType> nhwcOpAllInputList = { | static const std::vector<schema::PrimitiveType> nhwcOpAllInputList = { | ||||
| schema::PrimitiveType_AvgPoolGrad, schema::PrimitiveType_MaxPoolGrad, schema::PrimitiveType_ActivationGrad, | |||||
| schema::PrimitiveType_Conv2DBackpropFilterFusion, schema::PrimitiveType_BatchNormGrad}; | |||||
| schema::PrimitiveType_AvgPoolGrad, schema::PrimitiveType_MaxPoolGrad, | |||||
| schema::PrimitiveType_ActivationGrad, schema::PrimitiveType_Conv2DBackpropFilterFusion, | |||||
| schema::PrimitiveType_BatchNormGrad, schema::PrimitiveType_ResizeGrad}; | |||||
| // index {} mean all inputs need insert | // index {} mean all inputs need insert | ||||
| static std::unordered_map<schema::PrimitiveType, std::vector<int>> extNhwcInsertIndex = { | static std::unordered_map<schema::PrimitiveType, std::vector<int>> extNhwcInsertIndex = { | ||||
| @@ -128,6 +128,7 @@ int RunConverter(int argc, const char **argv) { | |||||
| oss << "CONVERT RESULT FAILED:" << status << " " << GetErrorInfo(status); | oss << "CONVERT RESULT FAILED:" << status << " " << GetErrorInfo(status); | ||||
| MS_LOG(ERROR) << oss.str(); | MS_LOG(ERROR) << oss.str(); | ||||
| std::cout << oss.str() << std::endl; | std::cout << oss.str() << std::endl; | ||||
| status = RET_ERROR; | |||||
| return status; | return status; | ||||
| } | } | ||||
| @@ -172,6 +172,11 @@ STATUS FormatTransPass::DoNodeInoutFormatTrans(schema::MetaGraphT *graph) { | |||||
| } | } | ||||
| } else if (IsContain(GetNhwcAllInputOpList(), opType)) { | } else if (IsContain(GetNhwcAllInputOpList(), opType)) { | ||||
| auto input_size = node->inputIndex.size(); | auto input_size = node->inputIndex.size(); | ||||
| if (GetCNodeTType(**iter) == schema::PrimitiveType_ResizeGrad) { | |||||
| if ((**iter).primitive->value.AsResizeGrad()->method == schema::ResizeMethod_NEAREST) { | |||||
| input_size = 1; | |||||
| } | |||||
| } | |||||
| for (size_t i = 0; i < input_size; i++) { | for (size_t i = 0; i < input_size; i++) { | ||||
| iter = InsertFormatTransNode(graph, iter, kBefore, i, beforeNodeType, &status); | iter = InsertFormatTransNode(graph, iter, kBefore, i, beforeNodeType, &status); | ||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_LITE_SRC_PASS_COMMON_GLLO_UTILS_H_ | |||||
| #define MINDSPORE_LITE_SRC_PASS_COMMON_GLLO_UTILS_H_ | |||||
| #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_COMMON_GLLO_UTILS_H_ | |||||
| #define MINDSPORE_LITE_TOOLS_OPTIMIZER_COMMON_GLLO_UTILS_H_ | |||||
| #include <memory> | #include <memory> | ||||
| #include <string> | #include <string> | ||||
| @@ -37,6 +37,7 @@ namespace mindspore { | |||||
| namespace opt { | namespace opt { | ||||
| inline const PrimitivePtr kPrimReturn = std::make_shared<Primitive>("Return"); | inline const PrimitivePtr kPrimReturn = std::make_shared<Primitive>("Return"); | ||||
| inline const PrimitivePtr kPrimMakeTuple = std::make_shared<Primitive>("MakeTuple"); | inline const PrimitivePtr kPrimMakeTuple = std::make_shared<Primitive>("MakeTuple"); | ||||
| inline const PrimitivePtr kPrimMakeTupleV2 = std::make_shared<Primitive>("make_tuple"); | |||||
| inline const PrimitivePtr kPrimIdentity = std::make_shared<Primitive>("Identity"); | inline const PrimitivePtr kPrimIdentity = std::make_shared<Primitive>("Identity"); | ||||
| std::vector<int> CastToInt(const ValuePtr &value); | std::vector<int> CastToInt(const ValuePtr &value); | ||||
| @@ -146,4 +147,4 @@ ParameterPtr BuildFloatValueParameterNode(const FuncGraphPtr &func_graph, const | |||||
| const std::string &node_name); | const std::string &node_name); | ||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_LITE_SRC_PASS_COMMON_GLLO_UTILS_H_ | |||||
| #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_COMMON_GLLO_UTILS_H_ | |||||
| @@ -131,7 +131,12 @@ constexpr auto kNameHSwishGrad = "HSwishGrad"; | |||||
| constexpr auto kNameReluGrad = "ReluGrad"; | constexpr auto kNameReluGrad = "ReluGrad"; | ||||
| constexpr auto kNameReLU6Grad = "ReLU6Grad"; | constexpr auto kNameReLU6Grad = "ReLU6Grad"; | ||||
| constexpr auto kNameSigmoidGrad = "SigmoidGrad"; | constexpr auto kNameSigmoidGrad = "SigmoidGrad"; | ||||
| constexpr auto kNameEluGrad = "EluGrad"; | |||||
| constexpr auto kNameGeluGrad = "GeluGrad"; | |||||
| constexpr auto kNameSlice = "Slice"; | constexpr auto kNameSlice = "Slice"; | ||||
| constexpr auto kNameAvgPoolGradGpu = "AvgPoolGradGpu"; | |||||
| constexpr auto kNameAvgPoolGradCpu = "AvgPoolGradCpu"; | |||||
| std::map<std::string, mindspore::ActivationType> activation_map = {{ops::kNameElu, mindspore::ELU}, | std::map<std::string, mindspore::ActivationType> activation_map = {{ops::kNameElu, mindspore::ELU}, | ||||
| {ops::kNameGeLU, mindspore::GELU}, | {ops::kNameGeLU, mindspore::GELU}, | ||||
| {ops::kNameLeakyRelu, mindspore::LEAKY_RELU}, | {ops::kNameLeakyRelu, mindspore::LEAKY_RELU}, | ||||
| @@ -145,7 +150,9 @@ std::map<std::string, mindspore::ActivationType> activation_map = {{ops::kNameEl | |||||
| {kNameHSwishGrad, mindspore::HSWISH}, | {kNameHSwishGrad, mindspore::HSWISH}, | ||||
| {kNameReluGrad, mindspore::RELU}, | {kNameReluGrad, mindspore::RELU}, | ||||
| {kNameReLU6Grad, mindspore::RELU6}, | {kNameReLU6Grad, mindspore::RELU6}, | ||||
| {kNameSigmoidGrad, mindspore::SIGMOID}}; | |||||
| {kNameSigmoidGrad, mindspore::SIGMOID}, | |||||
| {kNameEluGrad, mindspore::ELU}, | |||||
| {kNameGeluGrad, mindspore::GELU}}; | |||||
| std::map<std::string, mindspore::ReduceMode> reduce_map = { | std::map<std::string, mindspore::ReduceMode> reduce_map = { | ||||
| {ops::kNameReduceAll, mindspore::Reduce_All}, {ops::kNameReduceASum, mindspore::Reduce_ASum}, | {ops::kNameReduceAll, mindspore::Reduce_All}, {ops::kNameReduceASum, mindspore::Reduce_ASum}, | ||||
| @@ -351,16 +358,29 @@ int MoveAttrPoolGrad(const CNodePtr &cnode) { | |||||
| MS_LOG(ERROR) << "value node is invalid."; | MS_LOG(ERROR) << "value node is invalid."; | ||||
| return lite::RET_ERROR; | return lite::RET_ERROR; | ||||
| } | } | ||||
| auto status = AttrAdjust(src_prim, ops::kKernelSize, {2, 3}); | |||||
| PrimitivePtr dst_prim; | |||||
| if (src_prim->name() == kNameAvgPoolGrad || src_prim->name() == kNameAvgPoolGradGpu || | |||||
| src_prim->name() == kNameAvgPoolGradCpu) { | |||||
| dst_prim = std::make_shared<ops::AvgPoolGrad>(); | |||||
| } else if (src_prim->name() == kNameMaxPoolGrad) { | |||||
| dst_prim = std::make_shared<ops::MaxPoolGrad>(); | |||||
| } else { | |||||
| MS_LOG(ERROR) << "unsupported pooling type."; | |||||
| return lite::RET_ERROR; | |||||
| } | |||||
| MS_ASSERT(dst_prim != nullptr); | |||||
| dst_prim->SetAttrs(src_prim->attrs()); | |||||
| auto status = AttrAdjust(dst_prim, ops::kKernelSize, {2, 3}); | |||||
| if (status != lite::RET_OK) { | if (status != lite::RET_OK) { | ||||
| MS_LOG(ERROR) << "adjust ksize failed."; | MS_LOG(ERROR) << "adjust ksize failed."; | ||||
| return status; | return status; | ||||
| } | } | ||||
| status = AttrAdjust(src_prim, ops::kStrides, {2, 3}); | |||||
| status = AttrAdjust(dst_prim, ops::kStrides, {2, 3}); | |||||
| if (status != lite::RET_OK) { | if (status != lite::RET_OK) { | ||||
| MS_LOG(ERROR) << "adjust strides failed."; | MS_LOG(ERROR) << "adjust strides failed."; | ||||
| return status; | return status; | ||||
| } | } | ||||
| value_node->set_value(dst_prim); | |||||
| return lite::RET_OK; | return lite::RET_OK; | ||||
| } | } | ||||
| @@ -510,6 +530,8 @@ REGIST_PRIMITIVE_ADJUST(kNameArgMin, MoveAttrMapCommon<ops::ArgMinFusion>) | |||||
| REGIST_PRIMITIVE_ADJUST(kNameArgMinWithValue, MoveAttrMapCommon<ops::ArgMinFusion>) | REGIST_PRIMITIVE_ADJUST(kNameArgMinWithValue, MoveAttrMapCommon<ops::ArgMinFusion>) | ||||
| REGIST_PRIMITIVE_ADJUST(kNameAvgPool, MoveAttrPool) | REGIST_PRIMITIVE_ADJUST(kNameAvgPool, MoveAttrPool) | ||||
| REGIST_PRIMITIVE_ADJUST(kNameAvgPoolGrad, MoveAttrPoolGrad) | REGIST_PRIMITIVE_ADJUST(kNameAvgPoolGrad, MoveAttrPoolGrad) | ||||
| REGIST_PRIMITIVE_ADJUST(kNameAvgPoolGradGpu, MoveAttrPoolGrad) | |||||
| REGIST_PRIMITIVE_ADJUST(kNameAvgPoolGradCpu, MoveAttrPoolGrad) | |||||
| REGIST_PRIMITIVE_ADJUST(kNameBatchMatMul, MoveAttrMapCommon<ops::MatMul>) | REGIST_PRIMITIVE_ADJUST(kNameBatchMatMul, MoveAttrMapCommon<ops::MatMul>) | ||||
| REGIST_PRIMITIVE_ADJUST(kNameBatchNorm, MoveAttrMapCommon<ops::FusedBatchNorm>) | REGIST_PRIMITIVE_ADJUST(kNameBatchNorm, MoveAttrMapCommon<ops::FusedBatchNorm>) | ||||
| REGIST_PRIMITIVE_ADJUST(kNameConv2DBackpropFilter, MoveAttrMapCommon<ops::Conv2DBackpropFilterFusion>) | REGIST_PRIMITIVE_ADJUST(kNameConv2DBackpropFilter, MoveAttrMapCommon<ops::Conv2DBackpropFilterFusion>) | ||||
| @@ -519,10 +541,12 @@ REGIST_PRIMITIVE_ADJUST(kNameDepthWiseConv2D, MoveAttrMapConv2D) | |||||
| REGIST_PRIMITIVE_ADJUST(kNameConv2dTranspose, MoveAttrMapCommon<ops::Conv2dTransposeFusion>) | REGIST_PRIMITIVE_ADJUST(kNameConv2dTranspose, MoveAttrMapCommon<ops::Conv2dTransposeFusion>) | ||||
| REGIST_PRIMITIVE_ADJUST(kNameDiv, MoveAttrMapCommon<ops::DivFusion>) | REGIST_PRIMITIVE_ADJUST(kNameDiv, MoveAttrMapCommon<ops::DivFusion>) | ||||
| REGIST_PRIMITIVE_ADJUST(kNameElu, MoveAttrMapActivation) | REGIST_PRIMITIVE_ADJUST(kNameElu, MoveAttrMapActivation) | ||||
| REGIST_PRIMITIVE_ADJUST(kNameEluGrad, MoveAttrMapActivationGrad) | |||||
| REGIST_PRIMITIVE_ADJUST(kNameExp, MoveAttrMapCommon<ops::ExpFusion>) | REGIST_PRIMITIVE_ADJUST(kNameExp, MoveAttrMapCommon<ops::ExpFusion>) | ||||
| REGIST_PRIMITIVE_ADJUST(kNameFusedBatchNormEx, MoveAttrMapCommon<ops::FusedBatchNorm>) | REGIST_PRIMITIVE_ADJUST(kNameFusedBatchNormEx, MoveAttrMapCommon<ops::FusedBatchNorm>) | ||||
| REGIST_PRIMITIVE_ADJUST(kNameFusedBatchNormGradEx, MoveAttrMapCommon<ops::BatchNormGrad>) | REGIST_PRIMITIVE_ADJUST(kNameFusedBatchNormGradEx, MoveAttrMapCommon<ops::BatchNormGrad>) | ||||
| REGIST_PRIMITIVE_ADJUST(kNameGeLU, MoveAttrMapActivation) | REGIST_PRIMITIVE_ADJUST(kNameGeLU, MoveAttrMapActivation) | ||||
| REGIST_PRIMITIVE_ADJUST(kNameGeluGrad, MoveAttrMapActivationGrad) | |||||
| REGIST_PRIMITIVE_ADJUST(kNameHSigmoid, MoveAttrMapActivation) | REGIST_PRIMITIVE_ADJUST(kNameHSigmoid, MoveAttrMapActivation) | ||||
| REGIST_PRIMITIVE_ADJUST(kNameHSigmoidGrad, MoveAttrMapActivationGrad) | REGIST_PRIMITIVE_ADJUST(kNameHSigmoidGrad, MoveAttrMapActivationGrad) | ||||
| REGIST_PRIMITIVE_ADJUST(kNameHSwish, MoveAttrMapActivation) | REGIST_PRIMITIVE_ADJUST(kNameHSwish, MoveAttrMapActivation) | ||||
| @@ -35,6 +35,21 @@ int RemoveRedundantOpPass::ReplaceOp(const AnfNodePtr &anf_node, const FuncGraph | |||||
| return lite::RET_NO_CHANGE; | return lite::RET_NO_CHANGE; | ||||
| } | } | ||||
| } | } | ||||
| if (CheckPrimitiveType(anf_node, prim::kPrimDepend)) { | |||||
| if (cnode->size() != InputDoubleNum) { | |||||
| MS_LOG(DEBUG) << "The node inputs size is bigger than 1"; | |||||
| remove_cnode_.insert(anf_node); | |||||
| return lite::RET_NO_CHANGE; | |||||
| } | |||||
| } | |||||
| if (CheckPrimitiveType(anf_node, prim::kPrimControlDepend)) { | |||||
| if (cnode->size() != InputDoubleNum) { | |||||
| MS_LOG(DEBUG) << "The node inputs size is bigger than 1"; | |||||
| remove_cnode_.insert(anf_node); | |||||
| return lite::RET_NO_CHANGE; | |||||
| } | |||||
| } | |||||
| bool replace_succ = manager->Replace(anf_node, cnode->input(1)); | bool replace_succ = manager->Replace(anf_node, cnode->input(1)); | ||||
| if (!replace_succ) { | if (!replace_succ) { | ||||
| MS_LOG(ERROR) << "replace redundant op failed."; | MS_LOG(ERROR) << "replace redundant op failed."; | ||||
| @@ -70,7 +70,9 @@ lite::STATUS WeightFormatTransformPass::TransposeInsertForWeightSharing(const Fu | |||||
| continue; | continue; | ||||
| } | } | ||||
| if (CheckPrimitiveType(node, prim::kPrimConv2DFusion) || CheckPrimitiveType(node, kPrimConv2DBackpropInputFusion) || | if (CheckPrimitiveType(node, prim::kPrimConv2DFusion) || CheckPrimitiveType(node, kPrimConv2DBackpropInputFusion) || | ||||
| CheckPrimitiveType(node, prim::kPrimConv2dTransposeFusion)) { | |||||
| CheckPrimitiveType(node, prim::kPrimConv2dTransposeFusion) || | |||||
| CheckPrimitiveType(node, prim::kPrimApplyMomentum) || CheckPrimitiveType(node, prim::kPrimSGD) || | |||||
| CheckPrimitiveType(node, prim::kPrimAdam)) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| auto cnode = node->cast<CNodePtr>(); | auto cnode = node->cast<CNodePtr>(); | ||||