浏览代码

!2399 fix param KeyError in group params

Merge pull request !2399 from ghzl/fix-params-keyerror-in-group-params
tags/v0.6.0-beta
mindspore-ci-bot Gitee 5 年前
父节点
当前提交
4c4586ea6f
共有 2 个文件被更改,包括 20 次插入9 次删除
  1. +20
    -3
      mindspore/nn/optim/optimizer.py
  2. +0
    -6
      mindspore/ops/operations/debug_ops.py

+ 20
- 3
mindspore/nn/optim/optimizer.py 查看文件

@@ -219,8 +219,28 @@ class Optimizer(Cell):
raise TypeError("Learning rate should be float, Tensor or Iterable.")
return lr

def _check_group_params(self, parameters):
"""Check group params."""
parse_keys = ['params', 'lr', 'weight_decay', 'order_params']
for group_param in parameters:
invalid_key = list(filter(lambda x: x not in parse_keys, group_param.keys()))
if invalid_key:
raise KeyError(f'The key "{invalid_key}" cannot be recognized in group params.')

if 'order_params' in group_param.keys():
if len(group_param.keys()) > 1:
raise ValueError("The order params dict in group parameters should "
"only include the 'order_params' key.")
if not isinstance(group_param['order_params'], Iterable):
raise TypeError("The value of 'order_params' should be an Iterable type.")
continue

if not group_param['params']:
raise ValueError("Optimizer got an empty group parameter list.")

def _parse_group_params(self, parameters, learning_rate):
"""Parse group params."""
self._check_group_params(parameters)
if self.dynamic_lr:
dynamic_lr_length = learning_rate.size()
else:
@@ -250,9 +270,6 @@ class Optimizer(Cell):
if dynamic_lr_length not in (lr_length, 0):
raise ValueError("The dynamic learning rate in group should be the same size.")

if not group_param['params']:
raise ValueError("Optimizer got an empty group parameter list.")

dynamic_lr_length = lr_length
self.dynamic_lr_length = dynamic_lr_length



+ 0
- 6
mindspore/ops/operations/debug_ops.py 查看文件

@@ -309,12 +309,6 @@ class Print(PrimitiveWithInfer):
Output tensor or string to stdout.

Note:
The print operation cannot support the following cases currently.

1. The type of tensor is float64 or bool.

2. The data of tensor is a scalar type.

In pynative mode, please use python print function.

Inputs:


正在加载...
取消
保存