Browse Source

!995 Clean some pylint-warnings

Merge pull request !995 from SJN/master
tags/v0.3.0-alpha
mindspore-ci-bot Gitee 5 years ago
parent
commit
2c4fec57a0
3 changed files with 9 additions and 9 deletions
  1. +6
    -7
      example/yolov3_coco2017/train.py
  2. +1
    -0
      mindspore/_akg/add_path.py
  3. +2
    -2
      mindspore/_akg/utils/format_transform.py

+ 6
- 7
example/yolov3_coco2017/train.py View File

@@ -41,23 +41,22 @@ from config import ConfigYOLOV3ResNet18
def get_lr(learning_rate, start_step, global_step, decay_step, decay_rate, steps=False):
"""Set learning rate."""
lr_each_step = []
lr = learning_rate
for i in range(global_step):
if steps:
lr_each_step.append(lr * (decay_rate ** (i // decay_step)))
lr_each_step.append(learning_rate * (decay_rate ** (i // decay_step)))
else:
lr_each_step.append(lr * (decay_rate ** (i / decay_step)))
lr_each_step.append(learning_rate * (decay_rate ** (i / decay_step)))
lr_each_step = np.array(lr_each_step).astype(np.float32)
lr_each_step = lr_each_step[start_step:]
return lr_each_step


def init_net_param(net, init='ones'):
"""Init the parameters in net."""
params = net.trainable_params()
def init_net_param(network, init_value='ones'):
"""Init:wq the parameters in network."""
params = network.trainable_params()
for p in params:
if isinstance(p.data, Tensor) and 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name:
p.set_parameter_data(initializer(init, p.data.shape(), p.data.dtype()))
p.set_parameter_data(initializer(init_value, p.data.shape(), p.data.dtype()))


def main():


+ 1
- 0
mindspore/_akg/add_path.py View File

@@ -33,6 +33,7 @@ class AKGMetaPathFinder:

def find_module(self, fullname, path=None):
"""method _akg find module."""
_ = path
if fullname.startswith("_akg.tvm"):
rname = fullname[5:]
return AKGMetaPathLoader(rname)


+ 2
- 2
mindspore/_akg/utils/format_transform.py View File

@@ -15,9 +15,9 @@
"""format transform function"""
import _akg

def refine_reduce_axis(input, axis):
def refine_reduce_axis(input_content, axis):
"""make reduce axis legal."""
shape = get_shape(input)
shape = get_shape(input_content)
if axis is None:
axis = [i for i in range(len(shape))]
elif isinstance(axis, int):


Loading…
Cancel
Save