Browse Source

set device id master 0813

tags/v1.5.0-rc1
dingpeifei d00455729 4 years ago
parent
commit
b4bc6000dc
10 changed files with 50 additions and 19 deletions
  1. +30
    -0
      mindspore/_checkparam.py
  2. +2
    -2
      mindspore/context.py
  3. +1
    -1
      tests/st/ops/gpu/test_fake_quant_perchannel.py
  4. +1
    -1
      tests/st/ops/gpu/test_fake_quant_perchannel_grad.py
  5. +1
    -1
      tests/st/ops/gpu/test_fake_quant_perlayer_grad.py
  6. +1
    -2
      tests/ut/python/parallel/test_auto_parallel_resnet_predict.py
  7. +1
    -1
      tests/ut/python/parallel/test_auto_parallel_resnet_sharding_propagation.py
  8. +1
    -1
      tests/ut/python/parallel/test_auto_parallel_resnet_sharding_propagation2.py
  9. +12
    -9
      tests/ut/python/pynative_mode/test_context.py
  10. +0
    -1
      tests/ut/python/pynative_mode/test_multigraph_sink.py

+ 30
- 0
mindspore/_checkparam.py View File

@@ -951,3 +951,33 @@ def args_type_check(*type_args, **type_kwargs):
return wrapper

return type_check


_set_record = {}


def args_unreset_check(*unreset_args, **unreset_kwargs):
"""Check the entered non repeatable setting properties."""

def unreset_check(func):
sig = inspect.signature(func)
bound_unreset = sig.bind_partial(*unreset_args, **unreset_kwargs).arguments
@wraps(func)
def wrapper(*args, **kwargs):
nonlocal bound_unreset
bound_values = sig.bind(*args, **kwargs)
argument_dict = bound_values.arguments
if "kwargs" in bound_unreset:
bound_unreset = bound_unreset["kwargs"]
if "kwargs" in argument_dict:
argument_dict = argument_dict["kwargs"]
for name, value in argument_dict.items():
if name in _set_record.keys():
raise TypeError('Argument{}non resettable parameter{}.'.format(name, bound_unreset[name]))
if name in bound_unreset:
_set_record[name] = value
return func(*args, **kwargs)

return wrapper

return unreset_check

+ 2
- 2
mindspore/context.py View File

@@ -24,7 +24,7 @@ from collections import namedtuple
from types import FunctionType
from mindspore import log as logger
from mindspore._c_expression import MSContext, ms_ctx_param
from mindspore._checkparam import args_type_check, Validator
from mindspore._checkparam import args_type_check, Validator, args_unreset_check
from mindspore.parallel._auto_parallel_context import _set_auto_parallel_context, _get_auto_parallel_context, \
_reset_auto_parallel_context
from mindspore.parallel._ps_context import _set_ps_context, _get_ps_context, _reset_ps_context
@@ -507,7 +507,7 @@ def _check_target_specific_cfgs(device, arg_key):
", ignore it.")
return False

@args_unreset_check(device_id=int, variable_memory_max_size=str, max_device_memory=str)
@args_type_check(mode=int, precompile_only=bool, device_target=str, device_id=int, save_graphs=bool,
save_graphs_path=str, enable_dump=bool, auto_tune_mode=str,
save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str,


+ 1
- 1
tests/st/ops/gpu/test_fake_quant_perchannel.py View File

@@ -21,7 +21,7 @@ from mindspore.common.tensor import Tensor
from mindspore import nn
from mindspore.ops.operations import _quant_ops as Q

context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU', device_id=0)
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')


class Net(nn.Cell):


+ 1
- 1
tests/st/ops/gpu/test_fake_quant_perchannel_grad.py View File

@@ -20,7 +20,7 @@ import mindspore.nn as nn
import mindspore.context as context
from mindspore.ops.operations import _quant_ops as Q

context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU', device_id=0)
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')


class Net(nn.Cell):


+ 1
- 1
tests/st/ops/gpu/test_fake_quant_perlayer_grad.py View File

@@ -20,7 +20,7 @@ import mindspore.nn as nn
import mindspore.context as context
from mindspore.ops.operations import _quant_ops as Q

context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU', device_id=0)
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')


class Net(nn.Cell):


+ 1
- 2
tests/ut/python/parallel/test_auto_parallel_resnet_predict.py View File

@@ -21,9 +21,8 @@ from mindspore.context import ParallelMode
from mindspore.communication._comm_helper import GlobalComm
from .test_auto_parallel_resnet import resnet50


context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(device_id=0)
context.set_context.__wrapped__(device_id=0)
GlobalComm.CHECK_ENVS = False
init()
GlobalComm.CHECK_ENVS = True


+ 1
- 1
tests/ut/python/parallel/test_auto_parallel_resnet_sharding_propagation.py View File

@@ -34,7 +34,7 @@ from mindspore.context import ParallelMode
from mindspore.communication._comm_helper import GlobalComm

context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(device_id=0)
context.set_context.__wrapped__(device_id=0)
GlobalComm.CHECK_ENVS = False
init()
GlobalComm.CHECK_ENVS = True


+ 1
- 1
tests/ut/python/parallel/test_auto_parallel_resnet_sharding_propagation2.py View File

@@ -33,7 +33,7 @@ from mindspore.context import ParallelMode
from mindspore.communication._comm_helper import GlobalComm

context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(device_id=0)
context.set_context.__wrapped__(device_id=0)
GlobalComm.CHECK_ENVS = False
init()
GlobalComm.CHECK_ENVS = True


+ 12
- 9
tests/ut/python/pynative_mode/test_context.py View File

@@ -49,9 +49,8 @@ def test_switch_mode():
def test_set_device_id():
""" test_set_device_id """
with pytest.raises(TypeError):
context.set_context(device_id=1)
context.set_context(device_id="cpu")
assert context.get_context("device_id") == 0
context.set_context(device_id=1)
assert context.get_context("device_id") == 1


@@ -115,14 +114,17 @@ def test_variable_memory_max_size():
context.set_context(variable_memory_max_size=True)
with pytest.raises(TypeError):
context.set_context(variable_memory_max_size=1)
with pytest.raises(ValueError):
context.set_context(variable_memory_max_size="")
with pytest.raises(ValueError):
context.set_context(variable_memory_max_size="1G")
with pytest.raises(ValueError):
context.set_context(variable_memory_max_size="32GB")
context.set_context(variable_memory_max_size="3GB")
context.set_context.__wrapped__(variable_memory_max_size="3GB")

def test_max_device_memory_size():
"""test_max_device_memory_size"""
with pytest.raises(TypeError):
context.set_context(max_device_memory=True)
with pytest.raises(TypeError):
context.set_context(max_device_memory=1)
context.set_context(max_device_memory="3.5G")
context.set_context.__wrapped__(max_device_memory="3GB")

def test_print_file_path():
"""test_print_file_path"""
@@ -132,8 +134,9 @@ def test_print_file_path():

def test_set_context():
""" test_set_context """
context.set_context.__wrapped__(device_id=0)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend",
device_id=0, save_graphs=True, save_graphs_path="mindspore_ir_path")
save_graphs=True, save_graphs_path="mindspore_ir_path")
assert context.get_context("device_id") == 0
assert context.get_context("device_target") == "Ascend"
assert context.get_context("save_graphs")


+ 0
- 1
tests/ut/python/pynative_mode/test_multigraph_sink.py View File

@@ -21,7 +21,6 @@ from mindspore.common.tensor import Tensor

def setup_module(module):
context.set_context(mode=context.PYNATIVE_MODE, save_graphs=False, device_target="Ascend")
context.set_context(device_id=0)


c1 = Tensor([2], mstype.int32)


Loading…
Cancel
Save