You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

softplus.cc 1.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <memory>
  17. #include <vector>
  18. #include "common/graph_kernel/expanders/expander_factory.h"
  19. #include "ir/dtype.h"
  20. namespace mindspore::graphkernel::expanders {
  21. class Softplus : public OpDesc {
  22. public:
  23. Softplus() {}
  24. ~Softplus() = default;
  25. protected:
  26. bool CheckInputs() override {
  27. const auto &input_x = inputs_info_[0];
  28. if (input_x.type != kNumberTypeFloat32 && input_x.type != kNumberTypeFloat16) {
  29. MS_LOG(INFO) << "In Softplus, input_x's dtype must be float16 or float32";
  30. return false;
  31. }
  32. return true;
  33. }
  34. NodePtrList Expand() override {
  35. const auto &inputs = gb.Get()->inputs();
  36. const auto &input_x = inputs[0];
  37. auto exp_x = gb.Emit("Exp", {input_x});
  38. tensor::TensorPtr data = std::make_shared<tensor::Tensor>(static_cast<double>(1.0), TypeIdToType(input_x->type));
  39. auto const_one = gb.Value(data);
  40. auto exp_x_add_one = gb.Emit("Add", {exp_x, const_one});
  41. auto result = gb.Emit("Log", {exp_x_add_one});
  42. return {result};
  43. }
  44. };
  45. OP_EXPANDER_REGISTER("Softplus", Softplus);
  46. } // namespace mindspore::graphkernel::expanders