|
|
@@ -33,7 +33,7 @@ class FusedAdamWeightDecayGpuKernel : public GpuKernel { |
|
|
|
|
|
|
|
|
bool Init(const CNodePtr &kernel_node) override { |
|
|
bool Init(const CNodePtr &kernel_node) override { |
|
|
auto node_name = AnfAlgo::GetCNodeName(kernel_node); |
|
|
auto node_name = AnfAlgo::GetCNodeName(kernel_node); |
|
|
if (node_name == "AdamWeighDecay") { |
|
|
|
|
|
|
|
|
if (node_name == "FusedAdamWeightDecay") { |
|
|
weight_decay_ = true; |
|
|
weight_decay_ = true; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|