Browse Source

Add assign_lazy_load for ResourceVariable.

tags/v0.30
Oceania2018 5 years ago
parent
commit
ae9a161cd7
5 changed files with 56 additions and 11 deletions
  1. +30
    -4
      src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs
  2. +3
    -0
      src/TensorFlowNET.Core/Variables/IVariableV1.cs
  3. +16
    -0
      src/TensorFlowNET.Core/Variables/RefVariable.cs
  4. +1
    -1
      src/TensorFlowNET.Core/Variables/state_ops.cs
  5. +6
    -6
      src/TensorFlowNET.Keras/Layers/BatchNormalization.cs

+ 30
- 4
src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs View File

@@ -82,15 +82,22 @@ namespace Tensorflow
var value_tensor = ops.convert_to_tensor(value, dtype: dtype); var value_tensor = ops.convert_to_tensor(value, dtype: dtype);
var assign_op = gen_resource_variable_ops.assign_variable_op( var assign_op = gen_resource_variable_ops.assign_variable_op(
handle, value_tensor, name: name); handle, value_tensor, name: name);

if (read_value) if (read_value)
{
return gen_resource_variable_ops.read_variable_op(handle, dtype); return gen_resource_variable_ops.read_variable_op(handle, dtype);
// var variable = _lazy_read(assign_op, value_tensor);
// return variable;
}
return assign_op; return assign_op;
} }


public IVariableV1 assign_lazy_load(Tensor value, string name = null)
{
var value_tensor = ops.convert_to_tensor(value, dtype: dtype);
var assign_op = gen_resource_variable_ops.assign_variable_op(
handle, value_tensor, name: name);
var variable = _lazy_read(assign_op, value_tensor);
return variable;
}

public Tensor value() public Tensor value()
=> GraphElement ?? _read_variable_op(); => GraphElement ?? _read_variable_op();


@@ -157,6 +164,25 @@ namespace Tensorflow
return assign_add_op; return assign_add_op;
} }


public Tensor assign_sub<T>(T delta, bool use_locking = false, string name = null, bool read_value = true)
{
var assign_sub_op = gen_resource_variable_ops.assign_sub_variable_op(Handle,
ops.convert_to_tensor(delta, dtype: dtype), name: name);

if (read_value)
return gen_resource_variable_ops.read_variable_op(handle, dtype);
// return _lazy_read(assign_add_op);
return assign_sub_op;
}

public IVariableV1 assign_sub_lazy_load(Tensor delta, string name = null)
{
var assign_sub_op = gen_resource_variable_ops.assign_sub_variable_op(Handle,
ops.convert_to_tensor(delta, dtype: dtype), name: name);

return _lazy_read(assign_sub_op, delta);
}

public override string ToString() public override string ToString()
{ {
if (tf.Context.executing_eagerly()) if (tf.Context.executing_eagerly())


+ 3
- 0
src/TensorFlowNET.Core/Variables/IVariableV1.cs View File

@@ -47,7 +47,10 @@ namespace Tensorflow
TF_DataType dtype { get; } TF_DataType dtype { get; }
TensorShape shape { get; } TensorShape shape { get; }
Tensor assign_add<T>(T delta, bool use_locking = false, string name = null, bool read_value = true); Tensor assign_add<T>(T delta, bool use_locking = false, string name = null, bool read_value = true);
Tensor assign_sub<T>(T delta, bool use_locking = false, string name = null, bool read_value = true);
IVariableV1 assign_sub_lazy_load(Tensor delta, string name = null);
Tensor assign<T>(T value, bool use_locking = false, string name = null, bool read_value = true); Tensor assign<T>(T value, bool use_locking = false, string name = null, bool read_value = true);
IVariableV1 assign_lazy_load(Tensor value, string name = null);
Tensor AsTensor(TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false); Tensor AsTensor(TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false);
NDArray numpy(); NDArray numpy();
} }


+ 16
- 0
src/TensorFlowNET.Core/Variables/RefVariable.cs View File

@@ -23,6 +23,7 @@ using static Tensorflow.Binding;


namespace Tensorflow namespace Tensorflow
{ {
[Obsolete]
public partial class RefVariable : IVariableV1, IProtoBuf<VariableDef, RefVariable> public partial class RefVariable : IVariableV1, IProtoBuf<VariableDef, RefVariable>
{ {
protected string _name; protected string _name;
@@ -428,5 +429,20 @@ namespace Tensorflow


public NDArray numpy() public NDArray numpy()
=> throw new RuntimeError("Graph mode can't use numpy()."); => throw new RuntimeError("Graph mode can't use numpy().");

public Tensor assign_sub<T>(T delta, bool use_locking = false, string name = null, bool read_value = true)
{
throw new NotImplementedException();
}

public IVariableV1 assign_sub_lazy_load(Tensor delta, string name = null)
{
throw new NotImplementedException();
}

public IVariableV1 assign_lazy_load(Tensor value, string name = null)
{
throw new NotImplementedException();
}
} }
} }

+ 1
- 1
src/TensorFlowNET.Core/Variables/state_ops.cs View File

@@ -90,7 +90,7 @@ namespace Tensorflow
value, value,
use_locking: use_locking, use_locking: use_locking,
name: name) : name: name) :
@ref.assign(value, name: name) as Tensor;
@ref.assign_sub(value, name: name);


//"""Update 'ref' by adding 'value' to it. //"""Update 'ref' by adding 'value' to it.
// //


+ 6
- 6
src/TensorFlowNET.Keras/Layers/BatchNormalization.cs View File

@@ -209,23 +209,23 @@ namespace Tensorflow.Keras.Layers
return output; return output;
} }


Tensor _assign_new_value(IVariableV1 variable, Tensor value)
void _assign_new_value(IVariableV1 variable, Tensor value)
{ {
return tf_with(ops.name_scope("AssignNewValue", null, new { variable, value, momentum }), scope =>
tf_with(ops.name_scope("AssignNewValue", null, new { variable, value, momentum }), scope =>
{ {
// var cm = ops.colocate_with(variable); // var cm = ops.colocate_with(variable);
return state_ops.assign_sub(variable, value, name: scope);
variable.assign_lazy_load(value, name: scope);
}); });
} }


Tensor _assign_moving_average(IVariableV1 variable, Tensor value, Tensor momentum)
void _assign_moving_average(IVariableV1 variable, Tensor value, Tensor momentum)
{ {
return tf_with(ops.name_scope("AssignMovingAvg", null, new { variable, value, momentum }), scope =>
tf_with(ops.name_scope("AssignMovingAvg", null, new { variable, value, momentum }), scope =>
{ {
// var cm = ops.colocate_with(variable); // var cm = ops.colocate_with(variable);
var decay = ops.convert_to_tensor(1.0f - momentum, name: "decay"); var decay = ops.convert_to_tensor(1.0f - momentum, name: "decay");
var update_delta = (variable.AsTensor() - math_ops.cast(value, variable.dtype)) * decay; var update_delta = (variable.AsTensor() - math_ops.cast(value, variable.dtype)) * decay;
return state_ops.assign_sub(variable, update_delta, name: scope);
variable.assign_sub_lazy_load(update_delta, name: scope);
}); });
} }
} }


Loading…
Cancel
Save