You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

adam.cc 3.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. /**
  2. * Copyright 2021-2022 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <memory>
  17. #include <vector>
  18. #include "common/graph_kernel/expanders/expander_factory.h"
  19. #include "ir/dtype.h"
  20. namespace mindspore::graphkernel::expanders {
  21. class Adam : public OpDesc {
  22. public:
  23. Adam() {
  24. std::initializer_list<std::string> attrs{"use_nesterov"};
  25. (void)validators_.emplace_back(std::make_unique<CheckAttr>(attrs));
  26. }
  27. ~Adam() = default;
  28. protected:
  29. bool CheckInputs() override {
  30. const auto &var = inputs_info_[0];
  31. if (var.type != kNumberTypeFloat32 && var.type != kNumberTypeFloat16) {
  32. MS_LOG(INFO) << "In Adam, var's dtype must be float16 or float32";
  33. return false;
  34. }
  35. return true;
  36. }
  37. NodePtrList Expand() override {
  38. const auto &inputs = gb.Get()->inputs();
  39. const auto &var = inputs[0];
  40. const auto &m = inputs[1];
  41. const auto &v = inputs[2];
  42. const auto &beta1_power = inputs[3];
  43. const auto &beta2_power = inputs[4];
  44. const auto &lr = inputs[5];
  45. const auto &beta1 = inputs[6];
  46. const auto &beta2 = inputs[7];
  47. const auto &epsilon = inputs[8];
  48. const auto &grad = inputs[9];
  49. // calc m_new : m_new = beta1 * m + (1 - beta1) * grad
  50. auto m_b = gb.Emit("Mul", {beta1, m});
  51. tensor::TensorPtr data = std::make_shared<tensor::Tensor>(static_cast<double>(1.0), TypeIdToType(var->type));
  52. auto const_one = gb.Value(data);
  53. auto m1_beta1 = gb.Emit("Sub", {const_one, beta1});
  54. auto m_g = gb.Emit("Mul", {m1_beta1, grad});
  55. auto m_new = gb.Emit("Add", {m_b, m_g});
  56. // calc v_new: v_new = beta2 * v + (1 - beta2) * grad * grad
  57. auto v_b = gb.Emit("Mul", {beta2, v});
  58. auto m1_beta2 = gb.Emit("Sub", {const_one, beta2});
  59. auto grad_mul = gb.Emit("Mul", {grad, grad});
  60. auto v_g = gb.Emit("Mul", {m1_beta2, grad_mul});
  61. auto v_new = gb.Emit("Add", {v_b, v_g});
  62. // calc lr_t: lr_t = lr * sqrt(1 - beta2_power) / (1 - beta1_power);
  63. auto m1_beta2_power = gb.Emit("Sub", {const_one, beta2_power});
  64. auto m1_beta2_power_sqrt = gb.Emit("Sqrt", {m1_beta2_power});
  65. auto m1_beta1_power = gb.Emit("Sub", {const_one, beta1_power});
  66. auto power_div = gb.Emit("RealDiv", {m1_beta2_power_sqrt, m1_beta1_power});
  67. auto lr_t = gb.Emit("Mul", {lr, power_div});
  68. // if use_nesterov: var_new <- var - lr_t * (m_new * beta1 + (1 - beta1) * grad) / (epsilon + sqrt(v_new))
  69. // if not use_nesterov: var_new <- var - lr_t * m_new / (epsilon + sqrt(v_new))
  70. auto v_new_sqrt = gb.Emit("Sqrt", {v_new});
  71. auto v_new_sqrt_e = gb.Emit("Add", {epsilon, v_new_sqrt});
  72. auto lr_t_div = gb.Emit("RealDiv", {lr_t, v_new_sqrt_e});
  73. mindspore::graphkernel::inner::NodePtr var_sub;
  74. if (GetValue<bool>(attrs_["use_nesterov"])) {
  75. auto m_new_mul = gb.Emit("Mul", {m_new, beta1});
  76. auto m_new_mul_add = gb.Emit("Add", {m_new_mul, m_g});
  77. var_sub = gb.Emit("Mul", {lr_t_div, m_new_mul_add});
  78. } else {
  79. var_sub = gb.Emit("Mul", {lr_t_div, m_new});
  80. }
  81. auto var_new = gb.Emit("Sub", {var, var_sub});
  82. auto var_result = gb.Emit("Assign", {var, var_new});
  83. auto m_result = gb.Emit("Assign", {m, m_new});
  84. auto v_result = gb.Emit("Assign", {v, v_new});
  85. auto result = {var_result, m_result, v_result};
  86. return result;
  87. }
  88. };
  89. OP_EXPANDER_REGISTER("Adam", Adam);
  90. } // namespace mindspore::graphkernel::expanders