/** * Copyright 2019 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "operator/prim_to_function.h" #include #include #include namespace mindspore { // namespace to support prim related definition namespace prim { PrimToFunction::PrimToFunction() : prim_func_type_map_({// ONE_ARG prim {"bool_not", kPrimTypeOneArg}, {"scalar_cos", kPrimTypeOneArg}, {"scalar_exp", kPrimTypeOneArg}, {"scalar_floor", kPrimTypeOneArg}, {"scalar_log", kPrimTypeOneArg}, {"scalar_sin", kPrimTypeOneArg}, {"scalar_tan", kPrimTypeOneArg}, {"scalar_trunc", kPrimTypeOneArg}, {"typeof", kPrimTypeOneArg}, {"scalar_uadd", kPrimTypeOneArg}, {"scalar_usub", kPrimTypeOneArg}, // TWO_ARGS prim {"scalar_add", kPrimTypeTwoArgs}, {"bool_and", kPrimTypeTwoArgs}, {"bool_eq", kPrimTypeTwoArgs}, {"bool_or", kPrimTypeTwoArgs}, {"scalar_div", kPrimTypeTwoArgs}, {"scalar_eq", kPrimTypeTwoArgs}, {"scalar_ge", kPrimTypeTwoArgs}, {"scalar_gt", kPrimTypeTwoArgs}, {"scalar_le", kPrimTypeTwoArgs}, {"scalar_lt", kPrimTypeTwoArgs}, {"scalar_ne", kPrimTypeTwoArgs}, {"scalar_mod", kPrimTypeTwoArgs}, {"scalar_mul", kPrimTypeTwoArgs}, {"scalar_pow", kPrimTypeTwoArgs}, {"scalar_sub", kPrimTypeTwoArgs}, {"scalar_floordiv", kPrimTypeTwoArgs}}) {} bool PrimToFunction::GetFunction(const PrimitivePtr &prim, FunctionPtr *const func) const { bool result = false; if (func != nullptr) { int args_num = GetPrimType(prim); std::vector one_arg{std::make_shared()}; std::vector two_args{std::make_shared(), std::make_shared()}; TypePtr retval = std::make_shared(); result = true; switch (args_num) { case kPrimTypeOneArg: *func = Function(one_arg, retval).DeepCopy()->cast(); break; case kPrimTypeTwoArgs: *func = Function(two_args, retval).DeepCopy()->cast(); break; default: result = false; break; } } return result; } int PrimToFunction::GetPrimType(const PrimitivePtr &prim) const { MS_EXCEPTION_IF_NULL(prim); int prim_type = static_cast(kPrimTypeUnknown); auto value = prim_func_type_map_.find(prim->name()); if (value != prim_func_type_map_.end()) { prim_type = value->second; } return prim_type; } } // namespace prim } // namespace mindspore