You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

meta_func_graph.cc 2.5 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "ir/meta_func_graph.h"
  19. #include "pipeline/static_analysis/static_analysis.h"
  20. #include "pipeline/static_analysis/abstract_function.h"
  21. // namespace to support intermediate representation definition
  22. namespace mindspore {
  23. abstract::AbstractBasePtr MetaFuncGraph::MakeAbstractClosure(const AnfNodePtr &anf_node) {
  24. abstract::MetaFuncGraphAbstractClosurePtr meta_func_graph_fn;
  25. if (anf_node == nullptr) {
  26. meta_func_graph_fn = std::make_shared<abstract::MetaFuncGraphAbstractClosure>(shared_from_base<MetaFuncGraph>());
  27. } else {
  28. meta_func_graph_fn =
  29. std::make_shared<abstract::MetaFuncGraphAbstractClosure>(shared_from_base<MetaFuncGraph>(), anf_node->scope());
  30. }
  31. return meta_func_graph_fn;
  32. }
  33. FuncGraphPtr MetaFuncGraph::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_spec_list) {
  34. TypePtrList types;
  35. (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(types),
  36. [](const AbstractBasePtr &arg) -> TypePtr {
  37. MS_EXCEPTION_IF_NULL(arg);
  38. return arg->BuildType();
  39. });
  40. // filter unsafe characters in log print since name_ is from outside
  41. auto iter = cache_.find(types);
  42. if (iter == cache_.end()) {
  43. FuncGraphPtr fg = GenerateFromTypes(types);
  44. MS_EXCEPTION_IF_NULL(fg);
  45. MS_LOG(INFO) << "MetaFuncgraph: cache miss for types: " << mindspore::ToString(args_spec_list)
  46. << ", g: " << fg->ToString();
  47. cache_[types] = fg;
  48. return fg;
  49. } else {
  50. MS_LOG(DEBUG) << "MetaFuncgraph: cache hit for types: " << mindspore::ToString(args_spec_list)
  51. << ", g: " << iter->second->ToString();
  52. return iter->second;
  53. }
  54. }
  55. } // namespace mindspore