Browse Source

!799 fix tensor dirty flag setting

Merge pull request !799 from dinghao/dirty_flag
tags/v0.3.0-alpha
mindspore-ci-bot Gitee 5 years ago
parent
commit
e655e0e04c
3 changed files with 14 additions and 8 deletions
  1. +3
    -1
      mindspore/ccsrc/ir/meta_tensor.cc
  2. +2
    -1
      mindspore/ccsrc/session/session_basic.cc
  3. +9
    -6
      tests/st/ops/gpu/test_assign_add_op.py

+ 3
- 1
mindspore/ccsrc/ir/meta_tensor.cc View File

@@ -164,8 +164,9 @@ Tensor::Tensor(const py::float_ &input, const TypePtr &data_type) { init(py::arr
Tensor::Tensor(const py::int_ &input, const TypePtr &data_type) { init(py::array(input), data_type); }

Tensor::Tensor(const Tensor &tensor, const TypePtr &data_type)
: MetaTensor(tensor), dirty_(tensor.dirty_), device_address_(tensor.device_address_) {
: MetaTensor(tensor), device_address_(tensor.device_address_) {
init(tensor.data_, data_type);
dirty_ = tensor.is_dirty();
}

Tensor &Tensor::operator=(const Tensor &tensor) {
@@ -291,6 +292,7 @@ void Tensor::init(const py::array &input, const TypeId &data_type) {
} else {
data_ = input;
}
dirty_ = true;
}

void Tensor::init(TypeId data_type, const std::vector<int> &shape, py::array *const data) {


+ 2
- 1
mindspore/ccsrc/session/session_basic.cc View File

@@ -127,6 +127,7 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->execution_mode() == kPynativeMode) {
tensor->set_device_address(AnfAlgo::GetMutableOutputAddr(node, output_index));
tensor->set_dirty(false);
} else if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, output_index),
LongToSize(tensor->data().nbytes()), tensor->data_type(),
tensor->data_c(true))) {
@@ -491,7 +492,7 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap
need_sync = true;
}
} else {
if (tensor->is_dirty() || !AnfAlgo::IsParameterWeight(pk_node)) {
if (tensor->is_dirty()) {
need_sync = true;
} else if (tensor->device_address() != device_address) {
(void)tensor->data_sync();


+ 9
- 6
tests/st/ops/gpu/test_assign_add_op.py View File

@@ -51,19 +51,22 @@ def test_assign_add():
[[54, 57, 60],
[63, 66, 69],
[72, 75, 78]]]])
x = Tensor(np.arange(1 * 3 * 3 * 3).reshape(1, 3, 3, 3).astype(np.float32))
y = Tensor(np.arange(1 * 3 * 3 * 3).reshape(1, 3, 3, 3).astype(np.float32))
x1 = Tensor(np.arange(1 * 3 * 3 * 3).reshape(1, 3, 3, 3).astype(np.float32))
y1 = Tensor(np.arange(1 * 3 * 3 * 3).reshape(1, 3, 3, 3).astype(np.float32))

x2 = Tensor(np.arange(1 * 3 * 3 * 3).reshape(1, 3, 3, 3).astype(np.float32))
y2 = Tensor(np.arange(1 * 3 * 3 * 3).reshape(1, 3, 3, 3).astype(np.float32))

context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
add = AssignAdd()
output1 = add(x, y)
output1 = add(x1, y1)
assert (output1.asnumpy() == expect1).all()
output2 = add(output1, y)
output2 = add(output1, y1)
assert (output2.asnumpy() == expect2).all()

context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
add = AssignAdd()
output1 = add(x, y)
output1 = add(x2, y2)
assert (output1.asnumpy() == expect1).all()
output2 = add(output1, y)
output2 = add(output1, y2)
assert (output2.asnumpy() == expect2).all()

Loading…
Cancel
Save