|
|
|
@@ -46,6 +46,7 @@ using context::OpLevel_1; |
|
|
|
constexpr size_t kAssignInputIdx = 1; |
|
|
|
constexpr size_t kLambOptimizerInputIdx = 12; |
|
|
|
constexpr size_t kLambWeightInputIdx = 4; |
|
|
|
constexpr size_t kRandomInputIdx = 1; |
|
|
|
|
|
|
|
std::vector<PrimitivePtr> GetExpandOps() { |
|
|
|
std::vector<std::tuple<std::string, unsigned int, PrimitivePtr>> expand_ops_with_level = { |
|
|
|
@@ -93,6 +94,7 @@ std::vector<PrimitivePtr> GetExpandOps() { |
|
|
|
{kGPUDevice, OpLevel_0, prim::kPrimSquareSumAll}, |
|
|
|
{kGPUDevice, OpLevel_0, prim::kPrimIdentityMath}, |
|
|
|
{kGPUDevice, OpLevel_0, prim::kPrimOnesLike}, |
|
|
|
{kGPUDevice, OpLevel_0, prim::kPrimStandardNormal}, |
|
|
|
}; |
|
|
|
const auto &flags = context::GraphKernelFlags::GetInstance(); |
|
|
|
std::vector<PrimitivePtr> expand_ops = GetValidOps(expand_ops_with_level, flags.fusion_ops_level); |
|
|
|
@@ -201,6 +203,7 @@ ExpanderPtr GraphKernelExpander::GetExpander(const AnfNodePtr &node) { |
|
|
|
{prim::kPrimAssignSub, std::make_shared<OpUMonadExpander>(kAssignInputIdx)}, |
|
|
|
{prim::kLambApplyOptimizerAssign, std::make_shared<OpUMonadExpander>(kLambOptimizerInputIdx)}, |
|
|
|
{prim::kLambApplyWeightAssign, std::make_shared<OpUMonadExpander>(kLambWeightInputIdx)}, |
|
|
|
{prim::kPrimStandardNormal, std::make_shared<OpUMonadExpander>(kRandomInputIdx)}, |
|
|
|
}; |
|
|
|
|
|
|
|
for (auto &e : expanders) { |
|
|
|
|